Doru
Doru

Reputation: 1

Multihead Attention for 4-D tensor in Pytorch

I tried to transform tensorflow to pytorch, but I have a trouble with multi head attention due to its dimensions.

Input tensors for mha is 4-D, but pytorch mha couldn't accept 4-D tensor as an input for mha.

The shape of tensors is like this:

tmp_page_embed:  torch.Size([64, 16, 1, 256])  
offset_embed:  torch.Size([64, 16, 100, 256]) 

The original tensorflow source code is as below:

tmp_page_embed = tf.reshape(page_embed, shape=(-1, self.sequence_length, 1, self.page_embed_size))
offset_embed = tf.reshape(offset_embed, shape=(-1, self.sequence_length, self.offset_embed_size // self.page_embed_size, self.page_embed_size)) 
offset_embed = tf.reshape(self.mha(tmp_page_embed, offset_embed, training=training), shape=(-1, self.sequence_length, self.page_embed_size))

This is my attempt to covert this as pytorch.

offset_embed = tf.reshape(self.mha(tmp_page_embed, offset_embed, training=training), shape=(-1, self.sequence_length, self.page_embed_size))
tmp_offset_embed, _ = self.mha(tmp_page_embed, offset_embed, offset_embed)
offset_embed = tmp_offset_embed.reshape(-1, self.sequence_length, self.page_embed_size)

The error message is as below:

AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor     

How should I fix this error? Is there any way to convert 4D tensor to 3D?

Upvotes: 0

Views: 82

Answers (0)

Related Questions