lmsasu
lmsasu

Reputation: 7583

How do I export a darts' TCNModel to ONNX?

I have a TCNModel that is trained on a time series:

model_air = TCNModel(
        input_chunk_length=13,
        output_chunk_length=12,
        n_epochs=3,
        dropout=0.1,
        dilation_base=2,
        weight_norm=True,
        kernel_size=5,
        num_filters=3,
        random_state=0,
        save_checkpoints=True,
        model_name=model_name,
        force_reset=True,
    )

According to darts documentation, the inner model_air.model of type darts.models.forecasting.tcn_model._TCNModule (derived from PLPastCovariatesModule) can be serialized to ONNX via to_onnx call. However, no matter what I try, I get errors at export:

model_air.model.to_onnx('model_air.onnx')

ValueError: Could not export to ONNX since neither input_sample nor model.example_input_array attribute is set.

dummy_input = torch.randn(1, 13, 1)
model_air.model.to_onnx('model_air.onnx', input_sample=dummy_input)

ValueError: not enough values to unpack (expected 2, got 1)

dummy_input = (torch.randn(1, 13, 1), None)
model_air.model.to_onnx('model_air.onnx', input_sample=dummy_input)

TypeError: _TCNModule.forward() takes 2 positional arguments but 3 were given

What input value should I give to model_air.model?

Upvotes: 0

Views: 35

Answers (0)

Related Questions