greyBag
greyBag

Reputation: 11

How can I get the indices of the data used in every batch?

I need to save the indices of the data that are used in every mini-batch.

For example if my data is:

x = np.array([[1.1], [2.2], [3.3], [4.4]])

and the first mini-batch is [1.1] and [3.3], then I want to store 0 and 2 (since [1.1] is the 0th observations and [3.3] is the 2nd observation).

I am using tensorflow in eager execution with the keras.sequential APIs.

As far as I can tell from reading the source code, this information is not stored anywhere so I was unable to do this with a callback.

I am currently solving my problem by creating an object that stores the indices.

class IndexIterator(object):
    def __init__(self, n, n_epochs, batch_size, shuffle=True):
        data_ix = np.arange(n)
        if shuffle:
            np.random.shuffle(data_ix)

        self.ix_batches = np.array_split(data_ix, np.ceil(n / batch_size))
        self.batch_indices = []

    def generate_arrays(self, x, y):
        batch_ixs = np.arange(len(self.ix_batches))
        while 1: 
            np.random.shuffle(batch_ixs)
            for batch in batch_ixs:
                self.batch_indices.append(self.ix_batches[batch])
                yield (x[self.ix_batches[batch], :], y[self.ix_batches[batch], :])

data_gen = IndexIterator(n=32, n_epochs=100, batch_size=16)
dnn.fit_generator(data_gen.generate_arrays(x, y), 
                  steps_per_epoch=2, 
                  epochs=100)
# This is what I am looking for
print(data_gen.batch_indices)

Is there no way to do this using a tensorflow callback?

Upvotes: 1

Views: 3018

Answers (1)

zephyrus
zephyrus

Reputation: 1266

Not sure if this will be more efficient than your solution, but is certainly more general.

If you have training data with n indices you can create a secondary Dataset that contains only these indices and zip it with the "real" dataset.

I.E.

real_data = tf.data.Dataset ... 
indices = tf.data.Dataset.from_tensor_slices(tf.range(data_set_length)))
total_dataset = tf.data.Dataset.zip((real_data, indices))

# Perform optional pre-processing ops.

iterator = total_dataset.make_one_shot_iterator()

# Next line yields `(original_data_element, index)`
item_and_index_tuple = iterator.get_next() 

`

Upvotes: 2

Related Questions