smith
smith

Reputation: 186

Tensorflow Session.Run() Tensor object is not callable

I have a rnn model trained for the ptb example with tensorflow's ptb_word. Bellow I have a code where I'm trying to print a few examples to test the model trained. I'm getting a error TypeError: 'Tensor' object is not callable when running this code on the line I make probs, state = sess.run([mtest.output_probs(), mtest._final_state], feed_dict=feed_dict)

What exactly causes this error?

here is the code:

import numpy as np
import os
import tensorflow as tf
from ptb_word_lm import *
from tensorflow.models.rnn.ptb import reader
from tensorflow.python.platform import gfile

data_path = "/home/usr/simple-examples/data/"
raw_data = reader.ptb_raw_data(data_path)
train_data, valid_data, test_data, vocabulary = raw_data

test_path = os.path.join(data_path, "ptb.test.txt")
word_to_id = reader._build_vocab(test_path)


eval_config = get_config()
eval_config.batch_size = 1
eval_config.num_steps = 1

sess = tf.Session()

initializer = tf.random_uniform_initializer(-eval_config.init_scale,
                                            eval_config.init_scale)
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
with tf.variable_scope("model", reuse=None, initializer=initializer):
    mtest = PTBModel(is_training=False, config=eval_config, input_=test_input)

sess.run(tf.initialize_all_variables())

saver = tf.train.import_meta_graph('/home/usr/models/medium/model.ckpt-50979.meta')

ckpt = tf.train.get_checkpoint_state('/home/usr/models/medium/')
if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
    msg = 'Reading model parameters from %s' % ckpt.model_checkpoint_path
    print(msg)
    saver.restore(sess, ckpt.model_checkpoint_path)

def pick_from_weight(weight, pows=1.0):
    weight = weight**pows
    t = np.cumsum(weight)
    s = np.sum(weight)
    return int(np.searchsorted(t, np.random.rand(1) * s))

while True:
    number_of_sentences = 10
    sentence_cnt = 0
    text = '\n'
    end_of_sentence_char = word_to_id['<eos>']
    input_char = np.array([[end_of_sentence_char]])
    state = sess.run(mtest.initial_state)
    for attr in  mtest.__dict__:
        print attr
    print 'all attributes above'
    while sentence_cnt < number_of_sentences:
        feed_dict = {mtest._input: input_char,
                     mtest.initial_state: state}

        probs, state = sess.run([mtest.output_probs(), mtest._final_state], feed_dict=feed_dict)

        print 'after state'
        sampled_char = pick_from_weight(probs[0])
        print sampled_char
        if sampled_char == end_of_sentence_char:
            text += '.\n'
            sentence_cnt += 1
        else:
            text += ' ' + id_to_word[sampled_char]
        input_char = np.array([[sampled_char]])
    print(text)
    raw_input('press any key to continue ...')

Upvotes: 4

Views: 5548

Answers (1)

sunside
sunside

Reputation: 8259

Looking at the referenced code at GitHub I cannot find output_props, so maybe the versions differ. However, since mtest.initial_state is a @property, I assume that mtest.output_props is one as well. That is, try

probs, state = sess.run([mtest.output_probs, mtest.final_state], feed_dict=feed_dict)

instead, i.e. without using the parentheses.

Also mtest._final_state is an internal variable and shouldn't be used directly. You probably want to use mtest.final_state instead.

Upvotes: 2

Related Questions