Reputation: 6809
The problem is the following. I have a categorical prediction task of vocabulary size 25K. On one of them (input vocab 10K, output dim i.e. embedding 50), I want to introduce a trainable weight matrix for a matrix multiplication between the input embedding (shape 1,50) and the weights (shape(50,128)) (no bias) and the resulting vector score is an input for a prediction task along with other features.
The crux is, I think that the trainable weight matrix varies for each input, if I simply add it in. I want this weight matrix to be common across all inputs.
I should clarify - by input here I mean training examples. So all examples would learn some example specific embedding and be multiplied by a shared weight matrix.
After every so many epochs, I intend to do a batch update to learn these common weights (or use other target variables to do multiple output prediction)
LSTM? Is that something I should look into here?
Upvotes: 4
Views: 3406
Reputation: 1050
If I get the problem correctly you can reuse layers or even models inside another model.
Example with a Dense layer. Let's say you have 10 Inputs
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
# defining 10 inputs in a List with (X,) shape
inputs = [Input(shape = (X,),name='input_{}'.format(k)) for k in
range(10)]
# defining a common Dense layer
D = Dense(64, name='one_layer_to_rule_them_all')
nets = [D(inp) for inp in inputs]
model = Model(inputs = inputs, outputs = nets)
model.compile(optimizer='adam', loss='categorical_crossentropy')
This code is not going to work if the inputs have different shapes. The first call to D defines its properties. In this example, outputs are set directly to nets. But of course you can concatenate, stack, or whatever you want.
Now if you have some trainable model you can use it instead of the D:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
# defining 10 inputs in a List with (X,) shape
inputs = [Input(shape = (X,),name='input_{}'.format(k)) for k in
range(10)]
# defining a shared model with the same weights for all inputs
nets = [special_model(inp) for inp in inputs]
model = Model(inputs = inputs, outputs = nets)
model.compile(optimizer='adam', loss='categorical_crossentropy')
The weights of this model are shared among all inputs.
Upvotes: 1
Reputation: 2682
With the exception of an Embedding layer, layers apply to all examples in the batch.
Take as an example a very simple network:
inp = Input(shape=(4,))
h1 = Dense(2, activation='relu', use_bias=False)(inp)
out = Dense(1)(h1)
model = Model(inp, out)
This a simple network with 1 input layer, 1 hidden layer and an output layer. If we take the hidden layer as an example; this layer has a weights matrix of shape (4, 2,). At each iteration the input data which is a matrix of shape (batch_size, 4) is multiplied by the hidden layer weights (feed forward phase). Thus h1 activation is dependent on all samples. The loss is also computed on a per batch_size basis. The output layer has a shape (batch_size, 1). Given that in the forward phase all the batch samples affected the values of the weights, the same is true for backdrop and gradient updates.
When one is dealing with text, often the problem is specified as predicting a specific label from a sequence of words. This is modelled as a shape of (batch_size, sequence_length, word_index). Lets take a very basic example:
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
sequence_length = 80
emb_vec_size = 100
vocab_size = 10_000
def make_model():
inp = Input(shape=(sequence_length, 1))
emb = Embedding(vocab_size, emb_vec_size)(inp)
emb = Reshape((sequence_length, emb_vec_size))(emb)
h1 = Dense(64)(emb)
recurrent = LSTM(32)(h1)
output = Dense(1)(recurrent)
model = Model(inp, output)
model.compile('adam', 'mse')
return model
model = make_model()
model.summary()
You can copy and paste this into colab and see the summary.
What this example is doing is:
Upvotes: 1