Reputation: 65
Some notes: I'm using tensorflow 2.3.0, python 3.8.2, and numpy 1.18.5 (not sure if that one matters though)
I'm writing a custom layer that stores a non-trainable tensor N of shape (a, b) internally, where a, b are known values (this tensor is created during init). When called on an input tensor, it flattens the input tensor, flattens its stored tensor, and concatenates the two together. Unfortunately, I can't seem to figure out how to preserve the unknown batch dimension during this concatenation. Here's minimal code:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten
class CustomLayer(Layer):
def __init__(self, N): # N is a tensor of shape (a, b), where a, b > 1
super(CustomLayer, self).__init__()
self.N = self.add_weight(name="N", shape=N.shape, trainable=False, initializer=lambda *args, **kwargs: N)
# correct me if I'm wrong in using this initializer approach, but for some reason, when I
# just do self.N = N, this variable would disappear when I saved and loaded the model
def build(self, input_shape):
pass # my reasoning is that all the necessary stuff is handled in init
def call(self, input_tensor):
input_flattened = Flatten()(input_tensor)
N_flattened = Flatten()(self.N)
return tf.concat((input_flattened, N_flattened), axis=-1)
The first problem I noticed was that Flatten()(self.N)
would return a tensor with the same shape (a, b) as the original self.N
, and as a result, the returned value would have a shape of (a, num_input_tensor_values+b). My reasoning for this was that the first dimension, a, was treated as the batch size. I modified the call
function:
def call(self, input_tensor):
input_flattened = Flatten()(input_tensor)
N = tf.expand_dims(self.N, axis=0) # N would now be shape (1, a, b)
N_flattened = Flatten()(N)
return tf.concat((input_flattened, N_flattened), axis=-1)
This would return a tensor with shape (1, num_input_vals + a*b), which is great, but now the batch dimension is permanently 1, which I realized when I started training a model with this layer and it would only work for a batch size of 1. This is also really apparent in the model summary - if I were to put this layer after an input and add some other layers afterwards, the first dimension of the output tensors goes like None, 1, 1, 1, 1...
. Is there a way to store this internal tensor and use it in call
while preserving the variable batch size? (For example, with a batch size of 4, a copy of the same flattened N would be concatenated onto the end of each of the 4 flattened input tensors.)
Upvotes: 5
Views: 1577
Reputation: 745
You have to have as many flattened N
vectors, as you have samples in your input, because you are concatenating to every sample. Think of it like pairing up rows and concatenating them. If you have only one N
vector, then only one pair can be concatenated.
To solve this, you should use tf.tile()
to repeat N
as many times as there are samples in your batch.
Example:
def call(self, input_tensor):
input_flattened = Flatten()(input_tensor) # input_flattened shape: (None, ..)
N = tf.expand_dims(self.N, axis=0) # N shape: (1, a, b)
N_flattened = Flatten()(N) # N_flattened shape: (1, a*b)
N_tiled = tf.tile(N_flattened, [tf.shape(input_tensor)[0], 1]) # repeat along the first dim as many times, as there are samples and leave the second dim alone
return tf.concat((input_flattened, N_tiled), axis=-1)
Upvotes: 4