Reputation: 40481
I'm trying to implement some calculation, but I can't figure how to vectorize my code and not using loops.
Let me explain: I have a matrix M[N,C]
of either 0
or 1
. Another matrix Y[N,1]
containing values of [0,C-1]
(My classes). Another matrix ds[N,M]
which is my dataset.
My output matrix is of size grad[M,C]
and should be calculated as follow: I'll explain for grad[:,0]
, same logic for any other column.
For each row(sample) in ds
, if Y[that sample] != 0
(The current column of output matrix) and M[that sample, 0] > 0
, then grad[:,0] += ds[that sample]
If Y[that sample] == 0
, then grad[:,0] -= (ds[that sample] * <Num of non zeros in M[that sample,:]>
)
Here is my iterative approach:
for i in range(M.size(dim=1)):
for j in range(ds.size(dim=0)):
if y[j] == i:
grad[:,i] = grad[:,i] - (ds[j,:].T * sum(M[j,:]))
else:
if M[j,i] > 0:
grad[:,i] = grad[:,i] + ds[j,:].T
Upvotes: 1
Views: 60
Reputation: 2569
Since you are dealing with three dimensions n
, m
, and c
(in lowercase to avoid ambiguity), it can be useful to change the shape of all your tensors to (n, m, c)
, by replicating their values over the missing dimension (e.g. M(m, c)
becomes M(n, m, c)
).
However, you can skip the explicit replication and use broadcasting, so it is sufficient to unsqueeze the missing dimension (e.g. M(m, c)
becomes M(1, m, c)
.
Given these considerations, the vectorization of your code becomes as follows
cond = y.unsqueeze(2) == torch.arange(M.size(dim=1)).unsqueeze(0)
pos = ds.unsqueeze(2) * M.unsqueeze(1) * cond
neg = ds.unsqueeze(2) * M.unsqueeze(1).sum(dim=0, keepdim=True) * ~cond
grad += (pos - neg).sum(dim=0)
Here is a small test to check the validity of the solution
import torch
n, m, c = 11, 5, 7
y = torch.randint(c, size=(n, 1))
ds = torch.rand(n, m)
M = torch.randint(2, size=(n, c))
grad = torch.rand(m, c)
def slow_grad(y, ds, M, grad):
for i in range(M.size(dim=1)):
for j in range(ds.size(dim=0)):
if y[j] == i:
grad[:,i] = grad[:,i] - (ds[j,:].T * sum(M[j,:]))
else:
if M[j,i] > 0:
grad[:,i] = grad[:,i] + ds[j,:].T
return grad
def fast_grad(y, ds, M, grad):
cond = y.unsqueeze(2) == torch.arange(M.size(dim=1)).unsqueeze(0)
pos = ds.unsqueeze(2) * M.unsqueeze(1) * cond
neg = ds.unsqueeze(2) * M.unsqueeze(1).sum(dim=0, keepdim=True) * ~cond
grad += (pos - neg).sum(dim=0)
return grad
# Assert equality of all elements function outputs, throws an exception if false
assert torch.all(slow_grad(y, ds, M, grad) == fast_grad(y, ds, M, grad))
Feel free to test on other cases as well!
Upvotes: 1