Reputation: 545
Given a tensorflow dataset
Train_dataset =,Train_Image_Labels))
Train_dataset =
I would like to stratify my batches to deal with class imbalance. I found and thought I could use it in the following way:
Train_dataset_iter = Train_dataset.make_one_shot_iterator()
Train_dataset_Image_Batch,Train_dataset_Label_Batch = Train_dataset_iter.get_next()
Train_Stratified_Images,Train_Stratified_Labels =,Train_dataset_Label_Batch,[1/Classes]*Classes,Batch_Size)
But it gives the following error and I'm not sure that this would allow me to keep the performance benefits of tensorflow dataset as I may have then have to pass Train_Stratified_Images
and Train_Stratified_Labels
via feed_dict ?
File "/xxx/xxx/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/training/python/training/", line 192, in stratified_sample
with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]):
File "/xxx/xxx/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/", line 459, in __iter__
"Tensor objects are only iterable when eager execution is "
TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn.
What would be the "best practice" way of using dataset with stratified batches?
Upvotes: 4
Views: 4687
Reputation: 2108
Here is below a simple example to demonstrate the usage of sample_from_datasets (thanks @Agade for the idea).
import math
import tensorflow as tf
import numpy as np
def print_dataset(name, dataset):
elems = np.array([v.numpy() for v in dataset])
print("Dataset {} contains {} elements :".format(name, len(elems)))
def combine_datasets_balanced(dataset_smaller, size_smaller, dataset_bigger, size_bigger, batch_size):
ds_smaller_repeated = dataset_smaller.repeat(count=int(math.ceil(size_bigger / size_smaller)))
# we repeat the smaller dataset so that the 2 datasets are about the same size
balanced_dataset =[ds_smaller_repeated, dataset_bigger], weights=[0.5, 0.5])
# each element in the resulting dataset is randomly drawn (without replacement) from dataset even with proba 0.5 or from odd with proba 0.5
balanced_dataset = balanced_dataset.take(2 * size_bigger).batch(batch_size)
return balanced_dataset
N, M = 3, 10
even =, 2 * N, 2).repeat(count=int(math.ceil(M / N)))
odd =, 2 * M, 2)
even_odd = combine_datasets_balanced(even, N, odd, M, 2)
print_dataset("even", even)
print_dataset("odd", odd)
print_dataset("even_odd_all", even_odd)
Output :
Dataset even contains 12 elements : # 12 = 4 x N (because of .repeat)
[0 2 4 0 2 4 0 2 4 0 2 4]
Dataset odd contains 10 elements :
[ 1 3 5 7 9 11 13 15 17 19]
Dataset even_odd contains 10 elements : # 10 = 2 x M / 2 (2xM because of .take(2 * M) and /2 because of .batch(2))
[[ 0 2]
[ 1 4]
[ 0 2]
[ 3 4]
[ 0 2]
[ 4 0]
[ 5 2]
[ 7 4]
[ 0 9]
[ 2 11]]
Upvotes: 1