Reputation: 13
I would like to subclass torch.Tensor
to make tensors that will always satisfy some user-defined property. For example, I might want my subclassed tensor to represent a categorical probability distribution, so I always want the last dim to sum to one.
I can define my subclass as:
class ValidatedArray(torch.Tensor):
def __init__(self, array: torch.Tensor):
self.validate_array()
def __setitem__(self, key, value):
super().__setitem__(key, value)
self.validate_array()
def validate_array(self):
assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.'
This catches the most likely cases when the tensor might violate my property: at instantiation and while trying to set certain values.
Validation works at instantiation:
>>> array = torch.ones(3,4)
>>> va1 = ValidatedArray(array)
AssertionError: The last dim represents a categorical distribution. It must sum to one.
Validation works when trying to set an invalid value:
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1[0] = 1
AssertionError: The last dim represents a categorical distribution. It must sum to one.
Validation "fails" in these cases. I can make a ValidatedArray
that wouldn't pass validation.
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va2 = va1 + 2
>>> va2
ValidatedArray([[2.2500, 2.2500, 2.2500, 2.2500],
[2.2500, 2.2500, 2.2500, 2.2500],
[2.2500, 2.2500, 2.2500, 2.2500]])
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1.fill_(2.)
>>> va1
ValidatedArray([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
Is there a way I can ensure an instance of ValidatedArray
will always pass my validation method? Is there some inherited method from torch.Tensor
that runs anytime the underlying data is changed? I imagine I could extend that method with my validation.
Note: I don't want to make a container object with setter
and getter
methods to a tensor attribute. I would like my subclass to otherwise be a drop in replacement for a normal torch.Tensor
so I can use all the normal torch operations.
Upvotes: 1
Views: 513
Reputation: 41
Currently the accepted answer given by @jamied157 seems [obviously] to me to be incorrect. But somehow SO still hasn't figured out how to allow comments from idiots like me, so I'll rewrite the accepted answer in a way I think it was intended:
class ValidatedArray: def __init__(self, array: torch.Tensor): self._validate_array(array) self._array = array def _validate_array(self): assert torch.allclose(self._array.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.' @property def array(self): self._validate_array(self._array) return self._array
Upvotes: 0
Reputation: 26004
My Python's a bit rusty but you should be able to achieve your goal by implementing/overriding the required methods/operators/magic methods. For example you could solve the va1 + 2
issue by simply overriding __add__
:
def __add__(self, val):
return ValidatedArray(super().__add__(val2))
You should be able to generalize this to other methods and operators as well, but not all (like e.g. fill_
as it's not implemented directly inside Tensor
), at least not without putting enough thought into the efficiency criterion.
I also agree with the remarks of @Jamied157, and believe at the very least, constructing your stand-alone class which uses Tensor
internally, allows you to have a much easier time interfacing and implementing what you desire.
Upvotes: 0
Reputation: 61
I think I can see what you're trying to do and it makes sense as something that could be useful, you get this nice guarantee that every time you interact with the tensor it obeys some property. I'd probably advise against doing something like this though for these reasons:
If I was in your position, I'd probably write a validate function like you have
def validate_array(tensor: torch.tensor):
# do check here
...
and then every time this became important I'd add an assert
assert validate_array(tensor), "tensor should obey property"
I know this isn't a super exciting way to do this but I've not come across a use case that needs something more heavy weight.
One nice thing you could do with the container class is
class ValidatedArray:
def __init__(self, array: torch.Tensor):
self.validate_array()
self._array = array
def _validate_array(self):
assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.'
@property
def array(self):
self._validate_array()
return self._array
That way, every time the array is accessed, you perform the check - however, for the reasons above I probably wouldn't do this. (Also note how I've removed the subclass ;))
Upvotes: 1