Sphinx
Sphinx

Reputation: 424

tensorflow autodiff slower than pytorch's counterpart

I am using tensorflow 2.0 and trying to evaluate gradients for backpropagating to a simple feedforward neural network. Here's how my model looks like:

def __init__(self, input_size, output_size):
        inputs = tf.keras.Input(shape=(input_size,))
        hidden_layer1 = tf.keras.layers.Dense(30, activation='relu')(inputs)
        outputs = tf.keras.layers.Dense(output_size)(hidden_layer1)
        self.model = tf.keras.Model(inputs=inputs, outputs=outputs)

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.loss_function = tf.keras.losses.Huber()

The forward pass to this network is fine but when I use gradient tape to train the model, it is at least 10x slower than PyTorch. Training function:

def learn_modified_x(self, inputs, targets, actions):
        with tf.GradientTape() as tape:
            predictions = self.model(inputs)
            predictions_for_action = gather_single_along_axis(predictions, actions)
            loss = self.loss_function(targets, predictions_for_action)

        grads = tape.gradient(loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

I tried commenting lines to find what is actually causing the problem. I discovered that tape.gradient is a significant contributor to this situation.

Any idea?

PyTorch implementation

    def __init__(self, input_size, nb_action):
        super(Network, self).__init__()
        self.input_size = input_size
        self.nb_action = nb_action
        self.fc1 = nn.Linear(input_size, 30)
        self.fc2 = nn.Linear(30, nb_action)
    
    def forward(self, state):
        x = F.relu(self.fc1(state))
        q_values = self.fc2(x)
        return q_values

    def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
        outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
        next_outputs = self.model(batch_next_state).detach().max(1)[0]
        target = self.gamma*next_outputs + batch_reward
        td_loss = F.smooth_l1_loss(outputs, target)
        self.optimizer.zero_grad()
        td_loss.backward(retain_variables = True)
        self.optimizer.step()

Upvotes: 0

Views: 111

Answers (1)

王博洋
王博洋

Reputation: 11

def __init__(self,...):
  ...
  self.model.call = tf.function(self.model.call)
  ...

you need use tf.function to wrap your model's call function.

Upvotes: 1

Related Questions