Ali
Ali

Reputation: 1

Export PyTorch model to ONNX

I am working on training and exporting a CRNN model for an Automatic License Plate Recognition (ALPR) task using PyTorch. My model includes a ctc_decode function that performs post-processing after the logits are generated in the forward pass. This decoding is used during inference but not during training.

Here is my model:

import torch.onnx
import torch
import torch.nn as nn
import torch.nn.functional as F


class CRNNCTC(nn.Module):
    def __init__(self, image_height, image_width, num_classes, blank_token=34):

        super(CRNNCTC, self).__init__()

        self.blank_token = blank_token

        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d((2, 2))

        self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d((2, 2))

        self.conv3 = nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d((2, 1))  # Pooling over time dimension

        # Calculate the feature size after convolutions and pooling
        self.flatten_size = (image_height // 8) * 256

        # Dense and dropout layers
        self.fc = nn.Linear(self.flatten_size, 128)
        self.dropout = nn.Dropout(0.2)

        # Recurrent layers (Bidirectional LSTM)
        self.lstm = nn.LSTM(128, 128, bidirectional=True,
                            batch_first=True, dropout=0.25)

        # Output layer (CTC)
        # *2 for bidirectional LSTM
        self.output = nn.Linear(128 * 2, num_classes)

    def forward(self, x, apply_post_processing=True):

        # print(f"Shape of Input: {x.shape}")
        # Convolutional layers
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        # print(f"Shape after CONV1: {x.shape}")

        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        # print(f"Shape after CONV2: {x.shape}")

        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        # print(f"Shape after CONV3: {x.shape}")

        # Reshape to prepare for LSTM
        # batch_size = x.size(0)
        # x = x.view(batch_size, -1, self.flatten_size)
        x = x.permute(0, 3, 1, 2)
        batch_size = x.size(0)
        x = x.view(batch_size, x.size(1), -1)
        # print(f"Shape before LSTM: {x.shape}")

        # Dense and dropout layers
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        # print(f"Shape after FC: {x.shape}")

        # LSTM layers
        x, (hn, cn) = self.lstm(x)
        # print(f"Shape after LSTM: {x.shape}")

        # Output layer
        x = self.output(x)
        # print(f"Shape of Output: {x.shape}")

        # Post-processing: CTC decoding (greedy decoding)
        if apply_post_processing:
            my_idel_outputs = self.ctc_decode(F.log_softmax(x, dim=-1))
        # , decoded_probabilities, decoded_output # For inference: return both raw logits and decoded output
            return my_idel_outputs
        else:
            return x


def ctc_decode(self, log_probs, target_length=8, pad_value=-1):
    """
    Decodes the output of the model using CTC decoding rules.
    This includes collapsing repeated characters and removing the blank token.

    Args:
    log_probs (torch.Tensor): Log probabilities from the output layer (shape: [batch_size, time_steps, num_classes])

    Returns:
    List of decoded sequences (without blank tokens).
    """
    # Greedy decoding by taking the argmax at each time step
    # decoded_sequences = []
    # decoded_probabilities = []
    my_idel_outputs = []
    for batch in log_probs:
        # Apply argmax to get the most probable class at each time step
        predicted_labels = torch.argmax(batch, dim=-1)  # Shape: [time_steps]
        pred_probs = batch.exp().detach().cpu()

        # decoded_sequence = []
        # decoded_probability = []
        my_idel_output = []
        prev_char = self.blank_token  # Initialize with the blank token
    for t, char in enumerate(predicted_labels):
        # Collapse repeated characters and ignore blank tokens
        if char != prev_char and char != self.blank_token:
            # decoded_sequence.append(encode2actual[char.item()])
            # decoded_probability.append(pred_probs[t, char])
            my_idel_output.append(pred_probs[t, char].item())
            my_idel_output.append(encode2actual[char.item()])
            prev_char = char
            my_idel_output[5] = my_idel_output[5] - 9
        if len(my_idel_output) > 16:
            my_idel_output = my_idel_output[:16]
        elif len(my_idel_output) < 16:
            my_idel_output.extend(
                [pad_value] * (target_length - len(my_idel_output)))
            # decoded_sequences.append(decoded_sequence)
            # decoded_probabilities.append(decoded_probability)
            my_idel_outputs.append(torch.tensor(
                my_idel_output, dtype=torch.float32))
    return my_idel_outputs  # , decoded_probabilities, decoded_sequences


# Example usage:
model = CRNNCTC(IMG_HEIGHT, IMG_WIDTH, num_classes=len(
    unique_values) + 1)  # Number of unique + Blank of CTC loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)


# Ensure the model is in eval mode
model.eval()

# Create a random input tensor with the correct dimensions
onnx_device = "cpu"
onnx_file_path = 'crnn_best_model.onnx'
torch_input = torch.randn(1, 1, 50, 200).to(
    onnx_device)  # Shape [1, channels, height, width]
model.to(onnx_device)

# Export the model to ONNX
torch.onnx.export(
    model,  # The model to export
    torch_input,  # Example input tensor
    onnx_file_path,  # Output ONNX model file
    input_names=["Input"],  # Name of the input
    output_names=["Scores_Labels"],  # Name of the output
    opset_version=12,  # Specify the ONNX opset version
    verbose=True  # Print the model structure
)
print(f"Model exported to {onnx_file_path}")

I want to export this model to ONNX for inference in a C++ environment, but the ctc_decode part is causing issues because ONNX does not support custom Python functions like ctc_decode within the forward pass. I need to:

Upvotes: 0

Views: 121

Answers (0)

Related Questions