droosteh
droosteh

Reputation: 51

Issues defining a gensim word2vec based custom sklearn transformer

I've just started learning coding in Python and trying to use it for text classification purposes. To be able to create an sklearn pipeline (but also to learn how to define and use classes), I'd like to build a custom word2vec transformer. Example code to illustrate my issue (doc is a list of list of tokens):

from gensim.models.word2vec import Word2Vec
from sklearn.base import BaseEstimator, TransformerMixin

class GensimVectorizer(BaseEstimator, TransformerMixin):
    def __init__(self, n_dim = 64, n_context = 2):
        self.n_dim = n_dim
        self.n_context = n_context
    def fit(self, X):
        self.model = Word2Vec(X, size = self.n_dim, 
                              window = self.n_context, 
                              min_count = 1, sample = 1e-3, workers = 4)
        return self
    def transform(self, X):
        return self.model.wv[X]

vect = GensimVectorizer()
vect.fit(doc)
vect.transform('word')

This is giving me different results compared to the code below. What am I doing wrong?

w2v_model = Word2Vec(doc, size = 64, window = 2, min_count = 1, sample = 1e-3, workers = 4)
w2v_model.wv['word']

Upvotes: 3

Views: 1032

Answers (2)

Venkatachalam
Venkatachalam

Reputation: 16966

You can also use the sklearn wrapper for gensim's word2vec model like the following.

>>> from gensim.test.utils import common_texts
>>> from gensim.sklearn_api import W2VTransformer
>>>
>>> # Create a model to represent each word by a 10 dimensional vector.
>>> model = W2VTransformer(size=10, min_count=1, seed=1)
>>>
>>> # What is the vector representation of the word 'graph'?
>>> wordvecs = model.fit(common_texts).transform(['graph', 'system'])
>>> assert wordvecs.shape == (2, 10)

Read more here

Upvotes: 1

droosteh
droosteh

Reputation: 51

Found the answer here. 'Issue' is caused due to randomization when running word2vec. Solved by setting a seed paramater and limiting the model to a single worker thread.

Upvotes: 2

Related Questions