Alexander
Alexander

Reputation: 533

How to set the input of a keras subclass model in tensorflow?

I've created a keras subclass model using tensorflow. Snippets are shown below.

class SubModel(Model):
    def call(self, inputs):
        print(inputs)

model = SubModel()
model.fit(data, labels, ...)

When fit the model, it will get the inputs and input_shape itself. What I want to do is pass inputs to the model myself.Just like the functional API does.

inputs = tf.keras.input(shape=(100,))
model = tf.keras.Model(inputs=inputs, outputs=outputs)

Upvotes: 6

Views: 7282

Answers (3)

reexyyl
reexyyl

Reputation: 51

I ended up giving up on keras.Model subclassing. It was too tricky and I was getting errors about input shape.

I wanted to be able to use .fit() directly on my custom class model objects. For this purpose, an easy method I found was to implement the builtin __getattr__ method (more info in official Python doc).

With this, it is possible to use any method of keras.Model on our custom objects. The class implementation I use:

from tensorflow.keras import Input, layers, Model

class SubModel():
    def __init__(self):
        self.model = self.get_model()

    def get_model(self):
        # here we use the usual Keras functional API
        x = Input(shape=(24, 24, 3))
        y = layers.Conv2D(28, 3, strides=1)(x)
        return Model(inputs=[x], outputs=[y])

    def __getattr__(self, name):
        """
        This method enables to access an attribute/method of self.model.
        Thus, any method of keras.Model() can be used transparently from a SubModel object
        """
        return getattr(self.model, name)


if __name__ == '__main__':
    submodel = SubModel()
    submodel.fit(data, labels, ...)  # underlyingly calls SubModel.model.fit()

Explanation: __getattr__ is a special Python method called when the default attribute access fails, i.e. when name doesn't belong to self. In these cases, we try to access the attribute name of self.model.

Upvotes: 5

Daniel
Daniel

Reputation: 432

If you want to be able to specify the input shape before calling model.fit, you can use model.build It takes one positional parameter: input_shape.

Unrelated (but in case other people have this issue), you need to do this whenever you want to call model.summary or sometimes with Dense layers, e.g:

ValueError: The last dimension of the inputs to "Dense" should be defined. Found "None".

As an example:

class MyModel(keras.Model):
    def __init__(self, input_shape):
        super().__init__()
        # Example layers that would through an error if we didn't call build
        self.convT1 = keras.layers.Conv2DTranspose(filters=1, kernel_size=10)
        self.dense = keras.layers.Dense(10)

        self.compile(optimizer='Adam', loss='mse', metrics='acc')

        # ! Call build and pass the input_shape
        self.build(input_shape)
        self.summary() # Because we can now! (would fail without self.build)
model = MyModel(input_shape=(1, 1, 10, 10))

You could also call model.build after initialization instead of self.build.

Upvotes: -1

gdaras
gdaras

Reputation: 10119

Something like that?

model_ = SubModel()
inputs = tf.keras.input(shape=(100,))
outputs = model_(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

Upvotes: 5

Related Questions