JamesVR
JamesVR

Reputation: 173

Return a different class based on an optional flag in the arguments without factory

I am implementing a series of classes in Equinox to enable taking derivatives with respect to the class parameters. Most of the time, the user will be instantiating class A and using the fn function to generate some data, the details of which are unimportant. However, in cases where we are interested in gradients, it is beneficial to represent param_c in terms of a sigmoid function to ensure that it remains clamped in the range (0,1). However, I don't want the user to notice a difference in how the class behaves if they do this. As such, I implement another class A_sigmoid that has param_c as a property and use A_abstract to ensure that both classes inherit the fn method, which will call param_c in its logic. While I could simply have the user instantiate an A_sigmoid object with a _param_c_sigmoid instead of param_c I don't want to force the user to have to make this distinction. Rather, I would want them to pass in the same kwargs dictionary no matter the class and have conversion happen behind the scenes. I also wanted to make it so that when making a new A one could simply pass an optional flag to direct the program to use the sigmoid version of the code. To do so, I implemented the following MWE:

class A_abstract(eqx.Module):
    param_a: jax.Array
    param_b: jax.Array
    param_c: eqx.AbstractVar[jax.Array]
    
    def fn(self,*args,**kwargs):
        pass

class A_sigmoid(A_abstract):
    _param_c_sigmoid: jax.Array

    @property
    def param_c(self):
        return 1 / (1 + jnp.exp(-self._param_c_sigmoid))

class A(A_abstract):
    param_c: jax.Array

    def __new__(cls, **kwargs):
        sigmoid_flag = kwargs.pop('use_sigmoid_c',False)
        if sigmoid_flag == True:
            param_c = kwargs.pop('param_c')
            _param_c_sigmoid = jnp.log(param_c / (1 - param_c))
            kwargs['_param_c_sigmoid'] = _param_c_sigmoid
            instance = A_sigmoid.__new__(A_sigmoid)
            instance.__init__(**kwargs)
            print(type(instance))
            return instance
        else:
            return super(A,cls).__new__(cls)

classA = A(param_a = 1.,param_b = 2.,param_c = 0.5,use_sigmoid_c=True)
print(type(classA))

The code correctly says that instance has type A_sigmoid when print is called in the __new__ method. However, when I print type(classA), it is of type A and has no attribute param_c, though it does have a value for _param_c_sigmoid. Why is this the case? Am I missing something in my use of __new__ that is causing this error? While I know that in principle a factory would be the best way to do this, there are other classes of types B, C, etc. that don't have this need for a sigmoid implementation and that I would like to behave exactly the same way as A to enable them to be easily swapped. Thus, I don't want some custom method to instantiate A that would be different from calling the default constructor on the other classes.

I am running this on a Jupyter notebook with the following package versions:

Python           : 3.12.4
IPython          : 8.30.0
ipykernel        : 6.29.5
jupyter_client   : 8.6.3
jupyter_core     : 5.7.2

Upvotes: 2

Views: 54

Answers (1)

jakevdp
jakevdp

Reputation: 86443

If you were using a normal class, what you did is perfectly reasonable:

class A_abstract:
  pass

class A_sigmoid(A_abstract):
  pass

class A(A_abstract):
  def __new__(cls, flag, **kwds):
    if flag:
      instance = A_sigmoid.__new__(A_sigmoid)
    else:
      instance = super().__new__(cls)
    instance.__init__(**kwds)
    return instance

print(type(A(True))) # <class '__main__.A_sigmoid'>

However, eqx.Module includes a bunch of metaclass logic that overrides how __new__ works, and this seems to collide with the __new__ overrides that you're making. Notice here the only difference is that A_abstract inherits from eqx.Module, and the result is A rather than A_sigmoid:

import equinox as eqx

class A_abstract(eqx.Module):
  pass

class A_sigmoid(A_abstract):
  pass

class A(A_abstract):
  def __new__(cls, flag, **kwds):
    if flag:
      instance = A_sigmoid.__new__(A_sigmoid)
    else:
      instance = super().__new__(cls)
    instance.__init__(**kwds)
    return instance

print(type(A(True))) # <class '__main__.A'>

I dug-in for a few minutes to try and find the exact cause of this change, but wasn't able to pin it down.

If you're trying to do metaprogramming during class construction, you'll have to modify it to work within the construction-time metaprogramming that equinox is already doing.

Upvotes: 0

Related Questions