Yuval Atzmon
Yuval Atzmon

Reputation: 5945

Tensorflow: Is it possible to store TF record sequence examples as float16

Is it possible to store sequence example in tensorflow as float16 instead of regular float?

We can live with 16bit precision, and it will reduce the size of the data files we use, saving us ~200 GB.

Upvotes: 1

Views: 1705

Answers (1)

pedosb
pedosb

Reputation: 111

I think the snip below does just that.

import tensorflow as tf
import numpy as np

# generate the data
data_np = np.array(np.random.rand(10), dtype=np.float16)

with tf.python_io.TFRecordWriter('/tmp/data.tfrecord') as writer:
    # encode the data in a dictionary of features
    data = {'raw': tf.train.Feature(
        # the feature has a type ByteList
        bytes_list=tf.train.BytesList(
            # encode the data into bytes
            value=[data_np.tobytes()]))}
    # create a example from the features
    example = tf.train.Example(features=tf.train.Features(feature=data))
    # write the example to a TFRecord file
    writer.write(example.SerializeToString())

def _parse_tfrecord(example_proto):
    # describe how the TFRecord example will be interpreted
    features = {'raw': tf.FixedLenFeature((), tf.string)}
    # parse the example (dict of features) from the TFRecord
    parsed_features = tf.parse_single_example(example_proto, features)
    # decode the bytes as float16 array
    return tf.decode_raw(parsed_features['raw'], tf.float16)

def tfrecord_input_fn():
    # read the dataset
    dataset = tf.data.TFRecordDataset('/tmp/data.tfrecord')
    # parse each example of the dataset
    dataset = dataset.map(_parse_tfrecord)
    iterator = dataset.make_one_shot_iterator()

    return iterator.get_next()

# get an iterator over the TFRecord
it = tfrecord_input_fn()
# make a session and evaluates the Tensor
sess = tf.Session()
recovered_data = sess.run(it)
print(recovered_data == data_np)

Upvotes: 1

Related Questions