Alex Meiburg
Alex Meiburg

Reputation: 656

How to accomplish a certain multiplication in tensorflow

I have one tensor A of dimension [a,b,c,d], and another B of dimension [b,b,d,e], and C, a list of [a] integers from 0 to b. I need to produce the tensor D of dimension [a,b,c,e] given by

D[i,j,k,l] = sum for m=0..d of A[i,C[i],k,m] * B[C[i],j,m,l]

b is small enough (3 or 5, usually?) that I don't mind doing this in b separate operations -- but I can't afford the waste by going to something that takes b^2 memory or time, when this operation clearly should be linear in b. This seems like it will be some combination of pointwise multiplies (with broadcasting?) and tensor contractions (a matrix multiply across the common m dimension), but I can't pin it down.

If someone can really convince me that this isn't possible in O(b) flops with tensorflow's provided operations, then okay, but then I'd want an O(b^2) for sure.

Update: It's looking like the appropriately modified A tensors can be built individually using tf.gather_nd; if this can then be paired up with B somehow, maybe? Unfortunately my experiments in this so far led to finding a bug in tf.gather_nd itself which has slowed things down.

Upvotes: 1

Views: 564

Answers (1)

Alex Meiburg
Alex Meiburg

Reputation: 656

I figured out how to accomplish this, reasonably efficiently. First build a modified version of B with tf.gather, with the appropriate parts in the first index:

B2 = tf.gather(B, C)

Then pull out just the relevant parts of the A tensor using tf.gather_nd. We're going to pull out a bunch of pairs of indices of the form [0,C[0]], [1,C[1]], [2,C[2]]... and so on, so first we need to build the index tensor.

a = tf.shape(A)[0]
A2_indices = tf.stack([tf.range(a), C], axis=0)
A2 = tf.gather_nd(A, A2_indices)

producing A2 with shape [a,c,d]. Now we need to multiply A2 and B2 appropriately. It's tensor contraction in the m indices (2 and 3, respectively) but pointwise multiplication in the i index (0 in both). This means that, sadly, the resulting item isn't tensor contraction or pointwise multiplication! One option would be computing the tensor product and contracting only over m, and then taking tf.diag over the two i indices -- but this would waste a lot of computation building the rest of a matrix that we don't need. Instead, we can think of this as a batched matrix multiplication: this used to be called tf.batched_matmul but now it's just matmul. This has the caveat, though, that besides the 2 matrix dimensions in each input tensor, the rest all have to be pointwise multiplies. B and B2 fail this criterion, because they have the additional j index. But, we could "wrap that in" with the l output dimension, and then remove it later. This means first calling tf.transpose to put j and l next to each other, then tf.reshape to turn into one j*l output dimension, then doing tf.matmul, then another tf.reshape and tf.transpose to return to the original form. So

a, b, d, e = B2.get_shape().as_list()
B2_trans = tf.transpose(B2, perm=[0,2,1,3])
B2_jl = tf.reshape(B2, [a,d,b*e])
product_jl = tf.matmul(A2, B2_jl)
product_trans = tf.reshape(product_jl, [a,d,b,e])
result = tf.transpose(product_trans, perm=[0,2,1,3])

Which finishes it up! Of course in practice it may well be that B is only needed in this one instance, in which case it may be that B can start out already in the "compressed" state, saving a transpose (and a cheap reshape); or if A2 is going to be flattened or transposed anyway then it could also save a transpose. But overall everything is pretty minimal complexity. :)

Upvotes: 1

Related Questions