Reputation: 424
I am using tensorflow 2.0 and trying to evaluate gradients for backpropagating to a simple feedforward neural network. Here's how my model looks like:
def __init__(self, input_size, output_size):
inputs = tf.keras.Input(shape=(input_size,))
hidden_layer1 = tf.keras.layers.Dense(30, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(output_size)(hidden_layer1)
self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
self.loss_function = tf.keras.losses.Huber()
The forward pass to this network is fine but when I use gradient tape to train the model, it is at least 10x slower than PyTorch. Training function:
def learn_modified_x(self, inputs, targets, actions):
with tf.GradientTape() as tape:
predictions = self.model(inputs)
predictions_for_action = gather_single_along_axis(predictions, actions)
loss = self.loss_function(targets, predictions_for_action)
grads = tape.gradient(loss, self.model.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
I tried commenting lines to find what is actually causing the problem. I discovered that tape.gradient is a significant contributor to this situation.
Any idea?
PyTorch implementation
def __init__(self, input_size, nb_action):
super(Network, self).__init__()
self.input_size = input_size
self.nb_action = nb_action
self.fc1 = nn.Linear(input_size, 30)
self.fc2 = nn.Linear(30, nb_action)
def forward(self, state):
x = F.relu(self.fc1(state))
q_values = self.fc2(x)
return q_values
def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]
target = self.gamma*next_outputs + batch_reward
td_loss = F.smooth_l1_loss(outputs, target)
self.optimizer.zero_grad()
td_loss.backward(retain_variables = True)
self.optimizer.step()
Upvotes: 0
Views: 111
Reputation: 11
def __init__(self,...):
...
self.model.call = tf.function(self.model.call)
...
you need use tf.function
to wrap your model's call function.
Upvotes: 1