Reputation: 4511
I have a list of label names which I enuemrated and created a dictionary:
my_list = [b'airplane',
b'automobile',
b'bird',
b'cat',
b'deer',
b'dog',
b'frog',
b'horse',
b'ship',
b'truck']
label_dict =dict(enumerate(my_list))
{0: b'airplane',
1: b'automobile',
2: b'bird',
3: b'cat',
4: b'deer',
5: b'dog',
6: b'frog',
7: b'horse',
8: b'ship',
9: b'truck'}
Now I'm trying to cleaning map
/apply
the dict value to my target which is in an one-hot-encoded form.
y_test[0]
array([ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
y_test[0].map(label_dict) should return:
'cat'
I was playing around with
(lambda key,value: value for y_test[0] == 1)
but couldn't come up with any concrete
Thank you.
Upvotes: 2
Views: 1933
Reputation: 294258
we can use dot product to reverse one-hot encoding, if it really is ONE-hot.
Let's start with factorizing your list
f, u = pd.factorize(my_list)
now if you have an array you'd like to get back your strings with
a = np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
Then use dot
a.dot(u)
'cat'
Now assume
y_test = np.array([
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
])
Then
y_test.dot(u)
array(['cat', 'automobile', 'ship'], dtype=object)
If it isn't one-hot but instead multi-hot, you could join with commas
y_test = np.array([
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 1, 0, 0, 0, 0, 0, 1, 0]
])
[', '.join(u[y.astype(bool)]) for y in y_test]
['cat', 'automobile, truck', 'bird, ship']
Upvotes: 3
Reputation: 221564
Since we are working with one-hot encoded
array, argmax
could be used to get the index for one off 1
for each row. Thus, using the list as input -
[my_list[i] for i in y_test.argmax(1)]
Or with np.take
to have array output -
np.take(my_list,y_test.argmax(1))
To work with dict
and assuming sequential keys as 0,1,..
, we could have -
np.take(label_dict.values(),y_test.argmax(1))
If the keys are not essentially in sequence but sorted -
np.take(label_dict.values(), np.searchsorted(label_dict.keys(),y_test.argmax(1)))
Sample run -
In [79]: my_list
Out[79]:
['airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck']
In [80]: y_test
Out[80]:
array([[ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])
In [81]: [my_list[i] for i in y_test.argmax(1)]
Out[81]: ['cat', 'automobile', 'ship']
In [82]: np.take(my_list,y_test.argmax(1))
Out[82]:
array(['cat', 'automobile', 'ship'],
dtype='|S10')
Upvotes: 5