Mathias Müller
Mathias Müller

Reputation: 22617

Split examples of a tensorflow tf.data dataset in graph execution mode

Goal

I have a tf.data.Dataset where some of the examples are too long (the size of the 0 axis is too big). I'd like to split these overly long examples into several examples, where each is a chunk of the original example. If a particular example is not divisible by the desired chunk size, I'd like to truncate the remainder.

As an example, if a numpy view of the original dataset looked like this (5 elements):

>>> print(list(dataset.as_numpy_iterator()))
[array([25], dtype=int32),
 array([ 6, 91], dtype=int32),
 array([15, 30, 96], dtype=int32),
 array([14, 45, 27, 72], dtype=int32),
 array([ 7, 75, 89, 47, 66], dtype=int32)]

and the desired chunk size is 2, I expect as a result a new dataset as follows (7 elements):

>>> new_dataset = chunk_dataset(dataset, chunk_size=2)
>>> print(list(new_dataset.as_numpy_iterator()))
[array([25], dtype=int32),
 array([ 6, 91], dtype=int32),
 array([15, 30], dtype=int32),
 array([14, 45], dtype=int32),
 array([27, 72], dtype=int32),
 array([7, 75], dtype=int32)]
 array([89, 47], dtype=int32)]

Problem

I am unable to write a chunking function that works with a tf.data.Dataset where all operations run in graph mode (as opposed to eager execution). Depending on the exact chunking function I tried, I am running into different errors.

Please note that I do know how to achieve this outside of graph mode, for instance in numpy or with tf eager execution. I'd like to write it as a tf.data.Dataset operation for efficient preprocessing of my examples.

Code

See also this Colab notebook to reproduce my problem.

import tensorflow as tf
import numpy as np

from typing import List, Callable

"""## Code for chunking"""

def chunk_tensor_v1(input_tensor: tf.Tensor,
                    chunk_size: int) -> List[tf.Tensor]:

    tensor_chunks = []  # type: List[tf.Tensor]

    while tf.shape(input_tensor)[0] >= chunk_size:
        chunk = input_tensor[:chunk_size]
        tensor_chunks.append(chunk)
        input_tensor = input_tensor[chunk_size:]

    return tensor_chunks

def chunk_tensor_v2(input_tensor: tf.Tensor,
                    chunk_size: int) -> List[tf.Tensor]:

    frames = input_tensor.shape[0]

    if frames > chunk_size:
        remainder = frames % chunk_size
    else:
        remainder = 0

    if remainder != 0:
        input_tensor = input_tensor[:-remainder]

    num_splits = max(frames // chunk_size, 1)

    return tf.split(input_tensor, num_splits, axis=0)

def chunk_example(example: tf.Tensor,
                  chunk_size: int,
                  chunking_function: Callable) -> tf.data.Dataset:

    tensor_chunks = chunking_function(example, chunk_size=chunk_size)

    return tf.data.Dataset.from_tensor_slices(tensor_chunks)

def chunk_dataset(dataset: tf.data.Dataset, chunk_size: int, chunking_function: Callable) -> tf.data.Dataset:

    dataset = dataset.map(lambda example: chunk_example(example=example, chunk_size=chunk_size, chunking_function=chunking_function))
    dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)

    return dataset

"""## Code to create a dummy dataset"""

def create_dataset_with_single_example(size: int):
  t = tf.random.uniform((size,), minval=0, maxval=100, dtype=tf.dtypes.int32)
  d = tf.data.Dataset.from_tensors(t)

  return d

def create_dataset(num_examples: int) -> tf.data.Dataset:
  examples = [create_dataset_with_single_example(n + 1) for n in range(num_examples)]

  dataset = tf.data.Dataset.from_tensor_slices(examples)
  dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)

  return dataset

"""## Testing the chunking code with the dummy dataset"""

num_examples = 5

dataset = create_dataset(num_examples)

print(list(dataset.as_numpy_iterator()))

chunk_dataset(dataset, chunk_size=2, chunking_function=chunk_tensor_v1)

chunk_dataset(dataset, chunk_size=2, chunking_function=chunk_tensor_v2)

Errors

Using chunk_tensor_v1 leads to

