Doug Tischer
Doug Tischer

Reputation: 13

How do I force a PyTorch tensor to always satisfy some (possibly arbitrary) property?

I would like to subclass torch.Tensor to make tensors that will always satisfy some user-defined property. For example, I might want my subclassed tensor to represent a categorical probability distribution, so I always want the last dim to sum to one.

I can define my subclass as:

class ValidatedArray(torch.Tensor):
    def __init__(self, array: torch.Tensor):
        self.validate_array()
        
    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self.validate_array()
        
    def validate_array(self):
        assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.'  

This catches the most likely cases when the tensor might violate my property: at instantiation and while trying to set certain values.

Validation works at instantiation:

>>> array = torch.ones(3,4)
>>> va1 = ValidatedArray(array)
AssertionError: The last dim represents a categorical distribution. It must sum to one.

Validation works when trying to set an invalid value:

>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1[0] = 1
AssertionError: The last dim represents a categorical distribution. It must sum to one.

Validation "fails" in these cases. I can make a ValidatedArray that wouldn't pass validation.

>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va2 = va1 + 2
>>> va2
ValidatedArray([[2.2500, 2.2500, 2.2500, 2.2500],
                [2.2500, 2.2500, 2.2500, 2.2500],
                [2.2500, 2.2500, 2.2500, 2.2500]])
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1.fill_(2.)
>>> va1
ValidatedArray([[2., 2., 2., 2.],
                [2., 2., 2., 2.],
                [2., 2., 2., 2.]])

Is there a way I can ensure an instance of ValidatedArray will always pass my validation method? Is there some inherited method from torch.Tensor that runs anytime the underlying data is changed? I imagine I could extend that method with my validation.

Note: I don't want to make a container object with setter and getter methods to a tensor attribute. I would like my subclass to otherwise be a drop in replacement for a normal torch.Tensor so I can use all the normal torch operations.

Upvotes: 1

Views: 513

Answers (3)

Hello Worlds
Hello Worlds

Reputation: 41

Currently the accepted answer given by @jamied157 seems [obviously] to me to be incorrect. But somehow SO still hasn't figured out how to allow comments from idiots like me, so I'll rewrite the accepted answer in a way I think it was intended:

class ValidatedArray:
    def __init__(self, array: torch.Tensor):
        self._validate_array(array)
        self._array = array
    
    
    def _validate_array(self):
        assert torch.allclose(self._array.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.' 

    @property
    def array(self):
        self._validate_array(self._array)
        return self._array

Upvotes: 0

Hossein
Hossein

Reputation: 26004

My Python's a bit rusty but you should be able to achieve your goal by implementing/overriding the required methods/operators/magic methods. For example you could solve the va1 + 2 issue by simply overriding __add__:

def __add__(self, val):
    return ValidatedArray(super().__add__(val2))

You should be able to generalize this to other methods and operators as well, but not all (like e.g. fill_ as it's not implemented directly inside Tensor), at least not without putting enough thought into the efficiency criterion.

I also agree with the remarks of @Jamied157, and believe at the very least, constructing your stand-alone class which uses Tensor internally, allows you to have a much easier time interfacing and implementing what you desire.

Upvotes: 0

jamied157
jamied157

Reputation: 61

I think I can see what you're trying to do and it makes sense as something that could be useful, you get this nice guarantee that every time you interact with the tensor it obeys some property. I'd probably advise against doing something like this though for these reasons:

  • You should probably be careful subclassing classes from other packages, it's difficult to know what methods the package maintainers have added to their class and they might be important, you don't want to accidentally overwrite one of them. Unless a package maintainer explicitly says you can subclass, I would avoid it.
  • To do what you're asking would require either making a note of which operations preserve the property you're interested in and then restricting the class to only do those operations somehow. Or making the tensor perform the check after every operation, this could be costly, especially if you're in the middle of a big neural net computation and need to shift data around.

If I was in your position, I'd probably write a validate function like you have

def validate_array(tensor: torch.tensor):
    # do check here
    ...

and then every time this became important I'd add an assert

assert validate_array(tensor), "tensor should obey property"

I know this isn't a super exciting way to do this but I've not come across a use case that needs something more heavy weight.

One nice thing you could do with the container class is

class ValidatedArray:
    def __init__(self, array: torch.Tensor):
        self.validate_array()
        self._array = array
        
        
    def _validate_array(self):
        assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.' 

    @property
    def array(self):
        self._validate_array()
        return self._array 
    

That way, every time the array is accessed, you perform the check - however, for the reasons above I probably wouldn't do this. (Also note how I've removed the subclass ;))

Upvotes: 1

Related Questions