user15110545
user15110545

Reputation: 107

TensorFlow multiplication along axis

I want to multiply only along a given axis like this:

a = tf.ones([2,2,3])
b = tf.constant(7)
c = //multiply a[:,:,1] with b

so that c[...,0] and c[...,2] have ones but c[...,1] has sevens:

print(c.shape)
> (2, 2, 3)

print(a[...,0]) //output same for a[...,1] and a[...,2]
> tf.Tensor(
[[1. 1.]
 [1. 1.]], shape=(2, 2), dtype=float32)

print(c[...,0])
>tf.Tensor(
 [[1. 1.]
  [1. 1.]], shape=(2, 2), dtype=float32)


print(c[...,1])
>tf.Tensor(
 [[7. 7.]
  [7. 7.]], shape=(2, 2), dtype=float32)


print(c[...,2])
>tf.Tensor(
 [[1. 1.]
  [1. 1.]], shape=(2, 2), dtype=float32)

Upvotes: 1

Views: 326

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

I'm not quite sure what result you're expecting, but if I understood you correctly, you could do something like this:

import tensorflow as tf

a = tf.concat([tf.ones([1, 4, 3], dtype=tf.float32), 
               tf.ones([1, 4, 3], dtype=tf.float32) * 3, 
               tf.zeros([2, 4, 3], dtype=tf.float32)], axis=0)
b = tf.constant(7, dtype=tf.float32)

tensor = tf.slice(a,
               begin=[1, 0, 0],
               size=[1, 4, 3])
c = tensor * b
result = tf.tensor_scatter_nd_update(a, [[1]], c)
print('a -->', a, '\n')
print('c -->', c, '\n')
print('result -->', result)
a --> tf.Tensor(
[[[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[3. 3. 3.]
  [3. 3. 3.]
  [3. 3. 3.]
  [3. 3. 3.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]], shape=(4, 4, 3), dtype=float32) 

c --> tf.Tensor(
[[[21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]]], shape=(1, 4, 3), dtype=float32) 

result --> tf.Tensor(
[[[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]

 [[21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]], shape=(4, 4, 3), dtype=float32)

Update: slicing the data the way you want is not what you think it is:

import tensorflow as tf

a = tf.concat([tf.ones([1, 2, 3], dtype=tf.float32), 
               tf.zeros([1, 2, 3], dtype=tf.float32)], axis=0)

b = tf.ones([2, 2, 3], dtype=tf.float32)

print(a.shape, a[...,0])
print(b.shape, b[...,0])
(2, 2, 3) tf.Tensor(
[[1. 1.]
 [0. 0.]], shape=(2, 2), dtype=float32)
(2, 2, 3) tf.Tensor(
[[1. 1.]
 [1. 1.]], shape=(2, 2), dtype=float32)

As you can see when using a (2, 2, 3) tensor with mixed values, you do not only get ones like you would expect, when slicing along the last axis.

Upvotes: 1

Related Questions