Maverick Meerkat
Maverick Meerkat

Reputation: 6424

Combining gradients from different "networks" in TensorFlow2

I'm trying to combine a few "networks" into one final loss function. I'm wondering if what I'm doing is "legal", as of now I can't seem to make this work. I'm using tensorflow probability :

The main problem is here:

# Get gradients of the loss wrt the weights.
gradients = tape.gradient(loss, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights])

# Update the weights of our linear layer.
optimizer.apply_gradients(zip(gradients, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights])

Which gives me None gradients and throws on apply gradients:

AttributeError: 'list' object has no attribute 'device'

Full code:

univariate_gmm = tfp.distributions.MixtureSameFamily(
        mixture_distribution=tfp.distributions.Categorical(probs=phis_true),
        components_distribution=tfp.distributions.Normal(loc=mus_true,scale=sigmas_true)
    )
x = univariate_gmm.sample(n_samples, seed=random_seed).numpy()
dataset = tf.data.Dataset.from_tensor_slices(x) 
dataset = dataset.shuffle(buffer_size=1024).batch(64)  

m_phis = keras.layers.Dense(2, activation=tf.nn.softmax)
m_mus = keras.layers.Dense(2)
m_sigmas = keras.layers.Dense(2, activation=tf.nn.softplus)

def neg_log_likelihood(y, phis, mus, sigmas):
    a = tfp.distributions.Normal(loc=mus[0],scale=sigmas[0]).prob(y)
    b = tfp.distributions.Normal(loc=mus[1],scale=sigmas[1]).prob(y)
    c = np.log(phis[0]*a + phis[1]*b)
    return tf.reduce_sum(-c, axis=-1)

# Instantiate a logistic loss function that expects integer targets.
loss_fn = neg_log_likelihood

# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)

# Iterate over the batches of the dataset.
for step, y in enumerate(dataset):
    
    yy = np.expand_dims(y, axis=1)

    # Open a GradientTape.
    with tf.GradientTape() as tape:
        
        # Forward pass.
        phis = m_phis(yy)
        mus = m_mus(yy)
        sigmas = m_sigmas(yy)

        # Loss value for this batch.
        loss = loss_fn(yy, phis, mus, sigmas)

    # Get gradients of the loss wrt the weights.
    gradients = tape.gradient(loss, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights])

    # Update the weights of our linear layer.
    optimizer.apply_gradients(zip(gradients, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights]))

    # Logging.
    if step % 100 == 0:
        print("Step:", step, "Loss:", float(loss))

Upvotes: 0

Views: 382

Answers (1)

André
André

Reputation: 1068

There are two separate problems to take into account.

1. Gradients are None:

Typically this happens, if non-tensorflow operations are executed in the code that is watched by the GradientTape. Concretely, this concerns the computation of np.log in your neg_log_likelihood functions. If you replace np.log with tf.math.log, the gradients should compute. It may be a good habit to try not to use numpy in your "internal" tensorflow components, since this avoids errors like this. For most numpy operations, there is a good tensorflow substitute.

2. apply_gradients for multiple trainables:

This mainly has to do with the input that apply_gradients expects. There you have two options:

First option: Call apply_gradients three times, each time with different trainables

optimizer.apply_gradients(zip(m_phis_gradients, m_phis.trainable_weights))
optimizer.apply_gradients(zip(m_mus_gradients, m_mus.trainable_weights))
optimizer.apply_gradients(zip(m_sigmas_gradients, m_sigmas.trainable_weights))

The alternative would be to create a list of tuples, like indicated in the tensorflow documentation (quote: "grads_and_vars: List of (gradient, variable) pairs."). This would mean calling something like

optimizer.apply_gradients(
   [
      zip(m_phis_gradients, m_phis.trainable_weights),
      zip(m_mus_gradients, m_mus.trainable_weights),
      zip(m_sigmas_gradients, m_sigmas.trainable_weights),
   ]
)

Both options require you to split the gradients. You can either do that by computing the gradients and indexing them separately (gradients[0],...), or you can simply compute the gradiens separately. Note that this may require persistent=True in your GradientTape.

    # [...]
    # Open a GradientTape.
    with tf.GradientTape(persistent=True) as tape:
        # Forward pass.
        phis = m_phis(yy)
        mus = m_mus(yy)
        sigmas = m_sigmas(yy)

        # Loss value for this batch.
        loss = loss_fn(yy, phis, mus, sigmas)

    # Get gradients of the loss wrt the weights.
    m_phis_gradients = tape.gradient(loss, m_phis.trainable_weights)
    m_mus_gradients = tape.gradient(loss, m_mus.trainable_weights)
    m_sigmas_gradients = tape.gradient(loss, m_sigmas .trainable_weights)

    # Update the weights of our linear layer.
    optimizer.apply_gradients(
        [
            zip(m_phis_gradients, m_phis.trainable_weights),
            zip(m_mus_gradients, m_mus.trainable_weights),
            zip(m_sigmas_gradients, m_sigmas.trainable_weights),
       ]
   )
   # [...]

Upvotes: 1

Related Questions