Reputation: 511
Is it possible to set an __index method for torch classes? I have tried to implement a simple dataset
class as outlined in the Deep Learning with Torch tutoral: (ipynb here)
trainset = {
inputs = {0, 1, 1, 0},
targets = {1, 1, 1, 0}
}
index = function(t, i)
return {t.inputs[i], t.targets[i]}
end
setmetatable(trainset, {
__index = index
)
Which allows you to do trainset[1]]
which returns {0, 1}
.
However, implementing this as torch class does not work.
local torch = require("torch")
do
Dataset = torch.class("Dataset")
function Dataset:__init(i, t)
self.inputs = i
self.targets = t
end
function Dataset.__index(t, v)
print("inside index")
return {
rawget(t, inputs)[v],
rawget(t, targets)[v]
}
end
end
Dataset({0, 1, 1, 0}, {1, 1, 1, 0}) -- fails
It seems that, upon object creation, __index()
is called and fails since index
and targets
are not yet created. If rawget
is not used, then it causes a stack overflow.
My understanding of Lua is limited, but I'm surprised to see __index()
being called during object creation: I think there's stuff going on behind the scenes I don't fully understand.
Upvotes: 0
Views: 153
Reputation: 511
Torch classes all implement __index
, which will look for __index__
in the metatable, which is for overloading.
From the docs:
If one wants to provide index or newindex in the metaclass, these operators must follow a particular scheme:
index must either return a value and true or return false only. In the first case, it means index was able to handle the given argument (for e.g., the type was correct). The second case means it was not able to do anything, so __index in the root metatable can then try to see if the metaclass contains the required value.
Which means for the example, the __index__
(not __index
!) method must check if type(v) == "number"
and if not, return false
so that __index
can look for the value in the object metatable.
local torch = require("torch")
do
Dataset = torch.class("Dataset")
function Dataset:__init(i, t)
self.inputs = i
self.targets = t
end
function Dataset.__index__(t, v)
if type(v) == "number" then
local tbl = {
t.inputs[v],
t.targets[v]
}
return tbl, true
else
return false
end
end
local dset = Dataset({0, 1, 1, 0}, {1, 1, 1, 0})
dset[1] --> {0, 1}
Upvotes: 1