stu176
stu176

Reputation: 23

In Torch/Lua can I split/concat tensors as they flow through a network?

I'm a novice with Lua/Torch. I have an existing model that includes a max pooling layer. I want to take the input into that layer and split it into chunks, feeding each chunk into a new max pooling layer.

I have written a stand-alone Lua script that can split a tensor into two chunks and forward the two chunks into a network with two max-pooling layers.

But trying to integrate that back into the existing model I can't figure out how to amend the data "mid-flow", as it were, to do the tensor split. I've read the docs and can't see any function or example of architecture that somewhere along the line splits a tensor into two and forwards each part separately.

Any ideas? Thanks!

Upvotes: 0

Views: 286

Answers (1)

Chen
Chen

Reputation: 64

you want define a layer yourself. The layer will be like below, if your layer input is one dimension:

CSplit, parent = torch.class('nn.CSplit', 'nn.Module')

function CSplit:__init(firstCount)
    self.firstCount = firstCount
    parent.__init(self)
end

function CSplit:updateOutput(input)
    local inputSize = input:size()[1]
    local firstCount = self.firstCount
    local secondCount = inputSize - firstCount
    local first = torch.Tensor(self.firstCount)
    local second = torch.Tensor(secondCount)
    for i=1, inputSize do
        if i <= firstCount then
            first[i] = input[i]
        else
            second[i - firstCount] = input[i]
        end
    end
    self.output = {first, second}
    return self.output
end

function CSplit:updateGradInput(input, gradOutput)    
    local inputSize = input:size()[1]
    self.gradInput = torch.Tensor(input)
    for i=1, inputSize do
        if i <= self.firstCount then
            self.gradInput[i] = gradOutput[1][i]
        else
            self.gradInput[i] = gradOutput[2][i-self.firstCount]
        end
    end
    return self.gradInput
end

How to use it? you need to specify the first chunk size like the code below.

testNet = nn.CSplit(4)
input = torch.randn(10)
output = testNet:forward(input)
print(input)
print(output[1])
print(output[2])
testNet:backward(input, {torch.randn(4), torch.randn(6)})

you can see runnable iTorch notebook code here

Upvotes: 0

Related Questions