Ian
Ian

Reputation: 316

Compiling model using BERT as a hub.KerasLayer fails in TPUStrategy scope

I am attempting to use BERT Multilingual from TensorFlow Hub (https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1) as a layer in a Keras model. Training the model without using a distribution strategy works fine. However, when trying to utilize a Google Cloud TPU through a distribution strategy, training the model fails with the following error:

ValueError: Variable (<tf.Variable 'bert/embeddings/word_embeddings:0' shape=(119547, 768) dtype=float32>) was not created in the distribution strategy scope of (<tensorflow.python.distribute.tpu_strategy.TPUStrategy object at 0x7fc7b01d52e8>). It is most likely due to not all layers or the model or optimizer being created outside the distribution strategy scope. Try to make sure your code looks similar to the following.
with strategy.scope():
  model=_create_model()
  model.compile(...)

Here is my code for building and training the model:

def bert_model():
    in_id = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="input_ids", dtype=np.int32)
    in_mask = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="input_masks", dtype=np.int32)
    in_segment = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="segment_ids", dtype=np.int32)
    bert_inputs = {"input_ids": in_id, "input_mask": in_mask, "segment_ids": in_segment}

    bert_output = hub.KerasLayer(BERT_MODEL_HUB, trainable=True, signature="tokens", output_key="pooled_output")(bert_inputs)

    dense = tf.keras.layers.Dense(256, input_shape=(768,), activation='relu')(bert_output)
    pred = tf.keras.layers.Dense(len(unique_labels), activation='sigmoid')(dense)

    return tf.keras.models.Model(inputs=bert_inputs, outputs=pred)

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
tf.config.experimental_connect_to_host(resolver.master())
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

with strategy.scope():
    model = bert_model()

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=tf.keras.losses.binary_crossentropy,
        metrics=["accuracy"]
    )

I am using Python 3.5.3 and TensorFlow v2.0.0-rc2-26-g64c3d38 with GPU support.

Upvotes: 2

Views: 1459

Answers (1)

arnoegw
arnoegw

Reputation: 1238

Likely a bug in TF2.0. Ian also asked, and workarounds (upgrade) are being discussed, at https://github.com/tensorflow/hub/issues/469

Upvotes: 1

Related Questions