alpatters
alpatters

Reputation: 31

Serializing a tensor and writing to tfrecord from within a graph

I would like to write tensorflow example records to a TFRecordWriter from inside an AutoGraph generated graph.

The documentation for tensorflow 2.0 states the following:

The simplest way to handle non-scalar features is to use tf.serialize_tensor to convert tensors to binary-strings. Strings are scalars in tensorflow.

However, tf.io.serialize_tensor returns a tensor of byte-string. Creating an Example proto requires a bytes list, not a tensor.

How do I write a tf.train.Example to a tf record from inside a graph?

Code to reproduce:

%tensorflow_version 2.x
import tensorflow as tf

@tf.function
def example_write():
  writer = tf.io.TFRecordWriter("test.tfr")
  x = tf.constant([[0, 1], [2, 3]])
  x = tf.io.serialize_tensor(x)
  feature = {
      "data": tf.train.Features(
        bytes_list=tf.train.BytesList(value=[x]))
  }
  ex = tf.train.Example(features=tf.train.Features(
      feature=feature))
  writer.write(ex.SerializeToString())

example_write()

and the error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-df8a97eb17c9> in <module>()
     12   writer.write(ex.SerializeToString())
     13 
---> 14 example_write()

8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

TypeError: in user code:

    <ipython-input-6-df8a97eb17c9>:6 example_write  *
        feature = {

    TypeError: <tf.Tensor 'SerializeTensor:0' shape=() dtype=string> has type Tensor, but expected one of: bytes

Upvotes: 3

Views: 984

Answers (1)

Anatoly Alekseev
Anatoly Alekseev

Reputation: 2400

It's pretty straightforward: use x = tf.io.serialize_tensor(x).numpy()

Upvotes: -2

Related Questions