DSblizzard
DSblizzard

Reputation: 4149

Z3: how to improve performance?

In following Python code:

from itertools import product
from z3 import *

def get_sp(v0, v1):
    res = sum([v0[i] * v1[i] for i in range(len(v0))])
    return res

def get_is_mod_partition(numbers):
    n = len(numbers)
    mod_power = 2**n
    for signs in product((1, -1), repeat = len(numbers)):
        if get_sp(numbers, signs) % mod_power == 0:
            return 1
    return 0

def check_sat(numbers):
    n = len(numbers)
    s = Solver()
    signs = [Int("s" + str(i)) for i in range(n)]
    for i in range(n):
        s.add(Or(signs[i] == -1, signs[i] == 1))
    s.add(get_sp(numbers, signs) % 2**n == 0)
    print(s.check())

l = [1509, 1245, 425, 2684, 3630, 435, 875, 2286, 1886, 1205, 518, 1372]
check_sat(l)
get_is_mod_partition(l)

check_sat takes 22 seconds and get_is_mod_partition - 24 millseconds. I haven't expected such results from "high-performance theorem prover". Is there way to drastically improve performance?

Upvotes: 4

Views: 2989

Answers (1)

alias
alias

Reputation: 30497

Following Patrick's advice, you can code it as follows:

from z3 import *

def get_sp(v0, v1):
    res = sum([If(v1[i], v0[i], -v0[i]) for i in range(len(v0))])
    return res

def check_sat(numbers):
    n = len(numbers)
    s = Solver()
    signs = [Bool("s" + str(i)) for i in range(n)]
    s.add(get_sp(numbers, signs) % 2**n == 0)
    print(s.check())
    m = s.model()
    mod_power = 2 ** n
    print ("("),
    for (n, sgn) in zip (numbers, signs):
        if m[sgn]:
           print ("+ %d" % n),
        else:
           print ("- %d" % n),
    print (") %% %d == 0" % mod_power)

l = [1509, 1245, 425, 2684, 3630, 435, 875, 2286, 1886, 1205, 518, 1372]
check_sat(l)

Runs in about 0.14 seconds on my machine, and prints:

sat
( - 1509 - 1245 - 425 + 2684 + 3630 + 435 - 875 + 2286 - 1886 - 1205 - 518 - 1372 ) % 4096 == 0

However, as Patrick commented, it is not clear why this version is significantly faster than the original. I wanted to do some benchmarking, and did so using Haskell as I'm more familar with that language and its Z3 bindings:

import Data.SBV
import Criterion.Main

ls :: [SInteger]
ls = [1509, 1245, 425, 2684, 3630, 435, 875, 2286, 1886, 1205, 518, 1372]

original = do bs <- mapM (const free_) ls
              let inside b = constrain $ b .== 1 ||| b .== -1
              mapM_ inside bs
              return $ sum [b*l | (b, l) <- zip bs ls] `sMod` (2^length ls) .== 0

boolOnly = do bs <- mapM (const free_) ls
              return $ sum [ite b l (-l) | (b, l) <- zip bs ls] `sMod` (2^length ls) .== 0

main = defaultMain [ bgroup "problem" [ bench "orig" $ nfIO (sat original)
                                      , bench "bool" $ nfIO (sat boolOnly)
                                      ]
                   ]

And indeed, the bool-only version is about 8 times faster:

benchmarking problem/orig
time                 810.1 ms   (763.4 ms .. 854.7 ms)
                     0.999 R²   (NaN R² .. 1.000 R²)
mean                 808.4 ms   (802.2 ms .. 813.6 ms)
std dev              8.189 ms   (0.0 s .. 8.949 ms)
variance introduced by outliers: 19% (moderately inflated)

benchmarking problem/bool
time                 108.2 ms   (104.4 ms .. 113.5 ms)
                     0.997 R²   (0.992 R² .. 1.000 R²)
mean                 109.3 ms   (107.3 ms .. 111.5 ms)
std dev              3.408 ms   (2.261 ms .. 4.843 ms)

Two observations:

  • Haskell bindings are way faster than those of Python! About an order of magnitude
  • The bool-only version is about another order of magnitude faster compared to integers

For the former, it might be interesting to find out why Python bindings are performing so poorly; or simply switch to Haskell :-)

Further analysis, and a cute trick

It appears the issue is with the call to mod. In the Haskell translation, the system internally gives names to all intermediate expressions; which seems to make z3 go fast. Python bindings, however, translate expressions more wholesale, and inspection of the generated code (you can see it by looking at s.sexpr()) reveals it does not name internal expressions. And when mod is involved, I'm guessing the solver's heuristics fail to recognize the essential linearity of the problem and end up spending a lot of time.

To improve the time, you can do the following simple trick. The original says:

s.add(get_sp(numbers, signs) % 2**n == 0)

Instead, make sure the sum gets an explicit name. That is, replace the above line by:

ssum = Int("ssum")
s.add (ssum == get_sp(numbers, signs))
s.add (ssum % 2**n == 0)

You'll see that this also makes the Python version run much faster.

I still would prefer the boolean translation, but this one provides a nice rule of thumb: Try to name intermediate expressions, which gives the solver nice choice points as guided by the semantics of the formula; as opposed to one bulky output. As I mentioned, the Haskell bindings don't suffer from this, as it internally converts all formulas to simple three-operand SSA form, which allows the solver to more easily apply heuristics.

Upvotes: 9

Related Questions