Manal
Manal

Reputation: 13

Tensorflow dataset batching for complex data

I tried to follow the example in this link:

https://www.tensorflow.org/programmers_guide/datasets

but I am totally lost about how to run the session. I understand the first argument is the operations to run, and feed_dict is the placeholders (my understanding is the batches of the training or test dataset),

So, here is my code:

batch_size = 100
handle_mix = tf.placeholder(tf.float64, shape=[])
handle_src0 = tf.placeholder(tf.float64, shape=[])
handle_src1 = tf.placeholder(tf.float64, shape=[])
handle_src2 = tf.placeholder(tf.float64, shape=[])
handle_src3 = tf.placeholder(tf.float64, shape=[])

I create the dataset from mp4 tracks and stems, reading mixture and sources magnitudes, and pad them to be suitable to batching

dataset = tf.data.Dataset.from_tensor_slices(
    {"x_mixed":padded_lbl, "y_src0": padded_src[0], "y_src1":      
    padded_src[1],"y_src2": padded_src[1], "y_src3": padded_src[1]})
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)

from the example I should do:

next_element = iterator.get_next()

training_init_op = iterator.make_initializer(dataset)
for _ in range(20):
    # Initialize an iterator over the training dataset.
    sess.run(training_init_op)
    for _ in range(100):
        sess.run(next_element)

However, I have a loss, summaries, and optimiser operations and need to feed the data as batches, following another example as:

l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict=    {handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

So I thought something like:

batch_mix, batch_src0, batch_src1, batch_src2, batch_src3 = data.train.next_batch(batch_size) or maybe a separate run to fetch the batches first, then run the optimisation as above, such as:

batch_mix, batch_src0, batch_src1, batch_src2, batch_src3 = sess.run(next_element)
l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

This last attempt, returned string names of the batches as created in the tf.data.Dataset.from_tensor_slices ("x_mixed", "y_src0", ... etc) and failed to cast to tf.float64 placeholders in the session.

Can you please let me know how to create this dataset, there might be an error in the structure from tensor slices in the first place, then how to batch them,

thank you very much,

Upvotes: 1

Views: 264

Answers (1)

xdurch0
xdurch0

Reputation: 10474

The issue is that you packed your data into a dict when creating the dataset from tensor slices. This will result in iterator.get_next() returning each batch as a dict as well. If we do something like

d = {"a": 1, "b": 2}
k1, k2 = d

we get k1 == "a" and k2 == "b" (or the other way around due to unordered dict keys). That is, your attempt at unpacking the result of sess.run(next_element) just gives you the dict keys whereas you are interested in the dict values (tensors). This should work instead:

next_element = iterator.get_next()
x_mixed = next_element["x_mixed"]
y_src0 = next_element["y_src0"]
...

If you then build your model based on the variables x_mixed etc, it should work fine. Note that with the tf.data API you don't need placeholders! Tensorflow will see that your model output requires e.g. x_mixed, which is gotten from iterator.get_next(), so it will simply execute this op whenever you try to sess.run() your loss function/optimizer etc. If you're more comfortable with placeholders you can of course keep using them, just remember to unpack the dict properly. This should be about right:

batch_dict = sess.run(next_element)
l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_dict["x_mixed"], ... })

Upvotes: 2

Related Questions