abracadabra
abracadabra

Reputation: 381

model size too big with my attention model implementation?

I am implementing Minh-Thang Luong's attention model to build a english to chinese translater.And the model i trained has abnormally big size(980 MB).Minh-Thang Luong's original paper
model structure

this is model parameters

state size:120  
source language vocabulary size:400000  
source language word embedding size:400000*50  
target language vocabulary size:20000  
target language word embedding size:20000*300

This is my model implementation in tensorflow.

import tensorflow as tf

src_vocab_size=400000
src_w2v_dim=50
tgt_vocab_size=20000
tgt_w2v_dim=300
state_size=120

with tf.variable_scope('net_encode'):
    ph_src_embedding = tf.placeholder(dtype=tf.float32,shape=[src_vocab_size,src_w2v_dim],name='src_vocab_embedding_placeholder')
    #src_word_emb = tf.Variable(initial_value=ph_src_embedding,dtype=tf.float32,trainable=False, name='src_vocab_embedding_variable')

    encoder_X_ix = tf.placeholder(shape=(None, None), dtype=tf.int32)
    encoder_X_len = tf.placeholder(shape=(None), dtype=tf.int32)
    encoder_timestep = tf.shape(encoder_X_ix)[1]
    encoder_X = tf.nn.embedding_lookup(ph_src_embedding, encoder_X_ix)
    batchsize = tf.shape(encoder_X_ix)[0]

    encoder_Y_ix = tf.placeholder(shape=[None, None],dtype=tf.int32)
    encoder_Y_onehot = tf.one_hot(encoder_Y_ix, src_vocab_size)

    enc_cell = tf.nn.rnn_cell.LSTMCell(state_size)
    enc_initstate = enc_cell.zero_state(batchsize,dtype=tf.float32)
    enc_outputs, enc_final_states = tf.nn.dynamic_rnn(enc_cell,encoder_X,encoder_X_len,enc_initstate)
    enc_pred = tf.layers.dense(enc_outputs, units=src_vocab_size)
    encoder_loss = tf.losses.softmax_cross_entropy(encoder_Y_onehot,enc_pred)
    encoder_trainop = tf.train.AdamOptimizer(0.001).minimize(encoder_loss)

with tf.variable_scope('net_decode'):
    ph_tgt_embedding = tf.placeholder(dtype=tf.float32, shape=[tgt_vocab_size, tgt_w2v_dim],
                                      name='tgt_vocab_embedding_placeholder')
    #tgt_word_emb = tf.Variable(initial_value=ph_tgt_embedding, dtype=tf.float32, trainable=False, name='tgt_vocab_embedding_variable')
    decoder_X_ix = tf.placeholder(shape=(None, None), dtype=tf.int32)
    decoder_timestep = tf.shape(decoder_X_ix)[1]
    decoder_X_len = tf.placeholder(shape=(None), dtype=tf.int32)
    decoder_X = tf.nn.embedding_lookup(ph_tgt_embedding, decoder_X_ix)

    decoder_Y_ix = tf.placeholder(shape=[None, None],dtype=tf.int32)
    decoder_Y_onehot = tf.one_hot(decoder_Y_ix, tgt_vocab_size)

    dec_cell = tf.nn.rnn_cell.LSTMCell(state_size)
    dec_outputs, dec_final_state = tf.nn.dynamic_rnn(dec_cell,decoder_X,decoder_X_len,enc_final_states)

    tile_enc = tf.tile(tf.expand_dims(enc_outputs,1),[1,decoder_timestep,1,1]) # [batchsize,decoder_len,encoder_len,state_size]
    tile_dec = tf.tile(tf.expand_dims(dec_outputs, 2), [1, 1, encoder_timestep, 1]) # [batchsize,decoder_len,encoder_len,state_size]
    enc_dec_cat = tf.concat([tile_enc,tile_dec],-1) # [batchsize,decoder_len,encoder_len,state_size*2]
    weights = tf.nn.softmax(tf.layers.dense(enc_dec_cat,units=1),axis=-2) # [batchsize,decoder_len,encoder_len,1]
    weighted_enc = tf.tile(weights, [1, 1, 1, state_size])*tf.tile(tf.expand_dims(enc_outputs,1),[1,decoder_timestep,1,1]) # [batchsize,decoder_len,encoder_len,state_size]
    attention = tf.reduce_sum(weighted_enc,axis=2,keepdims=False) # [batchsize,decoder_len,state_size]
    dec_attention_cat = tf.concat([dec_outputs,attention],axis=-1) # [batchsize,decoder_len,state_size*2]
    dec_pred = tf.layers.dense(dec_attention_cat,units=tgt_vocab_size) # [batchsize,decoder_len,tgt_vocab_size]
    pred_ix = tf.argmax(dec_pred,axis=-1) # [batchsize,decoder_len]
    decoder_loss = tf.losses.softmax_cross_entropy(decoder_Y_onehot,dec_pred)
    total_loss = encoder_loss + decoder_loss
    decoder_trainop = tf.train.AdamOptimizer(0.001).minimize(total_loss)

_l0 = tf.summary.scalar('decoder_loss',decoder_loss)
_l1 = tf.summary.scalar('encoder_loss',encoder_loss)
log_all = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_path,graph=tf.get_default_graph())

this is a run down of model parameters size that i can think of so far

encoder cell
=(50*120+120*120+120)*4
=(src_lang_embedding_size*statesize+statesize*statesize+statesize)*(forget gate,remember gate,new state,output gate)
=(kernelsize_for_input+kernelsize_for_previous_state+bias)*(forget gate,remember gate,new state,output gate)  
=82080 floats

encoder dense layer  
=120*400000
=statesize*src_lang_vocabulary_size
=48000000 floats

decoder cell
=(300*120+120*120+120)*4
=(target_lang_embedding_size*statesize+statesize*statesize+statesize)*(forget gate,remember gate,new state,output gate)
=(kernelsize_for_input+kernelsize_for_previous_state+bias)*(forget gate,remember gate,new state,output gate)
=202080 floats

dense layer that compute attention weights
=(120+120)*1
=(encoder_output_size+decoder_output_size)*(1 unit)
=240 floats

decoder dense layer
=(120+120)*20000
=(attention_vector_size+decoder_outputsize)*target_lang_vocabulary_size
=4800000 floats

summing them all gets 212 MB,but the actual model size is 980 MB.So where is wrong?

Upvotes: 0

Views: 226

Answers (1)

Jindřich
Jindřich

Reputation: 11220

You are only computing the number of trainable parameters, these are not the only numbers you need to accommodate in the GPU memory.

You are using Adam optimizer so, you need to store gradients for all your parameters and momentums for all the parameters. This means that you need to store each parameter 3 times, this gives you 636 MB.

Then, you need to accommodate all the intermediate states of the network for the forward and the backward pass.

Let's say have a batch size of b and source and the target length of 50, then you have (at least, I might have something forgotten):

  • b×l×50 source embeddings,
  • b×l×300 target embeddings,
  • b×l×5×120 encoder states,
  • b×l×400000 encoder logits,
  • b×l×5×300 decoder states,
  • b×l×120 intermediate attention states,
  • b×l×20000 output logits.

This is in total 421970×b×l floats that you need to store for your forward and backward pass.

Btw. source vocabulary 400k is a tremendously large number, I don't believe most of them is frequent enough to learn anything meaningful about them. You should use pre-processing (i.e., SentencePiece) that would reduce your vocabulary to a reasonable size.

Upvotes: 1

Related Questions