Reputation: 41
Im having problems when I use the library tensorflow_model_optimization. I am developing a code to prune an already trained neural network. I imported the weights from an h5 file and so I use tensorflor_model_optimization to prune my neural network. I have this error when I call the fit method:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-26-ce9759e4dd53> in <module>
----> 1 model_for_pruning.fit_generator(base_treinamento, steps_per_epoch = 6000 /64, epochs = 5, validation_data = base_teste, validation_steps = 30, callbacks=callbacks)
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1859 use_multiprocessing=use_multiprocessing,
1860 shuffle=shuffle,
-> 1861 initial_epoch=initial_epoch)
1862
1863 def evaluate_generator(self,
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
853 # In this case we have created variables on the first call, so we run the
854 # defunned version which is guaranteed to never create variables.
--> 855 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
856 elif self._stateful_fn is not None:
857 # Release the lock early so that multiple threads can perform the call
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
2942 return graph_function._call_flat(
-> 2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
2945 @property
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1917 # No tape is watching; skip to running the function.
1918 return self._build_call_outputs(self._inference_function.call(
-> 1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
1921 args,
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
558 inputs=args,
559 attrs=attrs,
--> 560 ctx=ctx)
561 else:
562 outputs = execute.execute_with_cancellation(
~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Conv2DCustomBackpropFilterOp only supports NHWC.
[[node gradient_tape/sequential_3/prune_low_magnitude_conv2d_18/Conv2D/Conv2DBackpropFilter (defined at <ipython-input-24-fc85f8818d30>:1) ]] [Op:__inference_train_function_10084]
Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/sequential_3/prune_low_magnitude_conv2d_18/Conv2D/Conv2DBackpropFilter:
sequential_3/prune_low_magnitude_activation_18/Relu (defined at C:\Users\Pichau\anaconda3\envs\supernova\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_wrapper.py:270)
Function call stack:
train_function
enter code here
My code:
from keras.models import model_from_json
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
arquivo = open('model.json', 'r')
estrutura_rede = arquivo.read()
arquivo.close()
model = model_from_json(estrutura_rede)
model.load_weights('model.h5')
gerador_treinamento = ImageDataGenerator(rescale=None)
base_treinamento = gerador_treinamento.flow_from_directory('data/train', target_size = (51,51), batch_size = 64, class_mode = 'binary')
gerador_teste = ImageDataGenerator(rescale=None)
base_teste = gerador_teste.flow_from_directory('data/test', target_size = (51,51), batch_size = 64, class_mode = 'binary')
import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
batch_size = 128
epochs = 2
validation_split = 0.1
num_images = int(len(base_treinamento) * (1 - validation_split))
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
model_for_pruning.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
model_for_pruning.summary()
logdir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit_generator(base_treinamento, steps_per_epoch = 6000 /64, epochs = 5, validation_data = base_teste, validation_steps = 30, callbacks=callbacks)
python: 3.6.12
tensorflow: 2.2.0
tensorflow-model-optimization: 0.5.0
Can someone help me?
Upvotes: 0
Views: 1760
Reputation: 4209
Changing the runtime from CPU to GPU worked for me.
If you are running it on CPU consider changing it to GPU.
Upvotes: 1
Reputation: 61
This does not seem like a pruning issue, but instead a problem with the data format of the model. It looks like one Conv2D layer should be using NHWC
format (data_format="channels_last"
).
Could you share the code for the model?
Upvotes: 0