Reputation: 55
Given some simple PyTorch model, how do I compute the Fisher Metric?
Here's a (useless for practical purposes) trivial model which uses a single linear layer to solve the matrix equation Ax=b, where A is a 3x3 matrix, while both x and b are 3x1 column vectors. Given A and b, what's x? The problem isn't important.
import torch
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self, input_dim, output_dim):
super(Net, self).__init__()
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, x):
out = self.linear(x)
return out
# Define the training data
A = torch.tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
b = torch.tensor([[52.],
[124.],
[196.]])
# Define the model and the optimizer
model = Net(input_dim=9, output_dim=3)
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Train the model
for epoch in range(2000):
optimizer.zero_grad()
y_pred = model(A.reshape(9))
print(A@y_pred[:3])
loss = nn.MSELoss(reduction='sum')(A@y_pred.view((3,1)), b)
loss.backward()
optimizer.step()
# Evaluate the model
with torch.no_grad():
y_pred = model(A.reshape(9))
print("Solution:\n", y_pred)
From this, I would like to calculate the Fisher Metric of the model. I am trying to use the NNGeometry package, which requires a dataloader, so I create another uselessly trivial snippet with a single batch containing the trained matrix A:
class TrivialDataset(Dataset):
def __init__(self):
self.data = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).reshape(1,9)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# Create the DataLoader
dataset = TrivialDataset()
loader = DataLoader(dataset, batch_size=1)
Finally I try to generate the FIM:
from nngeometry.metrics import FIM
from nngeometry.object import PMatDense
fisher_metric = FIM(model, loader, n_output=1, variant='regression', representation=PMatDense, device='cpu')
but get an error:
RuntimeError: shape '[9, 1]' is invalid for input of size 3
I can see that the problem is coming from a view that had to be made, but surely NNGeometry must be able to handle models where the number of input dimensions is bigger than the number of outputs (like in classification for example)?
Can I circumvent this? Is there a good alternative to NNGeometry?
Upvotes: 0
Views: 872
Reputation: 196
NNGeometry library expects the model to output a tensor of shape (batch_size, n_output), but your model outputs a tensor of shape (n_output,) Thus a change to the forward pass is required to get the right shape that is expected by the NNGeometry library.
class Net(nn.Module):
def __init__(self, input_dim, output_dim):
super(Net, self).__init__()
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, x):
out = self.linear(x)
return out.view(-1, self.linear.out_features) ### Change from your forward pass.
# Train the model
for epoch in range(2000):
optimizer.zero_grad()
y_pred = model(A.view(1, -1))
print(A @ y_pred[:3])
loss = nn.MSELoss(reduction='sum')(A @ y_pred.view((3, 1)), b)
loss.backward()
optimizer.step()
Upvotes: 0