Reputation: 81
Tensorflow recently released Ragged Tensors: https://www.tensorflow.org/guide/ragged_tensors
But there isn't any documentation on how to save ragged data as a TFRecord, and restore it using the data api.
Upvotes: 5
Views: 4786
Reputation: 31
Yes its possible.
https://www.tensorflow.org/api_docs/python/tf/io/RaggedFeature
When you create the tf.example
create one feature with the flat values and other with the row_splits. As example, assumed are named values
and row_splits
respectively.
Then use the parse single example like this to get the ragged tensor.
tf.io.parse_single_example(serialized_example, features={
'rt': tf.io.RaggedFeature(
dtype=tf.float32, # Use the corresponding type for the values.
value_key='values',
partitions=(tf.io.RaggedFeature.RowSplits('row_splits')),
row_splits_dtype=tf.int64,
),
})
Upvotes: 0
Reputation: 1152
A ragged tensor needs two arrays: values
and something
defining how the values
should be split into rows (e.g. row_splits
, row_lengths
, ... see the docs). My take is to store these two arrays as two features in a tf.Example and create the ragged tensor when loading the files.
For example:
import tensorflow as tf
def serialize_example(vals, lens):
vals = tf.train.Feature(int64_list=tf.train.Int64List(value=vals))
lens = tf.train.Feature(int64_list=tf.train.Int64List(value=lens))
example = tf.train.Example(features=tf.train.Features(
feature={'vals': vals, 'lens': lens})
)
return example.SerializeToString()
def parse_example(raw_example):
example = tf.io.parse_single_example(raw_example, {
'vals':tf.io.VarLenFeature(dtype=tf.int64),
'lens':tf.io.VarLenFeature(dtype=tf.int64)
})
return tf.RaggedTensor.from_row_lengths(
example['vals'].values, row_lengths=example['lens'].values
)
ex1 = serialize_example([1,2,3,4,5,6,7,8,9,10], [3,2,5])
print(parse_example(ex1)) # <tf.RaggedTensor [[1, 2, 3], [4, 5], [6, 7, 8, 9, 10]]>
ex2 = serialize_example([1,2,3,4,5,6,7,8], [2,2,4])
print(parse_example(ex2)) # <tf.RaggedTensor [[1, 2], [3, 4], [5, 6, 7, 8]]>
When creating a dataset from TFRecord files, one would apply the parse_example
as a transformation by passing it to the Dataset.map()
function.
Upvotes: 2
Reputation: 11
Unfortunately, there is no RaggedFeature or equivalent. Your best bet is probably to convert to sparse (via to_sparse()
) and encode your data as SparseFeature. After decoding, you can convert back to ragged via the from_sparse()
builder.
Upvotes: 1