Reputation: 23
I defined a RNN "by hand", composed of multiple linear layers with pruned connections.
To keep track of the hidden states, I have a variable next_hidden_states
in which I save the hidden states at time t, to re-use them at time t+1. This variable is of size (batch_size, N)
.
During my training/evaluation, I would like to be able to evaluate the model for inputs with batch size (train the agent) or without batch size (run an episode in the environment). This is usually possible for classic pytorch modules as the batch size is implicit...
I thought about giving the next_hidden_states
as an argument and output of the network, but it is quite inelegant.
Edit
Here is a minimal version of my code
import numpy as np
import torch
import torch.nn.utils.prune as prune
import torch.nn as nn
class BrainRNN(nn.Module):
def __init__(self, activation=torch.sigmoid, batch_size=8):
super(BrainRNN, self).__init__()
self.n_neurons = 3*4
self.activation = activation
self.batch_size = batch_size
self.reset_hidden_states()
# Create the input layer
self.input_layer = nn.Linear(4, 4)
# Create forward hidden layers
self.hidden_layers = nn.ModuleList([])
new_layer = nn.Linear(4,4)
mask = np.ones((4,4))-np.eye(4)
prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
self.hidden_layers.append(new_layer)
# Create the backward weights
self.recurrent_layers = nn.ModuleList([]) # recurrent_layers[i](hidden_states) = layer j>i to i
new_layer = nn.Linear(self.n_neurons, 4, bias=False) # no bias for backward connection
mask = np.zeros((12,4))
mask[1,0] = 1
prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
self.recurrent_layers.append(new_layer)
# Create the output layer
self.output_layer = nn.Linear(4,4)
def forward(self, x):
next_hidden_states = torch.empty(x.shape[0], self.n_neurons) if x.dim() > 1 else torch.empty(self.n_neurons)
skips = [] # list of current states for skip connections
# Input layer
x = self.activation(self.input_layer(x) + self.recurrent_layers[0](self.hidden_states))
next_hidden_states[...,[0,1,2,3]] = x
# Hidden layers
x = self.hidden_layers[0](x)
x = self.activation(x)
next_hidden_states[...,[4,5,6,7]] = x
# Output layer
x = self.output_layer(x) # no activation nor recurrent/skip connection for the last one
self.hidden_states = next_hidden_states
return x
def reset_hidden_states(self, hidden_states=None):
if self.batch_size > 0:
self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1).repeat(self.batch_size,1) # same hidden states for all batches
else:
self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1)
nn = BrainRNN()
nn(torch.zeros(8,4)) # works well
nn(torch.zeros(4)) # shape issue at next_hidden_states[...,[0,1,2,3]] = x
where there are 3 layers of 4 nodes each, with a recurrent connexion between hidden layer and input layer, and some pruned connections.
The aim is to be able, if nn = BrainRNN(...)
, to evaluate nn(torch.zeros((B,4)))
as well as nn(torch.zeros(4))
.
Ideally, I would like to reproduce the behavior of classic nn.Modules, but I don't really know how to do so while saving the states...
Upvotes: 0
Views: 107
Reputation: 5095
Simple RNN below that accepts data as (sequence_length, n_features)
or (batch_size, sequence_length, n_features)
. It steps through the entire sequence and returns the outputs and hidden states for each step (it also stores them as attributes which you can access). Is this the sort of functionality you were after? No pruning, but you could add that in like in your original code.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size=4, output_size=2, activation='tanh', batch_first=True):
super().__init__()
#Onyl support batch_first=True (as per OP's test data)
assert batch_first, 'This model assumes batch_first=True for simplicity'
self.input_size = input_size
self.hidden_size = hidden_size
self.activation_fn = getattr(torch.nn.functional, activation)
self.Wxh = nn.Linear(self.input_size, self.hidden_size)
self.Whh = nn.Linear(self.hidden_size, self.hidden_size)
self.Why = nn.Linear(self.hidden_size, output_size)
def forward(self, x):
x = x.clone()
x_ndim_orig = x.ndim
#If it's 2D, assume that means (sequence_length, n_features,)
# and prepend batch
if x.ndim == 2:
print('X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)')
x = x.unsqueeze(dim=0)
elif x.ndim == 3:
print('X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)')
#Record the hidden state and y at each step for input x
hidden_states = []
outputs = []
batch_size, sequence_len, n_features = x.shape
assert self.input_size == n_features, f'Expected input features size of {self.input_size}'
#Initialise hidden_state to 0, and step through the sequence recurrently
hidden_state = torch.zeros(batch_size, self.hidden_size)
for frame_idx in range(sequence_len):
frame = x[:, frame_idx, :] #(batch, n_features) for this timestep
hidden_state = self.activation_fn(
self.Wxh(frame) + self.Whh(hidden_state)
)
output = self.activation_fn(self.Why(hidden_state))
#Record the hidden state and y for this frame
hidden_states.append(hidden_state)
outputs.append(output)
#Stack into (batch_size, sequence_length, output_size/hidden_size)
# Available as attributes
self.outputs = torch.stack(outputs, dim=1)
self.hidden_states = torch.stack(hidden_states, dim=1)
#Optionally drop the batch dim that we added
if x_ndim_orig == 2:
self.outputs, self.hidden_states = self.outputs[0], self.hidden_states[0]
return self.outputs, self.hidden_states
Test the shapes:
#Input: (sequence_length=12, n_features=4)
#Output: (sequence_length=12, hidden_size)
x = torch.rand(12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)
#Input: (batch_size=32, sequence_length=12, n_features=4)
#Output: (batch_size=32, sequence_length=12, hidden_size)
x = torch.rand(32, 12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)
X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)
torch.Size([12, 4])
X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)
torch.Size([32, 12, 4])
The RNN is untested, and is meant to illustrate how you can do the recurrence inside the class & store the hidden states.
Upvotes: 0