pscndr
pscndr

Reputation: 3

AttributeError: Can't pickle local object

I'm working on a machine learning university project and I need to save an "agent" (an object) containing some complex stuff that allows me to do other stuff ahahah...I'm using pickle but unfortunately there is an error....AttributeError: Can't pickle local object 'constant_fn.<locals>.func'

this is a piece of my code:


from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os

if os.path.isfile("./filename_pi.obj"):
    print("-FILE FOUND-")
    file_pi = open('filename_pi.obj', 'rb')
    trained_a2c = pickle.load(file_pi)
    file_pi.close()
else:
    print("-FILE NOT FOUND-")
    #A2C
    print("Training A2C model")
    agent = DRLAgent(env=env_train)
    model_a2c = agent.get_model("a2c")
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
    file_pi = open('filename_pi.obj', 'wb') 
    pickle.dump(trained_a2c, file_pi)
    file_pi.close()

Reading similar problems I understood that the problem is in something that is not global, but the problem is that I can not modify anything that is inside .get_model and .train_model because they are methods of a library not written by me and that I can not touch. Is there anything I can do? Maybe I don't have to pass "trained_a2c" ? or you recommend me to change the road?

Upvotes: 0

Views: 1306

Answers (2)

Jack O&#39;Neill
Jack O&#39;Neill

Reputation: 1081

If you look at the the source code of the library, you see how stored models can be loaded and adapt that to your own needs.

Models from stable-baselines3 can be loaded with modeltype.load(filename) where modeltype is a Model-class from the library, like A2C.

also, make sure to use the save() method provided from stable_baselines to save a trained model, to make sure it is stored properly. Not sure if just using pickle will achieve the same.

from stable_baselines3 import A2C

filename = "my_a2c_model" # don't have to include .zip extension, if using load()

# loading a trained model from file
model = A2C.load(filename)

# train the model again
agent = DRLAgent(env=env_train)
trained_a2c = agent.train_model(model=model, tb_log_name="a2c", total_timesteps=50000)

# saving the new model with the provided save() method from the library:
trained_a2c.save("my_new_model") # will be saved to my_new_model.zip

more information can be found here:

Upvotes: 0

user17716493
user17716493

Reputation:

Check this:

from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os

if os.path.isfile("./filename_pi.obj"):
    print("-FILE FOUND-")
    file_pi = open('filename_pi.obj', 'rb')
    trained_a2c = pickle.load(file_pi)
    file_pi.close()
else:
    print("-FILE NOT FOUND-")
    #A2C
    print("Training A2C model")
    agent = DRLAgent(env=env_train)
    model_a2c = agent.get_model("a2c")
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
    file_pi = open('filename_pi.obj', 'wb') 
    pickle.dump(trained_a2c, file_pi)
    file_pi.close()

And this for better design:

from finrl.agents.stablebaselines3.models import DRLAgent
import pickle
import os

def train_a2c():
    #A2C
    print("Training A2C model")
    agent = DRLAgent(env=env_train)
    model_a2c = agent.get_model("a2c")
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name="a2c", total_timesteps=50000)
    return trained_a2c

if os.path.isfile("./trained_a2c.obj"):
    print("-FILE FOUND-")
    file_pi = open('trained_a2c.obj', 'rb')
    trained_a2c = pickle.load(file_pi)
    file_pi.close()
else:
    print("-FILE NOT FOUND-")
    trained_a2c = train_a2c()
    file_pi = open('trained_a2c.obj', 'wb') 
    pickle.dump(trained_a2c, file_pi) 
    file_pi.close()

Upvotes: 0

Related Questions