Reputation: 1025
I'm trying to run a RNN beam search on a tf.keras.Model
in a vectorized way to have it work completely on GPU. However, despite having everything as tf.function
, as vectorized as I can make it, it runs exactly the same speed with or without a GPU. Attached is a minimal example with a fake model. In reality, for n=32, k=32, steps=128 which is what I would want to work with, this takes 20s (per n=32 samples) to decode, both on CPU and on GPU!
I must be missing something. When I train the model, on GPU a training iteration (128 steps) with batch size 512 takes 100ms, and on CPU a training iteration with batch size 32 takes 1 sec. The GPU isn't saturated at batch size 512. I get that I have overhead from doing the steps individually and doing a blocking operation per step, but in terms of computation my overhead is negligible compared to the rest of the model.
I also get that using a tf.keras.Model
in this way is probably not ideal, but is there another way to wire output tensors via a function back to the input tensors, and particularly also rewire the states?
Full working example: https://gist.github.com/meowcat/e3eaa4b8543a7c8444f4a74a9074b9ae
@tf.function
def decode_beam(states_init, scores_init, y_init, steps, k, n):
states = states_init
scores = scores_init
xstep = embed_y_to_x(y_init)
# Keep the results in TensorArrays
y_chain = tf.TensorArray(dtype="int32", size=steps)
sequences_chain = tf.TensorArray(dtype="int32", size=steps)
scores_chain = tf.TensorArray(dtype="float32", size=steps)
for i in range(steps):
# model_decode is the trained model with 3.5 million trainable params.
# Run a single step of the RNN model.
y, states = model_decode([xstep, states])
# Add scores of step n to previous scores
# (I left out the sequence end killer for this demo)
scores_y = tf.expand_dims(tf.reshape(scores, y.shape[:-1]), 2) + tm.log(y)
# Reshape into (n,k,tokens) and find the best k sequences to continue for each of n candidates
scores_y = tf.reshape(scores_y, [n, -1])
top_k = tm.top_k(scores_y, k, sorted=False)
# Transform the indices. I was using tf.unravel_index but
# `tf.debugging.set_log_device_placement(True)` indicated that this would be placed on the CPU
# thus I rewrote it
top_k_index = tf.reshape(
top_k[1] + tf.reshape(tf.range(n), (-1, 1)) * scores_y.shape[1], [-1])
ysequence = top_k_index // y.shape[2]
ymax = top_k_index % y.shape[2]
# this gives us two (n*k,) tensors with parent sequence (ysequence)
# and chosen character (ymax) per sequence.
# For continuation, pick the states, and "return" the scores
states = tf.gather(states, ysequence)
scores = tf.reshape(top_k[0], [-1])
# Write the results into the TensorArrays,
# and embed for the next step
xstep = embed_y_to_x(ymax)
y_chain = y_chain.write(i, ymax)
sequences_chain = sequences_chain.write(i, ysequence)
scores_chain = scores_chain.write(i, scores)
# Done: Stack up the results and return them
sequences_final = sequences_chain.stack()
y_final = y_chain.stack()
scores_final = scores_chain.stack()
return sequences_final, y_final, scores_final
Upvotes: 1
Views: 1215
Reputation: 1025
There was a lot going on here. I will comment on it because it might help others to resolve TensorFlow performance issues.
Note this very useful answer (the only one on the web) that shows how to profile arbitrary TensorFlow 2 code, rather than Keras training:
https://stackoverflow.com/a/56698035/1259675
logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# run any @tf.function decorated functions here
sequences, y, scores = decode_beam_steps(
y_init, states_init, scores_init,
steps = steps, k = k, n = n, pad_mask = pad_mask)
with writer.as_default():
tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
tf.summary.trace_off()
Note that an old Chromium version is needed to look at the profiling results, since at the time (4-17-20) this fails in current Chrome/Chromium.
The graph was made a bit lighter but not significantly faster by using unroll=True
in the LSTM cells used by the model (not shown here), since only one step is needed so the symbolic loop only adds clutter. This significantly slashes time for the first iteration of the function above, when AutoGraph builds the graph. Note that this time is enormous (see below).
unroll=False
(the default) builds in 300 seconds, unroll=True
builds in 100 seconds. Note that the performance itself stays the same (15-20 sec/iteration for n=32, k=32).
implementation=1
made it slightly slower, so I stayed with the default of implementation=2
.
tf.while_loop
instead of relying on AutoGraphfor i in range(steps)
loop. I had this both in the (above shown) inlined version, and in a modularized one: for i in range(steps):
ystep, states = model_decode([xstep, states])
ymax, ysequence, states, scores = model_beam_step(
ystep, states, scores, k, n, pad_mask)
xstep = model_rtox(ymax)
y_chain = y_chain.write(i, ymax)
sequences_chain = sequences_chain.write(i, ysequence)
scores_chain = scores_chain.write(i, scores)
where model_beam_step
does all the beam search math. Unsurprisingly, both performed exactly equally bad, and in particular, both took ~100/300 seconds on the first run when AutoGraph traced the graph. Further, tracing the graph with the profiler gives a crazy 30-50mb file that won't easily load on Tensorboard and more or less crash it. The profile had dozens of parallel GPU streams with a single operation each.
Substituting this with a tf.while_loop
slashed the setup time to zero (back_prop=False
makes only very little difference), and produces a nice 500kb graph that can easily be looked at in TensorBoard and profiled in an useful way with 4 GPU streams.
beam_steps_cond = lambda i, y_, seq_, sc_, xstep, states, scores: i < steps
def decode_beam_steps_body(i, y_, seq_, sc_, xstep, states, scores):
y, states = model_decode([xstep, states])
ymax, ysequence, states, scores = model_beam_step(
y, states, scores, k, n, pad_mask)
xstep = model_rtox(ymax)
y_ = y_.write(i, ymax)
seq_ = seq_.write(i, ysequence)
sc_= sc_.write(i, scores)
i = i + 1
return i, y_, seq_, sc_, xstep, states, scores
_, y_chain, sequences_chain, scores_chain, _, _, _ = \
tf.while_loop(
cond = beam_steps_cond,
body = decode_beam_steps_body,
loop_vars = [i, y_chain, sequences_chain, scores_chain,
xstep, states, scores],
back_prop = False
)
That I was actually able to look at the profile in a meaningful way showed me that the real issue was an output postprocessing function that runs on CPU. I didn't suspect it because it was running fast earlier, but I ignored that a beam search modification I made leads to >>>k sequences per candidate, which massively slows processing down. Thus, it was slashing every benefit I could gain from being efficient on GPU with the decoding step. Without this postprocessing, GPU runs >2 iterations / sec. Refactoring the postprocessing (which is extremely fast if done right) into TensorFlow resolved the issue.
Upvotes: 1