Reputation: 3
I have two matrices. A with size 160 x 250 and B with size 3200 x 250.
I want to get the set intersection of each row of A with each row of B to get a 160 x 3200 vector. (The set size is 250 elements)
Any ideas how to implement this?
I'm thinking it should require torch.eq, but not sure how to change the dimensions. For example:
result = torch.sum(torch.eq(A[0], B), dim=1) would give me a 3200 element vector comparison with just the 0th row of A. I want for all rows of A (160)
Upvotes: 0
Views: 71
Reputation: 40668
Assuming you want to check equality between the 160x3200
possible pairs of 250-feature vectors. You can do so with an indexing trick:
>>> (A[None] == B[:, None]).all(-1)
Upvotes: 2