Reputation: 173
I am implementing a series of classes in Equinox
to enable taking derivatives with respect to the class parameters. Most of the time, the user will be instantiating class A
and using the fn
function to generate some data, the details of which are unimportant. However, in cases where we are interested in gradients, it is beneficial to represent param_c
in terms of a sigmoid function to ensure that it remains clamped in the range (0,1). However, I don't want the user to notice a difference in how the class behaves if they do this. As such, I implement another class A_sigmoid
that has param_c
as a property
and use A_abstract
to ensure that both classes inherit the fn
method, which will call param_c
in its logic. While I could simply have the user instantiate an A_sigmoid
object with a _param_c_sigmoid
instead of param_c
I don't want to force the user to have to make this distinction. Rather, I would want them to pass in the same kwargs
dictionary no matter the class and have conversion happen behind the scenes. I also wanted to make it so that when making a new A
one could simply pass an optional flag to direct the program to use the sigmoid version of the code. To do so, I implemented the following MWE:
class A_abstract(eqx.Module):
param_a: jax.Array
param_b: jax.Array
param_c: eqx.AbstractVar[jax.Array]
def fn(self,*args,**kwargs):
pass
class A_sigmoid(A_abstract):
_param_c_sigmoid: jax.Array
@property
def param_c(self):
return 1 / (1 + jnp.exp(-self._param_c_sigmoid))
class A(A_abstract):
param_c: jax.Array
def __new__(cls, **kwargs):
sigmoid_flag = kwargs.pop('use_sigmoid_c',False)
if sigmoid_flag == True:
param_c = kwargs.pop('param_c')
_param_c_sigmoid = jnp.log(param_c / (1 - param_c))
kwargs['_param_c_sigmoid'] = _param_c_sigmoid
instance = A_sigmoid.__new__(A_sigmoid)
instance.__init__(**kwargs)
print(type(instance))
return instance
else:
return super(A,cls).__new__(cls)
classA = A(param_a = 1.,param_b = 2.,param_c = 0.5,use_sigmoid_c=True)
print(type(classA))
The code correctly says that instance
has type A_sigmoid
when print
is called in the __new__
method. However, when I print type(classA)
, it is of type A
and has no attribute param_c
, though it does have a value for _param_c_sigmoid
. Why is this the case? Am I missing something in my use of __new__
that is causing this error? While I know that in principle a factory would be the best way to do this, there are other classes of types B
, C
, etc. that don't have this need for a sigmoid implementation and that I would like to behave exactly the same way as A
to enable them to be easily swapped. Thus, I don't want some custom method to instantiate A
that would be different from calling the default constructor on the other classes.
I am running this on a Jupyter notebook with the following package versions:
Python : 3.12.4
IPython : 8.30.0
ipykernel : 6.29.5
jupyter_client : 8.6.3
jupyter_core : 5.7.2
Upvotes: 2
Views: 54
Reputation: 86443
If you were using a normal class, what you did is perfectly reasonable:
class A_abstract:
pass
class A_sigmoid(A_abstract):
pass
class A(A_abstract):
def __new__(cls, flag, **kwds):
if flag:
instance = A_sigmoid.__new__(A_sigmoid)
else:
instance = super().__new__(cls)
instance.__init__(**kwds)
return instance
print(type(A(True))) # <class '__main__.A_sigmoid'>
However, eqx.Module
includes a bunch of metaclass logic that overrides how __new__
works, and this seems to collide with the __new__
overrides that you're making. Notice here the only difference is that A_abstract
inherits from eqx.Module
, and the result is A
rather than A_sigmoid
:
import equinox as eqx
class A_abstract(eqx.Module):
pass
class A_sigmoid(A_abstract):
pass
class A(A_abstract):
def __new__(cls, flag, **kwds):
if flag:
instance = A_sigmoid.__new__(A_sigmoid)
else:
instance = super().__new__(cls)
instance.__init__(**kwds)
return instance
print(type(A(True))) # <class '__main__.A'>
I dug-in for a few minutes to try and find the exact cause of this change, but wasn't able to pin it down.
If you're trying to do metaprogramming during class construction, you'll have to modify it to work within the construction-time metaprogramming that equinox is already doing.
Upvotes: 0