Reputation: 2556
I tried computing AUC (area under ROC) grouped by the field id
. Given the following data:
# Within each key-value pair
# key is "id"
# value is a list of (score, label)
data = sc.parallelize(
[('id1', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0)),
('id2', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0))
]
The BinaryClassificationMetrics class can calculate the AUC given a list of (score, label)
.
I want to compute AUC by key (i.e. id1, id2
). But how to "map" a class
to an RDD by key?
I tried to wrap the BinaryClassificationMetrics
in a function:
def auc(scoreAndLabels):
return BinaryClassificationMetrics(scoreAndLabels).areaUnderROC
And then map the wrapper function to each values:
data.groupByKey()\
.mapValues(auc)
But the list of (score, label)
is in fact of type ResultIterable
in mapValues()
while the BinaryClassificationMetrics
expects RDD
.
Is there any approach of converting the ResultIterable
to RDD
so the the auc
function can be applied? Or any other workaround for computing group-by AUC (without importing third-party modules like scikit-learn)?
Upvotes: 4
Views: 6437
Reputation: 2995
Here's a way to get auc
without using sklearn
:
keys = data.map(lambda x: x[0]).distinct().collect()
rslt = {}
for k in keys:
scoreAndLabels = data.filter(lambda x: x[0]==k).flatMap(lambda x: x[1])
rslt[k] = BinaryClassificationMetrics(scoreAndLabels).areaUnderROC
print(rslt)
Note: this solution requires that the number of key
is small enough to fit in the memory.
If you have so many keys that you can't collect()
them into memory, don't use this
Upvotes: 0
Reputation: 379
Instead of using BinaryClassificationMetrics
you can use sklearn.metrics.auc and map each RDD element value and you'll get your AUC value per key:
from sklearn.metrics import auc
data = sc.parallelize([
('id1', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0)]),
('id2', [(0.5, 1.0), (0.6, 0.0), (0.7, 1.0), (0.8, 0.0)])])
result_aucs = data.map(lambda x: (x[0] + '_auc', auc(*zip(*x[1]))))
result_aucs.collect()
Out [1]: [('id1_auc', 0.15000000000000002), ('id2_auc', 0.15000000000000002)]
Upvotes: 5