Giuseppe Angora
Giuseppe Angora

Reputation: 853

python and tensorflow: Which is optimized way to call function which exploits tensorflow operations?

I'm developing a simple function using tensorflow:

def xcross(T, S):
    sum_spectra_sq = tf.reduce_sum(tf.square(S), 1) #shape (batch,)
    sum_template_sq = tf.reduce_sum(tf.square(T), 0) #shape (Nz)
    norm = tf.sqrt(tf.reshape(sum_spectra_sq, (-1,1))*tf.reshape(sum_template_sq, (1,-1)))
    xcorr = tf.matmul(S, T, transpose_a = False, transpose_b= False)/norm
    maxidxs = tf.math.argmax(xcorr, axis=1) #shape (batch)
    return xcorr, maxidxs

In the main, I'd like calling such function:

def main():
   ...
    with tf.Session() as session:
        for nb in range(n_batch):
            ...
            S = data[start:end]
            xcorr, maxidxs = xcross(T, S)
            x = xcorr.eval(session=session)
            ii = maxidxs.eval(session=session)
            ...

As you noted, the xcross function works on a batch of data. Now, with this configuration I get a memory error: CUDA_ERROR_OUT_OF_MEMORY. But, if I move the tf.Session in the function (clearly, removing it from the main):

def xcross(T, S):
    sum_spectra_sq = tf.reduce_sum(tf.square(S), 1) #shape (batch,)
    sum_template_sq = tf.reduce_sum(tf.square(T), 0) #shape (Nz)
    norm = tf.sqrt(tf.reshape(sum_spectra_sq, (-1,1))*tf.reshape(sum_template_sq, (1,-1)))
    xcorr = tf.matmul(S, T, transpose_a = False, transpose_b= False)/norm
    maxidxs = tf.math.argmax(xcorr, axis=1) #shape (batch)
    with tf.Session() as session
        _xcorr, _maxidxs = session.run([xcorr, maxidxs])
    return _xcorr, _maxidxs

the code works without errors. But each time in the for loop, the gpu is called, with a lot of printed text and, I think, loosing computational efficiency.

Thus, which is the most optimized way to call N times a function which exploits tensorflow operations, like the one i'm implementing?

Solution
Thanks to Frederik Bode, I found the solution:

def xcross():
    S = tf.placeholder(tf_float_type, name='spectra')
    T = tf.placeholder(tf_float_type, name='template')
    ...
    return xcorr, maxidxs

def main():
    ...
    xcorr_graph, maxidxs_graph = make_xcorr_tf()
    ...            
    xcorr, maxidxs = session.run(
       [xcorr_graph, maxidxs_graph], 
       feed_dict={'spectra:0':data_partial, 'template:0':template_partial})
    ...

Upvotes: 1

Views: 38

Answers (2)

Dan Moldovan
Dan Moldovan

Reputation: 975

You may want to have a look at tf.function (tutorial, another tutorial). In TF 2, it handles most of the boilerplate that you had to use with tf.Session and placeholders. It also helps with moving constructs like the for loop inside the TF graph for greater efficiency, and also makes it easier to compile your graph with XLA.

Upvotes: 0

Frederik Bode
Frederik Bode

Reputation: 2744

Your creating a new graph for each iteration (= calling xcross). You should redefine xcross so that it takes a tf.placeholder as input and define it outside the loop, even outside the with tf.Session as sess:. Then you can call the graph with:

xcross():
  T = tf.placeholder(name="T", ...)
  S = tf.placeholder(name="S", ...)
  ...
  return xcorr, maxidxs

xcorr_graph, maxidxs_graph = xcross()

with tf.Session() as sess:
  for ...: 
     sess.run([xcorr_graph, maxidxs_graph], feed_dict={"S":S, "T": T})

Note that my declaration of feed_dict might be wrong - it has been a while since I last used it, but this should set you on your way.

Upvotes: 1

Related Questions