
Reputation: 5198

Python 3.11 decorator to convert a function into a dataclass for scikit-learn BaseEstimator TransformerMixin including type annotations


Question / Goal

How can I write a decorator that takes in a function fn and creates a dataclass where each argument / keyword-argument is a field and the docstring is copied over for better intellisense support. I don't want to see **kwargs:Any I want to know what the variables are.

In the MWE below there are the following things:



from dataclasses import dataclass, field, fields, _FIELD
from typing import get_type_hints, Optional, NamedTuple, Callable, Dict, Any, List
import inspect
from inspect import Signature, Parameter
from functools import wraps
import numpy as np


def get_func_params(
    fn: Callable, 
    drop_self: Optional[bool] = True,
    drop_before: Optional[int] = 0,
    drop_idxs: Optional[List[int]] = list(),
    drop_names: Optional[List[str]] = list(),
    drop_after: Optional[int] = None,
) -> Dict[str, Any]:
    params = inspect.signature(fn).parameters
    params = {k: v.default for k, v in params.items()}
    if drop_self and 'self' in params: 

    params = {
        n: p for i, (n, p) in enumerate(params.items()) 
        if (
            # is before <= i < after
            (drop_before <= i or (drop_after is not None and i < drop_after))
            # i not in drop_idxs and n not in drop_names
            and (i not in drop_idxs and n not in drop_names)
    return params

Dummy Functions / Classes

class SomeModuleType(NamedTuple):
    a: int
    b: str

def mwe_func(data:np.ndarray, a_bool:bool=False, a_thing:Optional[SomeModuleType]=None) -> np.ndarray:
    data : np.ndarray
        A numpy array of data

    a_bool : bool, default=False
        A boolean

    a_thing : Optinoal[SomeModuleType]
        A thing
    # ...
    return data

class Class2Subclass:
    def expected_method(self, a:int=0):


class mwe_class:

No intelli-sense (that is co-pilot suggesting things)

enter image description here enter image description here


I am working with scanpy and sklearn. Currently I have the following:

import scanpy as sp, anndata as ad, numpy as np, pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from dataclasses import dataclass
import inspect
from functools import wraps

from typing import List, Any, Optional, Callable, Union, Tuple, Iterable, Set, TypeAlias, Type, Dict

def filter_kwargs_for_func(fn: Callable, **kwargs:Optional[dict]):
    params = inspect.signature(fn).parameters
    return {k:v for k,v in kwargs.items() if k in params}

def filter_kwargs_for_class(cls: Callable, **kwargs:Optional[dict]):
    params = inspect.signature(cls.__init__).parameters
    return {k:v for k,v in kwargs.items() if k in params}

def wrangle_kwargs_for_func(
    fn: Callable, 
    defaults: Optional[dict]=None,
) -> dict:
    # copy defaults
    params = (defaults or {}).copy()
    # update with kwargs of our function
    params.update(kwargs or {})
    # filter for only the params that other function accepts
    params = filter_kwargs_for_func(fn, **params)
    return params

def wrangle_kwargs_for_class(
    cls: Callable, 
    defaults: Optional[dict]=None,
) -> dict:
    # copy defaults
    params = (defaults or {}).copy()
    # update with kwargs of our class
    params.update(kwargs or {})
    # filter for only the params that other class accepts
    params = filter_kwargs_for_class(cls, **params)
    return params

def get_func_params(
    fn: Callable, 
    drop_self: Optional[bool] = True,
    drop_before: Optional[int] = 0,
    drop_idxs: Optional[List[int]] = list(),
    drop_names: Optional[List[str]] = list(),
    drop_after: Optional[int] = None,
) -> Dict[str, Any]:
    params = inspect.signature(fn).parameters
    params = {k: v.default for k, v in params.items()}
    if drop_self and 'self' in params: 

    params = {
        n: p for i, (n, p) in enumerate(params.items()) 
        if (
            # is before <= i < after
            (drop_before <= i or (drop_after is not None and i < drop_after))
            # i not in drop_idxs and n not in drop_names
            and (i not in drop_idxs and n not in drop_names)
    return params

class MyPipeline:
    def preprocess_data(self, min_genes: int = 200, min_cells: int = 3):
        sc.pp.filter_cells(, min_genes=min_genes)
        sc.pp.filter_genes(, min_cells=min_cells)
        sc.pp.normalize_total(, target_sum=1e4)
        sc.pp.highly_variable_genes(, min_mean=0.0125, max_mean=3, min_disp=0.5) =[:,]
        sc.pp.scale(, max_value=10), svd_solver='arpack')
        sc.pp.neighbors(, n_neighbors=10, n_pcs=40)

Where the focus here is on MyPipeline. Right now it isn't very flexible because very few keyword arguments are exposed and in the event some functions share the same keyword argument which function it belongs to.

Initially all I wanted was a way to to specify something like

class FilterCellKWArgs:

def pipeline(filter_cells_kwargs:FilterCellKWArgs, ...):

and then have intellisense show me (or anyone else) what args / keyword arguments these functions have available, what their defaults are and maybe even the docstring of the original function. That doesn't seem too tenable.

So now I am think it be useful to wrap functions like sc.pp.filter_cells , sc.pp.highly_variable_genes and sc.pp.scale as sklearn operators e.g. BaseEstimators / TransformerMixin / etc.

As the arguments / keyword arguments would be set on construction and then an sklearn Pipeline can handle the rest. So I am looking for something like this

class FilterCells:

which should be functionally equivalent to

class FilterCells(BaseEstimator, TransformerMixin):
    # NOTE: these are the defaults for sc.pp.filter_cells
    # you can get them from inspect.signature(sc.pp.filter_cells)
    # data: ad.Anndata
    min_counts: Optional[int] = None
    min_genes:  Optional[int] = None
    max_counts: Optional[int] = None
    max_genes:  Optional[int] = None
    inplace: bool = True
    copy: bool = False

    def fit(self, X: ad.AnnData, y=None):
         # NOTE: this is a dummy method
         # as we don't need to fit anything, just call the wrapped
         # function sc.pp.filtered_cells

    def transform(self, X):
        Y = sc.pp.filter_cells(
            X, min_counts=self.min_counts, min_genes=self.min_genes,
            max_counts=self.max_counts, max_genes=self.max_genes,
            inplace=self.inplace, copy=self.copy
        return X if self.inplace else Y

    def fit_transform(self, X, y=None):
        return, y).transform(X)

And I have tried quite a few things (see below).

Specific Question

How can I write a decorator that in one of these functions (or any function really) and creates a dataclass where each argument / keyword-argument is a field and the docstring is copied over for better intellisense support. I don't want to see **kwargs:Any I want to know what the variables are.


Current Attempt

from dataclasses import dataclass, field, fields, _FIELD
from typing import get_type_hints
from inspect import Signature, Parameter

def scop(fn):
    params = get_func_params(fn, drop_self=False)
    params = {k: v for k, v in params.items() if v is not inspect.Parameter.empty}

    def class_decorator(cls):
        cls = dataclass(cls)  # Ensure cls is a dataclass

        # Add fields from fn to cls
        for name, default in params.items():
            if name not in get_type_hints(cls):
                field_obj = field(default=default)
                setattr(cls, name, field_obj)
                cls.__annotations__[name] = type(default)

        # Update __init__ method to include new fields
        def __init__(self, **kwargs):
            for name, value in kwargs.items():
                setattr(self, name, value)
        cls.__init__ = __init__

        # Update __init__ method signature
        sig = inspect.signature(fn)
        parameters = [
            Parameter(name, Parameter.KEYWORD_ONLY, default=default)
            for name, default in params.items()
        cls.__init__.__signature__ = sig.replace(parameters=parameters)

        # Add methods to cls
        def fit(self, X, y=None):
            return self

        def transform(self, X):
            kwargs = { getattr(self, for f in fields(self)}
            fn(X, **kwargs)
            return X

        def fit_transform(self, X, y=None):
            return, y).transform(X) = fit
        cls.transform = transform
        cls.fit_transform = fit_transform

        # Update docstrings
        cls.__doc__ = fn.__doc__ = fn.__doc__
        cls.transform.__doc__ = fn.__doc__
        cls.fit_transform.__doc__ = fn.__doc__

        return cls

    return class_decorator

which does basically work:

class FilterCells:

fc = FilterCells(min_cells=3)
# 3

But. I still can not see the docstring / args / keyword-args when I am typing FilterCells(...) and now fit, fit_transform just show Any rather than (X, y=None)

I also loose the BaseEstimator repr so there is that...

Another Notable Attempt

def scop(fn):
    params = get_func_params(fn, drop_self=False)

    def class_decorator(cls):
        class Wrapper(cls, BaseEstimator, TransformerMixin):
            fn_params = {k: v for k, v in params.items() if v is not inspect.Parameter.empty}

            def __init__(self, **kwargs):
                self.params = {**self.fn_params, **kwargs}

            def fit(self, X, y=None):
                return self

            def transform(self, X):
                fn(X, **self.params)
                return X

            def fit_transform(self, X, y=None):
                return, y).transform(X)

        Wrapper.__name__ = cls.__name__
        Wrapper.__doc__ = fn.__doc__
        Wrapper.__annotations__ = {**cls.__annotations__, **params}

        return Wrapper

    return class_decorator

that can be used like

class FilterCells:

fc = FilterCells(min_genes=200)


Of note:

Original-ish Attempt

def scop(fn):
    params = get_func_params(fn, drop_self=False)
    class Wrapper(BaseEstimator, TransformerMixin):
        def __init__(self, *args, **kwargs) -> None:
            print('ARGS', args)
            print('KWARGS', kwargs)
            print('SIGNATURE', inspect.signature(sc.pp.filter_cells))
            print('PARAMS', params)

        def fit(self, X, y=None):
            return self

        def transform(self, X):
            fn(, X)

        def fit_transform(self, X, y=None):
            return, y).transform(X)

    Wrapper.__init__.__doc__ = fn.__doc__
    Wrapper.__init__.__annotations__ = fn.__annotations__
    # Wrapper = inspect.signature(sc.pp.filter_cells) = fn.__doc__ = fn.__annotations__

    Wrapper.transform.__doc__ = fn.__doc__
    Wrapper.transform.__annotations__ = fn.__annotations__

    Wrapper.fit_transform.__doc__ = fn.__doc__
    Wrapper.fit_transform.__annotations__ = fn.__annotations__
    Wrapper.__name__ = fn.__name__
    Wrapper.__doc__ = fn.__doc__
    Wrapper.__annotations__ = fn.__annotations__

    # methods = {
    #     '__init__': Wrapper.__init__,
    #     'fit':,
    #     'transform': Wrapper.transform,
    #     'fit_transform': Wrapper.fit_transform,
    # }
    # Wrapper = type(fn.__name__, (BaseEstimator, TransformerMixin), methods)
    return Wrapper

but notice that it prints ARGS (<class '__main__.FilterCells'>,) as this decorator gets called over the class not on class initialization.

Upvotes: 0

Views: 205

Answers (0)

Related Questions