Reputation: 11
I have a graph convolution model built and wanted to implement a custom unsupervised loss function like the one shown in below picture:
Where yv is the learned embedding of a node v and rand denotes random sampling operation over all the nodes.
I am new to Pytorch so not sure if this is already implemented. Any help will be sincerely appreciated.
Since I am new to Pytorch, not sure where to get started with implementing a custom loss function.
Upvotes: 1
Views: 146
Reputation: 809
Here's a general template that I use for implementing custom loss functions in PyTorch. I wrote some details as comments. You can use it as a starting point:
import torch
from torch import nn
# Create a class which inherits from `nn.modules.loss._Loss`
class MyCustomLoss(nn.modules.loss._Loss): # pylint: disable=protected-access
# If the loss function has parameters, implement the constructor
# In your case, it seems that there are no parameters
def __init__(
self,
# Add the parameters, if present. Example:
# margin: float
) -> None:
super().__init__()
# An example usage of setting a parameter:
# self.margin: Final[float] = margin
# Implement the forward method, specifying the inputs and outputs to your loss function,
# and the calculation logic
def forward(
self,
# Add your inputs here
) -> torch.Tensor:
pass # Implement the loss function logic, as is shown in the formula you provided
Note: I didn't fully implement this loss function as I wasn't familiar with the specific details of your use case.
After implementing the loss function, you can use it just the way you use PyTorch's built-in loss functions in your training phase. A simple example would be:
# Instantiate the custom loss function
criterion = MyCustomLoss(parameters_if_present)
# Forward pass
model_outputs = model(train_input_or_inputs)
# Calculate loss
loss = criterion(model_outputs)
# Back-propagation and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
Note that the latter example may slightly vary depending on your deep learning task.
Upvotes: 0