Reputation: 1
I want to use a new data representation instead of float for fine-tuning/testing a model (e.g., DNN) in Pytorch. The basic arithmetic operations (add/sub/multiply/division) in my data type is different from floating point. Is it possible to implement these basic operations and force all of functions in Pytorch (e.g., torch.add(), torch.sum(), torch.nn.Linear(), conv2d, etc.) to use this new implementation for basic arithmetic? If so, could you please guide me on how to do this?
Because I think otherwise it takes so much time and effort to first find which functions my model calls (which I dont know how to do it) and, then, I have to replace them one by one. This becomes complicated for a large model.
Thank you!
I found this link from Pytorch that shows how to extend pytorch. But it seems that it is not comprehensive enough to answer my question. https://pytorch.org/docs/stable/notes/extending.html
Upvotes: 0
Views: 271
Reputation: 64
To test or fine-tune a PyTorch model using a new data type with different arithmetic rules for basic operations, you need to define a custom data type by creating a new PyTorch tensor type. This process involves defining custom operations and ensuring that the new data type adheres to the expected behavior during various tensor operations.
Here's a general outline of the steps you can follow:
Here's a simple code:
import torch
class Modulo10Tensor(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
output = input % 10
ctx.save_for_backward(input)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
return grad_input
# Register the custom data type
torch.Tensor.register_hook('modulo10')(Modulo10Tensor.apply)
# Example usage
if __name__ == "__main__":
# Create a model using the custom data type
model = torch.nn.Linear(10, 1).to('modulo10')
# Generate sample input data
input_data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True).to('modulo10')
# Forward pass using the custom data type
output = model(input_data)
print(output) # Output will be modulo 10
# Perform backward pass to update gradients
loss = output.sum()
loss.backward()
# Check gradients
print(input_data.grad)
Good Luck
Upvotes: 0