Reputation: 22617
Goal
I have a tf.data.Dataset
where some of the examples are too long (the size of the 0 axis is too big). I'd like to split these overly long examples into several examples, where each is a chunk of the original example. If a particular example is not divisible by the desired chunk size, I'd like to truncate the remainder.
As an example, if a numpy view of the original dataset looked like this (5 elements):
>>> print(list(dataset.as_numpy_iterator()))
[array([25], dtype=int32),
array([ 6, 91], dtype=int32),
array([15, 30, 96], dtype=int32),
array([14, 45, 27, 72], dtype=int32),
array([ 7, 75, 89, 47, 66], dtype=int32)]
and the desired chunk size is 2, I expect as a result a new dataset as follows (7 elements):
>>> new_dataset = chunk_dataset(dataset, chunk_size=2)
>>> print(list(new_dataset.as_numpy_iterator()))
[array([25], dtype=int32),
array([ 6, 91], dtype=int32),
array([15, 30], dtype=int32),
array([14, 45], dtype=int32),
array([27, 72], dtype=int32),
array([7, 75], dtype=int32)]
array([89, 47], dtype=int32)]
Problem
I am unable to write a chunking function that works with a tf.data.Dataset
where all operations run in graph mode (as opposed to eager execution). Depending on the exact chunking function I tried, I am running into different errors.
Please note that I do know how to achieve this outside of graph mode, for instance in numpy or with tf eager execution. I'd like to write it as a tf.data.Dataset
operation for efficient preprocessing of my examples.
Code
See also this Colab notebook to reproduce my problem.
import tensorflow as tf
import numpy as np
from typing import List, Callable
"""## Code for chunking"""
def chunk_tensor_v1(input_tensor: tf.Tensor,
chunk_size: int) -> List[tf.Tensor]:
tensor_chunks = [] # type: List[tf.Tensor]
while tf.shape(input_tensor)[0] >= chunk_size:
chunk = input_tensor[:chunk_size]
tensor_chunks.append(chunk)
input_tensor = input_tensor[chunk_size:]
return tensor_chunks
def chunk_tensor_v2(input_tensor: tf.Tensor,
chunk_size: int) -> List[tf.Tensor]:
frames = input_tensor.shape[0]
if frames > chunk_size:
remainder = frames % chunk_size
else:
remainder = 0
if remainder != 0:
input_tensor = input_tensor[:-remainder]
num_splits = max(frames // chunk_size, 1)
return tf.split(input_tensor, num_splits, axis=0)
def chunk_example(example: tf.Tensor,
chunk_size: int,
chunking_function: Callable) -> tf.data.Dataset:
tensor_chunks = chunking_function(example, chunk_size=chunk_size)
return tf.data.Dataset.from_tensor_slices(tensor_chunks)
def chunk_dataset(dataset: tf.data.Dataset, chunk_size: int, chunking_function: Callable) -> tf.data.Dataset:
dataset = dataset.map(lambda example: chunk_example(example=example, chunk_size=chunk_size, chunking_function=chunking_function))
dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
"""## Code to create a dummy dataset"""
def create_dataset_with_single_example(size: int):
t = tf.random.uniform((size,), minval=0, maxval=100, dtype=tf.dtypes.int32)
d = tf.data.Dataset.from_tensors(t)
return d
def create_dataset(num_examples: int) -> tf.data.Dataset:
examples = [create_dataset_with_single_example(n + 1) for n in range(num_examples)]
dataset = tf.data.Dataset.from_tensor_slices(examples)
dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
"""## Testing the chunking code with the dummy dataset"""
num_examples = 5
dataset = create_dataset(num_examples)
print(list(dataset.as_numpy_iterator()))
chunk_dataset(dataset, chunk_size=2, chunking_function=chunk_tensor_v1)
chunk_dataset(dataset, chunk_size=2, chunking_function=chunk_tensor_v2)
Errors
Using chunk_tensor_v1
leads to
InaccessibleTensorError: tf.Graph captured an external symbolic tensor. The symbolic tensor <tf.Tensor 'while/strided_slice:0' shape=(None,) dtype=int32> is captured by FuncGraph(name=Dataset_map_lambda, id=140570786598224), but it is defined at FuncGraph(name=while_body_485049, id=140570787725264). A tf.Graph is not allowed to capture symoblic tensors from another graph. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
and chunk_tensor_v2
leads to
TypeError: '>' not supported between instances of 'NoneType' and 'int'
If someone knows how to further simplify my problem I am glad to edit the question.
Upvotes: 3
Views: 900
Reputation: 26708
A bit tricky but definitely possible! You could try something like this:
Core part of the code (can probably be simplified):
dataset1 = dataset.filter(lambda x: tf.less_equal(tf.shape(x)[0], chunk_size))
dataset2 = dataset.filter(lambda x: tf.greater(tf.shape(x)[0], chunk_size))
def body(i, m, n):
n = n.write(n.size(), m[i:i+chunk_size])
return tf.add(i,chunk_size), m, n
def split_data(data, chunk_size):
length = tf.shape(data)[0]
x = data[:(length // chunk_size) * chunk_size]
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
i0 = tf.constant(0)
c = lambda i, m, n: tf.less(i, tf.shape(x)[0] - 1)
_, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
return out.stack()
dataset2 = dataset2.map(lambda x: split_data(x, chunk_size))
dataset2 = dataset2.flat_map(tf.data.Dataset.from_tensor_slices)
dataset = dataset1.concatenate(dataset2)
Whole code:
import tensorflow as tf
tf.random.set_seed(456)
def create_dataset_with_single_example(size: int):
t = tf.random.uniform((size,), minval=0, maxval=100, dtype=tf.dtypes.int32)
d = tf.data.Dataset.from_tensors(t)
return d
def create_dataset(num_examples: int) -> tf.data.Dataset:
examples = [create_dataset_with_single_example(n + 1) for n in range(num_examples)]
dataset = tf.data.Dataset.from_tensor_slices(examples)
dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
num_examples = 5
chunk_size = 2
dataset = create_dataset(num_examples)
print('Before --> \n')
for d in dataset:
print(d)
dataset1 = dataset.filter(lambda x: tf.less_equal(tf.shape(x)[0], chunk_size))
dataset2 = dataset.filter(lambda x: tf.greater(tf.shape(x)[0], chunk_size))
def body(i, m, n):
n = n.write(n.size(), m[i:i+chunk_size])
return tf.add(i,chunk_size), m, n
def split_data(data, chunk_size):
length = tf.shape(data)[0]
x = data[:(length // chunk_size) * chunk_size]
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
i0 = tf.constant(0)
c = lambda i, m, n: tf.less(i, tf.shape(x)[0] - 1)
_, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
return out.stack()
dataset2 = dataset2.map(lambda x: split_data(x, chunk_size))
dataset2 = dataset2.flat_map(tf.data.Dataset.from_tensor_slices)
dataset = dataset1.concatenate(dataset2)
print('\nAfter --> \n')
for d in dataset:
print(d)
Before -->
tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86 2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84 31], shape=(5,), dtype=int32)
After -->
tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86], shape=(2,), dtype=int32)
tf.Tensor([54 78], shape=(2,), dtype=int32)
tf.Tensor([20 93], shape=(2,), dtype=int32)
tf.Tensor([51 87], shape=(2,), dtype=int32)
tf.Tensor([96 84], shape=(2,), dtype=int32)
chunk_size = 3
:
Before -->
tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86 2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84 31], shape=(5,), dtype=int32)
After -->
tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86 2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20], shape=(3,), dtype=int32)
tf.Tensor([51 87 96], shape=(3,), dtype=int32)
chunk_size = 4
:
Before -->
tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86 2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84 31], shape=(5,), dtype=int32)
After -->
tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86 2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84], shape=(4,), dtype=int32)
Upvotes: 4