Reputation: 2087
I have a tensor xx
with shape:
>>> xx.shape
TensorShape([32, 32, 256])
How can I add a leading None
dimension to get:
>>> xx.shape
TensorShape([None, 32, 32, 256])
I have seen many answers here but all are related to TF 1.x
What is the straight forward way for TF 2.0?
Upvotes: 10
Views: 11349
Reputation: 11
I do not think it is possible to simply add a "None" dimension.
However, assuming you are trying to prepend a variable-size batch dimension to your tensor, you can just tf.repeat
your tensor using the tf.shape()
of another tensor.
y = tf.keras.layers.Input(shape=(32, 32, 3)) # Shape: [None, 32, 32, 3]
...
batch_size = tf.shape(y)[0] # Will be None initially, but is mutable
xx = tf.ones(shape=(32, 32, 356)) # Shape: [32, 32, 356]
xx = tf.expand_dims(xx, 0) # Shape: [1, 32, 32, 356]
xx = tf.repeat(xx, repeats=batch_size, axis=0) # Shape: [None, 32, 32, 356]
This will likely be more useful than just hard-coding the first dimension to None
since what you probably actually want to be doing is copying it along that first dimension based on batch size.
Upvotes: 1
Reputation: 1905
In TF2 you can use tf.expand_dims:
xx = tf.expand_dims(xx, 0)
xx.shape
> TensorShape([1, 32, 32, 256])
Upvotes: 2
Reputation:
You can either use "None" or numpy's "newaxis" to create the new dimension.
General Tip: You can also use None in place of np.newaxis; These are in fact the same objects.
Below is the code that explains both the options.
try:
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
print(tf.__version__)
# TensorFlow and tf.keras
from tensorflow import keras
# Helper libraries
import numpy as np
#### Import the Fashion MNIST dataset
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
#Original Dimension
print(train_images.shape)
train_images1 = train_images[None,:,:,:]
#Add Dimension using None
print(train_images1.shape)
train_images2 = train_images[np.newaxis is None,:,:,:]
#Add dimension using np.newaxis
print(train_images2.shape)
#np.newaxis and none are same
np.newaxis is None
The Output of the above code is
2.1.0
(60000, 28, 28)
(1, 60000, 28, 28)
(1, 60000, 28, 28)
True
Upvotes: 2