samje
samje

Reputation: 23

How to keep track of hidden states for different input shapes

I defined a RNN "by hand", composed of multiple linear layers with pruned connections.

To keep track of the hidden states, I have a variable next_hidden_states in which I save the hidden states at time t, to re-use them at time t+1. This variable is of size (batch_size, N).

During my training/evaluation, I would like to be able to evaluate the model for inputs with batch size (train the agent) or without batch size (run an episode in the environment). This is usually possible for classic pytorch modules as the batch size is implicit...

I thought about giving the next_hidden_states as an argument and output of the network, but it is quite inelegant.

Edit

Here is a minimal version of my code

import numpy as np

import torch
import torch.nn.utils.prune as prune
import torch.nn as nn


class BrainRNN(nn.Module):
    def __init__(self, activation=torch.sigmoid, batch_size=8):
        super(BrainRNN, self).__init__()
        self.n_neurons = 3*4
        self.activation = activation
        self.batch_size = batch_size
        self.reset_hidden_states()

        # Create the input layer
        self.input_layer = nn.Linear(4, 4)

        # Create forward hidden layers
        self.hidden_layers = nn.ModuleList([])
        new_layer = nn.Linear(4,4)
        mask = np.ones((4,4))-np.eye(4)
        prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
        self.hidden_layers.append(new_layer)

        # Create the backward weights
        self.recurrent_layers = nn.ModuleList([]) # recurrent_layers[i](hidden_states) = layer j>i to i

        new_layer = nn.Linear(self.n_neurons, 4, bias=False) # no bias for backward connection
        mask = np.zeros((12,4))
        mask[1,0] = 1
        prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
        self.recurrent_layers.append(new_layer)

        # Create the output layer
        self.output_layer = nn.Linear(4,4)

    def forward(self, x):
        next_hidden_states = torch.empty(x.shape[0], self.n_neurons) if x.dim() > 1 else torch.empty(self.n_neurons)
        skips = [] # list of current states for skip connections

        # Input layer
        x = self.activation(self.input_layer(x) + self.recurrent_layers[0](self.hidden_states))
        next_hidden_states[...,[0,1,2,3]] = x

        # Hidden layers
        x = self.hidden_layers[0](x)
        x = self.activation(x)
        next_hidden_states[...,[4,5,6,7]] = x

        # Output layer
        x = self.output_layer(x) # no activation nor recurrent/skip connection for the last one
        
        self.hidden_states = next_hidden_states

        return x

    def reset_hidden_states(self, hidden_states=None):
        if self.batch_size > 0:
            self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1).repeat(self.batch_size,1) # same hidden states for all batches
        else:
            self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1)

nn = BrainRNN()
nn(torch.zeros(8,4)) # works well
nn(torch.zeros(4)) # shape issue at next_hidden_states[...,[0,1,2,3]] = x

where there are 3 layers of 4 nodes each, with a recurrent connexion between hidden layer and input layer, and some pruned connections.

The aim is to be able, if nn = BrainRNN(...), to evaluate nn(torch.zeros((B,4))) as well as nn(torch.zeros(4)). Ideally, I would like to reproduce the behavior of classic nn.Modules, but I don't really know how to do so while saving the states...

Upvotes: 0

Views: 107

Answers (1)

MuhammedYunus
MuhammedYunus

Reputation: 5095

Simple RNN below that accepts data as (sequence_length, n_features) or (batch_size, sequence_length, n_features). It steps through the entire sequence and returns the outputs and hidden states for each step (it also stores them as attributes which you can access). Is this the sort of functionality you were after? No pruning, but you could add that in like in your original code.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size=4, output_size=2, activation='tanh', batch_first=True):
        super().__init__()
        
        #Onyl support batch_first=True (as per OP's test data)
        assert batch_first, 'This model assumes batch_first=True for simplicity'

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation_fn = getattr(torch.nn.functional, activation)
        
        self.Wxh = nn.Linear(self.input_size, self.hidden_size)
        self.Whh = nn.Linear(self.hidden_size, self.hidden_size)
        self.Why = nn.Linear(self.hidden_size, output_size)
        
    def forward(self, x):
        x = x.clone()
        
        x_ndim_orig = x.ndim
        
        #If it's 2D, assume that means (sequence_length, n_features,)
        # and prepend batch
        if x.ndim == 2:
            print('X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)')
            x = x.unsqueeze(dim=0)
        elif x.ndim == 3:
            print('X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)')
        
        #Record the hidden state and y at each step for input x
        hidden_states = []
        outputs = []
        
        batch_size, sequence_len, n_features = x.shape
        assert self.input_size == n_features, f'Expected input features size of {self.input_size}'
        
        #Initialise hidden_state to 0, and step through the sequence recurrently
        hidden_state = torch.zeros(batch_size, self.hidden_size)
        for frame_idx in range(sequence_len):
            frame = x[:, frame_idx, :] #(batch, n_features) for this timestep
            
            hidden_state = self.activation_fn(
                self.Wxh(frame) + self.Whh(hidden_state)
            )
            output = self.activation_fn(self.Why(hidden_state))
            
            #Record the hidden state and y for this frame
            hidden_states.append(hidden_state)
            outputs.append(output)
        
        #Stack into (batch_size, sequence_length, output_size/hidden_size)
        # Available as attributes
        self.outputs = torch.stack(outputs, dim=1)
        self.hidden_states = torch.stack(hidden_states, dim=1)
        
        #Optionally drop the batch dim that we added
        if x_ndim_orig == 2:
            self.outputs, self.hidden_states = self.outputs[0], self.hidden_states[0]
        
        return self.outputs, self.hidden_states

Test the shapes:

#Input:  (sequence_length=12, n_features=4)
#Output: (sequence_length=12, hidden_size)
x = torch.rand(12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)

#Input: (batch_size=32, sequence_length=12, n_features=4)
#Output: (batch_size=32, sequence_length=12, hidden_size)
x = torch.rand(32, 12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)
X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)
torch.Size([12, 4])

X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)
torch.Size([32, 12, 4])

The RNN is untested, and is meant to illustrate how you can do the recurrence inside the class & store the hidden states.

Upvotes: 0

Related Questions