InaccessibleTensorError: tf.Graph captured an external symbolic tensor. The symbolic tensor <tf.Tensor 'while/strided_slice:0' shape=(None,) dtype=int32> is captured by FuncGraph(name=Dataset_map_lambda, id=140570786598224), but it is defined at FuncGraph(name=while_body_485049, id=140570787725264). A tf.Graph is not allowed to capture symoblic tensors from another graph. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

and chunk_tensor_v2 leads to

TypeError: '>' not supported between instances of 'NoneType' and 'int'

If someone knows how to further simplify my problem I am glad to edit the question.

Upvotes: 3

Views: 900

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

A bit tricky but definitely possible! You could try something like this:

Core part of the code (can probably be simplified):

dataset1 = dataset.filter(lambda x: tf.less_equal(tf.shape(x)[0], chunk_size))
dataset2 = dataset.filter(lambda x: tf.greater(tf.shape(x)[0], chunk_size))

def body(i, m, n):
  n = n.write(n.size(), m[i:i+chunk_size])
  return tf.add(i,chunk_size), m, n 

def split_data(data, chunk_size):
    length = tf.shape(data)[0]
    x = data[:(length // chunk_size) * chunk_size]
    ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    i0 = tf.constant(0)
    c = lambda i, m, n: tf.less(i, tf.shape(x)[0] - 1)
    _, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
    return out.stack()

dataset2 = dataset2.map(lambda x: split_data(x, chunk_size))
dataset2 = dataset2.flat_map(tf.data.Dataset.from_tensor_slices)
dataset = dataset1.concatenate(dataset2)

Whole code:

import tensorflow as tf
tf.random.set_seed(456)

def create_dataset_with_single_example(size: int):
  t = tf.random.uniform((size,), minval=0, maxval=100, dtype=tf.dtypes.int32)
  d = tf.data.Dataset.from_tensors(t)

  return d

def create_dataset(num_examples: int) -> tf.data.Dataset:
  examples = [create_dataset_with_single_example(n + 1) for n in range(num_examples)]

  dataset = tf.data.Dataset.from_tensor_slices(examples)
  dataset = dataset.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)

  return dataset

num_examples = 5
chunk_size = 2
dataset = create_dataset(num_examples)
print('Before --> \n')
for d in dataset:
  print(d)

dataset1 = dataset.filter(lambda x: tf.less_equal(tf.shape(x)[0], chunk_size))
dataset2 = dataset.filter(lambda x: tf.greater(tf.shape(x)[0], chunk_size))

def body(i, m, n):
  n = n.write(n.size(), m[i:i+chunk_size])
  return tf.add(i,chunk_size), m, n 

def split_data(data, chunk_size):
    length = tf.shape(data)[0]
    x = data[:(length // chunk_size) * chunk_size]
    ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    i0 = tf.constant(0)
    c = lambda i, m, n: tf.less(i, tf.shape(x)[0] - 1)
    _, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
    return out.stack()

dataset2 = dataset2.map(lambda x: split_data(x, chunk_size))
dataset2 = dataset2.flat_map(tf.data.Dataset.from_tensor_slices)
dataset = dataset1.concatenate(dataset2)

print('\nAfter --> \n')
for d in dataset:
  print(d)
Before --> 

tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86  2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84 31], shape=(5,), dtype=int32)

After --> 

tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86], shape=(2,), dtype=int32)
tf.Tensor([54 78], shape=(2,), dtype=int32)
tf.Tensor([20 93], shape=(2,), dtype=int32)
tf.Tensor([51 87], shape=(2,), dtype=int32)
tf.Tensor([96 84], shape=(2,), dtype=int32)

chunk_size = 3:

Before --> 

tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86  2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84 31], shape=(5,), dtype=int32)

After --> 

tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86  2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20], shape=(3,), dtype=int32)
tf.Tensor([51 87 96], shape=(3,), dtype=int32)

chunk_size = 4:

Before --> 

tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86  2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84 31], shape=(5,), dtype=int32)

After --> 

tf.Tensor([44], shape=(1,), dtype=int32)
tf.Tensor([23 10], shape=(2,), dtype=int32)
tf.Tensor([41 86  2], shape=(3,), dtype=int32)
tf.Tensor([54 78 20 93], shape=(4,), dtype=int32)
tf.Tensor([51 87 96 84], shape=(4,), dtype=int32)

Upvotes: 4

Related Questions