Reputation: 18693
Due to observations not revealing the entire state, I need to do reinforcement with a recurrent neural network so that the network has some sort of memory of what has happened in the past. For simplicity let's assume that we use an LSTM.
Now the in-built PyTorch LSTM requires you to feed it a an input of shape Time x MiniBatch x Input D
and it outputs a tensor of shape Time x MiniBatch x Output D
.
In reinforcement learning however, to know the input at time t+1
, I need to know the output at time t
, because I am doing actions in an environment.
So is it possible to use the in-built PyTorch LSTM to do BPTT in a reinforcement learning setting? And if it is, how could I do it?
Upvotes: 3
Views: 2131
Reputation: 961
I spend a lot of time getting things running (model learning :)) and want to share my findings. Checkout my code with a working LSTM RL training https://github.com/svenkroll/simple_RL-LSTM and also there are some more high level details i try to give back the community here
Upvotes: 0
Reputation: 1434
Maybe you can feed your input sequence in a loop to your LSTM. Something, like this:
h, c = Variable(torch.zeros()), Variable(torch.zeros())
for i in range(T):
input = Variable(...)
_, (h, c) = lstm(input, (h,c))
Every timestep you can use (h,c) and input to evaluate action for instance. As long as you do not break computational graph you can backpropagate as Variables keep all the history.
Upvotes: 1