Pave
Pave

Reputation: 45

Custom Loss Function with Principal Component Angle Calculation in PyTorch Not Differentiable?

Approach: I wrote a custom loss function in PyTorch that compares the angular difference between the original (input) and reconstructed images based on their first principal component axes. This involves computing the angle of the first principal component from the reconstructed image and comparing it to the known rotation angle of the input image.

Problem: During training RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn, indicating an issue with differentiation. Additionally, when checking if the loss tensor requires gradients with print(loss.requires_grad), it returns False.

Question: Could the approach I'm using to compute the angle itself be non-differentiable, or is there an issue in my implementation that makes the code non-differentiable?

class AngularDifferenceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    @staticmethod
    def compute_principal_component_angle(image):
        # Ensure the image is a 2D tensor [H, W]
        if image.dim() != 2:
            raise ValueError("Image must be a 2D tensor")
        
        # Find non-zero indices
        non_zero_indices = image.nonzero(as_tuple=False).float()
        
        # Handle empty image case
        if non_zero_indices.size(0) == 0:
            return torch.tensor(0.0, device=image.device)  # Return 0 degree angle
        
        # Compute center of mass
        com = non_zero_indices.mean(dim=0)
        
        # Center indices around the center of mass
        centered_indices = non_zero_indices - com
        
        # Compute covariance matrix
        covariance_matrix = centered_indices.t().mm(centered_indices) / centered_indices.size(0)
        
        # Eigen decomposition
        _, eigenvectors = torch.linalg.eigh(covariance_matrix)
        
        # Principal axis (eigenvector corresponding to the largest eigenvalue)
        principal_axis = eigenvectors[:, -1]
        
        # Compute the angle
        angle_radians = torch.atan2(principal_axis[1], principal_axis[0])
        angle_degrees = torch.rad2deg(angle_radians) % 360

        # Ensure angle_degrees is a scalar and supports gradient
        angle_degrees = angle_degrees.unsqueeze(0)  # Make it a 1D tensor

        return angle_degrees

    def angular_difference(self, angle1, angle2):
        # Ensure angles are 1D tensors
        angle1 = angle1.flatten()
        angle2 = angle2.flatten()

        diff = torch.abs(angle1 - angle2)
        diff = torch.min(diff, 360 - diff)
        return diff.mean()  # Return the mean difference as the loss


Checked for potential issues with the tensors disconnecting from the computational graph, but the loss does not require grad problem persists. Any insights for other functions or suggestions on ensuring the loss function remains differentiable throughout the operation would be greatly appreciated.

Upvotes: 0

Views: 43

Answers (0)

Related Questions