Mr. Fegur
Mr. Fegur

Reputation: 797

Reading multiple feature vectors from one TFRecord example in Tensorflow

I know how to store one feature per example inside a tfrecord file and then read it by using something like this:

import tensorflow as tf
import numpy as np
import os


# This is used to parse an example from tfrecords
def parse(serialized_example):
  features = tf.parse_single_example(
    serialized_example,
    features ={
      "label": tf.FixedLenFeature([], tf.string, default_value=""),
      "feat": tf.FixedLenFeature([], tf.string, default_value="")
    })

  feat = tf.decode_raw(features['feat'], tf.float64)
  label = tf.decode_raw(features['label'], tf.int64)

  return feat, label


################# Generate data

cwd = os.getcwd()
numdata = 10
with tf.python_io.TFRecordWriter(os.path.join(cwd, 'data.tfrecords')) as writer:
    for i in range(numdata):
        feat = np.random.randn(2)
        label = np.array(np.random.randint(0,9))

        featb  = feat.tobytes()
        labelb = label.tobytes()
        import pudb.b
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'feat': tf.train.Feature(bytes_list=tf.train.BytesList(value=[featb])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[labelb])),}))
        writer.write(example.SerializeToString())

        print('wrote f {}, l {}'.format(feat, label))

print('Done writing! Start reading and printing data')

################# Read data

filename = ['data.tfrecords']
dataset = tf.data.TFRecordDataset(filename).map(parse)
dataset = dataset.batch(100)
iterator = dataset.make_initializable_iterator()
feat, label = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            example = sess.run((feat,label))
            print example
    except tf.errors.OutOfRangeError:
        pass

What do I do in the case where each example has multiple feature vectors + labels in it. For example, in the above code, if feat was stored as a 2D array. I still want to do the same thing as before, which is to train a DNN with one feature per label, but each example in the tfrecords file has multiple features and multiple labels. This should be simple but I'm having trouble unpacking multiple features in tensorflow using tfrecords.

Upvotes: 1

Views: 2393

Answers (1)

rlys
rlys

Reputation: 480

Firstly, note that np.ndarray.tobytes() flattens out multi-dimensional arrays into a list, i.e.

feat = np.random.randn(N, 2)
reshaped = np.reshape(feat, (N*2,))
feat.tobytes() == reshaped.tobytes()   ## True

So, if you have a N*2 array that's saved as bytes in TFRecord format, you have to reshape it after parsing.

If you do that, you can unbatch the elements of a tf.data.Dataset so that each iteration gives you one feature and one label. Your code should be as follows:

# This is used to parse an example from tfrecords
def parse(serialized_example):
  features = tf.parse_single_example(
    serialized_example,
    features ={
      "label": tf.FixedLenFeature([], tf.string, default_value=""),
      "feat": tf.FixedLenFeature([], tf.string, default_value="")
    })

  feat = tf.decode_raw(features['feat'], tf.float64)    # array of shape (N*2, )
  feat = tf.reshape(feat, (N, 2))                       # array of shape (N, 2)
  label = tf.decode_raw(features['label'], tf.int64)    # array of shape (N, )

  return feat, label


################# Generate data

cwd = os.getcwd()
numdata = 10
with tf.python_io.TFRecordWriter(os.path.join(cwd, 'data.tfrecords')) as writer:
    for i in range(numdata):
        feat = np.random.randn(N, 2)
        label = np.array(np.random.randint(0,9, N))

        featb  = feat.tobytes()
        labelb = label.tobytes()
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'feat': tf.train.Feature(bytes_list=tf.train.BytesList(value=[featb])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[labelb])),}))
        writer.write(example.SerializeToString())

        print('wrote f {}, l {}'.format(feat, label))

print('Done writing! Start reading and printing data')

################# Read data

filename = ['data.tfrecords']
dataset = tf.data.TFRecordDataset(filename).map(parse).apply(tf.contrib.data.unbatch())
... etc

Upvotes: 2

Related Questions