Reputation: 3324
I'm trying to fine tune resnet50 in half precision mode without success. It seems there are parts of model which are not compatible with float16
. Here is my code:
dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)
model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
and I get this error:
Traceback (most recent call last):
File "train_resnet.py", line 40, in <module>
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
return base_fun(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
return resnet50.ResNet50(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
output = self.call(inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
epsilon=self.epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
epsilon=epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
data_format=tf_data_format)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
param_name=input_name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32
Upvotes: 0
Views: 288
Reputation: 3324
This was a reported bug and upgrading to Keras==2.2.5
solved the issue.
Upvotes: 1