WestCoastProjects
WestCoastProjects

Reputation: 63062

Debug into a custom metrics function in Keras/Tensorflow (in Pycharm)?

I have created a simple ConfusionMatrix custom metric and am running into a problem in Tensor conversion. It would speed up fixing if I could set a breakpoint. However breakpoints are not being respected by PyCharm .

Here is the code

def multiConfusion(expectsIn, actsIn):
  expects = tf.keras.backend.eval(expectsIn)  # Error occurs here
  acts = tf.keras.backend.eval(actsIn)
  classes = sorted(list(set(expects).union(set(acts))))
  
  from collections import defaultdict
  mx = defaultdict(lambda: defaultdict(int))
  for e,a in list(zip(expects,acts)):
    mx[e][a] += 1
  hdr = "Exp/Act" + ''.join([ f'\t\t{lab}' for lab in classes])
  ll = '\n'.join([
                  '\t\t\t' + '\t\t'.join([ str(mx[e][a]) for k,a in list(mx[e].items())])
                  for e in classes
                 ])
  mat =  f"{hdr}\n{ll}"
  print(mat)
  return mat

def confusionMat(x,y,num_classes=NClasses):
      from tensorflow import math as tfmath
      cmat = None
      if NClasses == 2:
        cmat = binaryConfusion(x,y)
      else:
        cmat = multiConfusion(x,y) # Breakpoint set here but gets skipped
      print(repr(cmat))
      return cmat

model.compile(loss="categorical_crossentropy", optimizer=opt, 
              metrics=[confusionMat,"accuracy"])

enter image description here

Some thoughts I had on trying to re-enable the debugger:

tf.config.threading.set_inter_op_parallelism_threads(1)

vgg_model.run_eagerly = True

But still the breakpoint is not respected. Any thoughts?

Update I have tested / updated / expanded the above code a fair bit. It does generate a confusion matrix properly

enter image description here

Debugging into the code works fine when called with a constant Tensorflow Tensor directly.

But nothing I can do will make the debugger activate when the code is invoked via the metrics machinery within the Tensorflow model training. (via model.fit())

Upvotes: 1

Views: 726

Answers (1)

Laplace Ricky
Laplace Ricky

Reputation: 1687

For your breakpoints to be respected, add this line at the beginning:

tf.config.run_functions_eagerly(True)

It will make tensorflow run almost everything eagerly except for tf.data transformation pipeline.

Upvotes: 2

Related Questions