Reputation: 5198
How can I write a decorator that takes in a function
fn
and creates a dataclass where each argument / keyword-argument is a field and the docstring is copied over for better intellisense support. I don't want to see**kwargs:Any
I want to know what the variables are.
In the MWE below there are the following things:
get_func_params
: utility function for getting default parameters from a functionSomeModuleType
: a dummy type to check if intelli-sense is showing what kind of types hints are being copied-overmwe_func
: a dummy function that has a docstring and some type annotations. The decorator func_to_class
should copy over the docstring and function annotations for the new class constructor. Ideally all the arguments are fields in of the new dataclass.Class2Subclass
: a dummy class for the decorated class to subclass.mwe_class
: the dummy class that we are decoratingfunc_to_class
: the decorator functionfrom dataclasses import dataclass, field, fields, _FIELD
from typing import get_type_hints, Optional, NamedTuple, Callable, Dict, Any, List
import inspect
from inspect import Signature, Parameter
from functools import wraps
import numpy as np
def get_func_params(
fn: Callable,
drop_self: Optional[bool] = True,
drop_before: Optional[int] = 0,
drop_idxs: Optional[List[int]] = list(),
drop_names: Optional[List[str]] = list(),
drop_after: Optional[int] = None,
) -> Dict[str, Any]:
params = inspect.signature(fn).parameters
params = {k: v.default for k, v in params.items()}
if drop_self and 'self' in params:
params.pop('self')
params = {
n: p for i, (n, p) in enumerate(params.items())
if (
# is before <= i < after
(drop_before <= i or (drop_after is not None and i < drop_after))
# i not in drop_idxs and n not in drop_names
and (i not in drop_idxs and n not in drop_names)
)
}
return params
class SomeModuleType(NamedTuple):
a: int
b: str
def mwe_func(data:np.ndarray, a_bool:bool=False, a_thing:Optional[SomeModuleType]=None) -> np.ndarray:
'''
Parameters
----------
data : np.ndarray
A numpy array of data
a_bool : bool, default=False
A boolean
a_thing : Optinoal[SomeModuleType]
A thing
'''
# ...
return data
class Class2Subclass:
def expected_method(self, a:int=0):
pass
pass
@func_to_class(mwe_func)
class mwe_class:
pass
No intelli-sense (that is co-pilot suggesting things)
I am working with scanpy and sklearn. Currently I have the following:
import scanpy as sp, anndata as ad, numpy as np, pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from dataclasses import dataclass
import inspect
from functools import wraps
from typing import List, Any, Optional, Callable, Union, Tuple, Iterable, Set, TypeAlias, Type, Dict
def filter_kwargs_for_func(fn: Callable, **kwargs:Optional[dict]):
params = inspect.signature(fn).parameters
return {k:v for k,v in kwargs.items() if k in params}
def filter_kwargs_for_class(cls: Callable, **kwargs:Optional[dict]):
params = inspect.signature(cls.__init__).parameters
return {k:v for k,v in kwargs.items() if k in params}
def wrangle_kwargs_for_func(
fn: Callable,
defaults: Optional[dict]=None,
**kwargs:Optional[dict]
) -> dict:
# copy defaults
params = (defaults or {}).copy()
# update with kwargs of our function
params.update(kwargs or {})
# filter for only the params that other function accepts
params = filter_kwargs_for_func(fn, **params)
return params
def wrangle_kwargs_for_class(
cls: Callable,
defaults: Optional[dict]=None,
**kwargs:Optional[dict]
) -> dict:
# copy defaults
params = (defaults or {}).copy()
# update with kwargs of our class
params.update(kwargs or {})
# filter for only the params that other class accepts
params = filter_kwargs_for_class(cls, **params)
return params
def get_func_params(
fn: Callable,
drop_self: Optional[bool] = True,
drop_before: Optional[int] = 0,
drop_idxs: Optional[List[int]] = list(),
drop_names: Optional[List[str]] = list(),
drop_after: Optional[int] = None,
) -> Dict[str, Any]:
params = inspect.signature(fn).parameters
params = {k: v.default for k, v in params.items()}
if drop_self and 'self' in params:
params.pop('self')
params = {
n: p for i, (n, p) in enumerate(params.items())
if (
# is before <= i < after
(drop_before <= i or (drop_after is not None and i < drop_after))
# i not in drop_idxs and n not in drop_names
and (i not in drop_idxs and n not in drop_names)
)
}
return params
@dataclass
class MyPipeline:
...
def preprocess_data(self, min_genes: int = 200, min_cells: int = 3):
sc.pp.filter_cells(self.data, min_genes=min_genes)
sc.pp.filter_genes(self.data, min_cells=min_cells)
self.data.raw = self.data
sc.pp.normalize_total(self.data, target_sum=1e4)
sc.pp.log1p(self.data)
sc.pp.highly_variable_genes(self.data, min_mean=0.0125, max_mean=3, min_disp=0.5)
self.data = self.data[:, self.data.var.highly_variable]
sc.pp.scale(self.data, max_value=10)
sc.tl.pca(self.data, svd_solver='arpack')
sc.pp.neighbors(self.data, n_neighbors=10, n_pcs=40)
sc.tl.umap(self.data)
Where the focus here is on MyPipeline
. Right now it isn't very flexible because very few keyword arguments are exposed and in the event some functions share the same keyword argument which function it belongs to.
Initially all I wanted was a way to to specify something like
@fn_kwargs(sc.pp.filter_cells)
class FilterCellKWArgs:
pass
...
def pipeline(filter_cells_kwargs:FilterCellKWArgs, ...):
...
...
and then have intellisense show me (or anyone else) what args / keyword arguments these functions have available, what their defaults are and maybe even the docstring of the original function. That doesn't seem too tenable.
So now I am think it be useful to wrap functions like sc.pp.filter_cells
, sc.pp.highly_variable_genes
and sc.pp.scale
as sklearn
operators e.g. BaseEstimator
s / TransformerMixin
/ etc.
As the arguments / keyword arguments would be set on construction and then an sklearn Pipeline
can handle the rest. So I am looking for something like this
@scop(sc.pp.filter_cells)
@dataclass
class FilterCells:
pass
which should be functionally equivalent to
@dataclass
class FilterCells(BaseEstimator, TransformerMixin):
# NOTE: these are the defaults for sc.pp.filter_cells
# you can get them from inspect.signature(sc.pp.filter_cells)
# data: ad.Anndata
min_counts: Optional[int] = None
min_genes: Optional[int] = None
max_counts: Optional[int] = None
max_genes: Optional[int] = None
inplace: bool = True
copy: bool = False
def fit(self, X: ad.AnnData, y=None):
# NOTE: this is a dummy method
# as we don't need to fit anything, just call the wrapped
# function sc.pp.filtered_cells
pass
def transform(self, X):
Y = sc.pp.filter_cells(
X, min_counts=self.min_counts, min_genes=self.min_genes,
max_counts=self.max_counts, max_genes=self.max_genes,
inplace=self.inplace, copy=self.copy
)
return X if self.inplace else Y
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
And I have tried quite a few things (see below).
How can I write a decorator that in one of these functions (or any function really) and creates a dataclass where each argument / keyword-argument is a field and the docstring is copied over for better intellisense support. I don't want to see **kwargs:Any
I want to know what the variables are.
from dataclasses import dataclass, field, fields, _FIELD
from typing import get_type_hints
from inspect import Signature, Parameter
def scop(fn):
params = get_func_params(fn, drop_self=False)
params = {k: v for k, v in params.items() if v is not inspect.Parameter.empty}
def class_decorator(cls):
cls = dataclass(cls) # Ensure cls is a dataclass
# Add fields from fn to cls
for name, default in params.items():
if name not in get_type_hints(cls):
field_obj = field(default=default)
setattr(cls, name, field_obj)
cls.__annotations__[name] = type(default)
# Update __init__ method to include new fields
def __init__(self, **kwargs):
for name, value in kwargs.items():
setattr(self, name, value)
cls.__init__ = __init__
# Update __init__ method signature
sig = inspect.signature(fn)
parameters = [
Parameter(name, Parameter.KEYWORD_ONLY, default=default)
for name, default in params.items()
]
cls.__init__.__signature__ = sig.replace(parameters=parameters)
# Add methods to cls
def fit(self, X, y=None):
return self
def transform(self, X):
kwargs = {f.name: getattr(self, f.name) for f in fields(self)}
fn(X, **kwargs)
return X
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
cls.fit = fit
cls.transform = transform
cls.fit_transform = fit_transform
# Update docstrings
cls.__doc__ = fn.__doc__
cls.fit.__doc__ = fn.__doc__
cls.transform.__doc__ = fn.__doc__
cls.fit_transform.__doc__ = fn.__doc__
return cls
return class_decorator
which does basically work:
@dataclass
@scop(sc.pp.filter_cells)
class FilterCells:
pass
fc = FilterCells(min_cells=3)
print(fc.min_cells)
# 3
But. I still can not see the docstring / args / keyword-args when I am typing FilterCells(...)
and now fit
, fit_transform
just show Any
rather than (X, y=None)
I also loose the BaseEstimator
repr
so there is that...
def scop(fn):
params = get_func_params(fn, drop_self=False)
def class_decorator(cls):
class Wrapper(cls, BaseEstimator, TransformerMixin):
fn_params = {k: v for k, v in params.items() if v is not inspect.Parameter.empty}
def __init__(self, **kwargs):
self.params = {**self.fn_params, **kwargs}
super().__init__()
def fit(self, X, y=None):
return self
def transform(self, X):
fn(X, **self.params)
return X
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
Wrapper.__name__ = cls.__name__
Wrapper.__doc__ = fn.__doc__
Wrapper.__annotations__ = {**cls.__annotations__, **params}
return Wrapper
return class_decorator
that can be used like
@scop(sc.pp.filter_cells)
@dataclass
class FilterCells:
pass
fc = FilterCells(min_genes=200)
fc
Of note:
BaseEstimator
repr
... (stop..class_decorator..Wrapper()fc.min_genes
results in AttributeError
, so we lose access to dataclass fields.def scop(fn):
params = get_func_params(fn, drop_self=False)
class Wrapper(BaseEstimator, TransformerMixin):
@wraps(fn)
def __init__(self, *args, **kwargs) -> None:
super().__init__()
print('ARGS', args)
print('KWARGS', kwargs)
print('SIGNATURE', inspect.signature(sc.pp.filter_cells))
print('PARAMS', params)
@wraps(fn)
def fit(self, X, y=None):
return self
@wraps(fn)
def transform(self, X):
fn(self.data, X)
return self.data
@wraps(fn)
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
Wrapper.__init__.__doc__ = fn.__doc__
Wrapper.__init__.__annotations__ = fn.__annotations__
# Wrapper = inspect.signature(sc.pp.filter_cells)
Wrapper.fit.__doc__ = fn.__doc__
Wrapper.fit.__annotations__ = fn.__annotations__
Wrapper.transform.__doc__ = fn.__doc__
Wrapper.transform.__annotations__ = fn.__annotations__
Wrapper.fit_transform.__doc__ = fn.__doc__
Wrapper.fit_transform.__annotations__ = fn.__annotations__
Wrapper.__name__ = fn.__name__
Wrapper.__doc__ = fn.__doc__
Wrapper.__annotations__ = fn.__annotations__
# methods = {
# '__init__': Wrapper.__init__,
# 'fit': Wrapper.fit,
# 'transform': Wrapper.transform,
# 'fit_transform': Wrapper.fit_transform,
# }
# Wrapper = type(fn.__name__, (BaseEstimator, TransformerMixin), methods)
return Wrapper
but notice that it prints ARGS (<class '__main__.FilterCells'>,)
as this decorator gets called over the class not on class initialization.
Upvotes: 0
Views: 205