Reputation: 45
Approach: I wrote a custom loss function in PyTorch that compares the angular difference between the original (input) and reconstructed images based on their first principal component axes. This involves computing the angle of the first principal component from the reconstructed image and comparing it to the known rotation angle of the input image.
Problem: During training RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
, indicating an issue with differentiation. Additionally, when checking if the loss tensor requires gradients with print(loss.requires_grad)
, it returns False
.
Question: Could the approach I'm using to compute the angle itself be non-differentiable, or is there an issue in my implementation that makes the code non-differentiable?
class AngularDifferenceLoss(nn.Module):
def __init__(self):
super().__init__()
@staticmethod
def compute_principal_component_angle(image):
# Ensure the image is a 2D tensor [H, W]
if image.dim() != 2:
raise ValueError("Image must be a 2D tensor")
# Find non-zero indices
non_zero_indices = image.nonzero(as_tuple=False).float()
# Handle empty image case
if non_zero_indices.size(0) == 0:
return torch.tensor(0.0, device=image.device) # Return 0 degree angle
# Compute center of mass
com = non_zero_indices.mean(dim=0)
# Center indices around the center of mass
centered_indices = non_zero_indices - com
# Compute covariance matrix
covariance_matrix = centered_indices.t().mm(centered_indices) / centered_indices.size(0)
# Eigen decomposition
_, eigenvectors = torch.linalg.eigh(covariance_matrix)
# Principal axis (eigenvector corresponding to the largest eigenvalue)
principal_axis = eigenvectors[:, -1]
# Compute the angle
angle_radians = torch.atan2(principal_axis[1], principal_axis[0])
angle_degrees = torch.rad2deg(angle_radians) % 360
# Ensure angle_degrees is a scalar and supports gradient
angle_degrees = angle_degrees.unsqueeze(0) # Make it a 1D tensor
return angle_degrees
def angular_difference(self, angle1, angle2):
# Ensure angles are 1D tensors
angle1 = angle1.flatten()
angle2 = angle2.flatten()
diff = torch.abs(angle1 - angle2)
diff = torch.min(diff, 360 - diff)
return diff.mean() # Return the mean difference as the loss
Checked for potential issues with the tensors disconnecting from the computational graph, but the loss does not require grad problem persists. Any insights for other functions or suggestions on ensuring the loss function remains differentiable throughout the operation would be greatly appreciated.
Upvotes: 0
Views: 43