dtr43
dtr43

Reputation: 155

Fusing the features of two classes and feeding them with the fused features

In the main class, there are two classes and firstly the first class is recalled and after that the second class is recalled. I want to use a module that receives the features from these two classes and does some calculations and finally each of the mentioned classes needs to receive the outcome of the module.

The idea that comes in my mind is that importing the feature of the first class in to the second class and after that applying the module in the second class but my question is that in this scenario it is not possible to import the outcome of the module into the first class.

For example for these two classes and the module class:

class first(nn.Module):
    def __init__(self, in_planes=128, out_planes=64, kernel_size=3, stride=1, padding=0):
        super(first, self).__init__()
        self.conv_s = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, groups=in_planes)
        self.bn_s = nn.BatchNorm2d(out_planes)
        self.relu_s = nn.ReLU()

    def forward(self, x):
        x = self.conv_s(x)
        y1 = self.bn_s(x)
        x = self.relu_s(x)
        return x 

class second(nn.Module):
    def __init__(self, in_planes=128, out_planes=64, kernel_size=3, stride=1, padding=0):
        super(second, self).__init__()
        self.conv_s = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, groups=in_planes)
        self.bn_s = nn.BatchNorm2d(out_planes)
        self.relu_s = nn.ReLU()
 def forward(self, x):
        x = self.conv_s(x)
        y2 = self.bn_s(x)
        x = self.relu_s(x)
        return x  

The Module class:

class module(nn.Module):
def __init__(self):
    super(module, self).__init__()
    self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn1h   = nn.BatchNorm2d(64)
    self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn2h   = nn.BatchNorm2d(64)
    self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn3h   = nn.BatchNorm2d(64)
    self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn4h   = nn.BatchNorm2d(64)

    self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn1v   = nn.BatchNorm2d(64)
    self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn2v   = nn.BatchNorm2d(64)
    self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn3v   = nn.BatchNorm2d(64)
    self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
    self.bn4v   = nn.BatchNorm2d(64)

def forward(self, left, down):
    if down.size()[2:] != left.size()[2:]:
        down = F.interpolate(down, size=left.size()[2:], mode='bilinear')
    out1h = F.relu(self.bn1h(self.conv1h(left )), inplace=True)
    out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True)
    out1v = F.relu(self.bn1v(self.conv1v(down )), inplace=True)
    out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True)
    fuse  = out2h*out2v
    out3h = F.relu(self.bn3h(self.conv3h(fuse )), inplace=True)+out1h
    out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True)
    out3v = F.relu(self.bn3v(self.conv3v(fuse )), inplace=True)+out1v
    out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True)
    return out4h, out4v

The order of the classes in the main class is as follows:

class Main(nn.Module):
def __init__(self):
    super(Main, self).__init__()
self.first=first(the required arguments)
self.second=second(the required arguments)
self.features = feature_extractor()


def forward(self, x):
   
   x1, x2 = self.features(x) # as self.features, you can produce 128 convolutional channels
   x1 = self.first(x1)
   x2 = self.first(x2)
return x1, x2

My question is that how it is possible to import the outcome of the module into the first class. To be clearer, after importing the y1 and y2 variable of the first and second classes into the module class, how can I multiply one of the the outcomes of the module class with y1 in the first class and another outcome of the module class with the y2 variable in the second class.This is because if I integrate the module in the second class, although I can import the y1 into the second class and the module, I cannot import back the outcome of the module to the first class and multiply it with y1.

Update: I want that the module class receive the y1 and y2 from the first class and second class respectively but I have not idea how I should integrate the module class inside the code. The image below can show clearer information about the idea:

Upvotes: 0

Views: 142

Answers (1)

Abhinav Mathur
Abhinav Mathur

Reputation: 8186

Since the program you've added is pretty convoluted, I've created similar classes to demonstrate the principle that will solve the problem.

class first:
    # this class gives y1 in your example
    def __init__(self, x = 0):
        self.x = x
    def get_y1(self):
        # do any computations if needed
        print(f"x (y1) has current value {self.x}")
        return self.x
    def calc(self, y):
        self.x += y
        print(f"x (y1) updated to {self.x}")

class second:
    # this class gives y2 in your example
    def __init__(self, x = 0):
        self.x = x
    def get_y2(self):
        # do any computations if needed
        print(f"x (y2) has current value {self.x}")
        return self.x
    def calc(self, y):
        self.x *= y
        print(f"x (y2) updated to {self.x}")

class module:
    # this class takes y1 and y2 for computation and returns results
    def __init__(self):
        pass
    def calc(self, x, y):
        return x+1,y+1

class main:
    def __init__(self):
        self.first = first(x = 5)
        self.second = second(x = 3)
        self.module = module()
    
    def calc(self):
        y1 = self.first.get_y1()
        y2 = self.second.get_y2()
        result1, result2 = self.module.calc(y1, y2)
        self.first.calc(result1)
        self.second.calc(result2)        

obj = main()
obj.calc()

'''
Output for this:
x (y1) has current value 5
x (y2) has current value 3
x (y1) updated to 11
x (y2) updated to 12
'''

This basically does the same thing as you're requesting: main class holds objects for first, second and module; a computation in module uses y1 and y2 from first and second. The returned values are then used to update y1 and y2 themselves.

Upvotes: 2

Related Questions