bicarlsen
bicarlsen

Reputation: 1371

Undefined output shape of custom Keras layer

I am writing a custom Keras layer that flattens all except the last dimension of the input. However, when feeding the output of the layer into the next layer an error occurs because the output shape of the layer is None in all dimensions.

class FlattenLayers( Layer ):
    """
    Takes a nD tensor flattening the middle dimensions, ignoring the edge dimensions.
    [ n, x, y, z ] -> [ n, xy, z ]
    """

    def __init__( self, **kwargs ):
        super( FlattenLayers, self ).__init__( **kwargs )


    def build( self, input_shape ):
        super( FlattenLayers, self ).build( input_shape )


    def call( self, inputs ):
        input_shape = tf.shape( inputs )

        flat = tf.reshape(
            inputs,
            tf.stack( [ 
                -1, 
                K.prod( input_shape[ 1 : -1 ] ),
                input_shape[ -1 ]
            ] )
        )

        return flat


    def compute_output_shape( self, input_shape ):
        if not all( input_shape[ 1: ] ):
            raise ValueError( 'The shape of the input to "Flatten" '
                             'is not fully defined '
                             '(got ' + str( input_shape[ 1: ] ) + '). '
                             'Make sure to pass a complete "input_shape" '
                             'or "batch_input_shape" argument to the first '
                             'layer in your model.' )

        output_shape = ( 
            input_shape[ 0 ], 
            np.prod( input_shape[ 1 : -1 ] ), 
            input_shape[ -1 ] 
        )

        return output_shape

For example, when a Dense layer follows I receive the error ValueError: The last dimension of the inputs to Dense should be defined. Found None.

Upvotes: 1

Views: 2892

Answers (1)

Vlad
Vlad

Reputation: 8585

Why do you have tf.stack() in new shape? You want to flatten all dimensions except the last one; this is how you could do it:

import tensorflow as tf
from tensorflow.keras.layers import Layer
import numpy as np

class FlattenLayer(Layer):

    def __init__( self, **kwargs):
        super(FlattenLayer, self).__init__(**kwargs)

    def build( self, input_shape ):
        super(FlattenLayer, self).build(input_shape)

    def call( self, inputs):
        new_shape = self.compute_output_shape(tf.shape(inputs))
        return tf.reshape(inputs, new_shape)

    def compute_output_shape(self, input_shape):
        new_shape = (input_shape[0]*input_shape[1]*input_shape[2],
                     input_shape[3])
        return new_shape

Testing with a single data point (tf.__version__=='1.13.1'):

inputs = tf.keras.layers.Input(shape=(10, 10, 1))    
res = tf.keras.layers.Conv2D(filters=3, kernel_size=2)(inputs)
res = FlattenLayer()(res)
model = tf.keras.models.Model(inputs=inputs, outputs=res)

x_data = np.random.normal(size=(1, 10, 10, 1))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    evaled = model.outputs[0].eval({model.inputs[0]:x_data})
    print(evaled.shape) # (81, 3)

Upvotes: 1

Related Questions