petrux
petrux

Reputation: 1793

Scanning over different dimensions of tensors in theano

I'm moving my first steps with theano and I cannot figure out how to solve this problem which could be actually very easy.

I have a 3 * 4 * 2 tensor, like the following:

[1 1] | [2 2] | [3 3]
[1 1] | [2 2] | [3 3]
[0 0] | [2 2] | [3 3]
[9 9] | [0 0] | [3 3]

So I have N=3 sequences, each of them of length L=4 with their elements that are vectors of dimension d=2. Actually, the sequences can be of different length but I could think of padding them with [0 0] vectors, as shown above.

What I want to do is, first scan through the first axis of the tensor and sum up all the vector in the lists up to the the first [0 0] vector -- that's why I added the [9 9] at the end of the first tensor slice, in order to check the sum exit condition [1]. I should end up in [[2 2], [6 6], [12 12]]. I tried in many ways to solve this problem which seems to me just a nested looping problem... but always got some weird errors[2].

Thanks,
Giulio

--
[1]: the actual problem is the training of a recurrent neural network for NLP purposes, with N the dimension of the batch, L the max length of a sentence in the batch and d the dimension of the representation of each word. I omitted the problem so that I could focus on the simplest coding aspect.
[2] I omit the history of my failures, maybe I could add them later.

Upvotes: 3

Views: 664

Answers (1)

Daniel Renshaw
Daniel Renshaw

Reputation: 34187

If your sequences are always zero padded then you can just sum along the axis of interest since the padding regions will not change the sum. However, if the padding regions may contain non-zero values there are two approaches.

  1. Use scan. This is slow and should be avoided if possible. In fact it can be avoided because,
  2. Create a binary mask and multiply out the padding region.

Here's some code that illustrates these three approaches. For the two approaches that allow for non-zero padding regions (v2 and v3) the computation needs an additional input: a vector giving the lengths of the sequences within the batch.

import numpy
import theano
import theano.tensor as tt


def v1():
    # NOTE: [9, 9] element changed to [0, 0] 
    # since zero padding must be used for
    # this method
    x_data = [[[1, 1], [1, 1], [0, 0], [0, 0]],
              [[2, 2], [2, 2], [2, 2], [0, 0]],
              [[3, 3], [3, 3], [3, 3], [3, 3]]]
    x = tt.tensor3()
    x.tag.test_value = x_data
    y = x.sum(axis=1)
    f = theano.function([x], outputs=y)
    print f(x_data)


def v2_step(i_t, s_tm1, x, l):
    in_sequence = tt.lt(i_t, l).dimshuffle(0, 'x')
    s_t = s_tm1 + tt.switch(in_sequence, x[i_t], 0)
    return s_t


def v2():
    x_data = [[[1, 1], [1, 1], [0, 0], [9, 9]],
              [[2, 2], [2, 2], [2, 2], [0, 0]],
              [[3, 3], [3, 3], [3, 3], [3, 3]]]
    l_data = [2, 3, 4]
    x = tt.tensor3()
    x.tag.test_value = x_data
    l = tt.lvector()
    l.tag.test_value = l_data
    # Must dimshuffle first because scan can only iterate over first (0'th) axis.
    x_hat = x.dimshuffle(1, 0, 2)
    y, _ = theano.scan(v2_step, sequences=[tt.arange(x_hat.shape[0])],
                       outputs_info=[tt.zeros_like(x_hat[0])],
                       non_sequences=[x_hat, l], strict=True)
    f = theano.function([x, l], outputs=y[-1])
    print f(x_data, l_data)


def v3():
    x_data = [[[1, 1], [1, 1], [0, 0], [9, 9]],
              [[2, 2], [2, 2], [2, 2], [0, 0]],
              [[3, 3], [3, 3], [3, 3], [3, 3]]]
    l_data = [2, 3, 4]
    x = tt.tensor3()
    x.tag.test_value = x_data
    l = tt.lvector()
    l.tag.test_value = l_data
    indexes = tt.arange(x.shape[1]).dimshuffle('x', 0)
    mask = tt.lt(indexes, l.dimshuffle(0, 'x')).dimshuffle(0, 1, 'x')
    y = (mask * x).sum(axis=1)
    f = theano.function([x, l], outputs=y)
    print f(x_data, l_data)


def main():
    theano.config.compute_test_value = 'raise'
    v1()
    v2()
    v3()


main()

In general, if your step function is dependent on the output of a previous step then you need to use scan.

If every step/iteration could, in principle, be executed concurrently (i.e. they don't rely on each other at all) then there is often a much more efficient way to do this without using scan

Upvotes: 4

Related Questions