user358829
user358829

Reputation: 761

Tensorflow custom op -- how do I read and write from Tensors?

I'm writing a custom Tensorflow op using the tutorial and I'm having trouble understanding how to read and write to/from Tensors.

let's say I have a Tensor in my OpKernel that I get from const Tensor& values_tensor = context->input(0); (where context = OpKernelConstruction*)

if that Tensor has shape, say, [2, 10, 20], how can I index into it (e.g. auto x = values_tensor[1, 4, 12], etc.)?

equivalently, if I have

Tensor *output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
  0,
  {batch_size, value_len - window_size, window_size},
  &output_tensor
));

how can I assign to output_tensor, like output_tensor[1, 2, 3] = 11, etc.?

sorry for the dumb question, but the docs are really tripping me up here and the examples in the Tensorflow kernel code for built-in ops somehow obfuscate this to the point that I get very confused :)

thank you!

Upvotes: 3

Views: 417

Answers (1)

mrry
mrry

Reputation: 126154

The easiest way to read from and write to tensorflow::Tensor objects is to convert them to an Eigen tensor, using the tensorflow::Tensor::tensor<T, NDIMS>() method. Note that you have to specify the (C++) type of elements in tensor as template parameter T.

For example, to read a particular value from a DT_FLOAT32 tensor:

const Tensor& values_tensor = context->input(0);
auto x = value_tensor.tensor<float, 3>()(1, 4, 12);

To write a particular value to a DT_FLOAT32 tensor:

Tensor* output_tensor = ...;
output_tensor->tensor<float, 3>()(1, 2, 3) = 11.0;

There are also convenience methods for accessing a scalar, vector, or matrix.

Upvotes: 1

Related Questions