iamzhangrl
iamzhangrl

Reputation: 1

Prediction of my Transfomer model during training is totally constructed of pad tokens

I am reproducing a paper named TableFormer: Table Structure Understanding with Transformers. It predicts a table image's HTML structure with a resnet18 as backbone, a 2 layers encoder and 4 layers decoder. After tokenizing and padding the target HTML structure to a fixed length I started to train the model. But after a few epoch's learning the predicted tags are all pad tokens like below where 23 is the index of predefine token set. And btw setting ignore_index=23 in CELoss makes the predictions all 27 or 19 which indicates <td> and </td>.

tensor([[23, 23, ..., 23, 23]], device='cuda:0')

The model code and tokenization code and traing_step code are below: Model code: code before cell bbox decoder part is for structure prediction

import math
import torch
from torch import nn, Tensor
import torchvision
from ..datas.utils import PREDEFINED_SET
from positional_encodings.torch_encodings import PositionalEncoding1D, PositionalEncoding2D, Summer


class FeatureExtractor(nn.Module):
    def __init__(self, encoded_image_size=28, device='cuda'):
        super(FeatureExtractor, self).__init__()
        self.enc_image_size = encoded_image_size

        resnet = torchvision.models.resnet18(pretrained=True)  # pretrained ImageNet ResNet-18

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules).to(device)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

    def forward(self, images):
        """
        args:
            images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        return: encoded images
        """
        images = images.float()  # TODO remove this after add normalize in transforms
        out = self.resnet(images)
        out = self.adaptive_pool(out)  # (batch_size, 512, enc_image_size, enc_image_size)
        out = out.permute(0, 2, 3, 1)  #  (batch_size, enc_image_size, enc_image_size, 512)
        return out


class MLP(nn.Module):
    def __init__(self, in_features=512, out_features=4, hidden_dim=512):
        super().__init__()
        op_list = []
        op_list.append(nn.Linear(in_features, hidden_dim))
        op_list.append(nn.ReLU())
        op_list.append(nn.Linear(hidden_dim, hidden_dim))
        op_list.append(nn.ReLU())
        op_list.append(nn.Linear(hidden_dim, out_features))
        op_list.append(nn.ReLU())
        self.mlp = nn.Sequential(*op_list)

    def forward(self, x):
        return self.mlp(x)


class TableFormer(nn.Module):
    def __init__(
            self,
            d_model=512,
            enc_nhead=4,
            enc_dim_feedforward=1024,
            enc_num_layers=2,
            dec_nhead=4,
            dec_dim_feedforward=1024,
            dec_num_layers=4,
            max_len=512,
            device='cuda'
    ):
        super().__init__()
        self.feature_extractor = FeatureExtractor(device=device)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=enc_nhead,
                                                   dim_feedforward=enc_dim_feedforward, dropout=0.5)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=enc_num_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=dec_nhead,
                                                   dim_feedforward=dec_dim_feedforward, dropout=0.5)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=dec_num_layers)
        self.linear = nn.Linear(in_features=d_model, out_features=len(PREDEFINED_SET))
        self.bbox_decoder_linear = nn.Linear(in_features=d_model, out_features=d_model)
        self.softmax = nn.Softmax(dim=2)  # TODO (l, b, len(vocab))
        self.sigmoid = nn.Sigmoid()

        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=1, dropout=0.5)

        self.mlp = MLP(in_features=d_model, out_features=4, hidden_dim=512)
        self.cls = nn.Linear(in_features=d_model, out_features=2)

        self.src_pe = Summer(PositionalEncoding2D(d_model))
        self.target_pe = Summer(PositionalEncoding1D(d_model))

        self.src_embed = nn.Embedding(num_embeddings=len(PREDEFINED_SET), embedding_dim=d_model)
        self.tgt_embed = nn.Embedding(num_embeddings=len(PREDEFINED_SET), embedding_dim=d_model)

        self.tgt_mask = nn.Transformer.generate_square_subsequent_mask(max_len + 1)

        self.pad_idx = PREDEFINED_SET.index("<pad>")
        self.td_a_idx = PREDEFINED_SET.index("<td>")
        self.td_b_idx = PREDEFINED_SET.index("<td")
        self.start_idx = PREDEFINED_SET.index("<start>")
        self.end_idx = PREDEFINED_SET.index("<end>")

        self.backbone_params = list(self.feature_extractor.parameters())
        self.structure_decoder_params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + \
                                          list(self.linear.parameters()) + list(self.softmax.parameters())
        self.bbox_decoder_params = list(self.bbox_decoder_linear.parameters()) + list(self.attention.parameters()) + \
                                     list(self.mlp.parameters()) + list(self.cls.parameters()) + list(self.sigmoid.parameters())

    def forward(self, img, y_tags, y_bboxes, y_classes):
        # backbone part
        x_f = self.feature_extractor(img)  # x_f : (b, h, w, c)
        x_f = x_f.permute(1, 2, 0, 3)  # (b, h, w, c) => (h, w, b, c)
        x_pe = self.src_pe(x_f)
        x_pe = x_pe.view(x_pe.shape[0] * x_pe.shape[1], -1, x_pe.shape[3])  # (l, b, c) (784, 1, 512)
        # structure decoder part
        enc_out = self.encoder(x_pe)  # (28*28, 1, 512)
        if y_tags is None:
            y_src = torch.tensor([self.start_idx]).unsqueeze(1).to(enc_out.device)
            terminal = False
            while not terminal:
                padding_mask = self.get_padding_mask(y_src[-1])
                dec_out = self.decoder(
                    tgt=self.target_pe(self.tgt_embed(y_src)),
                    memory=enc_out,
                    tgt_mask=None,
                    tgt_key_padding_mask=padding_mask.permute(1, 0),
                    memory_mask=None)
                dec_out_idx = torch.argmax(self.softmax(self.linear(dec_out)), dim=2)
                if dec_out_idx[-1] == self.end_idx:
                    terminal = True
                    break
                y_src = torch.cat([y_src, dec_out_idx[-1].unsqueeze(0)], dim=0)


        else:
            y_src = y_tags[:, :-1].permute(1, 0) # (seq_len, batch_size)
            # y_src = y_src[(y_src != self.pad_idx) & (y_src != self.end_idx)].unsqueeze(1)  # 
            # self.tgt_mask = nn.Transformer.generate_square_subsequent_mask(y_src.size(0))
            padding_mask = self.get_padding_mask(y_src)
            tgt_pe = self.target_pe(self.tgt_embed(y_src))
            dec_out = self.decoder(
                tgt=tgt_pe,
                memory=enc_out,
                tgt_mask=self.tgt_mask.to(enc_out.device),
                tgt_key_padding_mask=padding_mask.permute(1, 0),
                memory_mask=None)
            dec_out_ = self.get_tds(y_src, dec_out)
            pred_tags = self.softmax(self.linear(dec_out))  # (seq_len, batch_size, len(vocab))  
            # print(torch.argmax(pred_tags, dim=-1))
            # cell bbox decoder part
            linear_out = self.bbox_decoder_linear(enc_out)
            attn_out, _ = self.attention(dec_out_, linear_out, enc_out)  # q, k, v,
            attn_out = self.sigmoid(attn_out)
            pred_boxes, pred_clses = self.mlp(attn_out), self.cls(attn_out)
            return enc_out, pred_tags, pred_boxes, pred_clses

    def get_tds(self, src, dec_out):
        mask_a = src == self.td_a_idx
        mask_b = src == self.td_b_idx
        mask = mask_a + mask_b
        for i in range(dec_out.shape[1]):
            if i == 0:
                dec_out_ = dec_out[:-1, 0:1, :][mask[1:, 0]]
            else:
                dec_out_ = torch.cat([dec_out_, dec_out[:-1, i:i + 1, :][mask[1:, i]]], dim=1)
        return dec_out_

    def get_padding_mask(self, src):
        padding_mask = torch.zeros(src.shape, dtype=torch.bool).to(src.device)
        padding_mask[src == self.pad_idx] = True
        return padding_mask

