tumultous_rooster
tumultous_rooster

Reputation: 12550

return indicies of rows in a CountVectorizer that have non-zero entries for a particular feature in scikit learn

I have been scouring the documentation for Python's sklearn package.

I have created a CountVectorizer object, fitted and transformed, with my corpus.

I'm looking for a function that can return the indices of all rows that have non-zero entries, for some particular column.

So if the rows in my CountVectorizer consist of music reviews, and the columns consist of features, (for example, there is a column for counts of the word "lyrics"), is there a function in sci kit-learn that can return the indicies of the music reviews that contain this word?

I looked at the inverse_transform(X) function and it didn't perform this function.

I suspect I'm not the first person to be interested in this functionality.

Does such a feature exist in sklearn, and if not, has anyone else who is interested in a similar procedure come up with a good way on how to implement this?

Thanks in advance.

UPDATE:

My best solution involved iterating over the number of columns (in my case, I have 100 features):

for i in range(99):
    print X.indices[X.indptr[i]:X.indptr[i+1]]

But this looks wasteful since it is iterative and the range must be hard coded, and it returns empty lists for the sparse columns.

Upvotes: 3

Views: 1966

Answers (1)

David
David

Reputation: 9405

I don't see a function in the documention that will do exactly this either, but this should do the trick for you:

def lookUpWord(vec,dtm,word):
    i = vec.get_feature_names().index(word)
    return dtm[:,i].nonzero()[0]

Here's a trivial example:

>>> from sklearn.feature_extraction.text import CountVectorizer
>>> 
>>> corpus = [
...     'This is the first document.',
...     'This is the second second document.',
...     'And the third one.',
...     'Is this the first document?'
...     ]
>>> 
>>> X = CountVectorizer()
>>> Y = X.fit_transform(corpus)
>>> lookUpWord(X,Y,'first')
array([0, 3], dtype=int32)

Upvotes: 2

Related Questions