Reputation: 4209
I'm struggling to understand the scatter function in TensorFlow. For example I want to use tf.compat.v1.scatter_sub
to sub from the second index as following:
a = tf.Variable(tf.random.uniform(shape=[2]))
b = tf.Variable(tf.random.uniform(shape=[3, 2]))
Were a
goes:
<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([0.62174237, 0.7431344 ], dtype=float32)>
and b
goes:
<tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
array([[0.01709783, 0.72348535],
[0.48500955, 0.7092271 ],
[0.62199426, 0.26062095]], dtype=float32)>
I want to subtract a
from the second row of b
so that I can achieve this:
array([[0.01709783, 0.72348535 ],
[-0.13673282, -0.0339073],
[0.62199426, 0.26062095]], dtype=float32)>
I thought that tf.compat.v1.scatter_sub(b, [1], a)
must work but it failed. I tried transposing a
but it failed too. The full error is this:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-37-90775abdb544> in <module>()
9 print("------------------------")
10
---> 11 tf.compat.v1.scatter_sub(b, [1], a)
3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py in scatter_sub(ref, indices, updates, use_locking, name)
535 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
536 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
--> 537 name=name))
538
539
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub(resource, indices, updates, name)
1077 try:
1078 return resource_scatter_sub_eager_fallback(
-> 1079 resource, indices, updates, name=name, ctx=_ctx)
1080 except _core._SymbolicException:
1081 pass # Add nodes to the TensorFlow graph.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub_eager_fallback(resource, indices, updates, name, ctx)
1095 _attrs = ("dtype", _attr_dtype, "Tindices", _attr_Tindices)
1096 _result = _execute.execute(b"ResourceScatterSub", 0, inputs=_inputs_flat,
-> 1097 attrs=_attrs, ctx=ctx, name=name)
1098 _result = None
1099 return _result
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [2], indices.shape [1], params.shape [3,2] [Op:ResourceScatterSub]
What is the correct way to use this function ?
Upvotes: 0
Views: 1617
Reputation: 4209
I got it. The problem is the updates (which is a
here) suppose to be a list of several updates, but here I gave it just the vector a
itself, not a list that consists of only a
.
Now I should extend a
by one dimension. I mean a
is now [0.62174237, 0.7431344]
and I should change it to [[0.62174237, 0.7431344 ]]
and I can do this by tf.expand_dims
.
so the solution is:
tf.compat.v1.scatter_sub(b, [1], tf.expand_dims(a, axis=0))
Upvotes: 1
Reputation:
Have you had a look through the documentation here? https://www.tensorflow.org/api_docs/python/tf/scatter_nd
Args:
Upvotes: 0