Reputation: 479
I am trying to implement a FCN in pytorch
with the overall structure as below:
The code so far looks like below:
class SNet(nn.Module):
def __init__(self):
super(SNet, self).__init__()
self.enc_a = encoder(...)
self.dec_a = decoder(...)
self.enc_b = encoder(...)
self.dec_b = decoder(...)
def forward(self, x1, x2):
x1 = self.enc_a(x1)
x2 = self.enc_b(x2)
x2 = self.dec_b(x2)
x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
return x1, x2
In keras
it is relatively easy to do this using the functional API. However, I could not find any concrete example / tutorial to do this in pytorch
.
dec_a
(decoder part of autoencoder branch) after training?loss
will be sum (optionally weighted) of the loss
from both the branch?Upvotes: 1
Views: 278
Reputation: 3958
You can also define separate modes for your model for training and inference:
class SNet(nn.Module):
def __init__(self):
super(SNet, self).__init__()
self.enc_a = encoder(...)
self.dec_a = decoder(...)
self.enc_b = encoder(...)
self.dec_b = decoder(...)
self.training = True
def forward(self, x1, x2):
if self.training:
x1 = self.enc_a(x1)
x2 = self.enc_b(x2)
x2 = self.dec_b(x2)
x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
return x1, x2
else:
x1 = self.enc_a(x1)
x2 = self.enc_b(x2)
x2 = self.dec_b(x2)
return x2
These blocks are examples and may not do exactly what you want because I think there is a bit of ambiguity between how you define the training and inference operations in your block chart vs. your code, but in any case you get the idea of how you can use some modules only during training mode. Then you can just set this variable accordingly.
Upvotes: 3