codedturtle
codedturtle

Reputation: 65

Preserving unknown batch dimension for custom static tensors in Tensorflow

Some notes: I'm using tensorflow 2.3.0, python 3.8.2, and numpy 1.18.5 (not sure if that one matters though)

I'm writing a custom layer that stores a non-trainable tensor N of shape (a, b) internally, where a, b are known values (this tensor is created during init). When called on an input tensor, it flattens the input tensor, flattens its stored tensor, and concatenates the two together. Unfortunately, I can't seem to figure out how to preserve the unknown batch dimension during this concatenation. Here's minimal code:

import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten

class CustomLayer(Layer):
   def __init__(self, N):                     # N is a tensor of shape (a, b), where a, b > 1
      super(CustomLayer, self).__init__()
      self.N = self.add_weight(name="N", shape=N.shape, trainable=False, initializer=lambda *args, **kwargs: N)

      # correct me if I'm wrong in using this initializer approach, but for some reason, when I
      # just do self.N = N, this variable would disappear when I saved and loaded the model

   def build(self, input_shape):
      pass                                    # my reasoning is that all the necessary stuff is handled in init

   def call(self, input_tensor):
      input_flattened = Flatten()(input_tensor)
      N_flattened = Flatten()(self.N)
      return tf.concat((input_flattened, N_flattened), axis=-1)

The first problem I noticed was that Flatten()(self.N) would return a tensor with the same shape (a, b) as the original self.N, and as a result, the returned value would have a shape of (a, num_input_tensor_values+b). My reasoning for this was that the first dimension, a, was treated as the batch size. I modified the call function:

   def call(self, input_tensor):
      input_flattened = Flatten()(input_tensor)
      N = tf.expand_dims(self.N, axis=0)       # N would now be shape (1, a, b)
      N_flattened = Flatten()(N)
      return tf.concat((input_flattened, N_flattened), axis=-1)

This would return a tensor with shape (1, num_input_vals + a*b), which is great, but now the batch dimension is permanently 1, which I realized when I started training a model with this layer and it would only work for a batch size of 1. This is also really apparent in the model summary - if I were to put this layer after an input and add some other layers afterwards, the first dimension of the output tensors goes like None, 1, 1, 1, 1.... Is there a way to store this internal tensor and use it in call while preserving the variable batch size? (For example, with a batch size of 4, a copy of the same flattened N would be concatenated onto the end of each of the 4 flattened input tensors.)

Upvotes: 5

Views: 1577

Answers (1)

Andrea Angeli
Andrea Angeli

Reputation: 745

You have to have as many flattened N vectors, as you have samples in your input, because you are concatenating to every sample. Think of it like pairing up rows and concatenating them. If you have only one N vector, then only one pair can be concatenated. To solve this, you should use tf.tile() to repeat N as many times as there are samples in your batch.

Example:

   def call(self, input_tensor):
      input_flattened = Flatten()(input_tensor) # input_flattened shape: (None, ..)
      N = tf.expand_dims(self.N, axis=0)        # N shape: (1, a, b)
      N_flattened = Flatten()(N)                # N_flattened shape: (1, a*b)
      N_tiled = tf.tile(N_flattened, [tf.shape(input_tensor)[0], 1]) # repeat along the first dim as many times, as there are samples and leave the second dim alone
      return tf.concat((input_flattened, N_tiled), axis=-1)

Upvotes: 4

Related Questions