mtngld
mtngld

Reputation: 577

Tensorflow Dataset API how to order list_files?

I am using the Dataset API list_files in order to get a list of files in a source directory and target directory, something like:

source_path = '/tmp/data/source/*.ext1'
target_path = '/tmp/data/target/*.ext2'
source_dataset = tf.data.Dataset.list_files(source_path)
target_dataset = tf.data.Dataset.list_files(data_path)
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))

Source and target dir contents have same sequential filenames, but different extensions (e.g, source 0001.ext1 <-> target 0001.ext2).

But since list_files is not ordered in anyway, the zipped dataset contains missmatches between the source and the target.

How can I solve this within the new dataset API?

Upvotes: 4

Views: 2698

Answers (2)

ForceBru
ForceBru

Reputation: 44888

I had the same issue and I solved it by sorting the file paths first.

My files are named like in OP's case:

input image       -> corresponding output
data/mband/01.tif -> data/gt_mband/01.tif
data/mband/02.tif -> data/gt_mband/02.tif

The code looks like this:

from pathlib import Path
import tensorflow as tf

DATA_PATH = Path("data")

# Sort the PATHS
img_paths = sorted(map(str, (DATA_PATH / 'mband').glob('*.tif')))
mask_paths = sorted(map(str, (DATA_PATH / 'gt_mband').glob('*.tif')))

# These are tensors of PATHS
# Paths are strings, so order will be preserved
img_paths = tf.data.Dataset.from_tensor_slices(img_paths)
mask_paths = tf.data.Dataset.from_tensor_slices(mask_paths)

# Load the actual images
def parse_image(image_path: 'some_tensor'):
    # Load the image somehow...
    return image_as_tensor

imgs = img_paths.map(parse_image)
masks = mask_paths.map(parse_mask)

Upvotes: 0

jinhui chen
jinhui chen

Reputation: 21

The default behavior of this method is to return filenames in a non-deterministic random shuffled order. Pass a seed or shuffle=False to get results in a deterministic order.

source_dataset = tf.data.Dataset.list_files(source_path, shuffle=False)

or

val = 5
source_dataset = tf.data.Dataset.list_files(source_path, seed = val)
target_dataset = tf.data.Dataset.list_files(data_path, seed = val)

Upvotes: 2

Related Questions