Lord225
Lord225

Reputation: 11

Optimizing Assembly Code with Z3 Solver While Handling Unknown Initial CPU State

I'm working on building a simple assembly optimizer. I was thinking about using Z3 to find shorter sequences of instructions with the same overall effect on a CPU state. The idea is to take an existing sequence of assembly instructions and replace it with an equivalent—but shorter—sequence, while preserving the semantics for all possible initial states. However, I stumbled across an issue with Z3, and I don't know how to work around it.

Overall, I’ve done a few experiments on a simpler model. For example, I tried to generate a sequence of operations that would result in a specific state of the CPU.

Here is a working example that sets an 8-bit accumulator using the shortest sequence of available instructions. In this case, I assumed a simple architecture and a subset of instructions that are required to solve this problem (i.e., shift acc once, add imm to acc, and neg acc).

It works by creating an array of opcodes & arguments, and using z3.Select, I simulate how opcodes will change the state of acc.

def find_number_shifts(target: int, num_operations: int = 1):
    for i in range(1, 10):
        # Define the bit-width for 8-bit numbers
        BIT_WIDTH = 8
        # Define the initial accumulator (acc) as a Z3 BitVec
        acc = z3.BitVec('acc', BIT_WIDTH)
        # Define the array of operations
        num_operations = i  # Maximum number of operations
        operations = z3.Array('operations', z3.IntSort(), z3.IntSort())
        # The result after applying operations
        result = z3.BitVec('result', BIT_WIDTH)
        # Solver instance
        s = z3.Solver()
        # Define operation constants
        OP_ADD = 0
        OP_SHIFT = 1
        OP_NEG = 2
        # Define an array to hold intermediate results
        intermediate = [z3.BitVec(f'intermediate_{i}', BIT_WIDTH) for i in range(num_operations + 1)]
        # The first intermediate value is the initial accumulator
        s.add(intermediate[0] == 0)
        # Constraints for each operation
        for i in range(num_operations):
            op = z3.Select(operations, i)  # Get the operation at index i
            imm = z3.BitVec(f'imm_{i}', BIT_WIDTH)  # Immediate value for addition
            s.add(imm >= -8, imm <= 7)  # Immediate value constraint
            # Apply the operation based on its type
            cond_add = z3.And(op == OP_ADD, intermediate[i + 1] == (intermediate[i] + imm) & 0xFF)
            cond_shift = z3.And(op == OP_SHIFT, intermediate[i + 1] == (intermediate[i] << 1) & 0xFF)
            cond_neg = z3.And(op == OP_NEG, intermediate[i + 1] == ~intermediate[i] & 0xFF)
            # Ensure exactly one operation applies
            s.add(z3.Or(cond_add, cond_shift, cond_neg))

        # The last intermediate value must match the target
        s.add(intermediate[-1] == target)

        # Solve the constraints
        if s.check() == z3.sat:
            model = s.model()

            buffer = ['lda 0']
            print(f"// Target: {target}, Operations: {num_operations}")
            for i in range(num_operations):
                op = model.eval(z3.Select(operations, i)).as_long()
                if op == OP_ADD:
                    imm = model.eval(z3.BitVec(f'imm_{i}', BIT_WIDTH)).as_long()
                    if imm >= 8:
                        print(f"addi {imm-256}")
                        buffer.append(f"addi {imm-256}")
                    else:
                        print(f"addi {imm}")
                        buffer.append(f"addi {imm}")
                elif op == OP_SHIFT:
                    print("add acc")
                    buffer.append("add acc")
                elif op == OP_NEG:
                    print("nand acc")
                    buffer.append("nand acc")
            return buffer
    else:
        print(f"Target {target} not reachable with {num_operations} operations")

It can generate sequences for every imm value. For example, for 17, it generates:

lda 0    // set acc to 0
addi 6
add acc  // shift left
addi 5

I also tried a more advanced version where I asked it to set two registers, etc.

For this optimizer, I assume simple, that the architecture is imaginery, soft-cpu and in this case no jumps, the available instructions are (all operations uses accumlator as one operand) For first draft idea we can even assume no memory write and reads to simplify the problem even more

add r  // adds r to acc
addi i // adds imm to acc
nand r // nand r with acc
sta r  // stores acc to r
lda r  // loads r to acc
ld r   // memory load (for optimizer sets acc to unknown value)
st r   // memory store (for optimizer - this operation has to be preserved)

Now, for a given sequence and the unknown state of the CPU at the beginning, I want Z3 to find a shorter sequence that results in the same CPU state.

And for example if I have block of 3 instructions

addi 1
addi 1
addi 1

And ask for sequence of lenght 1 it would generate sequence:

addi 3

Other examples

sta 0
sta 0  //second sta can be removed so solver should be able prune that 

addi 1
addi 1 // these two adds can be merged into one 
st r1
addi 1 // but this cannot bsc it would change effect on st

The issue I'm facing is that I can't figure out how to tell Z3 that the initial state of the CPU is unknown, or more precisely, that it cannot assume a specific value. Before, I assumed that acc is zero at the first tick:

s.add(intermediate[0] == 0)

If I don't constrain that, Z3 will choose the value of the accumulator to be whatever it wants (usually the target value). Also I'm not sure if it will possible to build this optimizer using idea I have:

