Hongtao Yang
Hongtao Yang

Reputation: 411

How to pass scalar inputs to tf.keras model?

I want to have a scalar input for my keras model, but couldn't find a good way to do it.

I can specify an input like this: a = tf.keras.Input(shape=(), name="a"). However, keras automatically adds a batch dimension to a, see below example:

a = tf.keras.Input(shape=(), name="a")
print(a.shape)  # the output is (None,)

I just want a to be a scalar (i.e. has shape () instead of (None,)). How can I do it?

Update

I have found a workaround:

a = tf.keras.Input(shape=(), name="a")
a_scalar = tf.squeeze(a, axis=0)
print(a_scalar.shape)  # the output is ()

But this is just way too ugly and stupid.

Upvotes: 2

Views: 1491

Answers (1)

MateusR
MateusR

Reputation: 101

tf.keras.Input(batch_shape=(None,))

Result:

<KerasTensor: shape=(None,)

Upvotes: 2

Related Questions