Reputation: 301
I am migrating a code from pytorch to tensorflow, and in the function that calculates the loss, I have the below line that I need to migrate to tensorflow.
state_action_values = net(t_states_features).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
I found tf.gather
and tf.gather_nd
and I am not sure which is more suitable and how it could be used, also unsqueeze's alternative is maybe tf.expand_dims
In an attempt to get a clearer view of the line's result, I split it into multiple parts with print statements.
print("net result")
state_action_values = net(t_states_features)
print("gather result")
state_action_values = state_action_values.gather(1, actions_v.unsqueeze(-1))
print("last squeeze")
state_action_values = state_action_values.squeeze(-1)
net result
tensor([[ 45.6878, -14.9495, 59.3737],
[ 33.5737, -10.4617, 39.0078],
[ 67.7197, -22.8818, 85.7977],
[ 94.7701, -33.2053, 120.5519],
[ nan, nan, nan],
[ 84.7324, -29.2101, 108.0821],
[ 67.7193, -22.7702, 86.9558],
[113.6835, -38.7149, 142.6167],
[ 61.9260, -20.1968, 79.8010],
[ 51.6152, -17.7391, 66.0719],
[ 73.6565, -21.5699, 98.9463],
[ 84.0761, -26.5016, 107.6888],
[ 60.9459, -20.1257, 76.4105],
[103.2883, -35.4035, 130.4503],
[ 37.1156, -13.5180, 47.1067],
[ nan, nan, nan],
[ 55.6286, -18.5239, 71.9837],
[ 55.3858, -18.7892, 71.1197],
[ 50.2419, -17.2959, 66.7059],
[ 82.5715, -30.0302, 108.4984],
[ -0.8662, -1.1861, 1.6033],
[112.4620, -38.6416, 142.4556],
[ 57.8702, -19.8080, 74.7656],
[ 45.8418, -15.7436, 57.3367],
[ 81.6596, -27.5002, 104.6002],
[ 57.1507, -21.8001, 67.7933],
[ 35.0414, -11.8199, 47.6573],
[ 67.7085, -23.1017, 85.4623],
[ 40.6284, -12.4578, 58.9603],
[ 68.6394, -23.1481, 87.0832],
[ 27.0549, -8.6635, 34.0150],
[ 25.4071, -8.5511, 34.0285],
[ 62.9161, -22.1693, 78.7965],
[ 85.4505, -28.1487, 108.6252],
[ 67.6665, -23.2376, 85.7117],
[ 60.7806, -20.2784, 77.1022],
[ 66.5209, -21.5674, 88.5561],
[ 61.6637, -20.9891, 72.3873],
[ 45.1634, -15.4678, 61.4886],
[ 66.8119, -23.1250, 85.6189],
[ nan, nan, nan],
[ 67.8166, -24.8342, 84.6706],
[ 86.2114, -29.5941, 107.8025],
[ 66.2716, -23.3309, 83.9700],
[101.2122, -35.3554, 127.4772],
[ 61.0749, -19.4720, 78.5588],
[ 50.4058, -16.1262, 63.1010],
[ 27.7543, -9.3767, 35.7448],
[ 67.7810, -23.4962, 83.6030],
[ 35.0103, -11.7238, 44.7983],
[ 55.7402, -19.0223, 70.3627],
[ 67.9733, -22.0783, 85.1893],
[ 60.5253, -20.3157, 79.7312],
[ 67.2404, -21.5205, 81.4499],
[ 57.9502, -20.7747, 70.9109],
[ 87.6536, -31.4256, 112.6491],
[ 90.3668, -30.7755, 116.6192],
[ 59.0660, -19.6988, 75.0723],
[ 50.0969, -17.4135, 62.6556],
[ 28.8703, -9.0950, 34.5749],
[ 68.4053, -22.0715, 88.2302],
[ 69.1397, -21.4236, 84.7833],
[ 23.8506, -8.1834, 30.8318],
[ 58.4296, -20.2432, 73.8116],
[ 87.5317, -29.0606, 110.0389],
[ nan, nan, nan],
[ 88.6387, -30.6154, 112.4239],
[ 51.6089, -16.1073, 66.2757],
[ 94.3989, -32.1473, 119.0358],
[ 82.7449, -30.7778, 102.8537],
[ 74.3067, -26.6585, 98.2536],
[ 77.0881, -26.5706, 98.3553],
[ 28.5688, -9.2949, 41.1165],
[ 86.1560, -26.9364, 107.0244],
[ 41.8914, -16.9703, 57.3840],
[ 88.8886, -29.7008, 108.2697],
[ 61.1243, -20.7566, 77.2257],
[ 85.1174, -28.7558, 107.3853],
[ 81.7256, -27.9047, 104.5006],
[ 51.2663, -16.5880, 67.1428],
[ 46.9150, -12.7457, 61.3240],
[ 36.1758, -12.9769, 47.7178],
[ 85.5846, -29.4141, 107.9649],
[ 59.9424, -20.8349, 75.3359],
[ 62.6516, -22.1235, 81.6903],
[104.7664, -34.5876, 129.9478],
[ 64.4671, -23.3980, 83.9093],
[ 69.6928, -23.6567, 89.6024],
[ 60.4407, -19.6136, 75.9350],
[ 33.4921, -10.3434, 44.9537],
[ 57.9112, -19.4174, 74.3050],
[ 24.8262, -9.3637, 30.1057],
[ 85.3776, -28.9097, 110.1310],
[ 63.8175, -22.3843, 81.0308],
[ 34.6040, -12.3217, 46.0356],
[ 88.3740, -29.5049, 110.2897],
[ 66.8196, -22.5860, 85.5386],
[ 58.9767, -22.0601, 78.7086],
[ 83.2090, -26.3499, 113.5105],
[ 54.8450, -17.7980, 68.1161],
[ nan, nan, nan],
[ 85.0846, -29.2494, 107.6780],
[ 76.9251, -26.2295, 98.4755],
[ 98.2907, -32.8878, 124.9192],
[ 91.1387, -30.8262, 115.3978],
[ 73.1062, -24.9450, 90.0967],
[ 27.6564, -8.6114, 35.4470],
[ 71.8508, -25.1529, 95.5165],
[ 69.7275, -20.1357, 86.9620],
[ 67.0907, -21.9245, 84.8853],
[ 77.3163, -25.5980, 92.7700],
[ 63.0082, -21.0345, 78.7311],
[ 68.0553, -22.4280, 84.8031],
[ 5.8148, -2.3171, 8.0620],
[103.3399, -35.1769, 130.7801],
[ 54.8769, -18.6822, 70.4657],
[ 58.4446, -18.9764, 75.5509],
[ 91.0071, -31.2706, 112.6401],
[ 84.6577, -29.2644, 104.6046],
[ 45.4887, -15.8309, 59.0498],
[ 56.3384, -18.9264, 78.8834],
[ 63.5109, -21.3169, 81.5144],
[ 79.4635, -29.8681, 100.5056],
[ 27.6559, -10.0517, 35.6012],
[ 76.3909, -24.1689, 93.6133],
[ 34.3802, -11.5272, 45.8650],
[ 60.3553, -20.1693, 76.5371],
[ 56.0590, -18.6468, 69.8981]], grad_fn=<AddmmBackward0>)
gather result
tensor([[ 59.3737],
[ 67.7197],
[ 94.7701],
[ nan],
[ 67.7193],
[ 66.0719],
[ 98.9463],
[ 47.1067],
[ nan],
[ 55.6286],
[ 66.7059],
[ 1.6033],
[ 74.7656],
[ 81.6596],
[ 35.0414],
[ 40.6284],
[ 68.6394],
[ 34.0150],
[ 34.0285],
[ 78.7965],
[ 67.6665],
[ 72.3873],
[ 85.6189],
[ nan],
[ -9.3767],
[ 70.3627],
[ 67.2404],
[ 50.0969],
[ 34.5749],
[ 88.2302],
[ -8.1834],
[ 73.8116],
[ nan],
[ 98.2536],
[ 98.3553],
[ 28.5688],
[ 77.2257],
[ 67.1428],
[ 47.7178],
[ 59.9424],
[ 75.9350],
[ 30.1057],
[ 85.3776],
[ 63.8175],
[ 46.0356],
[ nan],
[ 76.9251],
[ 35.4470],
[ 95.5165],
[ 86.9620],
[ 78.7311],
[ 5.8148],
[ 70.4657],
[ 58.4446],
[ 91.0071],
[ 45.4887],
[ 63.5109],
[ 79.4635],
[ 76.3909],
[ 34.3802],
[-18.6468]], grad_fn=<GatherBackward0>)
last squeeze
tensor([ 59.3737, -10.4617, 67.7197, 94.7701, nan, -29.2101, 67.7193,
-38.7149, -20.1968, 66.0719, 98.9463, 107.6888, -20.1257, -35.4035,
47.1067, nan, 55.6286, -18.7892, 66.7059, -30.0302, 1.6033,
112.4620, 74.7656, -15.7436, 81.6596, -21.8001, 35.0414, -23.1017,
40.6284, 68.6394, 34.0150, 34.0285, 78.7965, -28.1487, 67.6665,
-20.2784, -21.5674, 72.3873, -15.4678, 85.6189, nan, -24.8342,
-29.5941, -23.3309, 101.2122, -19.4720, -16.1262, -9.3767, -23.4962,
-11.7238, 70.3627, -22.0783, -20.3157, 67.2404, -20.7747, 112.6491,
-30.7755, -19.6988, 50.0969, 34.5749, 88.2302, -21.4236, -8.1834,
73.8116, 110.0389, nan, 112.4239, -16.1073, -32.1473, -30.7778,
98.2536, 98.3553, 28.5688, 107.0244, -16.9703, -29.7008, 77.2257,
-28.7558, -27.9047, 67.1428, -12.7457, 47.7178, -29.4141, 59.9424,
-22.1235, 129.9478, -23.3980, -23.6567, 75.9350, -10.3434, -19.4174,
30.1057, 85.3776, 63.8175, 46.0356, -29.5049, -22.5860, -22.0601,
113.5105, -17.7980, nan, -29.2494, 76.9251, -32.8878, 115.3978,
-24.9450, 35.4470, 95.5165, 86.9620, -21.9245, -25.5980, 78.7311,
-22.4280, 5.8148, 103.3399, 70.4657, 58.4446, 91.0071, 104.6046,
45.4887, -18.9264, 63.5109, 79.4635, -10.0517, 76.3909, 34.3802,
-20.1693, -18.6468], grad_fn=<SqueezeBackward1>)
Edit 1: print of actions_v
tensor([2, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 2, 0, 2, 0, 1,
2, 0, 2, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1,
0, 2, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0, 2, 0, 0,
1, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 2, 2,
2, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2,
2, 1, 0, 2, 0, 0, 2, 1])
Upvotes: 2
Views: 394
Reputation: 136
takes inputs that have the same dimension as the input tensor, and will output a tensor of values being at those indices (which is what you want).
will output slices (but you can give as indice shape whatever you want, the output tensor will just be a bunch of slices that are structured accordingly to the shape of indices) which is not what you want.
So you should first make the indices match the dimensions of the initial matrix:
indices = tf.transpose(tf.stack((tf.range(tf.shape(state_action_values)[0]),actions_v)))
And then gather_nd
state_action_values = tf.gather_nd(state_action_values,indices)
Upvotes: 1