Reputation: 21
I am trying to write an image-capturing model, use the CNN model to extract the image feature, and then connect with BERT and MLP to generate a long paragraph However, after training, my model result is really bad. In my MLP part, I should add LSTM into the different feature and then concatenate them, but I don't know where should I add it. Anyone can give me some advice? Thank you!
My BERT decoder part:
class bert_decoder(tf.keras.layers.Layer):
def __init__(self, vocab_size, embed_dim, max_len):
super().__init__()
self.bert_model = TFBertModel.from_pretrained("bert-base-cased", trainable=False)
self.ffn_layer1 = keras.layers.Dense(1000)
self.ffn_layer2 = keras.layers.Dense(EMBEDDING_DIM * (max_len - 1))
self.flatten = keras.layers.Flatten()
self.reshape = keras.layers.Reshape((max_len - 1, embed_dim))
def call(self, input_ids):
bert_out = self.bert_model(input_ids)[0]
bert_out_flatten = self.flatten(bert_out)
bert_out = self.ffn_layer1(bert_out_flatten)
bert_out = self.ffn_layer2(bert_out)
bert_out = self.reshape(bert_out)
return bert_out
My MLP part:
class TransformerDecoderLayer(tf.keras.layers.Layer):
def __init__(self, embed_dim, units, num_heads):
super().__init__()
self.embedding = bert_decoder(
tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
self.attention_1 = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
)
self.attention_2 = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
)
self.layernorm_1 = tf.keras.layers.LayerNormalization()
self.layernorm_2 = tf.keras.layers.LayerNormalization()
self.layernorm_3 = tf.keras.layers.LayerNormalization()
self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
self.dropout_1 = tf.keras.layers.Dropout(0.3)
self.dropout_2 = tf.keras.layers.Dropout(0.5)
def call(self, input_ids, encoder_output, training, mask=None):
embeddings = self.embedding(input_ids)
combined_mask = None
padding_mask = None
if mask is not None:
causal_mask = self.get_causal_attention_mask(embeddings)
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
combined_mask = tf.minimum(combined_mask, causal_mask)
attn_output_1 = self.attention_1(
query=embeddings,
value=embeddings,
key=embeddings,
attention_mask=combined_mask,
training=training
)
out_1 = self.layernorm_1(embeddings + attn_output_1)
attn_output_2 = self.attention_2(
query=out_1,
value=encoder_output,
key=encoder_output,
attention_mask=padding_mask,
training=training
)
out_2 = self.layernorm_2(out_1 + attn_output_2)
ffn_out = self.ffn_layer_1(out_2)
ffn_out = self.dropout_1(ffn_out, training=training)
ffn_out = self.ffn_layer_2(ffn_out)
ffn_out = self.layernorm_3(ffn_out + out_2)
ffn_out = self.dropout_2(ffn_out, training=training)
preds = self.out(ffn_out)
return preds
def get_causal_attention_mask(self, inputs):
input_shape = tf.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = tf.range(sequence_length)[:, tf.newaxis]
j = tf.range(sequence_length)
mask = tf.cast(i >= j, dtype="int32")
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = tf.concat(
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
axis=0
)
return tf.tile(mask, mult)
Upvotes: 1
Views: 55