Derk
Derk

Reputation: 1395

Tensorflow prefetch_to_device

I try the new tensorflow function tf.contrib.data.prefetch to device.

My simple code example

model = build_network()

N=1000

def gen():
    while True:
        batch = np.random.rand(N, 48, 48, 3)
        # Do some heavy calculation
        yield batch

dataset = tf.data.Dataset.from_generator(gen, tf.float32)
dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/gpu:0'))

iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()

output = model(x)

g = gen()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100):
        if i == 50:
            start = time.time()
        result = sess.run(output)
        #result = model.predict(next(g))
    end = time.time()
print('\nAverage time of one forward pass: {}\n'.format((end-start)/50))
print('Done')

This gives the error:

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'IteratorGetDevice': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available. Colocation Debug Info: Colocation group had the following types and devices: IteratorToStringHandle: CPU IteratorGetDevice: CPU OneShotIterator: CPU

Colocation members and user-requested devices: OneShotIterator (OneShotIterator) IteratorGetDevice (IteratorGetDevice) /device:GPU:0 IteratorToStringHandle (IteratorToStringHandle)

Registered kernels: device='CPU'

[[Node: IteratorGetDevice = IteratorGetDevice_device="/device:GPU:0"]]

Is this new function not useable in combination with from_generator or is it something else?

Upvotes: 2

Views: 1944

Answers (1)

mrry
mrry

Reputation: 126154

This is a bug in the TensorFlow 1.8rc0 release candidate. Thanks for bringing it to our attention!

It is now fixed in the master branch and will be picked up in the next nightly build. I have also filed a cherry-pick to the 1.8 release branch, and it should be included in the next release candidate (and final release) for TensorFlow 1.8.

Upvotes: 2

Related Questions