A. Piro
A. Piro

Reputation: 785

Multi Label Classification : How to learn threshold values?

I have a deep CNN that works just fine for multi class classification. I would like to "upgrade" the challenge and train it on a multi label classification problem.

To do so I replaced my softmax by sigmoid and tried to train my network to minimize :

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

But I end up with weird prediction :

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646

0.38991812 0.54468459 0.34593087 0.82790571]

Prediction for Im1 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177

0.35400775 0.36479294 0.30002621 0.84438241]

Prediction for Im1 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551

0.43719986 0.54638696 0.20344526 0.88144571]

So I thought I try to make my network learn threshold for each class to determine if the sample belongs or not to te class.

So I added this to my code :

initial = tf.truncated_normal([numberOfClasses], stddev=0.1)
W_thresh = tf.Variable(initial)

y_predict_thresh = int(y_predict > W_thresh)

But I have an error :

TypeError: int() argument must be a string or a number, not 'Tensor'.

Anybody has any idea to help me moving forward (How to avoid this error ?, Is the fact that my dataset is really unbalanced causes these kind of "constant" predictions ? Other suggestion for multi label classification ?, ...) ?

Thank you

EDIT: I just realized that doing thresholding might not be really cool for backpropagation :/

Upvotes: 1

Views: 1754

Answers (1)

gionni
gionni

Reputation: 1303

Don't know if you still need it, but you can use tensorflow transformation function tf.to_int32, tf.to_int64. Before evaluation your expression is an object to python, therefore it can't simply cast it to int().

This does what you need:

with tf.Session() as sess:
    check = sess.run([tf.to_int64(W1 > W2)])

Upvotes: 3

Related Questions