Edvard-D
Edvard-D

Reputation: 384

How to use distributed training with a custom loss using Tensorflow?

I have a transformer model I'd like to train distributed across several workers on the Google Cloud AI Platform using Actor-Critic RL for training. I have my data broken up into individual files by date and uploaded to Cloud Storage. Since I'm using Actor-Critic RL, I have a custom loss function that calculates and applies the gradient. All the examples I've come across for distributed training make use of model.fit, which I'm not going to be able to do. I haven't been able to find any information on using a custom loss instead.

Along with distributing it across several machines, I'd like to know how to properly distribute training across several CPU cores as well. From my understanding model.fit takes care of this stuff normally.

Here's the custom loss function; right now it's the equivalent of a batch size of 1 I believe:

def learn(self, state_value_starting: tf.Tensor, probabilities: tf.Tensor, state_new: tf.Tensor,
            reward: tf.Tensor, is_done: tf.Tensor):
    with tf.GradientTape() as tape:
        state_value_starting = tf.squeeze(state_value_starting)
        state_value_new, _ = self.call(state_new)
        state_value_new = tf.squeeze(state_value_new)

        action_probabilities = tfp.distributions.Categorical(probs=probabilities)
        log_probability = action_probabilities.log_prob(self._last_action)

        delta = reward + (self._discount_factor * state_value_new * (1 - int(is_done))) - state_value_starting
        actor_loss = -log_probability * delta
        critic_loss = delta ** 2
        total_loss = actor_loss + critic_loss

    gradient = tape.gradient(total_loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradient, self.trainable_variables))

Upvotes: 0

Views: 609

Answers (1)

Grandesty
Grandesty

Reputation: 86

Tensorflow Model is provided with a practiced solution, defined in model_lib_v2.py.

See the function train_loop, the custom training loop is constructed makes use of

strategy = tf.compat.v2.distribute.get_strategy() #L501
with strategy.scope():
    training step ...

And custom loss in function eager_train_step.

Upvotes: 1

Related Questions