zer0fool
zer0fool

Reputation: 81

How do you use Ragged Tensors with tf.data and TFRecords?

Tensorflow recently released Ragged Tensors: https://www.tensorflow.org/guide/ragged_tensors

But there isn't any documentation on how to save ragged data as a TFRecord, and restore it using the data api.

Upvotes: 5

Views: 4786

Answers (3)

Yes its possible. https://www.tensorflow.org/api_docs/python/tf/io/RaggedFeature When you create the tf.example create one feature with the flat values and other with the row_splits. As example, assumed are named values and row_splits respectively.

Then use the parse single example like this to get the ragged tensor.

tf.io.parse_single_example(serialized_example, features={
  'rt': tf.io.RaggedFeature(
      dtype=tf.float32,  # Use the corresponding type for the values.
      value_key='values',
      partitions=(tf.io.RaggedFeature.RowSplits('row_splits')), 
      row_splits_dtype=tf.int64,              
  ),
})

Upvotes: 0

SheepPerplexed
SheepPerplexed

Reputation: 1152

A ragged tensor needs two arrays: values and something defining how the values should be split into rows (e.g. row_splits, row_lengths, ... see the docs). My take is to store these two arrays as two features in a tf.Example and create the ragged tensor when loading the files.

For example:

import tensorflow as tf

def serialize_example(vals, lens):
  vals = tf.train.Feature(int64_list=tf.train.Int64List(value=vals))
  lens = tf.train.Feature(int64_list=tf.train.Int64List(value=lens))
  example = tf.train.Example(features=tf.train.Features(
      feature={'vals': vals, 'lens': lens})
  )
  return example.SerializeToString()

def parse_example(raw_example):
  example = tf.io.parse_single_example(raw_example, {
      'vals':tf.io.VarLenFeature(dtype=tf.int64),
      'lens':tf.io.VarLenFeature(dtype=tf.int64)
  })
  return tf.RaggedTensor.from_row_lengths(
      example['vals'].values, row_lengths=example['lens'].values
  )

ex1 = serialize_example([1,2,3,4,5,6,7,8,9,10], [3,2,5])
print(parse_example(ex1))  # <tf.RaggedTensor [[1, 2, 3], [4, 5], [6, 7, 8, 9, 10]]>
ex2 = serialize_example([1,2,3,4,5,6,7,8], [2,2,4])
print(parse_example(ex2))  # <tf.RaggedTensor [[1, 2], [3, 4], [5, 6, 7, 8]]>

When creating a dataset from TFRecord files, one would apply the parse_example as a transformation by passing it to the Dataset.map() function.

Upvotes: 2

Nick
Nick

Reputation: 11

Unfortunately, there is no RaggedFeature or equivalent. Your best bet is probably to convert to sparse (via to_sparse()) and encode your data as SparseFeature. After decoding, you can convert back to ragged via the from_sparse() builder.

Upvotes: 1

Related Questions