Piper8x7b
Piper8x7b

Reputation: 1

How can I get this Mixture of Experts model working in tensorflow?

I have two tensors. Tensor 1 has the shape (10, None, 16, 16, 64) Tensor 2's shape is (None, 10) "None" being the batch size ofc

The first tensor represents a collection of logits from 10 different models (10) in the shape being each set of logits, (None) being the batch size, and (16, 16, 64) being the outputs of the corresponding model. The second tensor represents a single set of logits from 1 smaller model (None) being the batch size and (10) being 10 values that represents how weighted each set of 10 logits from the first tensor should be.

I want to multiply the first tensor by the second so that the output shape is (10, None, 16, 16, 64) and that each set of logits on the first axis is weighted by the corresponding logit from the second tensor

I will then sum the multiplied tensor on the first axis to get the output of one block for my MoE model

Here is how all of this will be implemented (By the way MOPE just stands for Mixture Of Pre-trained Experts):

def CreateMOPEBlock(x, blocks, blockNum):
    ExpertLogits = []

    for i in range(num_classes):
        blocks[i].trainable = False
        ExpertLogits.append(blocks[i](x))

    GatingInput = x
    GatingConv1 = tf.keras.layers.Conv2D(16, (3, 3), padding='same')(GatingInput)
    GatingLayerNorm1 = tf.keras.layers.LayerNormalization()(GatingConv1)
    GatingLeakyReLU1 = tf.keras.layers.LeakyReLU()(GatingLayerNorm1)
    GatingConv2 = tf.keras.layers.Conv2D(32, (3, 3), padding='same')(GatingLeakyReLU1)
    GatingLayerNorm2 = tf.keras.layers.LayerNormalization()(GatingConv2)
    GatingLeakyReLU2 = tf.keras.layers.LeakyReLU()(GatingLayerNorm2)
    GatingConv3 = tf.keras.layers.Conv2D(64, (3, 3), padding='same')(GatingLeakyReLU2)
    GatingLayerNorm3 = tf.keras.layers.LayerNormalization()(GatingConv3)
    GatingLeakyReLU3 = tf.keras.layers.LeakyReLU()(GatingLayerNorm3)
    GatingFlatten = tf.keras.layers.Flatten()(GatingLeakyReLU3)
    GatingLogits = tf.keras.layers.Dense(num_classes, activation='softmax')(GatingFlatten)

    logits1 = ExpertLogits # shape: (10, None, 16, 16, 64)
    logits2 = GatingLogits # shape: (None, 10)

    # Do some fancy math here
    multipled_logits = ?

    return tf.keras.layers.add(multipled_logits)

MOPEInput = tf.keras.layers.Input(shape=(32, 32, 3))

# Block1, 2, and 3 are just an array of 10 keras sequential models
MOPEBlock1 = CreateMOPEBlock(MOPEInput, Block1)
MOPEBlock2 = CreateMOPEBlock(MOPEBlock1, block2)
MOPEBlock3 = CreateMOPEBlock(MOPEBlock2, block3)

MOPEFlatten = tf.keras.layers.Flatten()(MOPEBlock3)

MOPEX = tf.keras.layers.Dense(1024, activation='relu')(MOPEFlatten)
MOPEX = tf.keras.layers.BatchNormalization()(MOPEX)
MOPEX = tf.keras.layers.Dropout(0.33)(MOPEX)

MOPEX = tf.keras.layers.Dense(1024, activation='relu')(MOPEX)
MOPEX = tf.keras.layers.BatchNormalization()(MOPEX)
MOPEX = tf.keras.layers.Dropout(0.33)(MOPEX)

MOPEOutput = tf.keras.layers.Dense(num_classes, activation='softmax')(MOPEX)

MOPEModel = tf.keras.Model(MOPEInput, MOPEOutput)

I have tried solving this myself! Multiple times! I have also tried asking multiple Large Language Models ranging from Mixtral-8x7b (its in my name hehe) to GPT4. The results of such looking something like this:

import tensorflow as tf

# suppose these are your tensors
tensor1 = tf.placeholder(tf.float32, shape=(10, None, 16, 16, 64))
tensor2 = tf.placeholder(tf.float32, shape=(None, 10))

# Reshaping tensor2 (None, 10) -> (None, 10, 1, 1, 1)
tensor2_expanded = tf.expand_dims(tf.expand_dims(tf.expand_dims(tensor2, axis=-1), axis=-1), axis=-1)

# Permuting axes of tensor1 (10, None, 16, 16, 64) -> (None, 10, 16, 16, 64)
tensor1_permuted = tf.transpose(tensor1, perm=[1, 0, 2, 3, 4])

# Multiplying tensors
result = tf.multiply(tensor1_permuted, tensor2_expanded)

# Finally, permuting axes of the result back (None, 10, 16, 16, 64) -> (10, None, 16, 16, 64)
result = tf.transpose(result, perm=[1, 0, 2, 3, 4])

None of the solutions from these models worked even with extensive debugging with each model. What should I do?

Upvotes: 0

Views: 122

Answers (0)

Related Questions