Sentient
Sentient

Reputation: 851

Keras Converting Functional Model to Model Subclass

I'm trying to use Keras Model Subclassing to rewrite a Functional model, but in the new model subclass, the summary generation does not work.

For reference, here is the functional model and its output.

filters = 32

# placeholder for inputs
inputs = Input(shape=[16, 16, 16, 12])  

# L-hand side of UNet
conv1 = DoubleConv3D(filters*1)(inputs)
pool1 = MaxPooling3D()(conv1)
...

# middle bottleneck
conv5 = DoubleConv3D(filters*5)(pool4)

# R-hand side of UNet
rsdc6 = ConcatConv3D(filters*4)(conv5, conv4)
conv6 = DoubleConv3D(filters*4)(rsdc6)
...

# sigmoid activation
outputs = Conv3D(1, (1, 1, 1), activation='sigmoid')(conv9)

model = Model(inputs=[inputs], outputs=[outputs])
model.summary()
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_29 (InputLayer)           (None, 16, 16, 16, 1 0                                            
__________________________________________________________________________________________________
conv3d_111 (Conv3D)             (None, 16, 16, 16, 3 10400       input_29[0][0]                   
__________________________________________________________________________________________________
...

And the Model Subclass looks like:

class UNet3D(Model):
    def __init__(self, **kwargs):
        super(UNet3D, self).__init__(name="UNet3D", **kwargs)        
        self.filters = 32

    def __call__(self, inputs):

        # L-hand side of UNet
        conv1 = DoubleConv3D(self.filters*1)(inputs)
        pool1 = MaxPooling3D()(conv1)
        ...

        # middle bottleneck
        conv5 = DoubleConv3D(self.filters*5)(pool4)

        # R-hand side of UNet
        rsdc6 = ConcatConv3D(self.filters*4)(conv5, conv4)
        conv6 = DoubleConv3D(self.filters*4)(rsdc6)
        ...

        # sigmoid activation
        outputs = Conv3D(1, (1, 1, 1), activation='sigmoid')(conv9)
        return outputs

unet3d = UNet3D()
unet3d.build(Input(shape=[None, None, None, 1]))
unet3d.summary()

However, instead of outputting the layers and number of parameters, the summary gives

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

Initially, I believed this was an error with not calling build before calling summary, and tried to both explicitly call the function and adding an InputLayer before the first Convolution layer, as noted in this related answer. However, neither solutions help in fixing the summary generation on the model subclass.

Upvotes: 2

Views: 1231

Answers (1)

Sentient
Sentient

Reputation: 851

I've found a solution to this Model subclassing problem by looking at the following example. Credit should go to the author of that repo.

One way to create convert a Keras Functional into a Model subclass is to create and call a function which replicates the model initialization, e.g. Model(inputs=[inputs], outputs=[outputs]). Here, we do that with the _build function.

class UNet3D(Model):
    def __init__(self, **kwargs):

        # Initialize model parameters.
        self.filters = 32
        ...

        # Initialize model.
        self._build(**kwargs)

    def __call__(self, inputs):

        # L-hand side of UNet
        conv1 = DoubleConv3D(self.filters*1)(inputs)
        pool1 = MaxPooling3D()(conv1)
        ...

        # middle bottleneck
        conv5 = DoubleConv3D(self.filters*5)(pool4)

        # R-hand side of UNet
        rsdc6 = ConcatConv3D(self.filters*4)(conv5, conv4)
        conv6 = DoubleConv3D(self.filters*4)(rsdc6)
        ...

        # sigmoid activation
        outputs = Conv3D(1, (1, 1, 1), activation='sigmoid')(conv9)
        return outputs

    def _build(self, **kwargs):
        """
        Replicates Model(inputs=[inputs], outputs=[outputs]) of functional model.
        """
        # Replace with shape=[None, None, None, 1] if input_shape is unknown.
        inputs  = Input(shape=[16, 16, 16, 12])
        outputs = self.__call__(inputs)
        super(UNet3D, self).__init__(name="UNet3D", inputs=inputs, outputs=outputs, **kwargs) 

unet3d = UNet3D()
unet3d.summary()

Upvotes: 1

Related Questions