Reputation: 21
I have built a Doc2Vec model and am trying to get the vectors of all my testing set (176 points). The code below I can only see one vector at a time. I want to be able to do "clean_corpus[404:]" to get the entire data set but when I try that it still outputs one vector.
model.save("d2v.model")
print("Model Saved")
from gensim.models.doc2vec import Doc2Vec
model= Doc2Vec.load("d2v.model")
#to find the vector of a document which is not in training data
test_data = clean_corpus[404]
v1 = model.infer_vector(test_data)
print("V1_infer", v1)
Is there a way to easily iterate over the model to get and save all 176 vectors?
Upvotes: 1
Views: 273
Reputation: 54173
Because .infer_vector()
takes a single text (list-of-words), you would want to call it multiple times in a loop if you need to infer many separate vectors for many different documents.
Another option would be to include all the documents of interest in the Doc2Vec
model training data, including your test set. Then, you can simply request the learned-during-training vectors for any document, by the unique tag
you supplied during training.
Whether this is an acceptable practice depends on other unstated aspects of your project goals. Doc2Vec
is an unsupervised algorithm, so in some cases it can be appropriate to use all available text to improve its training. (It doesn't necessarily cause the same problems as contaminating the training of a supervised classifier with the same already-labeled examples you'll be testing it against.)
Upvotes: 1
Reputation: 122052
The simplest way (not the cheapest though) is to iterate through the test set and then run it through the .infer_vector()
function.
from gensim.test.utils import common_texts
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(common_texts)]
model = Doc2Vec(documents, vector_size=5, window=2, min_count=1, workers=4)
sentence = "This is a system response".split() # sentence is of list(str) type.
vector = model.infer_vector(sentence)
print(vector)
And for multiple sentences:
import numpy as np
sentences = [
"This is a system response".split(),
"That is a hello world thing".split()
]
vectors = np.array([model.infer_vector(s) for s in sentences])
But looking at the code https://github.com/RaRe-Technologies/gensim/blob/62669aef21ae8047c3105d89f0032df81e73b4fa/gensim/models/doc2vec.py
There's a .dv
which means doc vectors that you can use to retrieve the vectors used to train the model. E.g.
from gensim.test.utils import common_texts
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(common_texts)]
model = Doc2Vec(documents, vector_size=5, window=2, min_count=1, workers=4)
print('No. of docs:', len(documents))
print('No. of doc vectors:',len(model.docvecs))
[out]:
No. of docs: 9
No. of doc vectors: 9
And if we add more 2 sentences:
from gensim.test.utils import common_texts
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(common_texts)]
model = Doc2Vec(documents, vector_size=5, window=2, min_count=1, workers=4)
original_len = len(documents)
print('No. of original docs:', original_len)
print('No. of original doc vectors:',len(model.docvecs))
print('-----')
# Adding the new docs.
sentences = [
"This is a system response".split(),
"That is a hello world thing".split()
]
documents += [TaggedDocument(doc, [i]) for i, doc in enumerate(sentences, start=len(documents))]
model = Doc2Vec(documents, vector_size=5, window=2, min_count=1, workers=4)
print('No. of docs:', len(documents))
print('No. of doc vectors:',len(model.docvecs))
print('-----')
for s, i in zip(sentences, range(len(sentences))):
print(i+original_len)
print(s)
print(model.docvecs[i+original_len])
print('-----')
[out]:
No. of original docs: 9
No. of original doc vectors: 9
-----
No. of docs: 11
No. of doc vectors: 11
-----
9
['This', 'is', 'a', 'system', 'response']
[-0.00675494 -0.09459886 -0.05916259 0.02931841 0.07335921]
-----
10
['That', 'is', 'a', 'hello', 'world', 'thing']
[ 0.06789951 0.07246465 0.00149267 -0.09202603 0.08346568]
Upvotes: 0