JPC
JPC

Reputation: 1919

Python AST NodeTransformer: Return multiple nodes

Is there an easy way to return multiple nodes to replace a single node when using the ast.NodeTransformer? For example, say I want to rewrite all expressions of the form

f(g()) to _x1 = g(); g(_x1)

It would be quite easy to do this if visit_Expr could return multiple, rather than a single node. I couldn't seem to get that to work though, so I assume this is not the way to do this. Any suggestions would be much appreciated.

[Update] As an update, I have a working version of this that accumulates the new and old nodes in a list, and assigns them to the body of the enclosing scope node (e.g. For, While, Module node etc). This is definitely a hacky way of doing it, and suspect there is a better way out there. I'll keep this around in case someone knows of that way.

[final update] looking at the docs for NodeTransformer it's actually entirely possible to return a list of nodes, if the node is part of a collection of statements.

Upvotes: 5

Views: 2048

Answers (1)

Martijn Pieters
Martijn Pieters

Reputation: 1122132

For statement nodes, you are allowed to return a list of new nodes. This lets you replace a single statement with multiple. Quoting the documentation:

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

For your expression, there is a single top-level Expr() expression statement node:

>>> ast.dump(ast.parse('f(g())'))
"Module(body=[Expr(value=Call(func=Name(id='f', ctx=Load()), args=[Call(func=Name(id='g', ctx=Load()), args=[], keywords=[])], keywords=[]))])"
>>> import ast
>>> ast.dump(ast.parse('f(g())'))
"Module(body=[Expr(value=Call(func=Name(id='f', ctx=Load()), args=[Call(func=Name(id='g', ctx=Load()), args=[], keywords=[])], keywords=[]))])"

So all you have to do is supply a visit_Expr handler that returns a list of 2 statement nodes; the first an assignment (Assign() statement node) of a call to g(), the second of a call to f passing in the new variable name (another Expr() statement node).

I'd keep state in the transformer subclass that sets flags when you enter and leave an expression statement context, and keeps track of the stack of calls under that context. When you then return to visit_Expr, you return your new setup:

self._expr_statement = False

def visit_Expr(node):
    self._expr_statement = True
    self.generic_visit(node)
    self._expr_statement = False

    if <specific state on self matches expectations>:
        tempvar = <generated_new_name>
        return [
            Assign([Name(tempvar, Store())], <inner_call>),
            Expr(Call(
                <outer_function_name_expr>,
                args=[Name(tempvar, Load())],
                keyword=[]))
        ]
    else:
        # no replacement takes place
        return node

The NodeTransformer then uses that list of elements to replace the previousExpr()` node.

Note that you have to call self.generic_visit(node) still to ensure that the nested nodes are still handled; further visit_* methods will then be called for those nested nodes. This then makes it so that inside a visit_Call method you could check the self._expr_statement flag and then test if there is a nested call, then store enough context on self for the visit_Expr() method to return.

Upvotes: 5

Related Questions