user8916969
user8916969

Reputation: 1

I am implementing transformers from scratch in pytorch and getting some error in addition in positional encoding part in output layer

I am implementing transformer in pytorch and getting an error when the Positional encoding is applied in the decoder layer that is in the op_positional_encoding = self.positional_encoding(op_embed) part: RuntimeError: The size of tensor a (19) must match the size of tensor b (20) at non-singleton dimension 1 in the part return x + self.pe[:x.shape[1]]

I am attaching all my codes: ** Data Generation code :**

en_tokenizer = spacy.load("en_core_web_sm")
fr_tokenizer = spacy.load("fr_core_news_sm")

def data_process(sentence_en,sentence_fr):
  data = []
  for (en_sen,fr_sen) in zip(sentence_en,sentence_fr):
    data.append((en_sen,fr_sen))
  return data
train_data = data_process(train_en,train_fr)
valid_data = data_process(valid_en,valid_fr)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos> ', ' <eos>']

def yield_tokens_en(data_iter):
    for (en_batch, fr_batch) in data_iter:
        for en_sent in en_batch:
            my_string = ' '.join(en_sent)
            en_tokens = my_string.split()  # Split English sentence into tokens
            yield en_tokens

def yield_tokens_fr(data_iter):
    for (en_batch, fr_batch) in data_iter:
        for fr_sent in fr_batch:
            my_string = ' '.join(fr_sent)
            fr_tokens = my_string.split()  # Split English sentence into tokens
            yield fr_tokens
en_vocab = build_vocab_from_iterator(yield_tokens_en(train_iter),
                                  min_freq=1,
                                  specials=special_symbols,
                                  special_first=True)
fr_vocab = build_vocab_from_iterator(yield_tokens_fr(train_iter),
                                  min_freq=1,
                                  specials=special_symbols,
                                  special_first=True)
en_vocab.set_default_index(UNK_IDX)
fr_vocab.set_default_index(UNK_IDX)

def pad_sequences_to_length(batch, max_length=20, padding_value=0):
    padded_sequences = []
    for seq in batch:
        if len(seq) < max_length:
            # Pad sequence to max_length
            padded_seq = torch.cat([seq, torch.full((max_length - len(seq),), padding_value)])
        else:
            # Truncate sequence if longer than max_length
            padded_seq = seq[:max_length]
        padded_sequences.append(padded_seq)
    return torch.stack(padded_sequences)

def collate_fn(data_batch):
  en_batch = []
  fr_batch = []
  for (en_item,fr_item) in data_batch:
    en_ids = torch.tensor([en_vocab[token] for token in en_item],dtype = torch.long)
    fr_ids = torch.tensor([fr_vocab[token] for token in fr_item],dtype = torch.long)
    en_batch.append(torch.cat([torch.tensor([en_vocab['<bos>']]), en_ids ,torch.tensor([en_vocab['<eos>']])], dim=0))
    fr_batch.append(torch.cat([torch.tensor([fr_vocab['<bos>']]), fr_ids ,torch.tensor([fr_vocab['<eos>']])], dim=0))
  en_batch = pad_sequences_to_length(en_batch,padding_value = PAD_IDX)
  fr_batch = pad_sequences_to_length(fr_batch,padding_value = PAD_IDX)
  return en_batch,fr_batch

train_data_token_iter = DataLoader(train_data, batch_size=32,
                        shuffle=True, collate_fn=collate_fn)
valid_data_token_iter = DataLoader(valid_data, batch_size=32,
                        shuffle=True, collate_fn=collate_fn)

**positional coding part : **

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, embed_dim)
        position = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        print(x.size())
        print(self.pe.size())
        return x + self.pe[:x.shape[1]]

**MultiHead Attention Part : **

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=512, n_heads=8):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % n_heads == 0, "Embedding dimension must be divisible by the number of heads"
        # Initialize dimensions
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.single_head_dim = embed_dim // n_heads

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.single_head_dim)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, embed_dim = x.size()
        return x.view(batch_size, seq_length, self.n_heads, self.single_head_dim).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, single_head_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embed_dim)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

Encoder and Decoder Block

class EncoderBlock(nn.Module):
  def __init__(self,embed_dim,n_heads=8,expansion_factor=4):
    super(EncoderBlock,self).__init__()
    self.attention = MultiHeadAttention(embed_dim,n_heads)
    self.feed_forward = FeedForwardNetwork(embed_dim,expansion_factor*embed_dim)
    self.norm_1 = nn.LayerNorm(embed_dim)
    self.norm_2 = nn.LayerNorm(embed_dim)
    self.dropout = nn.Dropout(0.2)

  def forward(self,x,mask):
    attention_output = self.attention(x,x,x,mask)
    x = self.norm_1(x + self.dropout(attention_output))
    ff_output = self.feed_forward(x)
    x = self.norm_2(x + self.dropout(ff_output))
    return x

