Rick Vink
Rick Vink

Reputation: 333

Attention does not seem to be applied at TransformerEncoderLayer and MultiheadAttention PyTorch

Changing something at one position in my input does not affect the outputs at other positions of my transformer encoder. I made a test in isolation in PyTorch:

# My encoder layer
encoder_layer = nn.TransformerEncoderLayer(d_model=8, nhead=2)
# Turn off dropout
encoder_layer.eval()
# Random input
src = torch.rand(2, 10, 8)
# Predict the output
out_0 = encoder_layer(src)
# Change the values at one of the positions (position 3 in this case)
src[:,3,:] += 1
# Predict once again the output
out_1 = encoder_layer(src)
# Check at which positions the outcomes are different between the two cases
# I summed in the embedding space direction
print(np.sum(np.abs(out_0.detach().numpy()),axis=-1) - np.sum(np.abs(out_1.detach().numpy()),axis=-1))

Output:

[[ 0. 0.  0.  -0.15470695  0.   0. 0.  0.     0.  0.   ]
 [ 0.   0.  0.  -0.27988768  0.  0.0.   0.   0.   0.   ]]

However, this does work in TensorFlow:

# My encoder layer
encoder_layer = TransformerBlock(8, 2, 8)
# Random input
src = np.random.randn(2, 10, 8)
# Predict the output
out_0 = encoder_layer(src, training=False)
# Change the values at one of the positions (position 3 in this case)
src[:,3,:] += 1
# Predict once again the output
out_1 = encoder_layer(src, training=False)
# Check at which positions the outcomes are different between the two cases
# I summed in the embedding space direction
print(np.sum(np.abs(out_0),axis=-1) )

Output:

[[6.4196725 6.775745  6.946576  7.26213   6.473065  5.520765  6.201167
  7.1266503 6.3147016 6.614853 ]
 [5.565378  7.030789  6.768366  6.6065626 6.7277775 7.480627  6.6785836
  6.4560523 6.4248576 6.6436586]]

My question is: Why aren't the values at all the position affected by changing the input at one input in PyTorch?

Upvotes: 0

Views: 634

Answers (1)

KonstantinosKokos
KonstantinosKokos

Reputation: 3453

From the documentation:

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False.

In other words, your input are 10 8-dimensional batches of sequence length 2 each. What you are doing is add 1 to all dimensions of all inputs of sample #4 in the batch, which --unsurprisingly-- alters only the output values of that specific sample.

Upvotes: 1

Related Questions