Bas Smit
Bas Smit

Reputation: 685

How to convert string tensor (obtained from .tfrecords) to float tensor?

The input to my network comes from files containing int32's. They are stored as .tfrecords as follows:

  writer = tf.python_io.TFRecordWriter(output_file)
  with tf.gfile.FastGFile(file_path, 'rb') as f:
    data = f.read()
    example = tf.train.Example(features=tf.train.Features(feature={
      'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data])) }))
    writer.write(example.SerializeToString())

I then read the tfrecords file like so:

with tf.name_scope(self.name):
  filename_queue = tf.train.string_input_producer([path])
  reader = tf.TFRecordReader()

  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      features={ 'data': tf.FixedLenFeature([], tf.string) })

  data = features['data']

After reading tfrecords I have string tensors like this:

Tensor("X/ParseSingleExample/ParseSingleExample:0", shape=(), dtype=string)

I would like to first convert this to int32's as that is what the initial data represents. After that I need to end up with a tensor of floats, can someone point me in the right direction?

PS Im new to tensorflow, please let me know if I can provide more useful information

Upvotes: 1

Views: 1022

Answers (1)

Sharky
Sharky

Reputation: 4543

This should help

data = features['data']
decoded = tf.decode_raw(data, tf.int32)

This will output tensor of dtype tf.int32. Then you can reshape it and cast to tf.float32

decoded = tf.reshape(decoded, shape)
decoded = tf.cast(decoded, tf.float32)

If you want inspect contents of a tfrecords file outside tf.Session

for str_rec in tf.python_io.tf_record_iterator('file.tfrecords'):
    example = tf.train.Example()
    example.ParseFromString(str_rec)
    data_str = example.features.feature['data'].bytes_list.value[0])
    decoded = np.fromstring(data_str, dtype)

To verify the content of a tensor you can inject a print node in the graph as explained in this answer

# Add print operation
decoded = tf.Print(decoded, [decoded])

Upvotes: 3

Related Questions