Reputation: 656
I have one tensor A
of dimension [a,b,c,d]
, and another B
of dimension [b,b,d,e]
, and C
, a list of [a]
integers from 0 to b
. I need to produce the tensor D
of dimension [a,b,c,e]
given by
D[i,j,k,l] = sum for m=0..d of A[i,C[i],k,m] * B[C[i],j,m,l]
b
is small enough (3 or 5, usually?) that I don't mind doing this in b
separate operations -- but I can't afford the waste by going to something that takes b^2
memory or time, when this operation clearly should be linear in b
. This seems like it will be some combination of pointwise multiplies (with broadcasting?) and tensor contractions (a matrix multiply across the common m
dimension), but I can't pin it down.
If someone can really convince me that this isn't possible in O(b)
flops with tensorflow's provided operations, then okay, but then I'd want an O(b^2)
for sure.
Update: It's looking like the appropriately modified A
tensors can be built individually using tf.gather_nd
; if this can then be paired up with B
somehow, maybe? Unfortunately my experiments in this so far led to finding a bug in tf.gather_nd
itself which has slowed things down.
Upvotes: 1
Views: 564
Reputation: 656
I figured out how to accomplish this, reasonably efficiently. First build a modified version of B
with tf.gather
, with the appropriate parts in the first index:
B2 = tf.gather(B, C)
Then pull out just the relevant parts of the A
tensor using tf.gather_nd
. We're going to pull out a bunch of pairs of indices of the form [0,C[0]], [1,C[1]], [2,C[2]]...
and so on, so first we need to build the index tensor.
a = tf.shape(A)[0]
A2_indices = tf.stack([tf.range(a), C], axis=0)
A2 = tf.gather_nd(A, A2_indices)
producing A2
with shape [a,c,d]
. Now we need to multiply A2
and B2
appropriately. It's tensor contraction in the m
indices (2 and 3, respectively) but pointwise multiplication in the i
index (0 in both). This means that, sadly, the resulting item isn't tensor contraction or pointwise multiplication! One option would be computing the tensor product and contracting only over m
, and then taking tf.diag
over the two i
indices -- but this would waste a lot of computation building the rest of a matrix that we don't need. Instead, we can think of this as a batched matrix multiplication: this used to be called tf.batched_matmul
but now it's just matmul
. This has the caveat, though, that besides the 2 matrix dimensions in each input tensor, the rest all have to be pointwise multiplies. B
and B2
fail this criterion, because they have the additional j
index. But, we could "wrap that in" with the l
output dimension, and then remove it later. This means first calling tf.transpose
to put j
and l
next to each other, then tf.reshape
to turn into one j*l
output dimension, then doing tf.matmul
, then another tf.reshape
and tf.transpose
to return to the original form. So
a, b, d, e = B2.get_shape().as_list()
B2_trans = tf.transpose(B2, perm=[0,2,1,3])
B2_jl = tf.reshape(B2, [a,d,b*e])
product_jl = tf.matmul(A2, B2_jl)
product_trans = tf.reshape(product_jl, [a,d,b,e])
result = tf.transpose(product_trans, perm=[0,2,1,3])
Which finishes it up! Of course in practice it may well be that B
is only needed in this one instance, in which case it may be that B
can start out already in the "compressed" state, saving a transpose (and a cheap reshape); or if A2
is going to be flattened or transposed anyway then it could also save a transpose. But overall everything is pretty minimal complexity. :)
Upvotes: 1