svetlov.vsevolod
svetlov.vsevolod

Reputation: 191

Loading large numpy matrix as partitioned variable in tensorflow graph

Imagine that i have big pretrained embeddings, which i can load as numpy array, for example with shape [3000000, 200]. Size of this matrix is greater than 2GB, so with this code:

data = np.zeros(shape=(3000000, 200))
variable = tf.get_variable(
    "weigths",
    [3000000, 200],
    initializer=tf.constant_initializer(data))

session = tf.Session()
session.run(tf.global_variables_initializer())

i've got error ValueError: Cannot create a tensor proto whose content is larger than 2GB.

I can load it with tf.assign and placeholder, but for some reasons i want to use partitioned version of this embeddings weights. The way with assign and placeholder is closed cause partitioned variables do not work with assign op: NotImplementedError: assign() has not been implemented for PartitionedVariable..

Is it possible to do such a thing?

Upvotes: 2

Views: 805

Answers (2)

svetlov.vsevolod
svetlov.vsevolod

Reputation: 191

SOLUTION

this is ugly, but it works:

def init_partitioned(session, var_name, data):
    partitioned_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=var_name + "/part_\d+:0")
    print("For {} founded {} parts".format(var_name, len(partitioned_var)))

    dtype = partitioned_var[0].dtype
    part_shape = partitioned_var[0].get_shape().as_list()
    part_shape[0] = None

    init = tf.placeholder(dtype, part_shape)
    offset = 0
    for idx, part in enumerate(partitioned_var):
        init_op = tf.assign(part, init)
        numRowsInPart = int(part.get_shape()[0])
        session.run(init_op, feed_dict={init: data[offset:offset + numRowsInPart]})
        offset += numRowsInPart

Upvotes: 4

Tianjin Gu
Tianjin Gu

Reputation: 784

Try:

import numpy as np
import tensorflow as tf

data = np.zeros(shape=(3000000, 200))

ph = tf.placeholder(tf.float32, shape=(3000000, 200))
variable = tf.Variable(ph)

session = tf.Session()
session.run(tf.global_variables_initializer(), feed_dict={ph:data})

Upvotes: 0

Related Questions