Reputation: 465
It seems a bit cumbersome to take into account the batch dimension for every layer in a neural network. Why don't we have some functionality in Tensorflow that can just set the batch size for an entire model?
Upvotes: 4
Views: 3569
Reputation: 11
I think what Hanhin Li what to say is regarding custom layers where you need to do some reshaping like this one:
Suppose that the input shape to the model is (batch_size, frames) and you want to implement a preprocessing layer that removes some frames based on a condition . The layer to do this could be something like this:
def process_file(data):
new_frames = tf.TensorArray(data.dtype, size=0, dynamic_size=True)
i = 0
for frame in data:
if tf.math.count_nonzero(frame) == 0:
new_frames.write(i, frame)
i += 1
return new_frames.stack()
class Preprocessing(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(Preprocessing, self).__init__(**kwargs)
def call(self, inputs, training=True, **kwargs):
return process_file(inputs)
The problem is that once you fit the model with batches the process_file()
from Preprocessing()
will received the batch size so all the function is messed up as now for frame in data
will iterate over the batches instead of frames (because of the batch dimension added).
So I think the question is, can you add something to the layer subclass to not send the batch dimension to the preprocessing function and stack the batches afterwards?
I'm sorry I dont have the answer to that. I' m struggling with the same issue -.-
Upvotes: 1
Reputation: 27052
In tensorflow you do not have to take into account the batch size.
In the MNIST Tutorial it's explained how tensorflow handles batches of every size.
Quoting the tutorial:
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
The input images x will consist of a 2d tensor of floating point numbers. Here we assign it a shape of [None, 784]
, where 784 is the dimensionality of a single flattened MNIST image, and None
indicates that the first dimension, corresponding to the batch size, can be of any size.
Upvotes: 3