Harshit
Harshit

Reputation: 1217

How to fit multidimensional output using scikit-learn?

I am trying to fit OneVsAll Classification output in training data , rows of output adds upto 1 .

One possible way is to read all the rows and find which column has highest value and prepare data for training .

Eg : y = [[0.2,0.8,0],[0,1,0],[0,0.3,0.7]] can be reduced to y = [b,b,c] , considering a,b,c as corresponding class of the columns 0,1,2 respectively.

Is there a function in scikit-learn which helps to achieve such transformations?

Upvotes: 2

Views: 1061

Answers (1)

Martin Böschen
Martin Böschen

Reputation: 1769

This code does what you want:

import numpy as np
import string

y = np.array([[0.2,0.8,0],[0,1,0],[0,0.3,0.7]])

def transform(y,labels):
  f = np.vectorize(lambda i : string.letters[i])
  y = f(y.argmax(axis=1)) 
  return y

y = transform(y,'abc') 

EDIT: Using the comment by alko, I made it more general be letting the user supply the labels to the transform function.

Upvotes: 1

Related Questions