Reputation: 3649
I am working on a conditional computing framework using MxNet. Assume that we have N samples in our minibatch. I need to execute such kind of operations in my computational graph, using pseudocode:
x = graph.Variable("x")
y = graph.DoSomeTranformations(x)
# The following operation generates a Nxk sized matrix, k responses for each sample.
z = graph.DoDecision(y)
for i in range(k):
argmax_sample_indices_for_i = graph.ArgMaxIndices(z, i)
y_selected_samples = graph.TakeSelectedSample(y, argmax_sample_indices_for_i )
result = graph.DoSomeTransformations(y_selected_samples)
What I want to achieve is the following: After I obtain y, I apply a decision function (this can be a D to k Fully Connected layer, where D is the data dimension) and obtain k activations for each sample in my N sized minibatch. Then, I want to dynamically split my minibatch into k different parts (k may be 2, 3, a small number), based on the column index of the maximum activation for each sample. My hypothetic "graph.ArgMaxIndices" function does that, given z, a Nxk sized matrix, and i, the function looks for the sample indices which give the maximum activations along the column i and returns their indices. (Note that I look for any series or combination of functions which give the equivalent result to "graph.ArgMaxIndices", not a single function, specifically). Then finally, for each i, I select the samples with maximum activations and apply specific transformations to them. Currently, to the best of my knowledge, MxNet does not support such kind of conditional calculations in their symbolic networks. Therefore, I build separate symbolic graphs after each decision and had to code my separate bookkeeping - conditional graph structures for each minibatch split, which produces 1) Very complex and cumbersome code to maintain and develop 2) Degraded running performance during training and evaluation.
My question is, can I do the above using the symbolic operators of Tensorflow? Does it allow one to select subsets of the minibatch, based on a criteria? Is a there function or series of functions which is equivalent to the "graph.ArgMaxIndices" in the pseudocode above? (Given a Nxk matrix and column index i, returns the indices of rows, which have the maximum activation at column k).
Upvotes: 1
Views: 473
Reputation: 6220
You can do that in Tensorflow.
The best way I see is using a mask and tf.boolean_mask k
times, with the i
-th mask being given by tf.equal(i, tf.argmax(z, axis=-1))
x = graph.Variable("x")
y = graph.DoSomeTranformations(x)
# The following operation generates a Nxk sized matrix, k responses for each sample.
z = graph.DoDecision(y)
max_indices = tf.argmax(z, axis=-1)
for i in range(k):
argmax_sample_indices_for_i = tf.equal(i, max_indices)
y_selected_samples = tf.boolean_mask(y, mask=argmax_sample_indices_for_i )
result = graph.DoSomeTransformations(y_selected_samples)
Upvotes: 2