kbaker
kbaker

Reputation: 21

cannot upload combined torch.nn.Module and torch.jit model to MLFlow due to missing __getstate__ method

I have a neural network model that has been provided to me as a torch.jit._script.RecursiveScriptModule, saved using torch.jit.script(torch_model).save(), so I don't have access to the underlying code for the model itself.

I want to use the results of this model to perform some additional calculation on the predictions and return the combined results as part of the final inference. I plan to do this using a torch.nn.Module wrapper object so I can upload it to MLFlow and register it so that it can be used as part of a formalised deployment pipeline.

Creating this combined model in code is not an issue but when trying to upload the model using the final mlflow.pytorch.log_model line, it fails with the following error:

RuntimeError: Tried to serialize object __torch__.JitModel which does not have a __getstate__ method defined!

The code below shows a simplified example of this process:

class JitModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(17, 2)

    def forward(self, x):
        return self.linear(x)


class InferenceModel(torch.nn.Module):
    def __init__(self, jit_model, model_name: str):
        super().__init__()
        self.model_name = model_name
        self.model = jit_model

    def forward(self, x):
        outputs = self.model(x.float())

        total_outputs = self.compute_additional_outputs(outputs)

        return total_outputs

    def compute_additional_outputs(self, outputs):
        add_outputs = torch.concat([outputs, outputs * 10], dim=-1)
        return add_outputs

if __name__ == "__main__":
    base_model = JitModel()
    jit_model = torch.jit.script(base_model)
    jit_model.save("test.pt")

    print(jit_model.__getstate__)

    jit_model = torch.jit.load("test.pt")
    print(type(jit_model))

    inference_model = InferenceModel(jit_model=jit_model, model_name="test")

    preds = inference_model(torch.rand(1, 17))
    print(preds.shape)
    mlflow.pytorch.log_model(inference_model, "model")

However, the print statement suggests that the jit_model.__getstate__ method is defined:

print(jit_model.__getstate__)
>>
<bound method Module.__getstate__ of RecursiveScriptModule(
  original_name=JitModel
  (linear): RecursiveScriptModule(original_name=Linear)
)>

Given that I've been able to independently upload both a jit model and a torch.nn.Module model directly to mlflow before, I was expecting this combined solution to work. How can I resolve this error?

As an alternative I also tried converting the whole model I'm working with to the jit format using jit.script() but it was incorrectly identifying some of the parameters I was passing to functions as Tuples of Tensors instead of just Tensors and therefore wouldn't compile correctly.

I am using the following versions in python 3.11.8:

MLflow version: 2.16.2
torch version: 2.4.0

Upvotes: 0

Views: 72

Answers (0)

Related Questions