Reputation: 217
I have the following segment of for loop in my code. The nested loop is slowing down my complete execution.
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
Here, composition_matrix
is [14000,2]
dimensional pytorch tensor with only positive integers as cell values. pred
and output
both are a [batchSize,2]
dimensional torch tensor.
As this for loop is slowing my code a lot and I am unable to get the equivalent broadcasting solution to this code segment.
Does a broadcasting solution exists to eleminate this for loop?
I shall be grateful for any help.
A minimum reproducible example is
import torch
composition_matrix=torch.randint(3, 10, (14000,2))
batchSize=64
pred=torch.randint(3, 10, (batchSize,2))
output=torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
Upvotes: 2
Views: 1915
Reputation: 4181
To make it simple, you first need to understand what the operation is essentially doing. You've got two tensors. Tensor A is of shape (14000, 2)
and tensor B is of shape (64, 2)
. The operation you want to do is:
For each row B[i] in B, compare that B[i] (of shape (2,) with A (of shape (14000, 2)). If B[i] occurs within A, set output[i] = index of first occurrence.
This can actually be done in two lines of code (maybe even one line):
comp = (composition_matrix[:, None, :] == pred).all(dim=-1)
output = torch.argmax(comp.float(), axis=0)
The first line creates comp
, the broadcasted comparison of composition_matrix
and pred
, a boolean tensor of shape (14000, 64)
.
The second line needs to find the "index of the first match". This can be done quite simply with argmax: it will return the index of the first "1" (or if all the values are "0", will return the first index, ie, 0).
(Note that torch does not support argmax for "bool" tensors, and so comp needed to be cast to another data type.)
Upvotes: 2
Reputation: 2002
Sorry for the short and probably over-simplified example. I fear a bigger one would be much more difficult to visualize. But I hope this suits your purpose. My solution may seem a little complicated but it's fully vectorized and includes no explicit loops. Here's what I would do:
import torch
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
output = torch.zeros((batchSize, 2))
composition_matrix = torch.randint(0, 10, (14, 2))
# compair all vectors in composition_matrix to all vectors in pred
comparisons = (composition_matrix.unsqueeze(0) == pred.unsqueeze(1))
comparisons = comparisons.all(2)
# form an index array the shape of the comparisons array
comparison_idxs = torch.arange(comparisons.shape[1])
comparison_idxs = comparison_idxs.repeat(batchSize).reshape(*comparisons.shape)
# multipy the comparisons array by the index array
where_result = (comparison_idxs*comparisons)
# replace invalind zeros with the maximal value in each sample
batch_idxs = torch.arange(comparisons.shape[0])
batch_idxs = batch_idxs.repeat(comparisons.shape[1])
batch_idxs = batch_idxs.reshape(comparisons.shape[1], comparisons.shape[0]).T
maxima = where_result.max(1).values[batch_idxs]
maxima_vecor = maxima[(1-comparisons.int()).bool()]
where_result[(1-comparisons.int()).bool()] = maxima_vecor
vectorized_output = where_result.min(1)[0]
output = torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
output:
composition_matrix =
tensor([[6, 8],
[4, 3],
[6, 9],
[1, 4],
[4, 1],
[9, 9],
[9, 0],
[1, 2],
[3, 0],
[5, 5],
[2, 9],
[1, 8],
[8, 3],
[6, 9]])
pred =
tensor([[4, 9],
[3, 0],
[3, 9],
[7, 3],
[7, 3],
[1, 6],
[6, 9],
[8, 6]])
output =
tensor([0., 8., 0., 0., 0., 0., 2., 0.])
vectorized_output =
tensor([0, 8, 0, 0, 0, 0, 2, 0])
Some timing results:
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
composition_matrix = torch.randint(0, 10, (14000, 2))
print('timing the vectorized_solution:')
%timeit -n 1000 vectorized_solution(composition_matrix, pred,)
print('timing the loop_solution:')
%timeit -n 1000 loop_solution(composition_matrix, pred,)
output:
timing the vectorized_solution:
1000 loops, best of 5: 137 µs per loop
timing the loop_solution:
1000 loops, best of 5: 1.89 ms per loop
Upvotes: 1