Reputation: 1919
Is there an easy way to return multiple nodes to replace a single node when using the ast.NodeTransformer? For example, say I want to rewrite all expressions of the form
f(g())
to _x1 = g(); g(_x1)
It would be quite easy to do this if visit_Expr
could return multiple, rather than a single node. I couldn't seem to get that to work though, so I assume this is not the way to do this. Any suggestions would be much appreciated.
[Update]
As an update, I have a working version of this that accumulates the new and old nodes in a list, and assigns them to the body of the enclosing scope node (e.g. For
, While
, Module
node etc). This is definitely a hacky way of doing it, and suspect there is a better way out there. I'll keep this around in case someone knows of that way.
[final update]
looking at the docs for NodeTransformer
it's actually entirely possible to return a list of nodes, if the node is part of a collection of statements.
Upvotes: 5
Views: 2048
Reputation: 1122132
For statement nodes, you are allowed to return a list of new nodes. This lets you replace a single statement with multiple. Quoting the documentation:
For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
For your expression, there is a single top-level Expr()
expression statement node:
>>> ast.dump(ast.parse('f(g())'))
"Module(body=[Expr(value=Call(func=Name(id='f', ctx=Load()), args=[Call(func=Name(id='g', ctx=Load()), args=[], keywords=[])], keywords=[]))])"
>>> import ast
>>> ast.dump(ast.parse('f(g())'))
"Module(body=[Expr(value=Call(func=Name(id='f', ctx=Load()), args=[Call(func=Name(id='g', ctx=Load()), args=[], keywords=[])], keywords=[]))])"
So all you have to do is supply a visit_Expr
handler that returns a list of 2 statement nodes; the first an assignment (Assign()
statement node) of a call to g()
, the second of a call to f
passing in the new variable name (another Expr()
statement node).
I'd keep state in the transformer subclass that sets flags when you enter and leave an expression statement context, and keeps track of the stack of calls under that context. When you then return to visit_Expr
, you return your new setup:
self._expr_statement = False
def visit_Expr(node):
self._expr_statement = True
self.generic_visit(node)
self._expr_statement = False
if <specific state on self matches expectations>:
tempvar = <generated_new_name>
return [
Assign([Name(tempvar, Store())], <inner_call>),
Expr(Call(
<outer_function_name_expr>,
args=[Name(tempvar, Load())],
keyword=[]))
]
else:
# no replacement takes place
return node
The NodeTransformer then uses that list of elements to replace the previous
Expr()` node.
Note that you have to call self.generic_visit(node)
still to ensure that the nested nodes are still handled; further visit_*
methods will then be called for those nested nodes. This then makes it so that inside a visit_Call
method you could check the self._expr_statement
flag and then test if there is a nested call, then store enough context on self
for the visit_Expr()
method to return.
Upvotes: 5