Reputation: 853
I'm developing a simple function using tensorflow:
def xcross(T, S):
sum_spectra_sq = tf.reduce_sum(tf.square(S), 1) #shape (batch,)
sum_template_sq = tf.reduce_sum(tf.square(T), 0) #shape (Nz)
norm = tf.sqrt(tf.reshape(sum_spectra_sq, (-1,1))*tf.reshape(sum_template_sq, (1,-1)))
xcorr = tf.matmul(S, T, transpose_a = False, transpose_b= False)/norm
maxidxs = tf.math.argmax(xcorr, axis=1) #shape (batch)
return xcorr, maxidxs
In the main, I'd like calling such function:
def main():
...
with tf.Session() as session:
for nb in range(n_batch):
...
S = data[start:end]
xcorr, maxidxs = xcross(T, S)
x = xcorr.eval(session=session)
ii = maxidxs.eval(session=session)
...
As you noted, the xcross function works on a batch of data
.
Now, with this configuration I get a memory error: CUDA_ERROR_OUT_OF_MEMORY
.
But, if I move the tf.Session in the function (clearly, removing it from the main):
def xcross(T, S):
sum_spectra_sq = tf.reduce_sum(tf.square(S), 1) #shape (batch,)
sum_template_sq = tf.reduce_sum(tf.square(T), 0) #shape (Nz)
norm = tf.sqrt(tf.reshape(sum_spectra_sq, (-1,1))*tf.reshape(sum_template_sq, (1,-1)))
xcorr = tf.matmul(S, T, transpose_a = False, transpose_b= False)/norm
maxidxs = tf.math.argmax(xcorr, axis=1) #shape (batch)
with tf.Session() as session
_xcorr, _maxidxs = session.run([xcorr, maxidxs])
return _xcorr, _maxidxs
the code works without errors. But each time in the for loop
, the gpu is called, with a lot of printed text and, I think, loosing computational efficiency.
Thus, which is the most optimized way to call N times a function which exploits tensorflow operations, like the one i'm implementing?
Solution
Thanks to Frederik Bode, I found the solution:
def xcross():
S = tf.placeholder(tf_float_type, name='spectra')
T = tf.placeholder(tf_float_type, name='template')
...
return xcorr, maxidxs
def main():
...
xcorr_graph, maxidxs_graph = make_xcorr_tf()
...
xcorr, maxidxs = session.run(
[xcorr_graph, maxidxs_graph],
feed_dict={'spectra:0':data_partial, 'template:0':template_partial})
...
Upvotes: 1
Views: 38
Reputation: 975
You may want to have a look at tf.function (tutorial, another tutorial). In TF 2, it handles most of the boilerplate that you had to use with tf.Session and placeholders. It also helps with moving constructs like the for loop inside the TF graph for greater efficiency, and also makes it easier to compile your graph with XLA.
Upvotes: 0
Reputation: 2744
Your creating a new graph for each iteration (= calling xcross
). You should redefine xcross
so that it takes a tf.placeholder
as input and define it outside the loop, even outside the with tf.Session as sess:
. Then you can call the graph with:
xcross():
T = tf.placeholder(name="T", ...)
S = tf.placeholder(name="S", ...)
...
return xcorr, maxidxs
xcorr_graph, maxidxs_graph = xcross()
with tf.Session() as sess:
for ...:
sess.run([xcorr_graph, maxidxs_graph], feed_dict={"S":S, "T": T})
Note that my declaration of feed_dict
might be wrong - it has been a while since I last used it, but this should set you on your way.
Upvotes: 1