Reputation: 797
I am trying to use a keras callbck to make prediction at the end of the batch as follows :
from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
import numpy as np
class CollectOutputAndTarget(Callback):
def __init__(self):
super(CollectOutputAndTarget, self).__init__()
self.targets = [] # collect y_true batches
self.inputs = [] # collect y_true batches
self.outputs = [] # collect y_pred batches
self.preds = []
# the shape of these 2 variables will change according to batch shape
# to handle the "last batch", specify `validate_shape=False`
self.var_y_true = tf.Variable(0., validate_shape=False)
self.var_input = tf.Variable(0., validate_shape=False)
self.var_y_pred = tf.Variable(0., validate_shape=False)
def on_batch_end(self, batch, logs=None):
# evaluate the variables and save them into lists
self.targets.append(K.eval(self.var_y_true))
batch_inp = K.eval(self.var_input)
self.inputs.append(batch_inp)
self.outputs.append(K.eval(self.var_y_pred))
current_pred = self.model.predict(batch_inp)
self.preds.append(current_pred)
# build a simple model
K.clear_session()
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(2,)), Dense(2)])
model.compile(loss='mse', optimizer='adam')
# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
tf.assign(cbk.var_input, model.inputs[0], validate_shape=False),
tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches} # use `model._function_kwargs` if using `Model` instead of `Sequential`
# fit the model and check results
X = np.arange(10).reshape((5, 2))
Y = X*2
model.fit(X, Y, epochs=1, batch_size=3, callbacks=[cbk], shuffle=False)
And I am getting the following error :
InvalidArgumentError Traceback (most recent call last)
<ipython-input-114-adfad08009ad> in <module>
3 Y = X*2
4
----> 5 model.fit(X, Y, epochs=1, batch_size=3, callbacks=[cbk], shuffle=False)
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1037 initial_epoch=initial_epoch,
1038 steps_per_epoch=steps_per_epoch,
-> 1039 validation_steps=validation_steps)
1040
1041 def evaluate(self, x=None, y=None,
/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
202 batch_logs[l] = o
203
--> 204 callbacks.on_batch_end(batch_index, batch_logs)
205 if callback_model.stop_training:
206 break
/usr/local/lib/python3.6/dist-packages/keras/callbacks.py in on_batch_end(self, batch, logs)
113 t_before_callbacks = time.time()
114 for callback in self.callbacks:
--> 115 callback.on_batch_end(batch, logs)
116 self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
117 delta_t_median = np.median(self._delta_ts_batch_end)
<ipython-input-111-65feb418f9ec> in on_batch_end(self, batch, logs)
19 self.inputs.append(batch_inp)
20 self.outputs.append(K.eval(self.var_y_pred))
---> 21 current_pred = self.model.predict(batch_inp)
22 self.preds.append(current_pred)
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in predict_on_batch(self, x)
1272 ins = x
1273 self._make_predict_function()
-> 1274 outputs = self.predict_function(ins)
1275 return unpack_singleton(outputs)
1276
/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2713 return self._legacy_call(inputs)
2714
-> 2715 return self._call(inputs)
2716 else:
2717 if py_any(is_tensor(x) for x in inputs):
/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
2674 else:
-> 2675 fetched = self._callable_fn(*array_vals)
2676 return fetched[:len(self.outputs)]
2677
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1456 ret = tf_session.TF_SessionRunCallable(self._session._session,
1457 self._handle, args,
-> 1458 run_metadata_ptr)
1459 if run_metadata:
1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_target' with dtype float and shape [?,?]
[[{{node dense_2_target}}]]
I am able to get the model output at the end of the batch from the variable self.var_y_pred
which is assigned to model.outputs[0]
.
However from my understanding this prediction is done before the backpropagation at the current step. And my objective is to be able to make prediction on the current batch using the model version which weights were already updated with current batch training.
How can I achieve this ?
Upvotes: 1
Views: 2222
Reputation: 86650
The answer is "you can't".
The objects model.inputs
and model.outputs
are lists of "tensors", not data. Tensors are void graph representations.
The only way to get a batch prediction is calling model.predict_on_batch(input_data_as_numpy)
or similar methods. This means making the model predict the same thing twice in your case. A terrible performance drawback.
For using predicted batches during training you need to switch to using eager mode on and make a custom training loop:
Upvotes: 3