Reputation: 93
I'm trying to vectorize the following for-loop in Pytorch. I'd be happy with just vectorizing the inner for-loop, but doing the whole batch would also be awesome.
# B: the batch size
# N: the number of training examples
# dim: the dimension of each feature vector
# K: the number of discrete labels. each vector has a single label
# delta: margin for hinge loss
batch_data = torch.tensor(...) # Tensor of shape [B x N x d]
batch_labels = torch.tensor(...) # Tensor of shape [B x N x 1], each element is one of K labels (ints)
batch_losses = [] # Ultimately should be [B x 1]
batch_centroids = [] # Ultimately should be [B x K_i x dim]
for i in range(B):
centroids = [] # Keep track of the means for each class.
classes = torch.unique(labels) # Get the unique labels for the classes.
# NOTE: The number of classes K for each item in the batch might actually
# be different. This may complicate batch-level operations.
total_loss = 0
# For each class independently. This is the part I want to vectorize.
for cl in classes:
# Take the subset of training examples with that label.
subset = data[torch.where(labels == cl)]
# Find the centroid of that subset.
centroid = subset.mean(dim=0)
centroids.append(centroid)
# Get the distance between each point in the subset and the centroid.
dists = subset - centroid
norm = torch.linalg.norm(dists, dim=1)
# The loss is the mean of the hinge loss across the subset.
margin = norm - delta
hinge = torch.clamp(margin, min=0.0) ** 2
total_loss += hinge.mean()
# Keep track of everything. If it's too hard to keep track of centroids, that's also OK.
loss = total_loss.mean()
batch_losses.append(loss)
batch_centroids.append(centroids)
I've been scratching my head on how to deal with the irregularly sized tensors. The number of classes in each batch K_i
is different, and the size of each subset is different.
Upvotes: 7
Views: 960
Reputation: 1652
It turns out it actually is possible to vectorize across ragged arrays. I'll use numpy, but code should be directly translatable to torch. The key technique is to:
For a single (non-batch) input of an n x d
matrix X
and an n
-length array label
, the following returns the k x d
centroids and n
-length distances to respective centroids:
def vcentroids(X, label):
"""
Vectorized version of centroids.
"""
# order points by cluster label
ix = np.argsort(label)
label = label[ix]
Xz = X[ix]
# compute pos where pos[i]:pos[i+1] is span of cluster i
d = np.diff(label, prepend=0) # binary mask where labels change
pos = np.flatnonzero(d) # indices where labels change
pos = np.repeat(pos, d[pos]) # repeat for 0-length clusters
pos = np.append(np.insert(pos, 0, 0), len(X))
Xz = np.concatenate((np.zeros_like(Xz[0:1]), Xz), axis=0)
Xsums = np.cumsum(Xz, axis=0)
Xsums = np.diff(Xsums[pos], axis=0)
counts = np.diff(pos)
c = Xsums / np.maximum(counts, 1)[:, np.newaxis]
repeated_centroids = np.repeat(c, counts, axis=0)
aligned_centroids = repeated_centroids[inverse_permutation(ix)]
dist = np.sum((X - aligned_centroids) ** 2, axis=1)
return c, dist
Batching requires little special handling. For an input B x n x d
array batch_X
, with B x n
batch labels batch_labels
, create unique labels for each batch:
batch_k = batch_labels.max(axis=1) + 1
batch_k[1:] = batch_k[:-1]
batch_k[0] = 0
base = np.cumsum(batch_k)
batch_labels += base.expand_dims(1)
So now each batch element has a unique contiguous range of labels. I.e., the first batch element will have n
labels in some range [0, k0)
where k0 = batch_k[0]
, the second will have range [k0, k0 + k1)
where k1 = batch_k[1]
, etc.
Then just flatten the n x B x d
input to n*B x d
and call the same vectorized method. Your loss function is derivable using the final distances and same position-array based reduction technique.
For a detailed explanation of how the vectorization works, see my blog post.
Upvotes: 4
Reputation: 536
You can vectorize the whole thing if you use a one-hot encoding for your classes and a pairwise distance trick for your norms:
import torch
B = 32
N = 1000
dim = 50
K = 25
batch_data = torch.randn((B, N, dim))
batch_labels = torch.randint(0, K, size=(B, N))
batch_one_hot = torch.nn.functional.one_hot(batch_labels)
centroids = torch.matmul(
batch_one_hot.transpose(-1, 1).type(batch_data.dtype),
batch_data
) / batch_one_hot.sum(1)[..., None]
norms = torch.linalg.norm(batch_data[:, :, None] - centroids[:, None], axis=-1)
# Compute the rest of your loss
# ...
A couple things to watch out for:
Upvotes: 0