class DecoderBlock(nn.Module):
  def __init__(self,embed_dim,n_heads=8,expansion_factor=4):
    super(DecoderBlock,self).__init__()
    self.masked_attention = MultiHeadAttention(embed_dim,n_heads)
    self.norm1 = nn.LayerNorm(embed_dim)
    self.attention = MultiHeadAttention(embed_dim,n_heads)
    self.norm2 = nn.LayerNorm(embed_dim)
    self.feed_forward = FeedForwardNetwork(embed_dim,expansion_factor*embed_dim)
    self.norm3 = nn.LayerNorm(embed_dim)
    self.dropout = nn.Dropout(0.2)

  def forward(self,x,encoder_output,source_mask,target_mask):
    masked_output = self.masked_attention(x,x,x,target_mask)
    x = self.norm1(x + self.dropout(masked_output))
    attention_output = self.attention(x,encoder_output,encoder_output,source_mask)
    x = self.norm2(x + self.dropout(attention_output))
    ff_output = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_output))
    return x

Transformer Block

class Transformer(nn.Module):
  def __init__(self,input_vocab_size,output_vocab_size,device,embed_dim=512,n_heads=8,expansion_factor=4,max_sentence_len=50,num_layers=1):
    super(Transformer,self).__init__()
    self.max_len = max_sentence_len
    # self.tokenizer_en = en_tokenizer
    # self.tokenizer_fr = fr_tokenizer
    self.input_embedding = nn.Embedding(input_vocab_size,embed_dim)
    self.positional_encoding = PositionalEncoding(embed_dim,max_sentence_len)
    self.output_embedding = nn.Embedding(output_vocab_size,embed_dim)
    self.encoder_layer = nn.ModuleList([EncoderBlock(embed_dim,n_heads,expansion_factor)])
    self.decoder_layer = nn.ModuleList([DecoderBlock(embed_dim,n_heads,expansion_factor)])
    self.fc = nn.Linear(embed_dim,output_vocab_size)
    self.dropout = nn.Dropout(0.2)
  def generate_mask(self,source,target):
    source_mask = (source!=0).unsqueeze(1).unsqueeze(2)
    target_mask = (target!=0).unsqueeze(1).unsqueeze(3)
    seq_length = target.size(1)
    nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length,device = device) ,diagonal=1)).bool()
    target_mask = target_mask&nopeak_mask
    return source_mask,target_mask
  def forward(self,source,target):
    source_mask,target_mask = self.generate_mask(source,target)
    ip_embed = self.input_embedding(source)
    ip_positional_encoding = self.positional_encoding(ip_embed)
    ip_embedded = self.dropout(ip_positional_encoding)
    encoder_output = ip_embedded
    for enc_layer in self.encoder_layer:
      encoder_output = enc_layer(encoder_output,source_mask)
    op_embed = self.output_embedding(target)
    op_positional_encoding = self.positional_encoding(op_embed)
    op_embedded = self.dropout(op_positional_encoding)
    decoder_output = op_embedded
    for dec_layer in self.decoder_layer:
      decoder_output = dec_layer(decoder_output,encoder_output,source_mask,target_mask)
    output = self.fc(decoder_output)
    return output

Training Script :

def train(model,data_iterator,optimizer,criterion,clip,target_vocab_size):

  model.train()

  epoch_loss = 0
  epoch_total = 0
  epoch_corrects = 0

  for _,(source,target) in enumerate(data_iterator):
    # print(source)
    # print(f"Source Shape:{source.size()}")
    print(type(source))
    source,target = source.to(device),target.to(device)
    print(f"Target before : {target.size()}")
    optimizer.zero_grad()

    output = model(source,target[:,:-1])
    print(f"Output before : {output.size()}")
    output_reshape = output.contiguous().view(-1, output.shape[-1])
    target = target[:, 1:].contiguous().view(-1)
    print(f"Output after : {output.size()}")
    print(f"Target after : {target.size()}")
    loss = criterion(output_reshape, target)
    _, predicted = torch.max(output, 1)
    corrects = torch.sum(predicted == target).item()
    total = target.size(0)
    accuracy = corrects / total
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()

    epoch_loss+=loss.item()
    epoch_corrects+=corrects
    epoch_total+=total
  return epoch_loss/len(data_iterator),epoch_corrects/epoch_total


source_vocab_size = len(en_vocab)
target_vocab_size = len(fr_vocab)
embed_dim = 256
n_heads = 8
num_layers = 6
expansion_factor = 4
max_sentence_length = 20
transformer_model = Transformer(source_vocab_size,target_vocab_size,device,embed_dim,n_heads,expansion_factor,max_sentence_length)
# transformer_model = Transformer(embed_dim,n_heads,expansion_factor,source_vocab_size,target_vocab_size,max_sentence_length,num_layers)
count_parameters(transformer_model)

N_EPOCHS = 5
CLIP = 1
loss_f = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(transformer_model.parameters(),lr = 0.0001,betas = (0.9,0.98),eps = 1e-9)
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
  start_time = time.time()
  train_loss,train_accuracy = train(transformer_model,train_data_token_iter,optimizer,loss_f,CLIP,target_vocab_size)
  end_time = time.time()
  epoch_min,epoch_sec = epoch_time(start_time,end_time)

  print(f"Epoch {epoch+1:02} completed | Time : {epoch_min}m and {epoch_sec}s")

also the size of target is torch.Size([32, 20]) and x.size : torch.Size([32, 19, 256]) pe.size : torch.Size([1, 20, 256])

when I changed the output as output = model(source, target) then another error : Expected input batch_size (640) to match target batch_size (608).

Upvotes: 0

Views: 81

Answers (0)

Related Questions