Reputation: 3
I'm working on a machine learning university project and I need to save an "agent" (an object) containing some complex stuff that allows me to do other stuff ahahah...I'm using pickle but unfortunately there is an error....AttributeError: Can't pickle local object 'constant_fn.<locals>.func'
this is a piece of my code:
from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os
if os.path.isfile("./filename_pi.obj"):
print("-FILE FOUND-")
file_pi = open('filename_pi.obj', 'rb')
trained_a2c = pickle.load(file_pi)
file_pi.close()
else:
print("-FILE NOT FOUND-")
#A2C
print("Training A2C model")
agent = DRLAgent(env=env_train)
model_a2c = agent.get_model("a2c")
trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
file_pi = open('filename_pi.obj', 'wb')
pickle.dump(trained_a2c, file_pi)
file_pi.close()
Reading similar problems I understood that the problem is in something that is not global, but the problem is that I can not modify anything that is inside .get_model and .train_model because they are methods of a library not written by me and that I can not touch. Is there anything I can do? Maybe I don't have to pass "trained_a2c" ? or you recommend me to change the road?
Upvotes: 0
Views: 1306
Reputation: 1081
If you look at the the source code of the library, you see how stored models can be loaded and adapt that to your own needs.
Models from stable-baselines3 can be loaded with modeltype.load(filename)
where modeltype
is a Model-class from the library, like A2C
.
also, make sure to use the save()
method provided from stable_baselines to save a trained model, to make sure it is stored properly. Not sure if just using pickle will achieve the same.
from stable_baselines3 import A2C
filename = "my_a2c_model" # don't have to include .zip extension, if using load()
# loading a trained model from file
model = A2C.load(filename)
# train the model again
agent = DRLAgent(env=env_train)
trained_a2c = agent.train_model(model=model, tb_log_name="a2c", total_timesteps=50000)
# saving the new model with the provided save() method from the library:
trained_a2c.save("my_new_model") # will be saved to my_new_model.zip
more information can be found here:
Upvotes: 0
Reputation:
Check this:
from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os
if os.path.isfile("./filename_pi.obj"):
print("-FILE FOUND-")
file_pi = open('filename_pi.obj', 'rb')
trained_a2c = pickle.load(file_pi)
file_pi.close()
else:
print("-FILE NOT FOUND-")
#A2C
print("Training A2C model")
agent = DRLAgent(env=env_train)
model_a2c = agent.get_model("a2c")
trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
file_pi = open('filename_pi.obj', 'wb')
pickle.dump(trained_a2c, file_pi)
file_pi.close()
And this for better design:
from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os
def train_a2c():
#A2C
print("Training A2C model")
agent = DRLAgent(env=env_train)
model_a2c = agent.get_model("a2c")
trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
return trained_a2c
if os.path.isfile("./trained_a2c.obj"):
print("-FILE FOUND-")
file_pi = open('trained_a2c.obj', 'rb')
trained_a2c = pickle.load(file_pi)
file_pi.close()
else:
print("-FILE NOT FOUND-")
trained_a2c = train_a2c()
file_pi = open('trained_a2c.obj', 'wb')
pickle.dump(trained_a2c, file_pi)
file_pi.close()
Upvotes: 0