Reputation: 1371
I am writing a custom Keras layer that flattens all except the last dimension of the input. However, when feeding the output of the layer into the next layer an error occurs because the output shape of the layer is None
in all dimensions.
class FlattenLayers( Layer ):
"""
Takes a nD tensor flattening the middle dimensions, ignoring the edge dimensions.
[ n, x, y, z ] -> [ n, xy, z ]
"""
def __init__( self, **kwargs ):
super( FlattenLayers, self ).__init__( **kwargs )
def build( self, input_shape ):
super( FlattenLayers, self ).build( input_shape )
def call( self, inputs ):
input_shape = tf.shape( inputs )
flat = tf.reshape(
inputs,
tf.stack( [
-1,
K.prod( input_shape[ 1 : -1 ] ),
input_shape[ -1 ]
] )
)
return flat
def compute_output_shape( self, input_shape ):
if not all( input_shape[ 1: ] ):
raise ValueError( 'The shape of the input to "Flatten" '
'is not fully defined '
'(got ' + str( input_shape[ 1: ] ) + '). '
'Make sure to pass a complete "input_shape" '
'or "batch_input_shape" argument to the first '
'layer in your model.' )
output_shape = (
input_shape[ 0 ],
np.prod( input_shape[ 1 : -1 ] ),
input_shape[ -1 ]
)
return output_shape
For example, when a Dense layer follows I receive the error ValueError: The last dimension of the inputs to Dense should be defined. Found None.
Upvotes: 1
Views: 2892
Reputation: 8585
Why do you have tf.stack()
in new shape? You want to flatten all dimensions except the last one; this is how you could do it:
import tensorflow as tf
from tensorflow.keras.layers import Layer
import numpy as np
class FlattenLayer(Layer):
def __init__( self, **kwargs):
super(FlattenLayer, self).__init__(**kwargs)
def build( self, input_shape ):
super(FlattenLayer, self).build(input_shape)
def call( self, inputs):
new_shape = self.compute_output_shape(tf.shape(inputs))
return tf.reshape(inputs, new_shape)
def compute_output_shape(self, input_shape):
new_shape = (input_shape[0]*input_shape[1]*input_shape[2],
input_shape[3])
return new_shape
Testing with a single data point (tf.__version__=='1.13.1'
):
inputs = tf.keras.layers.Input(shape=(10, 10, 1))
res = tf.keras.layers.Conv2D(filters=3, kernel_size=2)(inputs)
res = FlattenLayer()(res)
model = tf.keras.models.Model(inputs=inputs, outputs=res)
x_data = np.random.normal(size=(1, 10, 10, 1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
evaled = model.outputs[0].eval({model.inputs[0]:x_data})
print(evaled.shape) # (81, 3)
Upvotes: 1