Reputation: 107
I want to multiply only along a given axis like this:
a = tf.ones([2,2,3])
b = tf.constant(7)
c = //multiply a[:,:,1] with b
so that c[...,0]
and c[...,2]
have ones but c[...,1]
has sevens:
print(c.shape)
> (2, 2, 3)
print(a[...,0]) //output same for a[...,1] and a[...,2]
> tf.Tensor(
[[1. 1.]
[1. 1.]], shape=(2, 2), dtype=float32)
print(c[...,0])
>tf.Tensor(
[[1. 1.]
[1. 1.]], shape=(2, 2), dtype=float32)
print(c[...,1])
>tf.Tensor(
[[7. 7.]
[7. 7.]], shape=(2, 2), dtype=float32)
print(c[...,2])
>tf.Tensor(
[[1. 1.]
[1. 1.]], shape=(2, 2), dtype=float32)
Upvotes: 1
Views: 326
Reputation: 26708
I'm not quite sure what result you're expecting, but if I understood you correctly, you could do something like this:
import tensorflow as tf
a = tf.concat([tf.ones([1, 4, 3], dtype=tf.float32),
tf.ones([1, 4, 3], dtype=tf.float32) * 3,
tf.zeros([2, 4, 3], dtype=tf.float32)], axis=0)
b = tf.constant(7, dtype=tf.float32)
tensor = tf.slice(a,
begin=[1, 0, 0],
size=[1, 4, 3])
c = tensor * b
result = tf.tensor_scatter_nd_update(a, [[1]], c)
print('a -->', a, '\n')
print('c -->', c, '\n')
print('result -->', result)
a --> tf.Tensor(
[[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
[[3. 3. 3.]
[3. 3. 3.]
[3. 3. 3.]
[3. 3. 3.]]
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]], shape=(4, 4, 3), dtype=float32)
c --> tf.Tensor(
[[[21. 21. 21.]
[21. 21. 21.]
[21. 21. 21.]
[21. 21. 21.]]], shape=(1, 4, 3), dtype=float32)
result --> tf.Tensor(
[[[ 1. 1. 1.]
[ 1. 1. 1.]
[ 1. 1. 1.]
[ 1. 1. 1.]]
[[21. 21. 21.]
[21. 21. 21.]
[21. 21. 21.]
[21. 21. 21.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]
[ 0. 0. 0.]]], shape=(4, 4, 3), dtype=float32)
Update: slicing the data the way you want is not what you think it is:
import tensorflow as tf
a = tf.concat([tf.ones([1, 2, 3], dtype=tf.float32),
tf.zeros([1, 2, 3], dtype=tf.float32)], axis=0)
b = tf.ones([2, 2, 3], dtype=tf.float32)
print(a.shape, a[...,0])
print(b.shape, b[...,0])
(2, 2, 3) tf.Tensor(
[[1. 1.]
[0. 0.]], shape=(2, 2), dtype=float32)
(2, 2, 3) tf.Tensor(
[[1. 1.]
[1. 1.]], shape=(2, 2), dtype=float32)
As you can see when using a (2, 2, 3) tensor with mixed values, you do not only get ones like you would expect, when slicing along the last axis.
Upvotes: 1