passion
passion

Reputation: 1020

python scikit learn, get documents per topic in LDA

I am doing an LDA on a text data, using the example here: My question is:
How can I know which documents correspond to which topic? In other words, what are the documents talking about topic 1 for example?

Here are my steps:

n_features = 1000
n_topics = 8
n_top_words = 20

I read my text file line by line:

with open('dataset.txt', 'r') as data_file:
    input_lines = [line.strip() for line in data_file.readlines()]
    mydata = [line for line in input_lines]

a function to print the topics:

def print_top_words(model, feature_names, n_top_words):
    for topic_idx, topic in enumerate(model.components_):
        print("Topic #%d:" % topic_idx)
        print(" ".join([feature_names[i]
                        for i in topic.argsort()[:-n_top_words - 1:-1]]))                        

    print()

Doing a vectorization on the data:

tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2, token_pattern='\\b\\w{2,}\\w+\\b',
                                max_features=n_features,
                                stop_words='english')
tf = tf_vectorizer.fit_transform(mydata)

Initializing the LDA:

lda = LatentDirichletAllocation(n_topics=3, max_iter=5,
                                learning_method='online',
                                learning_offset=50.,
                                random_state=0)

running LDA on the tf data:

lda.fit(tf)

printing the results using the function above:

print("\nTopics in LDA model:")
tf_feature_names = tf_vectorizer.get_feature_names()

print_top_words(lda, tf_feature_names, n_top_words)

the output of the print is:

Topics in LDA model:
Topic #0:
solar road body lamp power battery energy beacon
Topic #1:
skin cosmetic hair extract dermatological aging production active
Topic #2:
cosmetic oil water agent block emulsion ingredients mixture

Upvotes: 8

Views: 7096

Answers (2)

Marcel
Marcel

Reputation: 2794

You need to do a transformation on the data:

doc_topic = lda.transform(tf)

and list the doc and its highest score topic like this:

for n in range(doc_topic.shape[0]):
    topic_most_pr = doc_topic[n].argmax()
    print("doc: {} topic: {}\n".format(n,topic_most_pr))

Upvotes: 19

Ryan Stout
Ryan Stout

Reputation: 1028

http://scikit-learn.org/stable/modules/generated/sklearn.decomposition.LatentDirichletAllocation.html#sklearn.decomposition.LatentDirichletAllocation.transform

the transform method takes as input a Document word matrix X and returns Document topic distribution for X.

So if you call transform passing in each of your documents, you can then look for those documents which have a high (enough for your purposes) fraction of words from your topic of interest.

Upvotes: 0

Related Questions