tokenize code:

from typing import List, Union
import torch

PREDEFINED_SET = [" colspan=\"10\"", " colspan=\"2\"", " colspan=\"3\"", " colspan=\"4\"", " colspan=\"5\"",
                  " colspan=\"6\"", " colspan=\"7\"", " colspan=\"8\"", " colspan=\"9\"", " rowspan=\"10\"",
                  " rowspan=\"2\"", " rowspan=\"3\"", " rowspan=\"4\"", " rowspan=\"5\"", " rowspan=\"6\"",
                  " rowspan=\"7\"", " rowspan=\"8\"", " rowspan=\"9\"", "</tbody>", "</td>", "</thead>",
                  "</tr>", "<end>", "<pad>", "<start>", "<tbody>", "<td", "<td>", "<thead>", "<tr>", "<unk>", ">"]


class TagConverter:
    def __init__(self, char_set=PREDEFINED_SET, max_len=512):
        # tokens = ['[sos]', '[eos]']
        char_set = char_set  # + tokens
        self.char_dict = dict()
        self.char_dict.update({v: idx for idx, v in enumerate(char_set)})
        self.max_len = max_len

    def encode(self, src: List):
        assert isinstance(src, List)
        # add <sos>, <eos>, <pad>
        src = src + ["<end>"]
        while len(src) < self.max_len + 1:
            src.append("<pad>")
        src = ["<start>"] + src
        return self._encode(src)  # 输出序列长度为max_len+2

    def _encode(self, input: Union[str, List]):
        if isinstance(input, List):
            encoded = []
            for i in input:
                encoded.append(self.char_dict[i])
            return torch.tensor(encoded, dtype=torch.long)
        return torch.tensor(self.char_dict[input], dtype=torch.long)

    def decode(self, encoded):
        tags = ""
        for i in encoded:
            tags += PREDEFINED_SET[i]
        return tags


def parse_tags(tags: List) -> List:
    idxs = []
    for tag in tags:
        assert tag in PREDEFINED_SET, f"{tag} not in predefined sets"
        idxs.append(PREDEFINED_SET.index(tag))
    return idxs

training_step code:


    def training_step(self, batch, batch_idx):
        backbone_opt, structure_decoder_opt, bbox_decoder_opt = self.optimizers()
    
        image, tags, boxes, classes = batch  # tags tokenized
        # tags:(1, 514), pred_tags:(dec_input_len,1,32)
        # tags: <start>...<end><pad>...
        enc_out, pred_tags, pred_boxes, pred_clses = self.model(image, tags, boxes, classes)
        print(torch.argmax(pred_tags, dim=2).permute(1, 0))
        # structure decode loss : cross-entropy
        l_s = F.cross_entropy(pred_tags.permute(1, 2, 0),  # TODO 调整权重
                              tags[:, 1: len(pred_tags)+1],
                              # ignore_index=self.model.pad_idx,
                              # weight=self.get_CEloss_weights(tags)
                              )  # (b, vocab-1) : vocab : len([<start>,<end>,...])
        structure_decoder_opt.zero_grad()
        self.manual_backward(l_s, retain_graph=True)
        structure_decoder_opt.step()

Thanks a lot for helping.

Did I correctly encoded the target, calculate loss, input or made any mistakes in my code that might cause the predicted tags are all tokens?

Upvotes: 0

Views: 21

Answers (0)

Related Questions