kowser66
kowser66

Reputation: 175

How to use pytorch multi-head attention for classification task?

I have a dataset where x shape is (10000, 102, 300) such as ( samples, feature-length, dimension) and y (10000,) which is my binary label. I want to use multi-head attention using PyTorch. I saw the PyTorch documentation from here but there is no explanation of how to use it. How can I use my dataset for classification using multi-head attention?

Upvotes: 2

Views: 1323

Answers (1)

Mohammad Ahmed
Mohammad Ahmed

Reputation: 1624

I will write a simple pretty code for classification this will work fine, if you need implementation detail then this part is the same as the Encoder layer in Transformer, except in the last you would need a GlobalAveragePooling Layer and a Dense Layer for classification

attention_layer = nn.MultiHeadAttion(300 , 300%num_of_heads==0,dropout=0.1)
neural_net_output = point_wise_neural_network(attention_layer)
normalize = LayerNormalization(input + neural_net_output)
globale_average_pooling = nn.GlobalAveragePooling(normalize)
nn.Linear(input , num_of_classes)(global_average_pooling)

Upvotes: 4

Related Questions