Tarquinnn
Tarquinnn

Reputation: 511

Setting __index for Torch Classes

Is it possible to set an __index method for torch classes? I have tried to implement a simple dataset class as outlined in the Deep Learning with Torch tutoral: (ipynb here)

trainset = {
    inputs = {0, 1, 1, 0},
    targets = {1, 1, 1, 0}
}

index = function(t, i)
    return {t.inputs[i], t.targets[i]}
end

setmetatable(trainset, {
    __index = index
)

Which allows you to do trainset[1]] which returns {0, 1}.

However, implementing this as torch class does not work.

local torch = require("torch")

do 
    Dataset = torch.class("Dataset")

    function Dataset:__init(i, t)
        self.inputs = i
        self.targets = t
    end

    function Dataset.__index(t, v)
        print("inside index")
        return {
            rawget(t, inputs)[v],
            rawget(t, targets)[v]
        }
    end
end

Dataset({0, 1, 1, 0}, {1, 1, 1, 0}) -- fails

It seems that, upon object creation, __index() is called and fails since index and targets are not yet created. If rawget is not used, then it causes a stack overflow.

My understanding of Lua is limited, but I'm surprised to see __index() being called during object creation: I think there's stuff going on behind the scenes I don't fully understand.

Upvotes: 0

Views: 153

Answers (1)

Tarquinnn
Tarquinnn

Reputation: 511

Torch classes all implement __index, which will look for __index__ in the metatable, which is for overloading.

From the docs:

If one wants to provide index or newindex in the metaclass, these operators must follow a particular scheme:

index must either return a value and true or return false only. In the first case, it means index was able to handle the given argument (for e.g., the type was correct). The second case means it was not able to do anything, so __index in the root metatable can then try to see if the metaclass contains the required value.

Which means for the example, the __index__ (not __index !) method must check if type(v) == "number" and if not, return false so that __index can look for the value in the object metatable.

local torch = require("torch")

do 
    Dataset = torch.class("Dataset")

    function Dataset:__init(i, t)
        self.inputs = i
        self.targets = t
    end

function Dataset.__index__(t, v)
    if type(v) == "number" then
        local tbl =  {
            t.inputs[v],
            t.targets[v]
        }
        return tbl, true
    else
        return false
    end
end

local dset = Dataset({0, 1, 1, 0}, {1, 1, 1, 0})
dset[1] --> {0, 1}

Upvotes: 1

Related Questions