Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new class 'BaseSklearnWrapper' #1383

Merged
43 changes: 43 additions & 0 deletions gensim/sklearn_integration/base_sklearn_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
#
"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
follows on scikit learn API conventions
"""
from abc import ABCMeta, abstractmethod


class BaseSklearnWrapper(object):
"""
Base sklearn wrapper module
"""
__metaclass__ = ABCMeta

@abstractmethod
def get_params(self, deep=True):
pass

@abstractmethod
def set_params(self, **parameters):
"""
Set all parameters.
"""
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self

@abstractmethod
def fit(self, X, y=None):
pass

@abstractmethod
def transform(self, docs, minimum_probability=None):
pass

@abstractmethod
def partial_fit(self, X):
pass
9 changes: 4 additions & 5 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

from gensim import models
from gensim import matutils
from gensim.sklearn_integration import base_sklearn_wrapper
from scipy import sparse
from sklearn.base import TransformerMixin, BaseEstimator


class SklearnWrapperLdaModel(models.LdaModel, TransformerMixin, BaseEstimator):
class SklearnWrapperLdaModel(models.LdaModel, base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base LDA module
"""
Expand Down Expand Up @@ -70,11 +71,9 @@ def set_params(self, **parameters):
"""
Set all parameters.
"""
for parameter, value in parameters.items():
self.parameter = value
return self
super(SklearnWrapperLdaModel, self).set_params(**parameters)

def fit(self, X, y=None):
def fit(self, X, y=None):
"""
For fitting corpus into the class object.
Calls gensim.model.LdaModel:
Expand Down
10 changes: 5 additions & 5 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_lsimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

from gensim import models
from gensim import matutils
from gensim.sklearn_integration import base_sklearn_wrapper
from scipy import sparse
from sklearn.base import TransformerMixin, BaseEstimator

class SklearnWrapperLsiModel(models.LsiModel, TransformerMixin, BaseEstimator):

class SklearnWrapperLsiModel(models.LsiModel, base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base LSI module
"""
Expand Down Expand Up @@ -51,11 +53,9 @@ def set_params(self, **parameters):
"""
Set all parameters.
"""
for parameter, value in parameters.items():
self.parameter = value
return self
super(SklearnWrapperLsiModel, self).set_params(**parameters)

def fit(self, X, y=None):
def fit(self, X, y=None):
"""
For fitting corpus into the class object.
Calls gensim.model.LsiModel:
Expand Down
27 changes: 27 additions & 0 deletions gensim/test/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def testPipeline(self):
score = text_lda.score(corpus, data.target)
self.assertGreater(score, 0.40)

def testSetGetParams(self):
# updating only one param
self.model.set_params(num_topics=3)
model_params = self.model.get_params()
self.assertEqual(model_params["num_topics"], 3)

# updating multiple params
param_dict = {"eval_every": 20, "decay": 0.7}
self.model.set_params(**param_dict)
model_params = self.model.get_params()
for key in param_dict.keys():
self.assertEqual(model_params[key], param_dict[key])


class TestSklearnLSIWrapper(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -165,5 +178,19 @@ def testPipeline(self):
score = text_lda.score(corpus, data.target)
self.assertGreater(score, 0.50)

def testSetGetParams(self):
# updating only one param
self.model.set_params(num_topics=3)
model_params = self.model.get_params()
self.assertEqual(model_params["num_topics"], 3)

# updating multiple params
param_dict = {"chunksize": 10000, "decay": 0.9}
self.model.set_params(**param_dict)
model_params = self.model.get_params()
for key in param_dict.keys():
self.assertEqual(model_params[key], param_dict[key])


if __name__ == '__main__':
unittest.main()