Reputation: 1905
I have two mini-batch of sequences :
a = C.sequence.input_variable((10))
b = C.sequence.input_variable((10))
Both a
and b
have variable-length sequences.
I want to do matching between them where matching is defined as: match (eg. dot product) token at each time step of a
with token at every time step of b
.
How can I do this?
Upvotes: 0
Views: 137
Reputation: 2050
I have mostly answered this on github but to be consistent with SO rules, I am including a response here. In case of something simple like a dot product you can take advantage of the fact that it factorizes nicely, so the following code works
axisa = C.Axis.new_unique_dynamic_axis('a')
axisb = C.Axis.new_unique_dynamic_axis('b')
a = C.sequence.input_variable(1, sequence_axis=axisa)
b = C.sequence.input_variable(1, sequence_axis=axisb)
c = C.sequence.broadcast_as(C.sequence.reduce_sum(a), b) * b
c.eval({a: [[1, 2, 3],[4, 5]], b: [[6, 7], [8]]})
[array([[ 36.],
[ 42.]], dtype=float32), array([[ 72.]], dtype=float32)]
In the general case you need the following steps
static_b, mask = C.sequence.unpack(b, neutral_value).outputs
scores = your_score(a, static_b)
The first line will convert the b
sequence into a static tensor with one more axis than b
. Because of packing, some elements of this tensor will be invalid and those will be indicated by the mask
. The neutral_value
will be placed as a dummy value in the static_b tensor wherever data was missing. Depending on your score you might be able to arrange for the neutral_value
to not affect the final score (e.g. if your score is a dot product a 0 would be a good choice, if it involves a softmax -infinity or something close to that would be a good choice). The second line can now have access to each element of a
and all the elements of b
as the first axis of static_b
. For a dot product static_b
is a matrix and one element of a
is a vector so a matrix vector multiplication will result in a sequence whose elements are all inner products between the corresponding element of a and all elements of b
.
Upvotes: 1