Reputation: 316
I am attempting to use BERT Multilingual from TensorFlow Hub (https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1) as a layer in a Keras model. Training the model without using a distribution strategy works fine. However, when trying to utilize a Google Cloud TPU through a distribution strategy, training the model fails with the following error:
ValueError: Variable (<tf.Variable 'bert/embeddings/word_embeddings:0' shape=(119547, 768) dtype=float32>) was not created in the distribution strategy scope of (<tensorflow.python.distribute.tpu_strategy.TPUStrategy object at 0x7fc7b01d52e8>). It is most likely due to not all layers or the model or optimizer being created outside the distribution strategy scope. Try to make sure your code looks similar to the following.
with strategy.scope():
model=_create_model()
model.compile(...)
Here is my code for building and training the model:
def bert_model():
in_id = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="input_ids", dtype=np.int32)
in_mask = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="input_masks", dtype=np.int32)
in_segment = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="segment_ids", dtype=np.int32)
bert_inputs = {"input_ids": in_id, "input_mask": in_mask, "segment_ids": in_segment}
bert_output = hub.KerasLayer(BERT_MODEL_HUB, trainable=True, signature="tokens", output_key="pooled_output")(bert_inputs)
dense = tf.keras.layers.Dense(256, input_shape=(768,), activation='relu')(bert_output)
pred = tf.keras.layers.Dense(len(unique_labels), activation='sigmoid')(dense)
return tf.keras.models.Model(inputs=bert_inputs, outputs=pred)
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
tf.config.experimental_connect_to_host(resolver.master())
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
model = bert_model()
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=tf.keras.losses.binary_crossentropy,
metrics=["accuracy"]
)
I am using Python 3.5.3 and TensorFlow v2.0.0-rc2-26-g64c3d38 with GPU support.
Upvotes: 2
Views: 1459
Reputation: 1238
Likely a bug in TF2.0. Ian also asked, and workarounds (upgrade) are being discussed, at https://github.com/tensorflow/hub/issues/469
Upvotes: 1