leenremm
leenremm

Reputation: 1454

How can I add the decode_batch_predictions() method into the Keras Captcha OCR model?

The current Keras Captcha OCR model returns a CTC encoded output, which requires decoding after inference.

To decode this, one needs to run a decoding utility function after inference as a separate step.

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

The decoded utility function uses keras.backend.ctc_decode, which in turn uses either a greedy or beam search decoder.

# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_length
    ]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

I would like to train a Captcha OCR model using Keras that returns the CTC decoded as an output, without requiring an additional decoding step after inference.

How would I achieve this?

Upvotes: 5

Views: 612

Answers (2)

leenremm
leenremm

Reputation: 1454

The most robust way to achieve this is by adding a method which is called as part of the model definition:

def CTCDecoder():
  def decoder(y_pred):
    input_shape = tf.keras.backend.shape(y_pred)
    input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast(
        input_shape[1], 'float32')
    unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0]
    unpadded_shape = tf.keras.backend.shape(unpadded)
    padded = tf.pad(unpadded,
                    paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]],
                    constant_values=-1)
    return padded

return tf.keras.layers.Lambda(decoder, name='decode')

Then defining the model as follows:

prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))

Credit goes to tulasiram58827.

This implementation supports exporting to TFLite, but only float32. Quantized (int8) TFLite export is still throwing an error, and is an open ticket with TF team.

Upvotes: 2

dataista
dataista

Reputation: 3457

Your question can be interpreted in two ways. One is: I want a neural network that solves a problem where the CTC decoding step is already inside what the network learned. The other one is that you want to have a Model class that does this CTC decoding inside of it, without using an external, functional function.

I don't know the answer to the first question. And I cannot even tell if it's feasible or not. In any case, sounds like a difficult theoretical problem and if you don't have luck here, you might want to try posting it in datascience.stackexchange.com, which is a more theory-oriented community.

Now, if what you are trying to solve is the second, engineering version of the problem, that's something I can help you with. The solution for that problem is the following:

You need to subclass keras.models.Model with a class with the method you want. I went over the tutorial in the link you posted and came with the following class:

class ModifiedModel(keras.models.Model):
    
    # A utility function to decode the output of the network
    def decode_batch_predictions(self, pred):
        input_len = np.ones(pred.shape[0]) * pred.shape[1]
        # Use greedy search. For complex tasks, you can use beam search
        results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
            :, :max_length
        ]
        # Iterate over the results and get back the text
        output_text = []
        for res in results:
            res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
            output_text.append(res)
        return output_text

    
    def predict_texts(self, batch_images):
        preds = self.predict(batch_images)
        return self.decode_batch_predictions(preds)

You can give it the name you want, it's just for illustration purposes. With this class defined, you would replace the line

# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

with

prediction_model = ModifiedModel(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

And then you can replace the lines

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

with

pred_texts = prediction_model.predict_texts(batch_images)

Upvotes: 1

Related Questions