Reputation: 1311
So my code is:
with tf.Session() as sess:
init.run()
epoch = 1
iteration = 1
print("Checkpoint 1")
X_batch, y_batch = tf.train.batch([X_train, y_train], batch_size=batch_size)
print("Checkpoint 2")
X = X_batch.eval()
y = X_batch.eval()
print("Checkpoint 3")
The problem is that when executing the 2 last lines the execution stuck, not giving any output (Only checkpoint 1 & 2 are printed). I've searched across google and the conversion from tensorflow.python.framework.ops.Tensor
to numpy.ndarray
seems to be a trivial operation .
I also tried the following in case it makes some difference:
X = X_batch.eval(session = sess)
y = X_batch.eval(session = sess)
Edit: I tried using Interactive session and the problem remains
`sess = tf.InteractiveSession()
X_batch, y_batch = tf.train.batch([X_train, y_train], batch_size=batch_size)
type(X_batch)
type(y_batch)
print(type(X_batch.eval()))
sess.close()`
Upvotes: 2
Views: 1102
Reputation: 6328
You need to start the queue
runner hidden in tf.train.batch
by using for example a tf.train.Coordinator
(have look to this or this for example to get more insight). As per the docstring of the method (highlights from me):
This function is implemented using a queue. A
QueueRunner
for the queue is added to the currentGraph
'sQUEUE_RUNNER
collection.
and
The returned operation is a dequeue operation and will throw
tf.errors.OutOfRangeError
if the input queue is exhausted. If this operation is feeding another input queue, its queue runner will catch this exception, however, if this operation is used in your main thread you are responsible for catching this yourself.
As you didn't start the queue runner, the thread get stuck waiting for the enqueue operation to be run.
X_batch, y_batch = tf.train.batch([X_train, y_train], batch_size=batch_size)
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
X, y = sess.run([X_batch, y_batch])
except Exception as e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
Upvotes: 1