Vedanshu
Vedanshu

Reputation: 2296

Move axis in tensorflow

I have two tensors. The main tensor is as follows:

array([[[ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217]],

       [[ 450,  607,  493,  662],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[ 950, 1277, 1028, 1335],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]]], dtype=int32)

I want to move this tensor according to the following tensor:

array([0, 2, 5], dtype=int32)

The above tensor contains the axis we want the current axis to move to.

The final tensor should look like this:

array([[[ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[ 450,  607,  493,  662],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[ 950, 1277, 1028, 1335],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]]], dtype=int32)

Upvotes: 1

Views: 629

Answers (1)

Safwan
Safwan

Reputation: 3428

You can use the tensorflow scatter function tf.scatter_nd for achieving this.

Define your input tensor:

input = tf.constant([[[ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217]],

   [[ 450,  607,  493,  662],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[ 950, 1277, 1028, 1335],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]]])

Since we are interested in only the first 3 elements along zeroth dimension, let's slice it into a new tensor:

sliced_input = tf.slice(input, [0, 0, 0], [3, -1, -1])

Define your target indices:

indices = tf.constant([[0], [2], [5]])

Define shapes of your target output, here same as your input shape:

shape = tf.shape(input)

Now use the scatter function to get your output:

output = tf.scatter_nd(indices, sliced_input, shape)

output:

array([[[ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[ 450,  607,  493,  662],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[ 950, 1277, 1028, 1335],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]]], dtype=int32)

Upvotes: 1

Related Questions