citrusvanilla
citrusvanilla

Reputation: 59

How to duplicate input tensors conditional on a tensor attribute ("oversampling") in a Tensorflow queue?

I am porting the Tensorflow Cifar-10 tutorial files for my own purpose, and have run into an interesting problem that I can not easily conceptualize due to the graph and session architecture of Tensorflow.

The issue is that my input dataset is highly imbalanced, and as such I need to "oversample" (and augment) certain images in the input pipeline conditional on their labels. In a normal Python environmental I could set up a simple control flow statement of the form if label then duplicate, but I am not able to write the same syntax in Tensorflow due to the control flow operation existing outside of a running session and label in this case does not return a value.

My question is, what is the easiest method for oversampling a tensor inside of a Tensorflow queue?

I know that I could simply duplicate the data of interest prior to the input operation, but this obviously removes any storage savings incurred by oversampling during runtime.

What I want to do is evaluate a Tensor's label (in the Cifar-10 case, by checking the 1D image.label attribute) and then duplicate that Tensor by a fixed factor (say, 4x if the label is "dog") and send all the Tensors down to the batching operation. My original approach was to attempt the duplication step after the Reader operation and before the batching operation, but this too is outside of a running session. I was thinking of utilizing TF's while control flow statement but I'm not sure this will function has the ability to do anything other than modify the input Tensor. What do you think?


Update #1

Basically I attempted to create a py_func() that took in the flattened image bytes and the label bytes, and vertically stack the same image bytes N times depending on the value of the label and then return that as a (N x image_bytes) tensor (py_func() auto converted the input tensor to numpy and back). i attempted to create an input_queue from the variable-height tensor whose shape reports as (?,image_bytes) and then instantiate a reader to rip off image_byte size records. Well it seems like you can't build queues off of unknown data sizes so this approach is not working for me which makes sense in hindsight but I still can't conceptualize a method for identifying a record in a queue, and repeating that record a specific number of times.


Update #2

Well after 48 hours I finally figured out a workaround, thanks to this SO thread that I was able to dig up. The solution outlined in that thread only assumes 2 classes of data though, so the tf.cond() function suffices to oversample one class if the pred is True, and to oversample the other if pred is False. In order to have an n-way conditional, I attempted to institute a tf.case() function that resulted in ValueError: Cannot infer Tensor's rank. Turns out that the tf.case() function does not retain shape properties and the graph construction fails as any batching op at the end of the input pipeline must take a shape argument, or take tensors of defined shape, as per this note in the documentation:

N.B.: You must ensure that either (i) the shapes argument is passed, or (ii) all of the tensors in tensors must have fully-defined shapes. ValueError will be raised if neither of these conditions holds.

Further digging shows that this is a known issue with tf.case() that has yet to be resolved as of December 2016. Just one of the many control-flow head-scratchers in Tensorflow. Anyway, my stripped-down solution to the n-way oversampling issue is thus:

# Initiate a queue of "raw" input data with embedded Queue Runner.
queue = tf.train.string_input_producer(rawdata_filename) 

# Instantiate Reader Op to read examples from files in the filename queue.
reader = tf.FixedLengthRecordReader(record_bytes)

# Pull off one instance, decode and cast image and label to 3D, 1D Tensors.
result.key, value = reader.read(queue)
image_raw, label_raw = decode(value)
image = tf.cast(image_raw, dtype) #3D tensor
label = tf.cast(label_raw, dtype) #1D tensor

# Assume your oversampling factors per class are fixed
# and you have 4 classes.
OVERSAMPLE_FACTOR = [1,2,4,10]

# Now we need to reshape input image tensors to 4D, where the 
# first dimension is the image number in a batch of oversampled tensors.
# images = tf.expand_dims(image, 0) # so, (*,height,width,channels) in 4D

# Set up your predicates, which are 1D boolean tensors.
# Note you will have to squash the boolean tensors to 0-dimension.
# This seems illogical to me, but it is what it is.
pred0 = tf.reshape(tf.equal(label, tf.convert_to_tensor([0])), []) #0D tf.bool
pred1 = tf.reshape(tf.equal(label, tf.convert_to_tensor([1])), []) #0D tf.bool
pred2 = tf.reshape(tf.equal(label, tf.convert_to_tensor([2])), []) #0D tf.bool
pred3 = tf.reshape(tf.equal(label, tf.convert_to_tensor([3])), []) #0D tf.bool

# Build your callables (functions) that vertically stack an input image and
# label tensors X times depending on the accompanying oversample factor.
def f0(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[0]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[0])
def f1(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[1]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[1])
def f2(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[2]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[2])
def f3(): return tf.concat(0, [images]*OVERSAMPLE_FACTOR[3]), tf.concat(0, [label]*OVERSAMPLE_FACTOR[3])

# Here we have N conditionals, one for each class.  These are exclusive
# but due to tf.case() not behaving every conditional gets evaluated.
[images, label] = tf.cond(pred0, f0, lambda: [images,label])
[images, label] = tf.cond(pred1, f1, lambda: [images,label])
[images, label] = tf.cond(pred2, f2, lambda: [images,label])
[images, label] = tf.cond(pred3, f3, lambda: [images,label])

# Pass the 4D batch of oversampled tensors to a batching op at the end
# of the input data queue.  The batching op must be set up to accept
# batches of tensors (4D) as opposed to individual tensors (in our case, 3D).
images, label_batch = tf.train.batch([images, label],
                                     batch_size=batch_size,
                                     num_threads=num_threads,
                                     capacity=capacity,
                                     enqueue_many = True) #accept batches

Upvotes: 1

Views: 898

Answers (1)

citrusvanilla
citrusvanilla

Reputation: 59

The solution to my problem is a workaround, and is outlined in 'Update 2' in the original question.

Upvotes: 1

Related Questions