Reputation: 27594
I am working to understand Erik Linder-Norén's implementation of the Categorical GAN model, and am confused by the generator in that model:
def build_generator(self):
model = Sequential()
# ...some lines removed...
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
My question is: How does the Embedding()
layer work here?
I know that noise
is a vector that has length 100, and label
is an integer, but I don't understand what the label_embedding
object contains or how it functions here.
I tried printing the shape of label_embedding
to try and figure out what's going on in that Embedding()
line but that returns (?,?)
.
If anyone could help me understand how the Embedding()
lines here work, I'd be very grateful for their assistance!
Upvotes: 0
Views: 2834
Reputation: 6044
From the documentation, https://keras.io/layers/embeddings/#embedding,
Turns positive integers (indexes) into dense vectors of fixed size. eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
In the GAN model, the input integer(0-9) is converted to a vector of shape 100. With this short code snippet, we can feed some test input to check the output shape of the Embedding layer.
from keras.layers import Input, Embedding
from keras.models import Model
import numpy as np
latent_dim = 100
num_classes = 10
label = Input(shape=(1,), dtype='int32')
label_embedding = Embedding(num_classes, latent_dim)(label)
mod = Model(label, label_embedding)
test_input = np.zeros((1))
print(f'output shape is {mod.predict(test_input).shape}')
mod.summary()
output shape is (1, 1, 100)
From model summary, output shape for embedding layer is (1,100) which is the same as output of predict.
embedding_1 (Embedding) (None, 1, 100) 1000
One additional point, in the output shape (1,1,100), the leftmost 1 is the batch size, the middle 1 is the input length. In this case, we provided an input of length 1.
Upvotes: 2
Reputation: 399
To keep in mind why use embedding here at all: the alternative is to concatenate the noise with the conditioned class, which may cause the generator to completely ignore the noise values, generating data with high similarity in each class (or even just 1 per class).
Upvotes: 3
Reputation: 2682
The embedding stores the per label state. If I read the code correctly, each label corresponds to a digit; i.e. there is an embedding that captures how to generate a 0, 1, ... 9.
This code takes some random noise and multiplies it to this per label state. The result should be a vector that leads the generator to display the digit corresponding to the label (i.e. 0..9).
Upvotes: 1