diff --git a/docs/notebooks/topic_coherence_tutorial.ipynb b/docs/notebooks/topic_coherence_tutorial.ipynb index 33c57e728b..c484733aa2 100644 --- a/docs/notebooks/topic_coherence_tutorial.ipynb +++ b/docs/notebooks/topic_coherence_tutorial.ipynb @@ -25,7 +25,7 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": false + "collapsed": true }, "outputs": [ { @@ -59,9 +59,28 @@ "\n", "from gensim.models import CoherenceModel, LdaModel, HdpModel\n", "from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet\n", - "from gensim.corpora import Dictionary\n", - "\n", - "warnings.filterwarnings('ignore') # To ignore all warnings that arise here to enhance clarity" + "from gensim.corpora.dictionary import Dictionary\n", + "from numpy import array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set up logging" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "logging.basicConfig(level=logging.DEBUG)\n", + "logger = logging.getLogger(__name__)\n", + "logger.setLevel(logging.DEBUG)" ] }, { @@ -99,11 +118,18 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false - }, - "outputs": [], + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:gensim.corpora.dictionary:adding document #0 to Dictionary(0 unique tokens: [])\n", + "INFO:gensim.corpora.dictionary:built Dictionary(12 unique tokens: ['graph', 'system', 'user', 'minors', 'human']...) from 9 documents (total 29 corpus positions)\n" + ] + } + ], "source": [ "dictionary = Dictionary(texts)\n", "corpus = [dictionary.doc2bow(text) for text in texts]" @@ -125,16 +151,73 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false - }, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:gensim.models.ldamodel:using symmetric alpha at 0.5\n", + "INFO:gensim.models.ldamodel:using symmetric eta at 0.08333333333333333\n", + "INFO:gensim.models.ldamodel:using serial LDA version on this node\n", + "INFO:gensim.models.ldamodel:running online LDA training, 2 topics, 1 passes over the supplied corpus of 9 documents, updating model once every 9 documents, evaluating perplexity and coherence every 9 documents, iterating 50x with a convergence threshold of 0.001000\n", + "WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n", + "DEBUG:gensim.models.ldamodel:bound: at document #0\n", + "INFO:gensim.models.ldamodel:-3.297 per-word bound, 9.8 perplexity estimate based on a held-out corpus of 9 documents with 29 words\n", + "INFO:gensim.models.ldamodel:-15.478 coherence estimate based on a held-out corpus of 9 documents with 29 words\n", + "INFO:gensim.models.ldamodel:PROGRESS: pass 0, at document #9/9\n", + "DEBUG:gensim.models.ldamodel:performing inference on a chunk of 9 documents\n", + "DEBUG:gensim.models.ldamodel:6/9 documents converged within 50 iterations\n", + "DEBUG:gensim.models.ldamodel:updating topics\n", + "INFO:gensim.models.ldamodel:topic #0 (0.500): 0.169*\"system\" + 0.120*\"human\" + 0.098*\"eps\" + 0.094*\"interface\" + 0.091*\"computer\" + 0.071*\"user\" + 0.067*\"graph\" + 0.062*\"trees\" + 0.062*\"time\" + 0.061*\"minors\"\n", + "INFO:gensim.models.ldamodel:topic #1 (0.500): 0.129*\"trees\" + 0.125*\"graph\" + 0.121*\"user\" + 0.098*\"response\" + 0.086*\"survey\" + 0.084*\"minors\" + 0.083*\"time\" + 0.080*\"system\" + 0.057*\"computer\" + 0.055*\"interface\"\n", + "INFO:gensim.models.ldamodel:topic diff=0.420882, rho=1.000000\n" + ] + } + ], + "source": [ + "goodLdaModel = LdaModel(corpus=corpus, id2word=dictionary, iterations=50, num_topics=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:gensim.models.ldamodel:using symmetric alpha at 0.5\n", + "INFO:gensim.models.ldamodel:using symmetric eta at 0.08333333333333333\n", + "INFO:gensim.models.ldamodel:using serial LDA version on this node\n", + "INFO:gensim.models.ldamodel:running online LDA training, 2 topics, 1 passes over the supplied corpus of 9 documents, updating model once every 9 documents, evaluating perplexity and coherence every 9 documents, iterating 1x with a convergence threshold of 0.001000\n", + "WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n", + "DEBUG:gensim.models.ldamodel:bound: at document #0\n", + "INFO:gensim.models.ldamodel:-3.317 per-word bound, 10.0 perplexity estimate based on a held-out corpus of 9 documents with 29 words\n", + "INFO:gensim.models.ldamodel:-16.260 coherence estimate based on a held-out corpus of 9 documents with 29 words\n", + "INFO:gensim.models.ldamodel:PROGRESS: pass 0, at document #9/9\n", + "DEBUG:gensim.models.ldamodel:performing inference on a chunk of 9 documents\n", + "DEBUG:gensim.models.ldamodel:0/9 documents converged within 1 iterations\n", + "DEBUG:gensim.models.ldamodel:updating topics\n", + "INFO:gensim.models.ldamodel:topic #0 (0.500): 0.138*\"system\" + 0.111*\"graph\" + 0.098*\"trees\" + 0.082*\"user\" + 0.081*\"survey\" + 0.075*\"interface\" + 0.073*\"time\" + 0.072*\"minors\" + 0.072*\"human\" + 0.069*\"eps\"\n", + "INFO:gensim.models.ldamodel:topic #1 (0.500): 0.114*\"user\" + 0.105*\"system\" + 0.097*\"trees\" + 0.086*\"response\" + 0.083*\"graph\" + 0.079*\"computer\" + 0.077*\"eps\" + 0.075*\"human\" + 0.074*\"minors\" + 0.073*\"time\"\n", + "INFO:gensim.models.ldamodel:topic diff=0.255033, rho=1.000000\n" + ] + } + ], "source": [ - "goodLdaModel = LdaModel(corpus=corpus, id2word=dictionary, iterations=50, num_topics=2)\n", "badLdaModel = LdaModel(corpus=corpus, id2word=dictionary, iterations=1, num_topics=2)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the LdaModel in gensim also logs the evaluation parameters `perplexity` and `coherence` according to `eval_every`. These parameter values can be used to monitor the LDA training and evaluate how the topics are improving during training. " + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -144,13 +227,23 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": { - "collapsed": false + "collapsed": true + }, + "outputs": [], + "source": [ + "goodcm = CoherenceModel(model=goodLdaModel, corpus=corpus, dictionary=dictionary, coherence='u_mass')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true }, "outputs": [], "source": [ - "goodcm = CoherenceModel(model=goodLdaModel, corpus=corpus, dictionary=dictionary, coherence='u_mass')\n", "badcm = CoherenceModel(model=badLdaModel, corpus=corpus, dictionary=dictionary, coherence='u_mass')" ] }, @@ -170,10 +263,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false - }, + "execution_count": 16, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -218,10 +309,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "metadata": { - "collapsed": false + "collapsed": true }, + "outputs": [], + "source": [ + "pyLDAvis.enable_notebook()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, "outputs": [ { "data": { @@ -284,10 +384,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false - }, + "execution_count": 19, + "metadata": {}, "outputs": [ { "data": { @@ -339,7 +437,24 @@ ] }, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" + } + ], + "source": [ + "pyLDAvis.gensim.prepare(badLdaModel, corpus, dictionary)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-14.0842451581\n" + ] } ], "source": [ @@ -350,10 +465,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false - }, + "execution_count": 21, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -397,10 +510,8 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false - }, + "execution_count": 27, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -423,10 +534,25 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false - }, + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.552164532134\n" + ] + } + ], + "source": [ + "print goodcm.get_coherence()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -486,10 +612,8 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "collapsed": false - }, + "execution_count": 8, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -511,7 +635,7 @@ "cell_type": "code", "execution_count": 16, "metadata": { - "collapsed": false + "collapsed": true }, "outputs": [], "source": [ @@ -521,10 +645,20 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 21, "metadata": { - "collapsed": false + "collapsed": true }, + "outputs": [], + "source": [ + "cm1 = CoherenceModel(model=model1, texts=texts, coherence='c_v')\n", + "cm2 = CoherenceModel(model=model2, texts=texts, coherence='c_v')" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -563,10 +697,23 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": { - "collapsed": false + "collapsed": true }, + "outputs": [], + "source": [ + "# To get the topic words from the model\n", + "topics = []\n", + "for topic_id, topic in hm.show_topics(num_topics=10, formatted=False):\n", + " topic = [word for word, _ in topic]\n", + " topics.append(topic)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "data": { @@ -613,10 +760,20 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 10, "metadata": { - "collapsed": false + "collapsed": true }, + "outputs": [], + "source": [ + "# Initialize CoherenceModel using `topics` parameter\n", + "cm = CoherenceModel(topics=topics, corpus=corpus, dictionary=dictionary, coherence='u_mass')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, "outputs": [ { "data": { @@ -652,23 +809,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.4.3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 9f41334d47..0422865d4e 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -35,6 +35,8 @@ import numbers from random import sample import os +import gensim +import copy from gensim import interfaces, utils, matutils from gensim.matutils import dirichlet_expectation @@ -195,7 +197,9 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, alpha='symmetric', eta=None, decay=0.5, offset=1.0, eval_every=10, iterations=50, gamma_threshold=0.001, minimum_probability=0.01, random_state=None, ns_conf={}, - minimum_phi_value=0.01, per_word_topics=False): + minimum_phi_value=0.01, per_word_topics=False, coherence='u_mass', + coherence_texts=None, coherence_window_size=None, coherence_topn=10, + log_diff=False, diff_distance="jaccard"): """ If given, start training from the iterable `corpus` straight away. If not given, the model is left untrained (presumably because you want to call `update()` manually). @@ -228,7 +232,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, Turn on `distributed` to force distributed computing (see the `web tutorial `_ on how to set up a cluster of machines for gensim). - Calculate and log perplexity estimate from the latest mini-batch every + Calculate and log perplexity and coherence estimate from the latest mini-batch every `eval_every` model updates (setting this to 1 slows down training ~2x; default is 10 for better performance). Set to None to disable perplexity estimation. @@ -239,6 +243,34 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, `random_state` can be a np.random.RandomState object or the seed for one + `coherence` : Coherence measure to be used for logging coherence. Supported values are: + 'u_mass' + 'c_v' + 'c_uci' also popularly known as c_pmi + 'c_npmi' + For 'c_v', 'c_uci' and 'c_npmi' texts should be provided. + + `coherence_texts` : Tokenized texts. Needed when eval_every is not None and coherence is logged for sliding_window_based measures c_v, c_uci, c_npmi. eg:: + texts = [['system', 'human', 'system', 'eps'], + ['user', 'response', 'time'], + ['trees'], + ['graph', 'trees'], + ['graph', 'minors', 'trees'], + ['graph', 'minors', 'survey']] + + `coherence_window_size` : Is the size of the window to be used for coherence measures using boolean sliding window as their + probability estimator. For 'u_mass' this doesn't matter. + If left 'None' the default window sizes are used which are: + 'c_v' : 110 + 'c_uci' : 10 + 'c_npmi' : 10 + + `coherence_topn` Integer corresponding to the number of top words to be extracted from each topic for coherence logging. + + `log_diff` set to True to log topic diff between consecutive epochs + + `diff_distance` is the distance measure to use for `log_diff` + Example: >>> lda = LdaModel(corpus, num_topics=100) # train model @@ -281,6 +313,14 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, self.minimum_phi_value = minimum_phi_value self.per_word_topics = per_word_topics + self.coherence = coherence + self.coherence_texts = coherence_texts + self.coherence_window_size = coherence_window_size + self.coherence_topn = coherence_topn + + self.log_diff = log_diff + self.diff_distance = diff_distance + self.alpha, self.optimize_alpha = self.init_dir_prior(alpha, 'alpha') assert self.alpha.shape == (self.num_topics,), "Invalid alpha shape. Got shape %s, but expected (%d, )" % (str(self.alpha.shape), self.num_topics) @@ -529,9 +569,33 @@ def log_perplexity(self, chunk, total_docs=None): (perwordbound, np.exp2(-perwordbound), len(chunk), corpus_words)) return perwordbound + def log_coherence(self, chunk, texts, coherence, window_size, topn): + """ + Log coherence using the `chunk` of documents as evaluation corpus. + """ + cm = gensim.models.CoherenceModel(model=self, corpus=chunk, texts=texts, dictionary=self.id2word, coherence=coherence, window_size=window_size, topn=topn) + coherence = cm.get_coherence() + corpus_words = sum(cnt for document in chunk for _, cnt in document) + logger.info("%.3f coherence estimate based on a held-out corpus of %i documents with %i words", coherence, len(chunk), corpus_words) + return coherence + + def log_epoch_diff(self, epoch, other_model, distance): + """ + Log topic diff between consecutive epochs + """ + diff_matrix, annotation = self.diff(other_model, distance) + diff_diagonal = np.diagonal(diff_matrix) + prev_epoch = epoch - 1 + if epoch == 0: + prev_epoch = "initial random model" + logger.info("Topic difference between %s and %s epoch: %s", prev_epoch, epoch, diff_diagonal) + return diff_diagonal + def update(self, corpus, chunksize=None, decay=None, offset=None, passes=None, update_every=None, eval_every=None, iterations=None, - gamma_threshold=None, chunks_as_numpy=False): + gamma_threshold=None, chunks_as_numpy=False, coherence=None, + coherence_texts=None, coherence_window_size=None, coherence_topn=None, + log_diff=None, diff_distance=None): """ Train the model with new documents, by EM-iterating over `corpus` until the topics converge (or until the maximum number of allowed iterations @@ -578,6 +642,18 @@ def update(self, corpus, chunksize=None, decay=None, offset=None, iterations = self.iterations if gamma_threshold is None: gamma_threshold = self.gamma_threshold + if coherence is None: + coherence = self.coherence + if coherence_texts is None: + coherence_texts = self.coherence_texts + if coherence_window_size is None: + coherence_window_size = self.coherence_window_size + if coherence_topn is None: + coherence_topn = self.coherence_topn + if log_diff is None: + log_diff = self.log_diff + if diff_distance is None: + diff_distance = self.diff_distance try: lencorpus = len(corpus) @@ -609,8 +685,8 @@ def update(self, corpus, chunksize=None, decay=None, offset=None, logger.info( "running %s LDA training, %s topics, %i passes over " "the supplied corpus of %i documents, updating model once " - "every %i documents, evaluating perplexity every %i documents, " - "iterating %ix with a convergence threshold of %f", + "every %i documents, evaluating perplexity and coherence " + "every %i documents, iterating %ix with a convergence threshold of %f", updatetype, self.num_topics, passes, lencorpus, updateafter, evalafter, iterations, gamma_threshold) @@ -640,6 +716,15 @@ def rho(): if eval_every and ((reallen == lencorpus) or ((chunk_no + 1) % (eval_every * self.numworkers) == 0)): self.log_perplexity(chunk, total_docs=lencorpus) + # texts input is needed for sliding window based coherence measures (c_v, c_uci, c_npmi) + if coherence_texts is not None: + init = (chunk_no + 1) * chunksize - chunksize + end = init + chunksize + # texts subarray corresponding to current chunk + texts_chunk = coherence_texts[init:end] + else: + texts_chunk = None + self.log_coherence(chunk, texts_chunk, coherence, coherence_window_size, coherence_topn) if self.dispatcher: # add the chunk to dispatcher's job queue, so workers can munch on it @@ -674,6 +759,15 @@ def rho(): other = LdaState(self.eta, self.state.sstats.shape) dirty = False # endfor single corpus iteration + + # log diff between consecutive epochs + if log_diff: + if pass_ == 0: + # save randomly initialized model for diff with first pass + previous = copy.deepcopy(self) + self.log_epoch_diff(pass_, previous, diff_distance) + previous = copy.deepcopy(self) + if reallen != lencorpus: raise RuntimeError("input corpus size changed during training (don't use generators as input)")