Reputation: 12627
Migrating to the TF2.0 I'm trying to use the tf.keras
approach for solving things.
In standard TF, I can use with tf.device(...)
to control where ops are.
For example, I might have a model something like
model = tf.keras.Sequential([tf.keras.layers.Input(..),
tf.keras.layers.Embedding(...),
tf.keras.layers.LSTM(...),
...])
Assuming I want to have the network up until Embedding
(including) on the CPU and the and from there on on the GPU, how will I go about that?
(This is just an example, the layers could have nothing to do with embeddings)
If the solution involves subclassing tf.keras.Model
that is OK too, I don't mind not using Sequential
Upvotes: 2
Views: 718
Reputation: 27050
You can use the Keras functional API:
inputs = tf.keras.layers.Input(..)
with tf.device("/GPU:0"):
model = tf.keras.layers.Embedding(...)(inputs)
outputs = tf.keras.layers.LSTM(...)(model)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
Upvotes: 2