Lukasz Tracewski
Lukasz Tracewski

Reputation: 11377

How to shape TFRecordDataset to meet Model API?

I am building a model based on this code for noise suppression. My problem with the vanilla implementation is that it loads all data at once, which is not the best idea when the training data gets really large; my input file, denoted in the linked code as training.h5, is over 30 GB.

I decided to instead go with tf.data interface that should allow me to work with large data sets; my problem here is that I don't know how to properly shape TFRecordDataset so that it meets what's required by the Model API.

If you check model.fit(x_train, [y_train, vad_train], it essentially requires the following:

window one typically fixes (in the code: 2000), so the only variable nb_sequences that stems from how large is your data set. However, with tf.data, we don't supply x and y, but only x (see Model API docs).

Saving tfrecord to file

In an effort to make the code reproducible, I created the input file with the following code:

writer = tf.io.TFRecordWriter(path='example.tfrecord')
for record in data:
    feature = {}
    feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=record[:42]))
    feature['y'] = tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64]))
    feature['vad'] = tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]]))
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    serialized = example.SerializeToString()
    writer.write(serialized)
writer.close()

data is our training data with shape [10000, 65]. My example.tfrecord is available here. It's 3 MB, in reality it would be 30 GB+.

You might notice that in the linked code, numpy array has shape [x, 87], while mine is [x, 65]. That's OK - the remainder is not used anywhere.

Loading the dataset with tf.data.TFRecordDataset

I would like to use tf.data to load "on demand" the data with some prefetching, there's no need to keep it all in memory. My attempt:

import datetime
import numpy as np
import h5py
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import GRU
from tensorflow.keras import regularizers
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.layers import concatenate

def load_dataset(path):
    def _parse_function(example_proto):
        keys_to_features = {
            'X': tf.io.FixedLenFeature([42], tf.float32),
            'y': tf.io.FixedLenFeature([22], tf.float32),
            'vad': tf.io.FixedLenFeature([1], tf.float32)
        }
        features = tf.io.parse_single_example(example_proto, keys_to_features)
        return (features['X'], (features['y'], features['vad']))

    dataset = tf.data.TFRecordDataset(path).map(_parse_function)
    return dataset


def my_crossentropy(y_true, y_pred):
    return K.mean(2 * K.abs(y_true - 0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)


def mymask(y_true):
    return K.minimum(y_true + 1., 1.)


def msse(y_true, y_pred):
    return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)


def mycost(y_true, y_pred):
    return K.mean(mymask(y_true) * (10 * K.square(K.square(K.sqrt(y_pred) - K.sqrt(y_true))) + K.square(
        K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01 * K.binary_crossentropy(y_pred, y_true)), axis=-1)


def my_accuracy(y_true, y_pred):
    return K.mean(2 * K.abs(y_true - 0.5) * K.equal(y_true, K.round(y_pred)), axis=-1)


class WeightClip(Constraint):
    '''Clips the weights incident to each hidden unit to be inside a range
    '''

    def __init__(self, c=2.0):
        self.c = c

    def __call__(self, p):
        return K.clip(p, -self.c, self.c)

    def get_config(self):
        return {'name': self.__class__.__name__,
                'c': self.c}

def build_model():
    reg = 0.000001
    constraint = WeightClip(0.499)
    main_input = Input(shape=(None, 42), name='main_input')
    tmp = Dense(24, activation='tanh', name='input_dense', kernel_constraint=constraint, bias_constraint=constraint)(
        main_input)
    vad_gru = GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='vad_gru',
                  kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg),
                  kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(tmp)
    vad_output = Dense(1, activation='sigmoid', name='vad_output', kernel_constraint=constraint,
                       bias_constraint=constraint)(vad_gru)
    noise_input = concatenate([tmp, vad_gru, main_input])
    noise_gru = GRU(48, activation='relu', recurrent_activation='sigmoid', return_sequences=True, name='noise_gru',
                    kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg),
                    kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(noise_input)
    denoise_input = concatenate([vad_gru, noise_gru, main_input])

    denoise_gru = GRU(96, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='denoise_gru',
                      kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg),
                      kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(
        denoise_input)

    denoise_output = Dense(22, activation='sigmoid', name='denoise_output', kernel_constraint=constraint,
                           bias_constraint=constraint)(denoise_gru)

    model = Model(inputs=main_input, outputs=[denoise_output, vad_output])

    model.compile(loss=[mycost, my_crossentropy],
                  metrics=[msse],
                  optimizer='adam', loss_weights=[10, 0.5])
    return model

model = build_model()
dataset = load_dataset('example.tfrecord')

My dataset has now the following shape:

<MapDataset shapes: ((42,), ((22,), (1,))), types: (tf.float32, (tf.float32, tf.float32))>

which I thought is what Model API expects (spoiler: it doesn't).

model.fit(dataset.batch(10))

gives following error:

ValueError: Error when checking input: expected main_input to have 3 dimensions, but got array with shape (None, 42)

Makes sense, I don't have the window here. At the same time it seems like it's not getting correct shape expected by Model(inputs=main_input, outputs=[denoise_output, vad_output]).

How to modify load_dataset so that it matches what's expected by the Model API for the tf.data?

Upvotes: 1

Views: 531

Answers (1)

sebastian-sz
sebastian-sz

Reputation: 1488

Given that your model has 1 input and 2 outputs, your tf.data.Dataset should have two entries:
1) Input array of shape (window, 42)
2) Tuple of two arrays each of shape (window, 22) and (window, 1)

EDIT: Updated answer - you already return two element tuple

I just noticed that your dataset has these two entries (similar to those described above) and the only thing that differs is the shape.
The only operations you need to perfom is to batch your data twice:
First - to restore the window parameter. Second - to pass a batch to a model.

window_size = 1
batch_size = 10
dataset = load_dataset('example.tfrecord')
model.fit(dataset.batch(window_size).batch(batch_size)

And that should work.

Below is an old answer, where I wrongfully assumed your dataset shape:

Old Answer, where I assumed you are returning three element tuple:

Assuming that you are starting from three element tuple of shapes (42,), (22,) and (1,), this can be achieved in the same batching operations, enriched with a custom_reshape function to return two-element tuple:

window_size = 1
batch_size = 10
dataset = load_dataset('example.tfrecord')
dataset = dataset.batch(window_size).batch(batch_size)

# Change output format
def custom_reshape(x, y, vad):
    return x, (y, vad)

dataset = dataset.map(custom_reshape)

In short, given this dataset shape, you could just call: model.fit(dataset.batch(window_size).batch(10).map(custom_reshape)
and it should work too.

Best of luck. And sorry again for the fuss.

Upvotes: 2

Related Questions