Reputation: 3051
I am trying to use logistic regression to classify message into 'spam' or 'ham'. I used data source from http://archive.ics.uci.edu/ ml/datasets/SMS+Spam+Collection And I found that TFIDF is the right way to get features from text so I used the TfidfVectorizer from scikit learn and here is my code :
msg_df = pd.read_csv('data/sms', delimiter='\t', header = None)
X_train_data, X_test_data, y_train_data, y_test_data = train_test_split(msg_df[1],msg_df[0])
sms_vectorizer = TfidfVectorizer()
X_train_vector = sms_vectorizer.fit_transform(X_train_data)
X_test_vector = sms_vectorizer.transform(X_test_data)
classifier = LogisticRegression()
classifier.fit(X_train_vector, y_train_data)
sms_predictions = classifier.predict(X_test_vector)
print sms_predictions
for i, prediction in enumerate(sms_predictions[:5]):
print 'Prediction: %s. Message: %s' % (prediction, X_test_data[i])
KeyError Traceback (most recent call last)
<ipython-input-19-b5f57158f320> in <module>()
4 print sms_predictions
5 for i, prediction in enumerate(sms_predictions[:5]):
----> 6 print 'Prediction: %s. Message: %s' % (prediction, X_test_data[i])
/usr/local/lib/python2.7/dist-packages/pandas/core/series.pyc in __getitem__(self, key)
555 def __getitem__(self, key):
556 try:
--> 557 result = self.index.get_value(self, key)
558
559 if not np.isscalar(result):
/usr/local/lib/python2.7/dist-packages/pandas/core/index.pyc in get_value(self, series, key)
1788
1789 try:
-> 1790 return self._engine.get_value(s, k)
1791 except KeyError as e1:
1792 if len(self) > 0 and self.inferred_type in ['integer','boolean']:
/usr/local/lib/python2.7/dist-packages/pandas/index.so in pandas.index.IndexEngine.get_value (pandas/index.c:3204)()
/usr/local/lib/python2.7/dist-packages/pandas/index.so in pandas.index.IndexEngine.get_value (pandas/index.c:2903)()
/usr/local/lib/python2.7/dist-packages/pandas/index.so in pandas.index.IndexEngine.get_loc (pandas/index.c:3843)()
/usr/local/lib/python2.7/dist-packages/pandas/hashtable.so in pandas.hashtable.Int64HashTable.get_item (pandas/hashtable.c:6525)()
/usr/local/lib/python2.7/dist-packages/pandas/hashtable.so in pandas.hashtable.Int64HashTable.get_item (pandas/hashtable.c:6463)()
KeyError: 0
Upvotes: 1
Views: 1529
Reputation: 25629
Start with this:
print sms_predictions
for i, prediction in enumerate(sms_predictions[:5]):
print i, prediction
#print 'Prediction: %s. Message: %s' % (prediction, X_test_data[i])
Upvotes: 2