Ne-al
Ne-al

Reputation: 11

ActorCritic - model cannot be saved because the input shapes have not been set

I want to save a Actor-Critic model, but this problem happens.

import os
import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense, LSTM, BatchNormalization
from tensorflow.keras import Model


class ActorCritic(Model):
    def __init__(self, action_size, state_size):
        super(ActorCritic, self).__init__()
        self.lstm1 = LSTM(16, return_sequences=True, input_shape=state_size)
        self.lstm2 = LSTM(8, return_sequences=True)
        self.flatten = Flatten()
        self.policy = Dense(action_size, activation='linear')
        self.value = Dense(1, activation='linear')

    def call(self, x):
        x = self.lstm1(x)
        x = self.lstm2(x)
        x = self.flatten(x)
        policy = self.policy(x)
        value = self.value(x)

        return policy, value


class A3CAgent():

    def __init__(self):
        self.state_size = (9, 23)
        self.action_size = 2
        self.save_path = os.path.join(os.getcwd(), 'model')

        self.global_model = ActorCritic(self.action_size, self.state_size)
        self.global_model.build((None, *self.state_size))

        self.global_model.save(self.save_path)


if __name__ == "__main__":
    global_agent = A3CAgent()

Output:

Traceback (most recent call last):
ValueError: Model <__main__.ActorCritic object at 0x000001933D7F1E10> cannot be saved because the input shapes have not been set.
Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`.
To manually set the shapes, call `model.build(input_shape)`.

I wrote 'self.global_model.build((None, *self.state_size))', but it doesn't work.

How can I call model.build(input_shape) or resolve it?

Upvotes: 1

Views: 339

Answers (2)

user11530462
user11530462

Reputation:

Running compute_output_shape to completely build the model worked as shown below.

self.global_model.build((None, *self.state_size))
self.global_model.compute_output_shape(input_shape=(None,9, 23)) # added this line

Just added the last line to your code without any modification. Please check the gist here.

Upvotes: 0

elbe
elbe

Reputation: 1508

It seems to me that the problem is not with the build method which is ok, but with the fact that the save method of the Model class does not work for custom Models. Instead, you can use the tensorflow checkpoint format, which allows to save the model and also other things like the optimizer (for a given iteration). You can find information here: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint, https://www.tensorflow.org/guide/checkpoint

You can modify your code like this (remove self.global_model.save(self.save_path)):

if __name__ == "__main__":
    state_size = (9, 23)
    action_size = 2
    save_path = os.path.join(os.getcwd(), 'model')
    print(save_path)

    global_agent = A3CAgent()
    global_model = global_agent.global_model
    
    # init check point
    checkpoint = tf.train.Checkpoint(model=global_model)
    #  restore latest model
    checkpoint.restore(tf.train.latest_checkpoint(save_path)).assert_consumed()
    
    # do training
    # train the model and save each xxx iterations
    checkpoint.save(os.path.join(save_path, 'ckpt'))
    
    # you can reload the latest model :
    checkpoint.restore(tf.train.latest_checkpoint(save_path)).assert_consumed()

Upvotes: 1

Related Questions