Arun
Arun

Reputation: 2478

TensorFlow 2.0 'build' function

I was reading about creating neural networks using TensorFlow 2.0 in conjunction with 'GradientTape' API and came across the following code:

model = tf.keras.Sequential((
tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(10)))

model.build()
optimizer = tf.keras.optimizers.Adam()

In this code, what's the use/function of 'model.build()'? Is it compiling the designed neural network?

The rest of the code is:

compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


def train_one_step(model, optimizer, x, y):
  with tf.GradientTape() as tape:
    logits = model(x)
    loss = compute_loss(y, logits)

  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  compute_accuracy(y, logits)
  return loss


@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if step % 10 == 0:
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())

Upvotes: 2

Views: 3079

Answers (1)

Djib2011
Djib2011

Reputation: 7432

They refer to this as the "delayed-build pattern", where you can actually create a model without defining what its input shape is.

For example

model = Sequential()
model.add(Dense(32))
model.add(Dense(32))
model.build((None, 500))

is equivalent to

model = Sequential()
model.add(Dense(32, input_shape=(500,)))
model.add(Dense(32)) 

In the second case you need to know the input shape before defining the model's architecture. model.build() allows you to actually define a model (i.e. its architecture) and actually build it (i.e. initialize parameters, etc.) later.

Example taken from here.

Upvotes: 5

Related Questions