Yannick
Yannick

Reputation: 419

Inverse_transform method (LabelEncoder)

You can find below the code I found on the internet to build a simple neural network. Everything works fine. I encoded the y labels and these are the predictions I get:

2 0 1 2 1 2 2 0 2 1 0 0 0 1 1 1 1 1 1 1 2 1 2 1 0 1 0 1 0 2

So now I need to convert it back to the original Iris class (Iris-Virginica, Setosa, Versicolor). I need to use the inverse_transform method. Can you help out?

    import pandas as pd
    from sklearn import preprocessing
    from sklearn.model_selection import train_test_split
    from sklearn.neural_network import MLPClassifier
    from sklearn.metrics import classification_report, confusion_matrix 
    
    
    # Location of dataset
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
    
    # Assign colum names to the dataset
    names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'Class']
    
    # Read dataset to pandas dataframe
    irisdata = pd.read_csv(url, names=names)  
    
    irisdata.head()
    #head_tableau=irisdata.head()
    #print(head_tableau)
    
    # Assign data from first four columns to X variable
    X = irisdata.iloc[:, 0:4]
    
    # Assign data from first fifth columns to y variable
    y = irisdata.select_dtypes(include=[object])  
    
    y.head()
    #afficher_y=y.head()
    #print(afficher_y)
    
    y.Class.unique()
    #affiche=y.Class.unique()
    #print(affiche)
    
    le = preprocessing.LabelEncoder()
    
    y = y.apply(le.fit_transform)  
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.20)
    
    mlp = MLPClassifier(hidden_layer_sizes=(10, 10, 10), max_iter=1000)  
    mlp.fit(X_train, y_train.values.ravel())
    
    predictions = mlp.predict(X_test)
    print(predictions)

Upvotes: 4

Views: 17376

Answers (1)

mbatchkarov
mbatchkarov

Reputation: 16039

You are on the right track:

In [7]: le.inverse_transform(predictions[:5])
Out[7]: 
array(['Iris-virginica', 'Iris-setosa', 'Iris-setosa', 'Iris-versicolor',
       'Iris-virginica'], dtype=object)

Upvotes: 8

Related Questions