elyase
elyase

Reputation: 40973

groupby aggregate mean in pytorch

I have a 2D tensor:

samples = torch.Tensor([
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
])

and a label for each sample corresponding to a class:

labels = torch.LongTensor([1, 2, 2, 0])

so len(samples) == len(labels). Now I want to calculate the mean for each class / label. Because there are 3 classes (0, 1 and 2) the final vector should have dimension [n_classes, samples.shape[1]] So the expected solution should be:

result == torch.Tensor([
    [0.1, 0.1],
    [0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
    [0.0, 0.0]
])

Question: How can this be done in pure pytorch (i.e. no numpy so that I can autograd) and ideally without for loops?

Upvotes: 10

Views: 6350

Answers (4)

Christian
Christian

Reputation: 305

For 3D Tensors:

For those, who are interested. I expanded @yhenon's answer to the case, where labels is a 2D tensor and samples is a 3D Tensor. This might be useful, if you want to execute this operation in batches (as I do). But it comes with a caveat (see at the end).

M = torch.zeros(labels.shape[0], labels.max()+1, labels.shape[1])
M[torch.arange(len(labels))[:,None], labels, torch.arange(labels.size(1))] = 1
M = torch.nn.functional.normalize(M, p=1, dim=-1)
result = M@samples
samples = torch.Tensor([[
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
], [
    [0.5, 0.5],    #-> group / class 0
    [0.2, 0.2],    #-> group / class 1
    [0.4, 0.4],    #-> group / class 2
    [0.1, 0.1]     #-> group / class 3
]])

labels = torch.LongTensor([[1, 2, 2, 0], [0, 1, 2, 3]])

Output:

>>> result
tensor([[[0.0000, 0.0000],
         [0.1000, 0.1000],
         [0.3000, 0.3000],
         [0.0000, 0.0000]],

        [[0.5000, 0.5000],
         [0.2000, 0.2000],
         [0.4000, 0.4000],
         [0.1000, 0.1000]]])

Be careful: Now, result[0] has a length of 4 (instead of 3 in @yhenon's answer), because labels[1] contains a 3. The last row contains only 0s. If you don't except 0s in the last rows of your resulting tensor, you can use this code and deal with the 0s later.

Upvotes: 0

Minseok
Minseok

Reputation: 41

As previous solutions do not work for the case of sparse groups (e.g., not all the groups are in the data), I made one :)

def groupby_mean(value:torch.Tensor, labels:torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
    """Group-wise average for (sparse) grouped tensors

    Args:
        value (torch.Tensor): values to average (# samples, latent dimension)
        labels (torch.LongTensor): labels for embedding parameters (# samples,)

    Returns: 
        result (torch.Tensor): (# unique labels, latent dimension)
        new_labels (torch.LongTensor): (# unique labels,)

    Examples:
        >>> samples = torch.Tensor([
                             [0.15, 0.15, 0.15],    #-> group / class 1
                             [0.2, 0.2, 0.2],    #-> group / class 3
                             [0.4, 0.4, 0.4],    #-> group / class 3
                             [0.0, 0.0, 0.0]     #-> group / class 0
                      ])
        >>> labels = torch.LongTensor([1, 5, 5, 0])
        >>> result, new_labels = groupby_mean(samples, labels)

        >>> result
        tensor([[0.0000, 0.0000, 0.0000],
            [0.1500, 0.1500, 0.1500],
            [0.3000, 0.3000, 0.3000]])

        >>> new_labels
        tensor([0, 1, 5])
    """
    uniques = labels.unique().tolist()
    labels = labels.tolist()

    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
    val_key = {val: key for key, val in zip(uniques, range(len(uniques)))}

    labels = torch.LongTensor(list(map(key_val.get, labels)))

    labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))

    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, value)
    result = result / labels_count.float().unsqueeze(1)
    new_labels = torch.LongTensor(list(map(val_key.get, unique_labels[:, 0].tolist())))
    return result, new_labels

Upvotes: 2

yhenon
yhenon

Reputation: 4291

All you need to do is form an mxn matrix (m=num classes, n=num samples) which will select the appropriate weights, and scale the mean appropriately. Then you can perform a matrix multiplication between your newly formed matrix and the samples matrix.

Given your labels, your matrix should be (each row is a class number, each class a sample number and its weight):

[[0.0000, 0.0000, 0.0000, 1.0000],
 [1.0000, 0.0000, 0.0000, 0.0000],
 [0.0000, 0.5000, 0.5000, 0.0000]]

Which you can form as follows:

M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(len(samples)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)

Output:

tensor([[0.0000, 0.0000],
        [0.1000, 0.1000],
        [0.3000, 0.3000]])

Note that the output means are correctly sorted in class order.

Why does M[labels, torch.arange(len(samples))] = 1 work?

This is performing a broadcast operation between the labels and the number of samples. Essentially, we are generating a 2D index for every element in labels: the first specifies which of the m classes it belongs to, and the second simply specifies its index position (from 1 to N). Another way would be top explicitly generate all the 2D indices:

twoD_indices = []
for count, label in enumerate(labels):
  twoD_indices.append((label, count))

Upvotes: 9

elyase
elyase

Reputation: 40973

Reposting here an answer from @ptrblck_de in the Pytorch forums

labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)

res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
res = res / labels_count.float().unsqueeze(1)

Upvotes: 4

Related Questions