user23
user23

Reputation: 9

Keras model.fit very, very slow compared to model.predict

I am working on a binary image classification task with large inputs (10e8 features). I have a small convolutional network defined in Keras with a Tensorflow backend that can in principle classify a batch of 500 such images in under a second using model.predict. Training the model is very slow compared to this. Using model.fit on just a single image takes around 20 minutes for each epoch. Is this sort of disparity to be expected? Are there any simple improvements?

Python code to reproduce this is below:

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Dropout,Dense

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=[4,150], activation='relu',input_shape=(257,44101,1)))      
model.add(MaxPooling2D(pool_size=(4, 500)))
model.add(Flatten)
model.add(Dropout(0.3))
model.add(Dense(50))
model.add(Dense(2, activation='softmax'))
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy())

ex_tensor=tf.ones([100,257,44101,1])
ex_labels=tf.ones([100,1])

model(ex_tensor) # Fast
model.fit(ex_tensor,ex_label,epochs=1,steps_per_epoch=1) # Very slow

Upvotes: 0

Views: 3340

Answers (1)

Hannah Morgan
Hannah Morgan

Reputation: 153

Yes, that is to be expected.

Predicting with a neural network is a simple matter of several matrix multiplications. Training/fitting your model, on the other hand, is a much more complex bit of math. You are waiting for the algorithm to figure out how to distinguish one picture from another - not an easy task.

You should never train on just one image. This will just cause the model to memorize that one image. Make sure you are training on a training set of your data and reserving the other (say 20%) for testing or validating the quality of your model.

Some ways to speed things up:

  1. Decrease the number of features by scaling down your images. 1,000,000,000 features is A LOT of features.
  2. Decrease the complexity of your model or change the hyperparameters.
  3. If those aren't acceptable, use a more powerful computer.

Upvotes: 3

Related Questions