mattdns
mattdns

Reputation: 894

Passing bool to feed dict

So here is an example of using batch normalization over a 1-D input vector. Batch normalization is performed over 100 training examples xTr. I then want to test on say just 1 example later on xTe.

import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import layers

if __name__ == "__main__":
    bn = layers.batch_norm
    nFeats = 3 
    nObs = 100 
    xTr = np.random.rand(nObs,nFeats) # Train
    xTe = np.random.rand(1,nFeats) # Test
    bnTrain = tf.placeholder(tf.bool) 
    X = tf.placeholder(tf.float32,[None,nFeats])
    Y = bn(X,nFeats,is_training=bnTrain) # want to be able to change is_training via a feed_dict.
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        yTr_ = Y.eval(feed_dict={X:xTr,bnTrain:True})
        yTe_ = Y.eval(feed_dict={X:xTe,bnTrain:False})

But I can't pass a tf.Tensor to a function expecting a normal python bool. What is the best way of going about this so I can change a bool during a session.

Upvotes: 2

Views: 1006

Answers (1)

mrry
mrry

Reputation: 126184

The current implementation of the tf.contrib.layers.batch_norm() function is designed to accept a tf.Tensor as the is_training argument (although this fact doesn't appear to be documented), and looking at the revision history, it was added in the TensorFlow 0.10 release. If you are using an older version, please try upgrading to the latest release (currently 0.12), and your existing code should work. Among other improvements, it contains a fused implementation of batch normalization that should make a significant performance improvement.

Upvotes: 2

Related Questions