Peyman
Peyman

Reputation: 4209

TensorFlow 2: How to use scatter function?

I'm struggling to understand the scatter function in TensorFlow. For example I want to use tf.compat.v1.scatter_sub to sub from the second index as following:

a = tf.Variable(tf.random.uniform(shape=[2]))
b = tf.Variable(tf.random.uniform(shape=[3, 2]))

Were a goes:

<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([0.62174237, 0.7431344 ], dtype=float32)>

and b goes:

<tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
array([[0.01709783, 0.72348535],
       [0.48500955, 0.7092271 ],
       [0.62199426, 0.26062095]], dtype=float32)>

I want to subtract a from the second row of b so that I can achieve this:

array([[0.01709783, 0.72348535 ],
       [-0.13673282, -0.0339073],
       [0.62199426, 0.26062095]], dtype=float32)>

I thought that tf.compat.v1.scatter_sub(b, [1], a) must work but it failed. I tried transposing a but it failed too. The full error is this:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-37-90775abdb544> in <module>()
      9 print("------------------------")
     10 
---> 11 tf.compat.v1.scatter_sub(b, [1], a)

3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py in scatter_sub(ref, indices, updates, use_locking, name)
    535   return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub(  # pylint: disable=protected-access
    536       ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
--> 537       name=name))
    538 
    539 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub(resource, indices, updates, name)
   1077     try:
   1078       return resource_scatter_sub_eager_fallback(
-> 1079           resource, indices, updates, name=name, ctx=_ctx)
   1080     except _core._SymbolicException:
   1081       pass  # Add nodes to the TensorFlow graph.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub_eager_fallback(resource, indices, updates, name, ctx)
   1095   _attrs = ("dtype", _attr_dtype, "Tindices", _attr_Tindices)
   1096   _result = _execute.execute(b"ResourceScatterSub", 0, inputs=_inputs_flat,
-> 1097                              attrs=_attrs, ctx=ctx, name=name)
   1098   _result = None
   1099   return _result

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [2], indices.shape [1], params.shape [3,2] [Op:ResourceScatterSub]

What is the correct way to use this function ?

Upvotes: 0

Views: 1617

Answers (2)

Peyman
Peyman

Reputation: 4209

I got it. The problem is the updates (which is a here) suppose to be a list of several updates, but here I gave it just the vector a itself, not a list that consists of only a.

Now I should extend a by one dimension. I mean a is now [0.62174237, 0.7431344] and I should change it to [[0.62174237, 0.7431344 ]] and I can do this by tf.expand_dims.

so the solution is:

tf.compat.v1.scatter_sub(b, [1], tf.expand_dims(a, axis=0))

Upvotes: 1

user14518353
user14518353

Reputation:

Have you had a look through the documentation here? https://www.tensorflow.org/api_docs/python/tf/scatter_nd

Args:

  • indices
    A Tensor. Must be one of the following types: int32, int64. Index tensor.
  • updates
    A Tensor. Updates to scatter into output.
  • shape A Tensor. Must have the same type as indices. 1-D. The shape of the resulting tensor.
  • name
    A name for the operation (optional).

Upvotes: 0

Related Questions