NoviceProg
NoviceProg

Reputation: 825

Filter Tensorflow dataset by id

ISSUE

I'm trying to filter a Tensorflow 2.4 dataset based on a numpy array containing the indices that I wish to subset. The dataset has 1000 images of shape (28,28,1).

TOY EXAMPLE CODE

m_X_ds = tf.data.Dataset.from_tensor_slices(list(range(1, 21))).shuffle(10, reshuffle_each_iteration=False)
arr = np.array([3, 4, 5])
m_X_ds = tf.gather(m_X_ds, arr)  # This is the offending code

ERROR MESSAGE

ValueError: Attempt to convert a value (<ShuffleDataset shapes: (), types: tf.int32>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.ShuffleDataset'>) to a Tensor.

RESEARCH SO FAR

I found this and this but they are specific to their use-cases while I'm looking for a more general way to subset (i.e. based on an externally-derived array of indices).

I'm very new to Tensorflow datasets and to-date have found the learning curve quite steep. Hope to get some help. Thanks in advance!

Upvotes: 2

Views: 2680

Answers (1)

sebastian-sz
sebastian-sz

Reputation: 1488

TL;DR

It is recommended to use Option C, defined below.

Full answer

The tf.data.Dataset object is created so that all the objects don't have to be loaded into memory. Becuase of that, using tf.gather is not going to work by default. There are three options you can go with:

Option A: Load ds into memory + tf.gather

If you wanted to use gather you would have to load the entire dataset into memory, and select a subset:

m_X_ds = list(m_X_ds)  # Load into memory.
m_X_ds = tf.gather(m_X_ds, arr))  # Gather as usual.
print(m_X_ds)  
# Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>

Note, that this is not always possible, especially with huge datasets.

Option B: Iterate over the dataset, and filter undesired samples

You could also iterate over the dataset and manually choose the samples with desired index. This would be possible via combination of filter and tf.py_function

m_X_ds = m_X_ds.enumerate()  # Create index,value pairs in the dataset.

# Create filter function:
def filter_fn(idx, value):
    return idx in arr

# The above is not going to work in graph mode
# We are wrapping it with py_function to execute it eagerly
def py_function_filter(idx, value):
    return tf.py_function(filter_fn, (idx, value), tf.bool)

# Filter the dataset as usual:
filtered_ds = m_X_ds.filter(py_function_filter)

Option C: combine option B with tf.lookup.StaticHashTable

Option B is good apart from the fact that you can expect a performance hit when converting graph tensor -> eager tensor -> graph tensor. tf.py_function is useful but at a cost.

Instead, we could declare a dictionary where desired indices would return true and non present indices could return false. This dict could look like this

my_table = {3: True, 4: True, 5: True}.

We cannot use tensor as dictionary key, but we can declare a tensorflow's hash table to let us check for "good" indices.

m_X_ds = m_X_ds.enumerate()  # Do not repeat this if executed in Option B.

keys_tensor = tf.constant(arr)
vals_tensor = tf.ones_like(keys_tensor)  # Ones will be casted to True.

table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=0)  # If index not in table, return 0.


def hash_table_filter(index, value):
    table_value = table.lookup(index)  # 1 if index in arr, else 0.
    index_in_arr =  tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
    return index_in_arr

filtered_ds = m_X_ds.filter(hash_table_filter)

Regardless of Option B or C, all there is left is to drop the index from your fileterd dataset. We can use simple map, with lambda function:

final_ds = filtered_ds.map(lambda idx,value: value)
for entry in final_ds:
    print(entry)

# tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(13, shape=(), dtype=int32)
# tf.Tensor(6, shape=(), dtype=int32)

Best of luck.

Upvotes: 8

Related Questions