Reputation: 411
I want to have a scalar input for my keras model, but couldn't find a good way to do it.
I can specify an input like this: a = tf.keras.Input(shape=(), name="a")
. However, keras automatically adds a batch dimension to a
, see below example:
a = tf.keras.Input(shape=(), name="a")
print(a.shape) # the output is (None,)
I just want a
to be a scalar (i.e. has shape ()
instead of (None,)
). How can I do it?
I have found a workaround:
a = tf.keras.Input(shape=(), name="a")
a_scalar = tf.squeeze(a, axis=0)
print(a_scalar.shape) # the output is ()
But this is just way too ugly and stupid.
Upvotes: 2
Views: 1491
Reputation: 101
tf.keras.Input(batch_shape=(None,))
Result:
<KerasTensor: shape=(None,)
Upvotes: 2