user3708902
user3708902

Reputation: 161

Can't save output of scikit-learn's DecisionTreeClassifier to a CSV

I have the following code, which is meant to take in some training and testing data for scikit-learn's DecisionTreeClassifier. It works up until I wish to save the output of the .predict function into a CSV. The code so far is:

import numpy as np
import pandas as pd
from sklearn import tree

with open('data/training.csv', 'r') as f:

    df = pd.read_csv(f, index_col=None)

df['Num_Labels'] = df.Label.map(lambda x: '-1' if x == 's' else '1')  # Convert labels to '0' or '1'.

Train_values = df.iloc[:, 1:31].values
Train_labels = df.iloc[:, 33:34].values
# print Train_values.values
# print type(Train_values.values)

with open('data/test.csv', 'r') as f2:

    df2 = pd.read_csv(f2, index_col=None)

Test_values = df2.iloc[:, 1:31].values

# #----------------------------------------------------------------------------------------------

X = Train_values
Y = Train_labels

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)

Pred = clf.predict(Test_values)

#print Pred
#print type(Pred[:1])
np.savetxt('Output.csv', Pred, delimiter =' ')

And the terminal output is the following:

/usr/bin/python2.7 /home/amit/PycharmProjects/HB/Read.py
Traceback (most recent call last):
  File "/home/amit/PycharmProjects/HB/Read.py", line 38, in <module>
    np.savetxt('Output.csv', Pred, delimiter =' ')
  File "/usr/lib/python2.7/dist-packages/numpy/lib/npyio.py", line 1073, in savetxt
    fh.write(asbytes(format % tuple(row) + newline))
TypeError: float argument required, not str
['1' '-1' '-1' ..., '1' '1' '1']
<type 'numpy.ndarray'>

Process finished with exit code 1

Upvotes: 2

Views: 1128

Answers (1)

DrV
DrV

Reputation: 23510

There is most likely something fishy with Pred. The savetxt code in npyio.py is fairly simple:

for row in X:
    fh.write(asbytes(format % tuple(row) + newline))

This reads X (the input array, Pred in this case) row-by-row. The format string format is in this case %f %f %f with as many placeholders as there are elements in one row (i.e. the number of columns). The error message complains that there are something else than float values in the vector row. Furthermore, the error message looks as if there were short text strings instead floats.

My guess is that Pred is a ndarray. It would be odd if it weren't. However, it may be an array of strings instead of an array of floats or other numbers!

You may track this down almost as you tried it but with:

print Pred.dtype

If it is something like S3, then you have an array of strings. In that case I suggest you check that the data types of X and Y are correct. If they are not numbers, things may get a bit weird.

Upvotes: 2

Related Questions