beisner
beisner

Reputation: 93

How to vectorize indexing and computation when indexed tensors are different dimensions?

I'm trying to vectorize the following for-loop in Pytorch. I'd be happy with just vectorizing the inner for-loop, but doing the whole batch would also be awesome.

# B: the batch size
# N: the number of training examples 
# dim: the dimension of each feature vector
# K: the number of discrete labels. each vector has a single label
# delta: margin for hinge loss

batch_data = torch.tensor(...)  # Tensor of shape [B x N x d]
batch_labels = torch.tensor(...)  # Tensor of shape [B x N x 1], each element is one of K labels (ints)

batch_losses = []  # Ultimately should be [B x 1]
batch_centroids = []  # Ultimately should be [B x K_i x dim]
for i in range(B):
    centroids = []  # Keep track of the means for each class. 
    classes = torch.unique(labels)  # Get the unique labels for the classes.

    # NOTE: The number of classes K for each item in the batch might actually
    # be different. This may complicate batch-level operations.

    total_loss = 0

    # For each class independently. This is the part I want to vectorize.
    for cl in classes:
        # Take the subset of training examples with that label.
        subset = data[torch.where(labels == cl)]

        # Find the centroid of that subset.
        centroid = subset.mean(dim=0)
        centroids.append(centroid)
  
        # Get the distance between each point in the subset and the centroid.
        dists = subset - centroid
        norm = torch.linalg.norm(dists, dim=1)

        # The loss is the mean of the hinge loss across the subset.
        margin = norm - delta
        hinge = torch.clamp(margin, min=0.0) ** 2

        total_loss += hinge.mean()

    # Keep track of everything. If it's too hard to keep track of centroids, that's also OK.
    loss = total_loss.mean()
    batch_losses.append(loss)
    batch_centroids.append(centroids)
   
   

I've been scratching my head on how to deal with the irregularly sized tensors. The number of classes in each batch K_i is different, and the size of each subset is different.

Upvotes: 7

Views: 960

Answers (2)

VF1
VF1

Reputation: 1652

It turns out it actually is possible to vectorize across ragged arrays. I'll use numpy, but code should be directly translatable to torch. The key technique is to:

  1. Sort by ragged array membership
  2. Perform an accumulation
  3. Find boundary indices, compute adjacent differences

For a single (non-batch) input of an n x d matrix X and an n-length array label, the following returns the k x d centroids and n-length distances to respective centroids:

def vcentroids(X, label):
    """
    Vectorized version of centroids.
    """        
    # order points by cluster label
    ix = np.argsort(label)
    label = label[ix]
    Xz = X[ix]
    
    # compute pos where pos[i]:pos[i+1] is span of cluster i
    d = np.diff(label, prepend=0) # binary mask where labels change
    pos = np.flatnonzero(d) # indices where labels change
    pos = np.repeat(pos, d[pos]) # repeat for 0-length clusters
    pos = np.append(np.insert(pos, 0, 0), len(X))
    
    Xz = np.concatenate((np.zeros_like(Xz[0:1]), Xz), axis=0)
    Xsums = np.cumsum(Xz, axis=0)
    Xsums = np.diff(Xsums[pos], axis=0)
    counts = np.diff(pos)
    c = Xsums / np.maximum(counts, 1)[:, np.newaxis]
    
    repeated_centroids = np.repeat(c, counts, axis=0)
    aligned_centroids = repeated_centroids[inverse_permutation(ix)]
    dist = np.sum((X - aligned_centroids) ** 2, axis=1)
    
    return c, dist

Batching requires little special handling. For an input B x n x d array batch_X, with B x n batch labels batch_labels, create unique labels for each batch:

batch_k = batch_labels.max(axis=1) + 1
batch_k[1:] = batch_k[:-1]
batch_k[0] = 0
base = np.cumsum(batch_k)
batch_labels += base.expand_dims(1) 

So now each batch element has a unique contiguous range of labels. I.e., the first batch element will have n labels in some range [0, k0) where k0 = batch_k[0], the second will have range [k0, k0 + k1) where k1 = batch_k[1], etc.

Then just flatten the n x B x d input to n*B x d and call the same vectorized method. Your loss function is derivable using the final distances and same position-array based reduction technique.

For a detailed explanation of how the vectorization works, see my blog post.

Upvotes: 4

jbencook
jbencook

Reputation: 536

You can vectorize the whole thing if you use a one-hot encoding for your classes and a pairwise distance trick for your norms:

import torch

B = 32
N = 1000
dim = 50
K = 25

batch_data = torch.randn((B, N, dim))
batch_labels = torch.randint(0, K, size=(B, N))
batch_one_hot = torch.nn.functional.one_hot(batch_labels)

centroids = torch.matmul(
    batch_one_hot.transpose(-1, 1).type(batch_data.dtype),
    batch_data
) / batch_one_hot.sum(1)[..., None]

norms = torch.linalg.norm(batch_data[:, :, None] - centroids[:, None], axis=-1)

# Compute the rest of your loss
# ...

A couple things to watch out for:

  1. You'll get a divide by zero for any batches that have a missing class. You can handle this by first computing the class sums (with matmul) and counts (summing the one-hot tensor along axis 1) separately. Then, mask the sums with count == 0 and divide the rest of them by their class counts.
  2. If you have a large number of classes, this will cause memory problems because the one-hot tensor will be too big. In that case, the answer from @VF1 probably makes more sense.

Upvotes: 0

Related Questions