Nitin
Nitin

Reputation: 7667

Keras Metric for bucketizing a regression output

How do I define a custom keras metric for computing accuracy like so,

y_true = [12.5, 45.5]
y_predicted = [14.5, 29]
splits = [-float("inf"), 10, 20, 30, float("inf")]

"""
Splits to Classes translation =>
Class 0: -inf to 9
Class 1: 10 to 19
Class 2: 20 to 29
Class 3: 30 to inf
"""
# using the above translation, 
y_true_classes = [1, 3]
y_predicted_classes = [1, 2]

accuracy = K.equal( y_true_classes, y_predicted_classes ) # => 0.5 here

return accuracy

Upvotes: 0

Views: 103

Answers (1)

Gerges
Gerges

Reputation: 6519

Here is an idea on how you might you around implementing this (although probably not the best one).

def convert_to_classes(vals, splits):
    out = tf.zeros_like(vals, dtype=tf.int32)

    for split in splits:
        out = tf.where(vals > split, out + 1, out)

    return out


def my_acc(splits):
    def custom_acc(y_true, y_pred):
        y_true = convert_to_classes(y_true, splits)
        y_pred = convert_to_classes(y_pred, splits)

        return K.mean(K.equal(y_true, y_pred))
    return custom_acc

The function convert_to_classes converts the floats to bucks, assuming the bounds are always +-inf.

The closure my_acc lets you define the splits (without +-inf) at compile time (added statically to the graph), and then returns a metric function as expected with keras.

Testing using tensorflow:

y_true = tf.constant([12.5, 45.5])
y_pred = tf.constant([14.5, 29])

with tf.Session() as sess:
    print(sess.run(my_acc((10, 20, 30))(y_true, y_pred)))

gives the expected 0.5 accuracy.

And quick test with Keras:

x = np.random.randn(100, 10)*100
y = np.random.randn(100)*100

model = Sequential([Dense(20, activation='relu'),
                    Dense(1, activation=None)])
model.compile(optimizer='Adam',
              loss='mse',
              metrics=[my_acc(splits=(10, 20, 30))])

model.fit(x, y, batch_size=32, epochs=10)

Given the metric (named as the inner function in the closure custom_acc)

100/100 [==============================] - 0s 2ms/step - loss: 10242.2591 - custom_acc: 0.4300
Epoch 2/10
100/100 [==============================] - 0s 53us/step - loss: 10101.9658 - custom_acc: 0.4200
Epoch 3/10
100/100 [==============================] - 0s 53us/step - loss: 10011.4662 - custom_acc: 0.4300
Epoch 4/10
100/100 [==============================] - 0s 51us/step - loss: 9899.7181 - custom_acc: 0.4300
Epoch 5/10
100/100 [==============================] - 0s 50us/step - loss: 9815.1607 - custom_acc: 0.4200
Epoch 6/10
100/100 [==============================] - 0s 74us/step - loss: 9736.5554 - custom_acc: 0.4300
Epoch 7/10
100/100 [==============================] - 0s 50us/step - loss: 9667.0845 - custom_acc: 0.4400
Epoch 8/10
100/100 [==============================] - 0s 58us/step - loss: 9589.5439 - custom_acc: 0.4400
Epoch 9/10
100/100 [==============================] - 0s 61us/step - loss: 9511.8003 - custom_acc: 0.4400
Epoch 10/10
100/100 [==============================] - 0s 51us/step - loss: 9443.9730 - custom_acc: 0.4400

Upvotes: 1

Related Questions