Reputation: 1
I am reproducing a paper named TableFormer: Table Structure Understanding with Transformers. It predicts a table image's HTML structure with a resnet18 as backbone, a 2 layers encoder and 4 layers decoder. After tokenizing and padding the target HTML structure to a fixed length I started to train the model. But after a few epoch's learning the predicted tags are all pad tokens like below where 23 is the index of predefine token set. And btw setting ignore_index=23 in CELoss makes the predictions all 27 or 19 which indicates <td> and </td>.
tensor([[23, 23, ..., 23, 23]], device='cuda:0')
The model code and tokenization code and traing_step code are below: Model code: code before cell bbox decoder part is for structure prediction
import math
import torch
from torch import nn, Tensor
import torchvision
from ..datas.utils import PREDEFINED_SET
from positional_encodings.torch_encodings import PositionalEncoding1D, PositionalEncoding2D, Summer
class FeatureExtractor(nn.Module):
def __init__(self, encoded_image_size=28, device='cuda'):
super(FeatureExtractor, self).__init__()
self.enc_image_size = encoded_image_size
resnet = torchvision.models.resnet18(pretrained=True) # pretrained ImageNet ResNet-18
# Remove linear and pool layers (since we're not doing classification)
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules).to(device)
# Resize image to fixed size to allow input images of variable size
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
def forward(self, images):
"""
args:
images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
return: encoded images
"""
images = images.float() # TODO remove this after add normalize in transforms
out = self.resnet(images)
out = self.adaptive_pool(out) # (batch_size, 512, enc_image_size, enc_image_size)
out = out.permute(0, 2, 3, 1) # (batch_size, enc_image_size, enc_image_size, 512)
return out
class MLP(nn.Module):
def __init__(self, in_features=512, out_features=4, hidden_dim=512):
super().__init__()
op_list = []
op_list.append(nn.Linear(in_features, hidden_dim))
op_list.append(nn.ReLU())
op_list.append(nn.Linear(hidden_dim, hidden_dim))
op_list.append(nn.ReLU())
op_list.append(nn.Linear(hidden_dim, out_features))
op_list.append(nn.ReLU())
self.mlp = nn.Sequential(*op_list)
def forward(self, x):
return self.mlp(x)
class TableFormer(nn.Module):
def __init__(
self,
d_model=512,
enc_nhead=4,
enc_dim_feedforward=1024,
enc_num_layers=2,
dec_nhead=4,
dec_dim_feedforward=1024,
dec_num_layers=4,
max_len=512,
device='cuda'
):
super().__init__()
self.feature_extractor = FeatureExtractor(device=device)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=enc_nhead,
dim_feedforward=enc_dim_feedforward, dropout=0.5)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=enc_num_layers)
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=dec_nhead,
dim_feedforward=dec_dim_feedforward, dropout=0.5)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=dec_num_layers)
self.linear = nn.Linear(in_features=d_model, out_features=len(PREDEFINED_SET))
self.bbox_decoder_linear = nn.Linear(in_features=d_model, out_features=d_model)
self.softmax = nn.Softmax(dim=2) # TODO (l, b, len(vocab))
self.sigmoid = nn.Sigmoid()
self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=1, dropout=0.5)
self.mlp = MLP(in_features=d_model, out_features=4, hidden_dim=512)
self.cls = nn.Linear(in_features=d_model, out_features=2)
self.src_pe = Summer(PositionalEncoding2D(d_model))
self.target_pe = Summer(PositionalEncoding1D(d_model))
self.src_embed = nn.Embedding(num_embeddings=len(PREDEFINED_SET), embedding_dim=d_model)
self.tgt_embed = nn.Embedding(num_embeddings=len(PREDEFINED_SET), embedding_dim=d_model)
self.tgt_mask = nn.Transformer.generate_square_subsequent_mask(max_len + 1)
self.pad_idx = PREDEFINED_SET.index("<pad>")
self.td_a_idx = PREDEFINED_SET.index("<td>")
self.td_b_idx = PREDEFINED_SET.index("<td")
self.start_idx = PREDEFINED_SET.index("<start>")
self.end_idx = PREDEFINED_SET.index("<end>")
self.backbone_params = list(self.feature_extractor.parameters())
self.structure_decoder_params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + \
list(self.linear.parameters()) + list(self.softmax.parameters())
self.bbox_decoder_params = list(self.bbox_decoder_linear.parameters()) + list(self.attention.parameters()) + \
list(self.mlp.parameters()) + list(self.cls.parameters()) + list(self.sigmoid.parameters())
def forward(self, img, y_tags, y_bboxes, y_classes):
# backbone part
x_f = self.feature_extractor(img) # x_f : (b, h, w, c)
x_f = x_f.permute(1, 2, 0, 3) # (b, h, w, c) => (h, w, b, c)
x_pe = self.src_pe(x_f)
x_pe = x_pe.view(x_pe.shape[0] * x_pe.shape[1], -1, x_pe.shape[3]) # (l, b, c) (784, 1, 512)
# structure decoder part
enc_out = self.encoder(x_pe) # (28*28, 1, 512)
if y_tags is None:
y_src = torch.tensor([self.start_idx]).unsqueeze(1).to(enc_out.device)
terminal = False
while not terminal:
padding_mask = self.get_padding_mask(y_src[-1])
dec_out = self.decoder(
tgt=self.target_pe(self.tgt_embed(y_src)),
memory=enc_out,
tgt_mask=None,
tgt_key_padding_mask=padding_mask.permute(1, 0),
memory_mask=None)
dec_out_idx = torch.argmax(self.softmax(self.linear(dec_out)), dim=2)
if dec_out_idx[-1] == self.end_idx:
terminal = True
break
y_src = torch.cat([y_src, dec_out_idx[-1].unsqueeze(0)], dim=0)
else:
y_src = y_tags[:, :-1].permute(1, 0) # (seq_len, batch_size)
# y_src = y_src[(y_src != self.pad_idx) & (y_src != self.end_idx)].unsqueeze(1) #
# self.tgt_mask = nn.Transformer.generate_square_subsequent_mask(y_src.size(0))
padding_mask = self.get_padding_mask(y_src)
tgt_pe = self.target_pe(self.tgt_embed(y_src))
dec_out = self.decoder(
tgt=tgt_pe,
memory=enc_out,
tgt_mask=self.tgt_mask.to(enc_out.device),
tgt_key_padding_mask=padding_mask.permute(1, 0),
memory_mask=None)
dec_out_ = self.get_tds(y_src, dec_out)
pred_tags = self.softmax(self.linear(dec_out)) # (seq_len, batch_size, len(vocab))
# print(torch.argmax(pred_tags, dim=-1))
# cell bbox decoder part
linear_out = self.bbox_decoder_linear(enc_out)
attn_out, _ = self.attention(dec_out_, linear_out, enc_out) # q, k, v,
attn_out = self.sigmoid(attn_out)
pred_boxes, pred_clses = self.mlp(attn_out), self.cls(attn_out)
return enc_out, pred_tags, pred_boxes, pred_clses
def get_tds(self, src, dec_out):
mask_a = src == self.td_a_idx
mask_b = src == self.td_b_idx
mask = mask_a + mask_b
for i in range(dec_out.shape[1]):
if i == 0:
dec_out_ = dec_out[:-1, 0:1, :][mask[1:, 0]]
else:
dec_out_ = torch.cat([dec_out_, dec_out[:-1, i:i + 1, :][mask[1:, i]]], dim=1)
return dec_out_
def get_padding_mask(self, src):
padding_mask = torch.zeros(src.shape, dtype=torch.bool).to(src.device)
padding_mask[src == self.pad_idx] = True
return padding_mask
tokenize code:
from typing import List, Union
import torch
PREDEFINED_SET = [" colspan=\"10\"", " colspan=\"2\"", " colspan=\"3\"", " colspan=\"4\"", " colspan=\"5\"",
" colspan=\"6\"", " colspan=\"7\"", " colspan=\"8\"", " colspan=\"9\"", " rowspan=\"10\"",
" rowspan=\"2\"", " rowspan=\"3\"", " rowspan=\"4\"", " rowspan=\"5\"", " rowspan=\"6\"",
" rowspan=\"7\"", " rowspan=\"8\"", " rowspan=\"9\"", "</tbody>", "</td>", "</thead>",
"</tr>", "<end>", "<pad>", "<start>", "<tbody>", "<td", "<td>", "<thead>", "<tr>", "<unk>", ">"]
class TagConverter:
def __init__(self, char_set=PREDEFINED_SET, max_len=512):
# tokens = ['[sos]', '[eos]']
char_set = char_set # + tokens
self.char_dict = dict()
self.char_dict.update({v: idx for idx, v in enumerate(char_set)})
self.max_len = max_len
def encode(self, src: List):
assert isinstance(src, List)
# add <sos>, <eos>, <pad>
src = src + ["<end>"]
while len(src) < self.max_len + 1:
src.append("<pad>")
src = ["<start>"] + src
return self._encode(src) # 输出序列长度为max_len+2
def _encode(self, input: Union[str, List]):
if isinstance(input, List):
encoded = []
for i in input:
encoded.append(self.char_dict[i])
return torch.tensor(encoded, dtype=torch.long)
return torch.tensor(self.char_dict[input], dtype=torch.long)
def decode(self, encoded):
tags = ""
for i in encoded:
tags += PREDEFINED_SET[i]
return tags
def parse_tags(tags: List) -> List:
idxs = []
for tag in tags:
assert tag in PREDEFINED_SET, f"{tag} not in predefined sets"
idxs.append(PREDEFINED_SET.index(tag))
return idxs
training_step code:
def training_step(self, batch, batch_idx):
backbone_opt, structure_decoder_opt, bbox_decoder_opt = self.optimizers()
image, tags, boxes, classes = batch # tags tokenized
# tags:(1, 514), pred_tags:(dec_input_len,1,32)
# tags: <start>...<end><pad>...
enc_out, pred_tags, pred_boxes, pred_clses = self.model(image, tags, boxes, classes)
print(torch.argmax(pred_tags, dim=2).permute(1, 0))
# structure decode loss : cross-entropy
l_s = F.cross_entropy(pred_tags.permute(1, 2, 0), # TODO 调整权重
tags[:, 1: len(pred_tags)+1],
# ignore_index=self.model.pad_idx,
# weight=self.get_CEloss_weights(tags)
) # (b, vocab-1) : vocab : len([<start>,<end>,...])
structure_decoder_opt.zero_grad()
self.manual_backward(l_s, retain_graph=True)
structure_decoder_opt.step()
Thanks a lot for helping.
Did I correctly encoded the target, calculate loss, input or made any mistakes in my code that might cause the predicted tags are all tokens?
Upvotes: 0
Views: 21