Reputation: 1706
I am using transformers
TFBertForSequenceClassification.from_pretrained
with 'bert-base-multilingual-uncased') and keras
to build my model.
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# metric
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
# optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon)
# create and compile the Keras model in the context of strategy.scope
model = TFBertForSequenceClassification.from_pretrained(pretrained_weights,
num_labels=num_labels,
cache_dir=pretrained_model_dir)
model._name = 'tf_bert_classification'
# compile Keras model
model.compile(optimizer=optimizer,
loss=loss,
metrics=[metric])
I am using SST2
data, that are tokenize and the feed to the model for training. The data have the following shape:
shape: (32,)
dict structure
dim: 3
[input_ids / attention_mask / token_type_ids ]
[(32, 128) / (32, 128) / (32, 128) ]
[ndarray / ndarray / ndarray ]
and here an example:
({'input_ids': <tf.Tensor: shape=(32, 128), dtype=int32, numpy=
array([[ 101, 21270, 94696, ..., 0, 0, 0],
[ 101, 143, 45100, ..., 0, 0, 0],
[ 101, 24220, 102, ..., 0, 0, 0],
...,
[ 101, 11008, 10346, ..., 0, 0, 0],
[ 101, 43062, 15648, ..., 0, 0, 0],
[ 101, 13178, 18418, ..., 0, 0, 0]], dtype=int32)>, 'attention_mask': ....
As we can see we have input_ids
with shape (32, 128) where 32 is the batch size and 128 is the maxiumum length of the string (max for BERT is 512). We also have attention_mask
and token_type_ids
with the same structure.
I am able to train a model and to do prediction using model.evaluate(test_dataset)
. All good.
The issue that I am having is that when I serve the model on GCP, then it require data in a different input shape and structure! I saw the same if I run the cli on the saved model:
saved_model_cli show --dir $MODEL_LOCAL --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['input_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, 5)
name: serving_default_input_ids:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
As we can see we only need to give input_ids
and not (attention_mask
and token_type_ids
) and the sape is different. While the batch size is not defined (-1) which expected, the maxium length is 5 instead of 128!It was working 2 months ago and I probably introduce something that created this issue.
I tried few version of Tensorfow
(2.2.0
and 2.3.0
) and transformers (2.8.0
, 2.9.0
and 3.0.2
). I cannot see the Keras'model input and outpu shape (None):
model.inputs
model.outputs
model.summary()
Model: "tf_bert_classification"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bert (TFBertMainLayer) multiple 167356416
_________________________________________________________________
dropout_37 (Dropout) multiple 0
_________________________________________________________________
classifier (Dense) multiple 1538
=================================================================
Total params: 167,357,954
Trainable params: 167,357,954
Non-trainable params: 0
Any idea what could explain that the saved model require a different input that the one use for training! I could use the Keras functional API and defined the input shape but I am pretty sure the this code was working before.
Upvotes: 1
Views: 369
Reputation: 2410
I have seen such behavior when model was instantiated from pretrained one, then weights were loaded, and only then it was saved in fully-pledged keras format. When I was loading the latter afterwards, it was not able to issue correct prediction because its signatures became garbage: attention_mask disappeared, seq_length changed, dummy None inputs appeared out of nowhere. So probably try to save your model in keras format right after fitting (without intermediary loading from weights), if that's your case.
Upvotes: 0