Alexey Romanov
Alexey Romanov

Reputation: 170745

Recursive strategies with additional parameters in Hypothesis

Using recursive, I can generate simple ASTs, e.g.

from hypothesis import *
from hypothesis.strategies import *

def trees():
    base = integers(min_value=1, max_value=10).map(lambda n: 'x' + str(n))

    @composite
    def extend(draw, children):
        op = draw(sampled_from(['+', '-', '*', '/']))
        return (op, draw(children), draw(children))

    return recursive(base, draw)

Now I want to change it so I can generate boolean operations in addition to the arithmetical ones. My initial idea is to add a parameter to trees:

def trees(tpe):
    base = integers(min_value=1, max_value=10).map(lambda n: 'x' + str(n) + ': ' + tpe)

    @composite
    def extend(draw, children):
        if tpe == 'bool':
            op = draw(sampled_from(['&&', '||']))
            return (op, draw(children), draw(children))
        elif tpe == 'num':
            op = draw(sampled_from(['+', '-', '*', '/']))
            return (op, draw(children), draw(children))

    return recursive(base, draw)

Ok so far. But how do I mix them? That is, I also want comparison operators and the ternary operator, which would require "calling children with a different parameter", so to say.

The trees need to be well-typed: if the operation is '||' or '&&', both arguments need to be boolean, arguments to '+' or '<' need to be numbers, etc. If I only had two types, I could just use filter (given a type_of function):

if op in ('&&', '||'):
    bool_trees = children.filter(lambda x: type_of(x) == 'bool')
    return (op, draw(bool_trees), draw(bool_trees))

but in the real case it wouldn't be acceptable.

Does recursive support this? Or is there another way? Obviously, I can directly define trees recursively, but that runs into the standard problems.

Upvotes: 1

Views: 298

Answers (2)

Zac Hatfield-Dodds
Zac Hatfield-Dodds

Reputation: 3003

You can simply describe trees where the comparison is drawn from either set of operations - in this case trivially by sampling from ['&&', '||', '+', '-', '*', '/'].

def trees():
    return recursive(
        integers(min_value=1, max_value=10).map('x{}'.format),
        lambda node: tuples(sampled_from('&& || + - * /'.split()), node, node)
    )

But of course that won't be well-typed (except perhaps by rare coincidence). I think the best option for well-typed ASTs is:

  1. For each type, define a strategy for trees which evaluate to that type. The base case is simply (a strategy for) a value of that type.
  2. The extension is to pre-calculate the possible combinations of types and operations that would generate a value of this type, using mutual recursion via st.deferred. That would look something like...
bool_strat = deferred(
    lambda: one_of(
        booleans(),
        tuples(sampled_from(["and", "or"], bool_strat, bool_strat), 
        tuples(sampled_from(["==", "!=", "<", ...]), integer_strat, integer_strat),
    )
)
integer_strat = deferred(
    lambda: one_of(
        integers(),
        tuples(sampled_from("= - * /".split()), integer_strat, integer_strat),
    )
)
any_type_ast = bool_strat | integer_strat

And it will work as if by magic :D

(on the other hand, this is a fair bit more complex - if your workaround is working for you, don't feel obliged to do this instead!)

If you're seeing problematic blowups in size - which should be very rare, as the engine has had a lot of work since that article was written - there's honestly not much to do about it. Threading a depth limit through the whole thing and decrementing it each step does work as a last resort, but it's not nice to work with.

Upvotes: 2

Alexey Romanov
Alexey Romanov

Reputation: 170745

The solution I used for now is to adapt the generated trees so e.g. if a num tree is generated when the operation needs a bool, I also draw a comparison operator op and a constant const and return (op, tree, const):

def make_bool(tree, draw):
    if type_of(tree) == 'bool':
        return tree
    else type_of(tree) == 'num':
        op = draw(sampled_from(comparison_ops))
        const = draw(integers())
        side = draw(booleans())
        return (op, tree, const) if side else (op, const, tree)

// in def extend:
if tpe == 'bool':
    op = draw(sampled_from(bool_ops + comparison_ops))
    if op in bool_ops:
        return (op, make_bool(draw(children), draw), make_bool(draw(children), draw))
    else:
        return (op, make_num(draw(children), draw), make_num(draw(children), draw))

Unfortunately, it's specific to ASTs and will mean specific kinds of trees are generated more often. So I'd still be happy to see better alternatives.

Upvotes: 2

Related Questions