Reputation: 1
I am trying to use transformer encoders (i.e. no cross-attention) to implement an encoder-decoder network for source recovery (refer Transformer-based dimensionality reduction). I am using [1,7040] sized arrays (audio files) as my input, which I am not putting through a positional encoding. I have very low MSELoss (accumulated testing 5e-3 and training 8e-2 over all batches), but my reconstructed output is nothing like the original.
Original (normalized) , Unnormalized output, Normalized output
My models:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
# Ensure that the model dimension (d_model) is divisible by the number of heads
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# Initialize dimensions
self.d_model = d_model # Model's dimension
self.num_heads = num_heads # Number of attention heads
self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
# Linear layers for transforming inputs
self.W_q = nn.Linear(d_model, d_model).to(device) # Query transformation
self.W_k = nn.Linear(d_model, d_model).to(device) # Key transformation
self.W_v = nn.Linear(d_model, d_model).to(device) # Value transformation
self.W_o = nn.Linear(d_model, d_model).to(device) # Output transformation
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Calculate attention scores
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
# Apply mask if provided (useful for preventing attention to certain parts like padding)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# Softmax is applied to obtain attention probabilities
attn_probs = torch.softmax(attn_scores, dim=-1).to(device)
# Multiply by values to obtain the final output
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
# Reshape the input to have num_heads for multi-head attention
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
# Combine the multiple heads back to original shape
batch_size, _, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, mask=None):
# Apply linear transformations and split heads
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# Perform scaled dot-product attention
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
# Combine heads and apply output transformation
output = self.W_o(self.combine_heads(attn_output))
return output
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model,layer_size):
super(PositionWiseFeedForward, self).__init__()
assert layer_size[0]==d_model, "first layer size should go from input size to defined size"
self.layersize=layer_size
self.relu=nn.ReLU().to(device)
def module(self):
layerdict=dict()
for i in range(len(self.layersize)-1):
layer=nn.Linear(self.layersize[i],self.layersize[i+1])
layerdict[f"layer{i}"]=layer
self.ffn=OrderedDict(layerdict)
return self.ffn
def forward(self, x):
ffn=self.module()
ffn=nn.Sequential(ffn).to(device)
return self.relu(ffn(x))
# %%
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, layer_size, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model,layer_size)
self.norm1 = nn.LayerNorm(d_model).to(device)
self.norm2 = nn.LayerNorm(layer_size[-1]).to(device) #for second residual connection
self.dropout = nn.Dropout(dropout).to(device)
def forward(self, x, mask=None):
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.norm2(self.feed_forward(x))
#ff_output = self.norm2(ff_output + self.dropout(ff_output)) '''because this residual connection just doesn't make sense, at least so far'''
return ff_output
The "Channel Layer" you see in the training loop:
class Chan_Model(object):
"""Define MIMO channel model."""
def __init__(self, name):
self.name = name
def __call__(self, _input, std=0.1):
_input = torch.permute(_input,(0,2,1))
batch_size = _input.shape[0]
_shape = list(_input.shape)
#assert (_shape[2]*_shape[3]) % 2 == 0, "number of transmitted symbols must be an integer."
# reshape layer and normalize the average power of each dim in x into 0.5
x = _input.view((batch_size, _shape[1], _shape[2]//2, 2)).to(device)
#print(x.shape)
x_norm = torch.norm(x, dim=2,keepdim=True)
x_real = x_norm[:, :, :, 0]
x_imag = x_norm[:, :, :, 1]
x_complex = torch.complex(real=x_real, imag=x_imag)
# channel h
h = torch.normal(size=(batch_size, _shape[1], 1, 2),mean=0.0,std=1.0).to(device)
h = (torch.sqrt(torch.tensor(1./2.)) + torch.sqrt(torch.tensor(1./2.))*h) / torch.sqrt(torch.tensor(2.))
h_real = h[:, :, :, 0]
h_imag = h[:, :, :, 1]
h_complex =torch.complex(real=h_real, imag=h_imag)
# noise n
n=torch.normal(size=tuple(x.shape),mean=0.0,std=std).to(device)
n_real = n[:, :, :, 0]
n_imag = n[:, :, :, 1]
n_complex = torch.complex(real=n_real, imag=n_imag)
# receive y
y_complex = torch.mul(h_complex, x_complex) + n_complex
# estimate x_hat with perfect CSI
x_hat_complex = torch.div(y_complex, h_complex)
# convert complex to real
x_hat_real = torch.unsqueeze(torch.real(x_hat_complex), dim=-1)
x_hat_imag = torch.unsqueeze(torch.imag(x_hat_complex), dim=-1)
x_hat = torch.cat([x_hat_real, x_hat_imag], dim=-1)
_output = torch.reshape(x_hat, shape=_input.shape)
#print(_output.shape)
_output = _output.permute(0,2,1)
return _output
I got the suggestion of normalizing outputs (before calculating loss in the training loop). I did, but my output is still very different.
-1) Is my training loop correct? Are my input and output normalizations placed correctly, i.e should I normalize the entire dataset before training, or is it okay to call F.normalize in the training loop for every batch?
-2) Is my loss calculation correct - should I divide the accumulated loss by the total number of batches (then my train loss would be 8e-2/(number of batches) )?
-3) Training usually tapers in around 25 epochs, should I do something differently?
-4) Should I add positional encodings?
-5) Could there be an issue with test-train splits?
for epoch in range(epochs):
t1=time.time()
train_loss=0
val_loss=0
for x in train_loader:
x=x.to(device)
x=nn.functional.normalize(x,dim=2)
optimizer.zero_grad()
SE_Encoder1_out=SE_Encoder1(x).squeeze()
TF_Encoder1_out = TF_Encoder1(SE_Encoder1_out)
TF_Encoder2_out = TF_Encoder2(TF_Encoder1_out)
Channel_out=L_channelmodel(TF_Encoder2_out)
TF_Decoder1_out = TF_Decoder1(Channel_out)
TF_Decoder2_out = TF_Decoder2(TF_Decoder1_out)
TF_Decoder2_out=TF_Decoder2_out.unsqueeze(-1)
TF_Decoder2_out=TransConv_BN(TF_Decoder2_out).squeeze(-1)
y=nn.functional.normalize(TF_Decoder2_out,dim=2)
if TF_Decoder2_out.shape!=x.shape:
break
loss = criterion(y, x)
loss.backward()
optimizer.step()
train_loss+=loss.item()
t2=time.time()
train_loss_list.append(train_loss)
print(f"epoch {epoch} | train loss: {train_loss} | time:{t2-t1}")
Upvotes: 0
Views: 26