Reputation: 11
I am trying to create a multi-head attention using tensorflowjs
. When trying to train the model, an error kept popping up that the gradient shape was inconsistent with the input shape.
reproducable bug.
const EMBEDDING_DIMENSION = 16;
const NUMBER_OF_HEADS = 2;
const BATCH_SIZE = 8;
const SEQUENCES_LENGTH = 4;
const EPOCHS = 10;
tf.setBackend("cpu");
tf.enableDebugMode();
function dotProductAttention(query, key, value) {
return tf.tidy(() => {
const matMul = query.matMul(key, false, true);
const dk = tf.scalar(query.shape[query.shape.length - 1], 'float32');
const scaled = matMul.div(tf.sqrt(dk));
const attentionWeights = scaled.softmax(-1);
return attentionWeights.matMul(value);
})
}
function muliHeadAttention(input, numberOfHeads, qkvWeights, outputWeights) {
return tf.tidy(() => {
const batchSize = input.shape[0];
const seuquencesLength = input.shape[1];
const embeddingDimension = qkvWeights.shape[0];
const headDimension = embeddingDimension / numberOfHeads;
// calculate all query, key, and value by doing one big mat multipication.
const QKV = input.matMul(qkvWeights);
// reshap the QKV to include number of heads
const QKVReshaped = QKV
.reshape([batchSize, seuquencesLength, numberOfHeads, 3 * headDimension])
.transpose([0, 2, 1, 3]);
// split the mega QKV into 3 chunk each query, key, and value matrix
const [query, key, value] = QKVReshaped.split(3, -1);
// calculate the attention of query, key, and value
const attention = dotProductAttention(query, key, value);
// concatenate the attention values
const concatenated = attention
.transpose([0, 2, 1, 3])
.reshape([batchSize, seuquencesLength, embeddingDimension]);
// matrix multiply with output weights
const output = concatenated.matMul(outputWeights);
// finally, return it [batchSize, seuquencesLength, embeddingDimension]
return output;
})
}
function initializeQKVWeights(embeddingDimension) {
return tf.variable(tf.randomNormal([embeddingDimension, 3 * embeddingDimension]));
}
function initializeOutputWeights(embeddingDimension) {
return tf.variable(tf.randomNormal([embeddingDimension, embeddingDimension]));
}
// query, key, and value weights combined for better performance.
const QKVWeights = initializeQKVWeights(EMBEDDING_DIMENSION);
// final weight of muti head attention
const outputWeights = initializeOutputWeights(EMBEDDING_DIMENSION);
// the predict function
const f = (x) => muliHeadAttention(x, NUMBER_OF_HEADS, QKVWeights, outputWeights);
// mean squared error
const loss = (predict, real) => predict.sub(real).square().mean();
// the optimizer
const optimizer = tf.train.sgd(0.01);
// temporary generated train datas
const xTrain = tf.randomNormal([BATCH_SIZE, SEQUENCES_LENGTH, EMBEDDING_DIMENSION]);
const yTrain = tf.randomNormal([BATCH_SIZE, SEQUENCES_LENGTH, EMBEDDING_DIMENSION]);
// training the model
for (let i = 0; i < EPOCHS; i++) {
tf.tidy(() => {
const lossValue = optimizer.minimize(() => loss(f(xTrain), yTrain), [QKVWeights, outputWeights]);
// Logging loss every epoch
lossValue.data().then(lossValueArray => {
console.log(`Epoch ${i + 1}, Loss: ${lossValueArray[0]}`);
});
});
}
f(xTrain).array().then(pred => {
console.log("Predictions:");
console.log(pred);
xTrain.dispose();
});
// Disposing tensors
yTrain.dispose();;
At first, I thought the minimise function used random tensors to backpropagate. therefore, I freed most of the intermediate tensors (which were created during operations) and explicitly set which variable to modify (aka. the weights). However, the error persisted. Next, I believed that the shape of my tensors were wrong and manually compiled the shapes (one by one) with my pencil and enabled debugging mode to verify it. it looked perfectly fine. and to strengthen my correct operations, I only tried forward propagation, which worked marvellously. So, where did this error come from? the error only emerges when trying to train the model or, specifically, plugging the loss function into the minimize function of the optimizer. I don't know why it throws an error. I believe that there is a fundamental flaw in the optimizer function or I am bad at creating multi-head attention. If you know anything about this, please kindly share it.
Upvotes: 1
Views: 45