Reputation: 1510
I have 3 tensor
X
shape (1, c, h, w)
, assume (1, 20, 40, 50)
Fx
shape (num, w, N)
, assume (1000, 50, 10)
Fy
shape (num, N, h)
, assume (1000, 10, 40)
What I want to do is Fy * (X * Fx)
(*
means matmul
)
X * Fx
shape (num, c, h, N)
, assume (1000, 20, 40, 10)
Fy * (X * Fx)
shape (num, c, N, N)
, assume (1000, 20, 10, 10)
I am using tf.tile
and tf.expand_dims
to do it
but I think it use a lot of memory(tile
copy data right?), and slow
try to find better way that faster and use small memory to accomplish
# X: (1, c, h, w)
# Fx: (num, w, N)
# Fy: (num, N, h)
X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1]) # (num, c, h, w)
Fx_ex = tf.expand_dims(Fx, axis=1) # (num, 1, w, N)
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1]) # (num, c, w, N)
tmp = tf.matmul(X, Fxt_ex) # (num, c, h, N)
Fy_ex = tf.expand_dims(Fy, axis=1) # (num, 1, N, h)
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1]) # (num, c, N, h)
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N)
Upvotes: 4
Views: 904
Reputation: 20960
A case for the mythical einsum
, I guess:
>>> import numpy as np
>>> X = np.random.rand(1, 20, 40, 50)
>>> Fx = np.random.rand(100, 50, 10)
>>> Fy = np.random.rand(100, 10, 40)
>>> np.einsum('nMh,uchw,nwN->ncMN', Fy, X, Fx).shape
(100, 20, 10, 10)
It's should work almost the same in tf
as in numpy
(using uppercase indices isn't allowed in some tf
versions, I saw). Although this admittedly exceeds a regex in unreadability if you've never seen the notation before.
Upvotes: 2
Reputation: 1510
For otherone may interested
I think the answer of @phg maybe work
But in my case num
h
w
are dynamic, i.e. None
So tf.einsum
in tensorflow r1.0 will raise error, since there are more than one None
shape in one tensor
fortunately, there is a issue and pull request
seems can handle situation that there are more than one None
shape
Need to build from source(master branch)
I will report the result after I re-build tensorflow
BTW, in tf.einsum
only accept lowercase
Report
Yes, The newest version of tensorflow (master branch) accept dynamic shape for tf.einsum
and it is huge speed improvement after using tf.einsum
, really awesome
Upvotes: 0