If there’s a better tool for such work in Python, I would be more than happy to hear about it.

Update #1

Based on response in comments tinkered with forAll function for a while and I have toy example. Here is the code

import z3

# Define operation codes.
OP_ADDI = 0
OP_NAND = 1
BIT_WIDTH = 8

def simulate_cpu_array(instructions, init_state, s):
    n = len(instructions)
    states = [z3.BitVec(f'orig_acc_{i}', BIT_WIDTH) for i in range(n + 1)]
    s.add(states[0] == init_state)
    
    for i, instr in enumerate(instructions):
        if instr.startswith("addi"):
            imm = int(instr.split()[1])
            s.add(states[i + 1] == (states[i] + imm))
        elif instr.startswith("nand"):
            s.add(states[i + 1] == ~states[i])
        else:
            raise ValueError(f"Unknown instruction: {instr}")
    
    return states

def simulate_optimized_cpu(ops, args, init_state, s, target_length):
    states = [z3.BitVec(f'opt_acc_{i}', BIT_WIDTH) for i in range(target_length + 1)]
    s.add(states[0] == init_state)
    
    for i in range(target_length):
        op = ops[i]
        arg = args[i]
        acc = states[i]
        next_acc = states[i + 1]
        
        s.add(z3.Or(
            z3.And(op == OP_ADDI, next_acc == acc + arg, arg <= 7, arg >= -8),
            z3.And(op == OP_NAND, next_acc == ~acc, arg == 0)
        ))
    
    return states

def optimize(instructions, target_length):
    s = z3.Solver()

    acc_init = z3.BitVec('acc_init', BIT_WIDTH)
    original_states = simulate_cpu_array(instructions, acc_init, s)
    original_final = original_states[-1]

    ops = [z3.Int(f'op_{i}') for i in range(target_length)]
    args = [z3.BitVec(f'arg_{i}', BIT_WIDTH) for i in range(target_length)]
    
    optimized_states = simulate_optimized_cpu(ops, args, acc_init, s, target_length)
    optimized_final = optimized_states[-1]

    s.add(z3.ForAll(acc_init, optimized_final == original_final))

    for constraint in s.assertions():
        print(constraint)

    if s.check() == z3.sat:
        model = s.model()
        optimized_instructions = []
        for i in range(target_length):
            op_val = model[ops[i]].as_long()
            if op_val == OP_ADDI:
                imm_val = model[args[i]].as_long()
                # Convert to unsigned 8-bit value
                if imm_val > 8:
                    imm_val -= 256
                optimized_instructions.append(f"addi {imm_val}")
            elif op_val == OP_NAND:
                optimized_instructions.append("nand")
            else:
                raise ValueError(f"Unknown op: {op_val}")
        return optimized_instructions
    else:
        return None

if __name__ == "__main__":
    instructions = ["addi 2", "addi 2", "addi 5"]
    print("\nOriginal Instructions:", instructions)
    optimized = optimize(instructions, target_length=1)
    print("Optimized Instructions:", optimized)

    instructions = ["addi 1", "addi 1", "addi 1"]
    print("\nOriginal Instructions:", instructions)
    optimized = optimize(instructions, target_length=1)
    print("Optimized Instructions:", optimized)

and overall it is able to optimize some cases but it can produce incorrent results for other.

Original Instructions: ['addi 1', 'addi 1', 'addi 1']
Optimized Instructions: ['addi 3'] # Corrent Yey!

But for example in this case it for whatever reason fails to optimize. (this should result in addi 9 but it is illgal operation as imms are constraint to be in range <-8, 7>)

Original Instructions: ['addi 2', 'addi 2', 'addi 5']
Optimized Instructions: ['nand'] # incorrent.

And finnaly somtimes it correctly proofs that optimization is impossible.

Here are constraintes produced by script for both cases.

Original Instructions: ['addi 2', 'addi 2', 'addi 5']
orig_acc_0 == acc_init
orig_acc_1 == orig_acc_0 + 2
orig_acc_2 == orig_acc_1 + 2
orig_acc_3 == orig_acc_2 + 5
opt_acc_0 == acc_init
Or(And(op_0 == 0,
       opt_acc_1 == opt_acc_0 + arg_0,
       arg_0 <= 7,
       arg_0 >= 248),
   And(op_0 == 1, opt_acc_1 == ~opt_acc_0, arg_0 == 0))
ForAll(acc_init, opt_acc_1 == orig_acc_3)
Optimized Instructions: ['nand']

Original Instructions: ['addi 1', 'addi 1', 'addi 1']
orig_acc_0 == acc_init
orig_acc_1 == orig_acc_0 + 1
orig_acc_2 == orig_acc_1 + 1
orig_acc_3 == orig_acc_2 + 1
opt_acc_0 == acc_init
Or(And(op_0 == 0,
       opt_acc_1 == opt_acc_0 + arg_0,
       arg_0 <= 7,
       arg_0 >= 248),
   And(op_0 == 1, opt_acc_1 == ~opt_acc_0, arg_0 == 0))
ForAll(acc_init, opt_acc_1 == orig_acc_3)
Optimized Instructions: ['addi 3']

If I manage to get this working I will try to extend it to use incremental solver and then more complex cpu model.

Upvotes: 1

Views: 52

Answers (0)

Related Questions