Reputation: 13116
I am running into ValueError: Tensor("conv2d_1/kernel:0", ...) must be from the same graph as Tensor("IteratorGetNext:0", ...)
. I am trying to reuse a keras model with Estimator
class.
I tried enclosing everything possible into
g = tf.Graph()
with g.as_default():
import tensorflow as tf
g = tf.Graph()
with g.as_default():
MODEL = get_keras_model(...)
def model_fn(mode, features, labels, params):
logits = MODEL(features)
...
def parser(record):
...
def get_dataset_inp_fn(filenames, epochs=20):
def dataset_input_fn():
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
...
with tf.Session(graph=g) as sess:
est = tf.estimator.Estimator(
model_fn,
model_dir=None,
config=None,
params={"optimizer": "AdamOptimizer",
"opt_params":{}}
)
est.train(get_dataset_inp_fn(["mydata.tfrecords"],epochs=20))
but that is not helpful.
Is there a way to list all graphs defined up to current point?
Upvotes: 1
Views: 702
Reputation: 13116
The function that checks the graphs and returns the error (wish they return the graph addresses as well) calls following function to check the graphs:
from tensorflow.python.framework.ops import _get_graph_from_inputs
_get_graph_from_inputs([x])
In this case the graph that keras has created is identical to graph g
, but one that is created by get_dataset_inp_fn
is different from g
.
Upvotes: 0
Reputation: 57893
Here's a general debugging technique, put import pdb; pdb.set_trace()
into tf.Graph
constructor, and then use bt
to figure out who is creating the Graph. My first guess would that Keras does not use the default graph and creates its own. You can do inspect.getsourcefile(tf.Graph)
to find where Graph
file is located locally
Upvotes: 1