Mike
Mike

Reputation: 806

Automatically refactor python lambdas to named functions

I am working on the purescript-python project and there are several core libraries that make extensive use of lambdas. Because of the way the code is compiled, the location of the lambdas winds up getting obscured, which results in bizarre console messages whenever an error occurs.

I would like to refactor these libraries to use lambdas as little as possible. So, for example, if there is something like:

def foo(a):
  return lambda b: lambda c: lambda d: lambda e: a + b + c + d + e

it would be nice to generate

def foo(a):
  def _foo_internal_anon_1(b):
    def _foo_internal_anon_2(c):
        def _foo_internal_anon_3(d):
          def _foo_internal_anon_4(e):
            return a + b + c + d + e
          return _foo_internal_anon_4
        return _foo_internal_anon_3
    return _foo_internal_anon_2
  return _foo_internal_anon_1

Is there a way to do this, ie with pylint or a vscode or pycharm plugin, or does this have to be done by hand?

Upvotes: 7

Views: 376

Answers (2)

a_guest
a_guest

Reputation: 36299

You can use a custom ast.NodeTransformer to transform Lambda inside Return nodes to full function definitions. The transformed AST can then be unparsed with help of the unparse.py tool from the CPython repo (starting with Python 3.9 you can also use ast.unparse). This allows to transform a whole script, not just single functions.

This is the node transformer:

import ast
from contextlib import contextmanager


@contextmanager
def resetattr(obj, name, value):
    old_value = getattr(obj, name)
    setattr(obj, name, value)
    yield
    setattr(obj, name, old_value)


class ConvertLambda(ast.NodeTransformer):
    def __init__(self):
        super().__init__()
        self.base_name = None
        self.n = 0

    def visit_FunctionDef(self, node):
        if isinstance(node.body[-1], ast.Return) and isinstance(node.body[-1].value, ast.Lambda):
            lambda_node = node.body[-1].value
            with resetattr(self, 'base_name', self.base_name or node.name):
                with resetattr(self, 'n', self.n+1):
                    func_name = f'_{self.base_name}_internal_anon_{self.n}'
                    func_def = ast.FunctionDef(
                        name=func_name,
                        args=lambda_node.args,
                        body=[ast.Return(value=lambda_node.body)],
                        decorator_list=[],
                        returns=None,
                    )
                    self.visit(func_def)
            node.body.insert(-1, func_def)
            node.body[-1].value = ast.Name(id=func_name)
        return node

It can be used together with the Unparser class as follows (or alternatively ast.unparse for Python 3.9+):

from unparse import Unparser

def convert_func_def(text):
    tree = ast.parse(text)
    tree = ast.fix_missing_locations(ConvertLambda().visit(tree))
    Unparser(tree)

By default this prints the result to sys.stdout but Unparser can be configured to use any file-like object: Unparser(tree, file=...).

This is the result obtained for the example function:

def foo(a):

    def _foo_internal_anon_1(b):

        def _foo_internal_anon_2(c):

            def _foo_internal_anon_3(d):

                def _foo_internal_anon_4(e):
                    return ((((a + b) + c) + d) + e)
                return _foo_internal_anon_4
            return _foo_internal_anon_3
        return _foo_internal_anon_2
    return _foo_internal_anon_1

It adds some additional blank lines and parentheses around the addition but this can also be customized by modifying the Unparser class.

Upvotes: 6

a_guest
a_guest

Reputation: 36299

Not completely automatic but you could use re.sub to replace the various lambda inside a function. Perhaps you can turn that into a macro for your favorite IDE which allows you to highlight some text and then run the transformation on it.

import functools
import re
import textwrap

def convert_func_def(text):
    """Assumes a multiple of 4 spaces as indentation."""
    func_name = re.search('def (.+?)(?=[(])', text).group(1)
    n_lambda = text.count('lambda')
    for i in range(n_lambda):
        text = re.sub(
            '^( +)return lambda (.+?): (.+?)$',
            functools.partial(replace, func_name=f'_{func_name}_internal_anon_{i+1}'),
            text,
            count=1,
            flags=re.MULTILINE,
        )
    return text

def replace(match, *, func_name):
    indent, args, body = match.groups()
    template = textwrap.dedent(f'''
        def {func_name}({args}):
            return {body}
        return {func_name}
    ''').strip()
    return textwrap.indent(template, indent)

Applied to your example function:

print(convert_func_def('''def foo(a):
    return lambda b: lambda c: lambda d: lambda e: a + b + c + d + e'''))

this is the output:

def foo(a):
    def _foo_internal_anon_1(b):
        def _foo_internal_anon_2(c):
            def _foo_internal_anon_3(d):
                def _foo_internal_anon_4(e):
                    return a + b + c + d + e
                return _foo_internal_anon_4
            return _foo_internal_anon_3
        return _foo_internal_anon_2
    return _foo_internal_anon_1

Upvotes: -1

Related Questions