Reputation: 3347
I am trying to construct label-dependent convolutional filters in keras/tensorflow. Therefore, the convolutional filter(s) depend on each example in the batch.
# function used for tf.map_fn
def single_conv(tupl):
x, kernel = tupl
return tf.nn.conv2d(x, kernel, strides=(1, 1, 1, 1), padding='SAME')
# first dimension is None (batch size)
input_img = tf.keras.layers.Input(shape=(28,28,1), dtype=tf.float32)
label = tf.keras.layers.Input(shape=(10,), dtype=tf.float32)
# the network is learning a mapping for the label
label_encoded = tf.keras.layers.Dense(9, activation='relu')(label)
# turn mapping into conv filter
kernels = tf.keras.layers.Reshape((3,3,1,1))(label_encoded)
# class dependent filter(s)
conditional_conv = tf.map_fn(single_conv, (tf.expand_dims(input_img, 1), kernels), fn_output_signature=tf.float32)
When I run this code snippet, I get a TypeError: 'Tensor' object cannot be interpreted as an integer
for the last line. Since the last line uses tf.map_fn
, I saw that tf.map_fn results in a TypeError if either the function used (single_conv in this case) is not callable or the structure of the output of function and fn_output_signature do not match: https://www.tensorflow.org/api_docs/python/tf/map_fn#raises.
However, I'm still not sure why this is happening? I feel like both of those reasons should not be an issue?
Upvotes: 2
Views: 1596
Reputation: 11
What worked for me was creating a custom layer, and doing the map_fn there.
class custom_layer(tf.keras.layers.Layer):
def __init__(self):
super(custom_layer, self).__init__()
def call(self, inputs):
return tf.map_fn(single_conv, inputs, fn_output_signature=tf.float32)
In your case,
conditional_conv = custom_layer()((tf.expand_dims(input_img, 1), kernels))
Upvotes: 1