Hanhan Li
Hanhan Li

Reputation: 465

Why do we need to worry about the batch dimension when specifying a model in Tensorflow?

It seems a bit cumbersome to take into account the batch dimension for every layer in a neural network. Why don't we have some functionality in Tensorflow that can just set the batch size for an entire model?

Upvotes: 4

Views: 3569

Answers (2)

Adrivg
Adrivg

Reputation: 11

I think what Hanhin Li what to say is regarding custom layers where you need to do some reshaping like this one:

Suppose that the input shape to the model is (batch_size, frames) and you want to implement a preprocessing layer that removes some frames based on a condition . The layer to do this could be something like this:

def process_file(data): 
    new_frames = tf.TensorArray(data.dtype, size=0, dynamic_size=True) 
    i = 0
    for frame in data:
       if tf.math.count_nonzero(frame) == 0:
           new_frames.write(i, frame)
           i += 1 
    return new_frames.stack()


class Preprocessing(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(Preprocessing, self).__init__(**kwargs)

    def call(self, inputs, training=True, **kwargs):
        return process_file(inputs)

The problem is that once you fit the model with batches the process_file() from Preprocessing() will received the batch size so all the function is messed up as now for frame in data will iterate over the batches instead of frames (because of the batch dimension added).

So I think the question is, can you add something to the layer subclass to not send the batch dimension to the preprocessing function and stack the batches afterwards?

I'm sorry I dont have the answer to that. I' m struggling with the same issue -.-

Upvotes: 1

nessuno
nessuno

Reputation: 27052

In tensorflow you do not have to take into account the batch size.

In the MNIST Tutorial it's explained how tensorflow handles batches of every size.

Quoting the tutorial:

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

The input images x will consist of a 2d tensor of floating point numbers. Here we assign it a shape of [None, 784], where 784 is the dimensionality of a single flattened MNIST image, and None indicates that the first dimension, corresponding to the batch size, can be of any size.

Upvotes: 3

Related Questions