Marc Felix
Marc Felix

Reputation: 421

Tensorflow splits batches for no reason

I am having an issue, wehre I input a batch of shape [128, 3] into a tensorflow model (together with many other inputs) and when I define the output of the model to just be tf.shape(input), then the output is:

[32, 3, 32, 3, 32, 3, 32, 3]

I already asked a similar question once, with a working code example, but nobody replied. In case anyone wants to see it, this is the link: Link So this time, I will not add any code, and hope that someone can just answer the question:

How can this happen at all, no matter what the model is? Why does tensorflow just change the shape of the input?

This behavior only exists for batch sizes of more than 32, otherwise the output correctly is [batch_size, 3]. When I just return the input from the model, and print np.shape(output), then this is again [128, 3].

I am totally confused why anything like this can happen at all, I hope anyone can explain this to me.

Upvotes: 1

Views: 33

Answers (1)

geometrikal
geometrikal

Reputation: 3294

Don't specify batch_size in the input layer

e.g. use

feature_1 = layers.Input(shape=(2,), dtype=tf.float32)

This will give -1 for the batch dimension, which means it can take any size batch.

Upvotes: 1

Related Questions