Reputation: 1
I tried to transform tensorflow to pytorch, but I have a trouble with multi head attention due to its dimensions.
Input tensors for mha is 4-D, but pytorch mha couldn't accept 4-D tensor as an input for mha.
The shape of tensors is like this:
tmp_page_embed: torch.Size([64, 16, 1, 256])
offset_embed: torch.Size([64, 16, 100, 256])
The original tensorflow source code is as below:
tmp_page_embed = tf.reshape(page_embed, shape=(-1, self.sequence_length, 1, self.page_embed_size))
offset_embed = tf.reshape(offset_embed, shape=(-1, self.sequence_length, self.offset_embed_size // self.page_embed_size, self.page_embed_size))
offset_embed = tf.reshape(self.mha(tmp_page_embed, offset_embed, training=training), shape=(-1, self.sequence_length, self.page_embed_size))
This is my attempt to covert this as pytorch.
offset_embed = tf.reshape(self.mha(tmp_page_embed, offset_embed, training=training), shape=(-1, self.sequence_length, self.page_embed_size))
tmp_offset_embed, _ = self.mha(tmp_page_embed, offset_embed, offset_embed)
offset_embed = tmp_offset_embed.reshape(-1, self.sequence_length, self.page_embed_size)
The error message is as below:
AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor
How should I fix this error? Is there any way to convert 4D tensor to 3D?
Upvotes: 0
Views: 82