Reputation: 3
In the following code what does torch.cat
really do. I know it concatenates the batch which is contained in the sample but why do we have to do that and what does concatenate really mean.
# memory is just a list of events
def sample(self, batch_size):
samples = zip(*random.sample(self.memory, batch_size))
return map(lambda x: Variable(torch.cat(x,0)))
Upvotes: 0
Views: 392
Reputation: 24814
torch.cat
concatenates as the name suggests along specified dimension.
Example from documentation will tell you everything you need to know:
x = torch.randn(2, 3) # shape (2, 3)
catted = torch.cat((x, x, x), dim=0) # shape (6, 3), e.g. 3 x stacked on each other
Remember concatenated tensors need to have the same dimension except the one along which you are concatenating.
In the above example it doesn't do anything though and isn't even viable as it lacks second argument (inputs to apply map
to), see here.
Assuming you would do this mapping instead:
map(lambda x: Variable(torch.cat(x,0)), samples)
It would create a new tensor of shape [len(samples), x_dim_1, x_dim_2, ...]
provided all samples
have the same dimensionality except 0
.
Still it is pretty convoluted example and definitely shouldn't be done like that (torch.autograd.Variable
is deprecated, see here), this should be enough:
# assuming random.sample returns either `list` or `tuple`
def sample(self, batch_size):
return torch.cat(random.sample(self.memory, batch_size), dim=0)
Upvotes: 1