Reputation: 28554
Assume that we have two equally sized tensors of size batch_size * 1
. For each index in the batch dimension we want to choose randomly between the two tensors. My solution was to create an indices
tensor that contains random 0
or 1
indices of size batch_size
and use those to index_select
from the concatenation of the two tensors. However, to do so I had the "view" that cat
tensor and the solution ended up to be quite "ugly":
import torch
bs = 8
a = torch.zeros(bs, 1)
print("a size", a.size())
b = torch.ones(bs, 1)
c = torch.cat([a, b], dim=-1)
print(c)
print("c size", c.size())
# create bs number of random 0 and 1's
indices = torch.randint(0, 2, [bs])
print("idxs size", indices.size())
print("idxs", indices)
# use `indices` to slice the `cat`ted tensor
d = c.view(1, -1).index_select(-1, indices).view(-1, 1)
print("d size", d.size())
print(d)
I am wondering whether there is a prettier and, more importantly, more efficient solution.
Upvotes: 0
Views: 750
Reputation: 28554
Posting two answers that I got over at the PyTorch forums
import torch
bs = 8
a = torch.zeros(bs, 1)
b = torch.ones(bs, 1)
c = torch.cat([a, b], dim=-1)
choices_flat = c.view(-1)
# index = torch.randint(choices_flat.numel(), (bs,))
# or if replace = False
index = torch.randperm(choices_flat.numel())[:bs]
select = choices_flat[index]
print(select)
import torch
bs = 8
a = torch.zeros(bs, 1)
print("a size", a.size())
b = torch.ones(bs, 1)
idx = torch.randint(2 * bs, (bs,))
d = torch.cat([a, b])[idx] # [bs, 1]
Upvotes: 1