Reputation: 21
I have a neural network model that has been provided to me as a torch.jit._script.RecursiveScriptModule
, saved using torch.jit.script(torch_model).save()
, so I don't have access to the underlying code for the model itself.
I want to use the results of this model to perform some additional calculation on the predictions and return the combined results as part of the final inference. I plan to do this using a torch.nn.Module
wrapper object so I can upload it to MLFlow and register it so that it can be used as part of a formalised deployment pipeline.
Creating this combined model in code is not an issue but when trying to upload the model using the final mlflow.pytorch.log_model
line, it fails with the following error:
RuntimeError: Tried to serialize object __torch__.JitModel which does not have a __getstate__ method defined!
The code below shows a simplified example of this process:
class JitModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(17, 2)
def forward(self, x):
return self.linear(x)
class InferenceModel(torch.nn.Module):
def __init__(self, jit_model, model_name: str):
super().__init__()
self.model_name = model_name
self.model = jit_model
def forward(self, x):
outputs = self.model(x.float())
total_outputs = self.compute_additional_outputs(outputs)
return total_outputs
def compute_additional_outputs(self, outputs):
add_outputs = torch.concat([outputs, outputs * 10], dim=-1)
return add_outputs
if __name__ == "__main__":
base_model = JitModel()
jit_model = torch.jit.script(base_model)
jit_model.save("test.pt")
print(jit_model.__getstate__)
jit_model = torch.jit.load("test.pt")
print(type(jit_model))
inference_model = InferenceModel(jit_model=jit_model, model_name="test")
preds = inference_model(torch.rand(1, 17))
print(preds.shape)
mlflow.pytorch.log_model(inference_model, "model")
However, the print statement suggests that the jit_model.__getstate__
method is defined:
print(jit_model.__getstate__)
>>
<bound method Module.__getstate__ of RecursiveScriptModule(
original_name=JitModel
(linear): RecursiveScriptModule(original_name=Linear)
)>
Given that I've been able to independently upload both a jit model and a torch.nn.Module
model directly to mlflow before, I was expecting this combined solution to work. How can I resolve this error?
As an alternative I also tried converting the whole model I'm working with to the jit
format using jit.script()
but it was incorrectly identifying some of the parameters I was passing to functions as Tuples of Tensors instead of just Tensors and therefore wouldn't compile correctly.
I am using the following versions in python 3.11.8:
MLflow version: 2.16.2
torch version: 2.4.0
Upvotes: 0
Views: 72