KoPa
KoPa

Reputation: 1

Should I normalize outputs while training a source recovery task?

I am trying to use transformer encoders (i.e. no cross-attention) to implement an encoder-decoder network for source recovery (refer Transformer-based dimensionality reduction). I am using [1,7040] sized arrays (audio files) as my input, which I am not putting through a positional encoding. I have very low MSELoss (accumulated testing 5e-3 and training 8e-2 over all batches), but my reconstructed output is nothing like the original.

Original (normalized) , Unnormalized output, Normalized output

My models:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model).to(device) # Query transformation
        self.W_k = nn.Linear(d_model, d_model).to(device) # Key transformation
        self.W_v = nn.Linear(d_model, d_model).to(device) # Value transformation
        self.W_o = nn.Linear(d_model, d_model).to(device) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1).to(device)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model,layer_size):
        super(PositionWiseFeedForward, self).__init__()
        assert layer_size[0]==d_model, "first layer size should go from input size to defined size"
        self.layersize=layer_size
        self.relu=nn.ReLU().to(device)

    def module(self):
        layerdict=dict()
        for i in range(len(self.layersize)-1):
            layer=nn.Linear(self.layersize[i],self.layersize[i+1])
            layerdict[f"layer{i}"]=layer
        self.ffn=OrderedDict(layerdict)
        return self.ffn
    
    def forward(self, x):
        ffn=self.module()
        ffn=nn.Sequential(ffn).to(device)
        return self.relu(ffn(x))

# %%
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, layer_size, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model,layer_size)
        self.norm1 = nn.LayerNorm(d_model).to(device)
        self.norm2 = nn.LayerNorm(layer_size[-1]).to(device) #for second residual connection
        self.dropout = nn.Dropout(dropout).to(device)
        
    def forward(self, x, mask=None):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.norm2(self.feed_forward(x))
        #ff_output = self.norm2(ff_output + self.dropout(ff_output)) '''because this residual connection just doesn't make sense, at least so far'''
        return ff_output

The "Channel Layer" you see in the training loop:

class Chan_Model(object):
    """Define MIMO channel model."""
    def __init__(self, name):

        self.name = name

    def __call__(self, _input, std=0.1):
        
        _input = torch.permute(_input,(0,2,1))

        batch_size = _input.shape[0]
        _shape = list(_input.shape)
        #assert (_shape[2]*_shape[3]) % 2 == 0, "number of transmitted symbols must be an integer."

        # reshape layer and normalize the average power of each dim in x into 0.5
        x = _input.view((batch_size, _shape[1], _shape[2]//2, 2)).to(device)
        #print(x.shape)
        x_norm = torch.norm(x, dim=2,keepdim=True)

        x_real = x_norm[:, :, :, 0]
        x_imag = x_norm[:, :, :, 1]
        x_complex = torch.complex(real=x_real, imag=x_imag)

        # channel h
        h = torch.normal(size=(batch_size, _shape[1], 1, 2),mean=0.0,std=1.0).to(device)
        h = (torch.sqrt(torch.tensor(1./2.)) + torch.sqrt(torch.tensor(1./2.))*h) / torch.sqrt(torch.tensor(2.))
        h_real = h[:, :, :, 0]
        h_imag = h[:, :, :, 1]
        h_complex =torch.complex(real=h_real, imag=h_imag)

        # noise n
        n=torch.normal(size=tuple(x.shape),mean=0.0,std=std).to(device)

        n_real = n[:, :, :, 0]
        n_imag = n[:, :, :, 1]
        n_complex = torch.complex(real=n_real, imag=n_imag)

        # receive y
        y_complex = torch.mul(h_complex, x_complex) + n_complex

        # estimate x_hat with perfect CSI
        x_hat_complex = torch.div(y_complex, h_complex)

        # convert complex to real
        x_hat_real = torch.unsqueeze(torch.real(x_hat_complex), dim=-1)
        x_hat_imag = torch.unsqueeze(torch.imag(x_hat_complex), dim=-1)
        x_hat = torch.cat([x_hat_real, x_hat_imag], dim=-1)

        _output = torch.reshape(x_hat, shape=_input.shape)
        #print(_output.shape)
        _output = _output.permute(0,2,1)

        return _output

I got the suggestion of normalizing outputs (before calculating loss in the training loop). I did, but my output is still very different.
-1) Is my training loop correct? Are my input and output normalizations placed correctly, i.e should I normalize the entire dataset before training, or is it okay to call F.normalize in the training loop for every batch?

-2) Is my loss calculation correct - should I divide the accumulated loss by the total number of batches (then my train loss would be 8e-2/(number of batches) )?

-3) Training usually tapers in around 25 epochs, should I do something differently?

-4) Should I add positional encodings?

-5) Could there be an issue with test-train splits?

for epoch in range(epochs):
        t1=time.time()
        train_loss=0
        val_loss=0
        for x in train_loader:
            x=x.to(device)
            x=nn.functional.normalize(x,dim=2)
            optimizer.zero_grad()

            SE_Encoder1_out=SE_Encoder1(x).squeeze()
            TF_Encoder1_out = TF_Encoder1(SE_Encoder1_out)
            TF_Encoder2_out = TF_Encoder2(TF_Encoder1_out)
            Channel_out=L_channelmodel(TF_Encoder2_out)
            TF_Decoder1_out = TF_Decoder1(Channel_out)
            TF_Decoder2_out = TF_Decoder2(TF_Decoder1_out)
            TF_Decoder2_out=TF_Decoder2_out.unsqueeze(-1)
            TF_Decoder2_out=TransConv_BN(TF_Decoder2_out).squeeze(-1)
            y=nn.functional.normalize(TF_Decoder2_out,dim=2)
            
            if TF_Decoder2_out.shape!=x.shape:
              break
            
            loss = criterion(y, x)
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
        
        t2=time.time()
        train_loss_list.append(train_loss)
        print(f"epoch {epoch} | train loss: {train_loss} | time:{t2-t1}")

Upvotes: 0

Views: 26

Answers (0)

Related Questions