diff --git a/docs/notebooks/soft_cosine_tutorial.ipynb b/docs/notebooks/soft_cosine_tutorial.ipynb new file mode 100644 index 0000000000..e5d11dcd3f --- /dev/null +++ b/docs/notebooks/soft_cosine_tutorial.ipynb @@ -0,0 +1,593 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Finding similar documents with Word2Vec and Soft Cosine Measure \n", + "\n", + "Soft Cosine Measure (SCM) is a promising new tool in machine learning that allows us to submit a query and return the most relevant documents. In **part 1**, we will show how you can compute SCM between two documents using `softcossim`. In **part 2**, we will use `SoftCosineSimilarity` to retrieve documents most similar to a query and compare the performance against other similarity measures.\n", + "\n", + "First, however, we go through the basics of what Soft Cosine Measure is.\n", + "\n", + "## Soft Cosine Measure basics\n", + "\n", + "Soft Cosine Measure (SCM) is a method that allows us to assess the similarity between two documents in a meaningful way, even when they have no words in common. It uses a measure of similarity between words, which can be derived [2] using [word2vec][] [3] vector embeddings of words. It has been shown to outperform many of the state-of-the-art methods in the semantic text similarity task in the context of community question answering [2].\n", + "\n", + "[word2vec]: https://radimrehurek.com/gensim/models/word2vec.html\n", + "\n", + "SCM is illustrated below for two very similar sentences. The sentences have no words in common, but by modeling synonymy, SCM is able to accurately measure the similarity between the two sentences. The method also uses the bag-of-words vector representation of the documents (simply put, the word's frequencies in the documents). The intution behind the method is that we compute standard cosine similarity assuming that the document vectors are expressed in a non-orthogonal basis, where the angle between two basis vectors is derived from the angle between the word2vec embeddings of the corresponding words.\n", + "\n", + "![Soft Cosine Measure](soft_cosine_tutorial.png)\n", + "\n", + "This method was perhaps first introduced in the article “Soft Measure and Soft Cosine Measure: Measure of Features in Vector Space Model” by Grigori Sidorov, Alexander Gelbukh, Helena Gomez-Adorno, and David Pinto ([link to PDF](http://www.scielo.org.mx/pdf/cys/v18n3/v18n3a7.pdf)).\n", + "\n", + "In this tutorial, we will learn how to use Gensim's SCM functionality, which consists of the `softcossim` function for one-off computation, and the `SoftCosineSimilarity` class for corpus-based similarity queries.\n", + "\n", + "> **Note**:\n", + ">\n", + "> If you use this software, please consider citing [1] and [2].\n", + ">\n", + "\n", + "## Running this notebook\n", + "You can download this [Jupyter notebook](http://jupyter.org/), and run it on your own computer, provided you have installed the `gensim`, `jupyter`, `sklearn`, `pyemd`, and `wmd` Python packages.\n", + "\n", + "The notebook was run on an Ubuntu machine with an Intel core i7-6700HQ CPU 3.10GHz (4 cores) and 16 GB memory. Assuming all resources required by the notebook have already been downloaded, running the entire notebook on this machine takes about 30 minutes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize logging.\n", + "import logging\n", + "logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 1: Computing the Soft Cosine Measure\n", + "\n", + "To use SCM, we need some word embeddings first of all. You could train a [word2vec][] (see tutorial [here](http://rare-technologies.com/word2vec-tutorial/)) model on some corpus, but we will use pre-trained word2vec embeddings.\n", + "\n", + "[word2vec]: https://radimrehurek.com/gensim/models/word2vec.html\n", + "\n", + "Let's create some sentences to compare." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "sentence_obama = 'Obama speaks to the media in Illinois'.lower().split()\n", + "sentence_president = 'The president greets the press in Chicago'.lower().split()\n", + "sentence_orange = 'Oranges are my favorite fruit'.lower().split()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first two sentences have very similar content, and as such the SCM should be large. Before we compute the SCM, we want to remove stopwords (\"the\", \"to\", etc.), as these do not contribute a lot to the information in the sentences." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package stopwords to /home/witiko/nltk_data...\n", + "[nltk_data] Package stopwords is already up-to-date!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2018-02-05 10:47:42,975 : INFO : built Dictionary(11 unique tokens: ['president', 'fruit', 'greets', 'obama', 'illinois']...) from 3 documents (total 11 corpus positions)\n" + ] + } + ], + "source": [ + "# Import and download stopwords from NLTK.\n", + "from nltk.corpus import stopwords\n", + "from nltk import download\n", + "download('stopwords') # Download stopwords list.\n", + "\n", + "# Remove stopwords.\n", + "stop_words = stopwords.words('english')\n", + "sentence_obama = [w for w in sentence_obama if w not in stop_words]\n", + "sentence_president = [w for w in sentence_president if w not in stop_words]\n", + "sentence_orange = [w for w in sentence_orange if w not in stop_words]\n", + "\n", + "# Prepare a dictionary and a corpus.\n", + "from gensim import corpora\n", + "documents = [sentence_obama, sentence_president, sentence_orange]\n", + "dictionary = corpora.Dictionary(documents)\n", + "corpus = [dictionary.doc2bow(document) for document in documents]\n", + "\n", + "# Convert the sentences into bag-of-words vectors.\n", + "sentence_obama = dictionary.doc2bow(sentence_obama)\n", + "sentence_president = dictionary.doc2bow(sentence_president)\n", + "sentence_orange = dictionary.doc2bow(sentence_orange)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, as we mentioned earlier, we will be using some downloaded pre-trained embeddings. Note that the embeddings we have chosen here require a lot of memory. We will use the embeddings to construct a term similarity matrix that will be used by the `softcossim` function." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2018-02-06 16:14:29,104 : INFO : constructed a term similarity matrix with 91.735537 % nonzero elements\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 21.2 s, sys: 224 ms, total: 21.4 s\n", + "Wall time: 21.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "import gensim.downloader as api\n", + "\n", + "w2v_model = api.load(\"glove-wiki-gigaword-50\")\n", + "similarity_matrix = w2v_model.similarity_matrix(dictionary)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So let's compute SCM using the `softcossim` function." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "similarity = 0.5789\n" + ] + } + ], + "source": [ + "from gensim.matutils import softcossim\n", + "\n", + "similarity = softcossim(sentence_obama, sentence_president, similarity_matrix)\n", + "print('similarity = %.4f' % similarity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try the same thing with two completely unrelated sentences. Notice that the similarity is smaller." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "similarity = 0.1439\n" + ] + } + ], + "source": [ + "similarity = softcossim(sentence_obama, sentence_orange, similarity_matrix)\n", + "print('similarity = %.4f' % similarity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2: Similarity queries using `SoftCosineSimilarity`\n", + "You can use SCM to get the most similar documents to a query, using the SoftCosineSimilarity class. Its interface is similar to what is described in the [Similarity Queries](https://radimrehurek.com/gensim/tut3.html) Gensim tutorial.\n", + "\n", + "### Qatar Living unannotated dataset\n", + "Contestants solving the community question answering task in the [SemEval 2016][semeval16] and [2017][semeval17] competitions had an unannotated dataset of 189,941 questions and 1,894,456 comments from the [Qatar Living][ql] discussion forums. As our first step, we will use the same dataset to build a corpus.\n", + "\n", + "[semeval16]: http://alt.qcri.org/semeval2016/task3/\n", + "[semeval17]: http://alt.qcri.org/semeval2017/task3/\n", + "[ql]: http://www.qatarliving.com/forum" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package stopwords to /home/witiko/nltk_data...\n", + "[nltk_data] Package stopwords is already up-to-date!\n", + "Number of documents: 3\n", + "CPU times: user 1min 59s, sys: 6.06 s, total: 2min 5s\n", + "Wall time: 2min 22s\n" + ] + } + ], + "source": [ + "%%time\n", + "from itertools import chain\n", + "import json\n", + "from re import sub\n", + "from os.path import isfile\n", + "\n", + "import gensim.downloader as api\n", + "from gensim.utils import simple_preprocess\n", + "from nltk.corpus import stopwords\n", + "from nltk import download\n", + "\n", + "\n", + "download(\"stopwords\") # Download stopwords list.\n", + "stopwords = set(stopwords.words(\"english\"))\n", + "\n", + "def preprocess(doc):\n", + " doc = sub(r']+(>|$)', \" image_token \", doc)\n", + " doc = sub(r'<[^<>]+(>|$)', \" \", doc)\n", + " doc = sub(r'\\[img_assist[^]]*?\\]', \" \", doc)\n", + " doc = sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', \" url_token \", doc)\n", + " return [token for token in simple_preprocess(doc, min_len=0, max_len=float(\"inf\")) if token not in stopwords]\n", + "\n", + "corpus = list(chain(*[\n", + " chain(\n", + " [preprocess(thread[\"RelQuestion\"][\"RelQSubject\"]), preprocess(thread[\"RelQuestion\"][\"RelQBody\"])],\n", + " [preprocess(relcomment[\"RelCText\"]) for relcomment in thread[\"RelComments\"]])\n", + " for thread in api.load(\"semeval-2016-2017-task3-subtaskA-unannotated\")]))\n", + "\n", + "print(\"Number of documents: %d\" % len(documents))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the corpus we have just build, we will now construct a [dictionary][], a [TF-IDF model][tfidf], a [word2vec model][word2vec], and a term similarity matrix.\n", + "\n", + "[dictionary]: https://radimrehurek.com/gensim/corpora/dictionary.html\n", + "[tfidf]: https://radimrehurek.com/gensim/models/tfidfmodel.html\n", + "[word2vec]: https://radimrehurek.com/gensim/models/word2vec.html" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2018-02-05 10:52:53,477 : INFO : built Dictionary(462807 unique tokens: ['reclarify', 'depeneded', 'autralia', 'cloudnight', 'openmoko']...) from 2274338 documents (total 40096354 corpus positions)\n", + "2018-02-05 10:56:50,633 : INFO : training on a 200481770 raw words (192577574 effective words) took 224.3s, 858402 effective words/s\n", + "2018-02-05 11:13:14,895 : INFO : constructed a term similarity matrix with 0.003564 % nonzero elements\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of unique words: 462807\n", + "CPU times: user 1h 2min 21s, sys: 12min 56s, total: 1h 15min 17s\n", + "Wall time: 21min 27s\n" + ] + } + ], + "source": [ + "%%time\n", + "from gensim.corpora import Dictionary\n", + "from gensim.models import TfidfModel\n", + "from gensim.models import Word2Vec\n", + "from multiprocessing import cpu_count\n", + "\n", + "dictionary = Dictionary(corpus)\n", + "tfidf = TfidfModel(dictionary=dictionary)\n", + "w2v_model = Word2Vec(corpus, workers=cpu_count(), min_count=5, size=300, seed=12345)\n", + "similarity_matrix = w2v_model.wv.similarity_matrix(dictionary, tfidf, nonzero_limit=100)\n", + "\n", + "print(\"Number of unique words: %d\" % len(dictionary))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation\n", + "Next, we will load the validation and test datasets that were used by the SemEval 2016 and 2017 contestants. The datasets contain 208 original questions posted by the forum members. For each question, there is a list of 10 threads with a human annotation denoting whether or not the thread is relevant to the original question. Our task will be to order the threads so that relevant threads rank above irrelevant threads." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "datasets = api.load(\"semeval-2016-2017-task3-subtaskBC\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we will perform an evaluation to compare three unsupervised similarity measures – the Soft Cosine Measure, two different implementations of the [Word Mover's Distance][wmd], and standard cosine similarity. We will use the [Mean Average Precision (MAP)][map] as an evaluation measure and 10-fold cross-validation to get an estimate of the variance of MAP for each similarity measure.\n", + "\n", + "[wmd]: http://vene.ro/blog/word-movers-distance-in-python.html\n", + "[map]: https://medium.com/@pds.bangalore/mean-average-precision-abd77d0b9a7e" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from math import isnan\n", + "from time import time\n", + "\n", + "from gensim.similarities import MatrixSimilarity, WmdSimilarity, SoftCosineSimilarity\n", + "import numpy as np\n", + "from sklearn.model_selection import KFold\n", + "from wmd import WMD\n", + "\n", + "def produce_test_data(dataset):\n", + " for orgquestion in datasets[dataset]:\n", + " query = preprocess(orgquestion[\"OrgQSubject\"]) + preprocess(orgquestion[\"OrgQBody\"])\n", + " documents = [\n", + " preprocess(thread[\"RelQuestion\"][\"RelQSubject\"]) + preprocess(thread[\"RelQuestion\"][\"RelQBody\"])\n", + " for thread in orgquestion[\"Threads\"]]\n", + " relevance = [\n", + " thread[\"RelQuestion\"][\"RELQ_RELEVANCE2ORGQ\"] in (\"PerfectMatch\", \"Relevant\")\n", + " for thread in orgquestion[\"Threads\"]]\n", + " yield query, documents, relevance\n", + "\n", + "def cossim(query, documents):\n", + " # Compute cosine similarity between the query and the documents.\n", + " query = tfidf[dictionary.doc2bow(query)]\n", + " index = MatrixSimilarity(\n", + " tfidf[[dictionary.doc2bow(document) for document in documents]],\n", + " num_features=len(dictionary))\n", + " similarities = index[query]\n", + " return similarities\n", + "\n", + "def softcossim(query, documents):\n", + " # Compute Soft Cosine Measure between the query and the documents.\n", + " query = tfidf[dictionary.doc2bow(query)]\n", + " index = SoftCosineSimilarity(\n", + " tfidf[[dictionary.doc2bow(document) for document in documents]],\n", + " similarity_matrix)\n", + " similarities = index[query]\n", + " return similarities\n", + "\n", + "def wmd_gensim(query, documents):\n", + " # Compute Word Mover's Distance as implemented in PyEMD by William Mayner\n", + " # between the query and the documents.\n", + " index = WmdSimilarity(documents, w2v_model)\n", + " similarities = index[query]\n", + " return similarities\n", + "\n", + "def wmd_relax(query, documents):\n", + " # Compute Word Mover's Distance as implemented in WMD by Source{d}\n", + " # between the query and the documents.\n", + " words = [word for word in set(chain(query, *documents)) if word in w2v_model.wv]\n", + " indices, words = zip(*sorted((\n", + " (index, word) for (index, _), word in zip(dictionary.doc2bow(words), words))))\n", + " query = dict(tfidf[dictionary.doc2bow(query)])\n", + " query = [\n", + " (new_index, query[dict_index])\n", + " for new_index, dict_index in enumerate(indices)\n", + " if dict_index in query]\n", + " documents = [dict(tfidf[dictionary.doc2bow(document)]) for document in documents]\n", + " documents = [[\n", + " (new_index, document[dict_index])\n", + " for new_index, dict_index in enumerate(indices)\n", + " if dict_index in document] for document in documents]\n", + " embeddings = np.array([w2v_model.wv[word] for word in words], dtype=np.float32)\n", + " nbow = dict(((index, list(chain([None], zip(*document)))) for index, document in enumerate(documents)))\n", + " nbow[\"query\"] = (None, *zip(*query))\n", + " distances = WMD(embeddings, nbow, vocabulary_min=1).nearest_neighbors(\"query\")\n", + " similarities = [-distance for _, distance in sorted(distances)]\n", + " return similarities\n", + "\n", + "strategies = {\n", + " \"cossim\" : cossim,\n", + " \"softcossim\": softcossim,\n", + " \"wmd-gensim\": wmd_gensim,\n", + " \"wmd-relax\": wmd_relax}\n", + "\n", + "def evaluate(split, strategy):\n", + " # Perform a single round of evaluation.\n", + " results = []\n", + " start_time = time()\n", + " for query, documents, relevance in split:\n", + " similarities = strategies[strategy](query, documents)\n", + " assert len(similarities) == len(documents)\n", + " precision = [\n", + " (num_correct + 1) / (num_total + 1) for num_correct, num_total in enumerate(\n", + " num_total for num_total, (_, relevant) in enumerate(\n", + " sorted(zip(similarities, relevance), reverse=True)) if relevant)]\n", + " average_precision = np.mean(precision) if precision else 0.0\n", + " results.append(average_precision)\n", + " return (np.mean(results) * 100, time() - start_time)\n", + "\n", + "def crossvalidate(args):\n", + " # Perform a cross-validation.\n", + " dataset, strategy = args\n", + " test_data = np.array(list(produce_test_data(dataset)))\n", + " kf = KFold(n_splits=10)\n", + " samples = []\n", + " for _, test_index in kf.split(test_data):\n", + " samples.append(evaluate(test_data[test_index], strategy))\n", + " return (np.mean(samples, axis=0), np.std(samples, axis=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.49 s, sys: 1.28 s, total: 2.77 s\n", + "Wall time: 1min 42s\n" + ] + } + ], + "source": [ + "%%time\n", + "from multiprocessing import Pool\n", + "\n", + "args_list = [\n", + " (dataset, technique)\n", + " for dataset in (\"2016-test\", \"2017-test\")\n", + " for technique in (\"softcossim\", \"wmd-gensim\", \"wmd-relax\", \"cossim\")]\n", + "with Pool() as pool:\n", + " results = pool.map(crossvalidate, args_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The table below shows the pointwise estimates of means and standard variances for MAP scores and elapsed times. Baselines and winners for each year are displayed in bold. We can see that the Soft Cosine Measure gives a strong performance on both the 2016 and the 2017 dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "\n", + "Dataset | Strategy | MAP score | Elapsed time (sec)\n", + ":---|:---|:---|---:\n", + "2016-test|softcossim|77.29 ±10.35|0.20 ±0.06\n", + "2016-test|**Winner (UH-PRHLT-primary)**|76.70 ±0.00|\n", + "2016-test|cossim|76.45 ±10.40|0.48 ±0.07\n", + "2016-test|wmd-gensim|76.07 ±11.52|8.36 ±2.05\n", + "2016-test|**Baseline 1 (IR)**|74.75 ±0.00|\n", + "2016-test|wmd-relax|73.01 ±10.33|0.97 ±0.16\n", + "2016-test|**Baseline 2 (random)**|46.98 ±0.00|\n", + "\n", + "\n", + "Dataset | Strategy | MAP score | Elapsed time (sec)\n", + ":---|:---|:---|---:\n", + "2017-test|**Winner (SimBow-primary)**|47.22 ±0.00|\n", + "2017-test|softcossim|46.06 ±18.00|0.15 ±0.03\n", + "2017-test|cossim|44.38 ±14.71|0.43 ±0.07\n", + "2017-test|wmd-gensim|44.20 ±16.02|9.78 ±1.80\n", + "2017-test|**Baseline 1 (IR)**|41.85 ±0.00|\n", + "2017-test|wmd-relax|41.24 ±14.87|1.00 ±0.26\n", + "2017-test|**Baseline 2 (random)**|29.81 ±0.00|" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display, Markdown\n", + "\n", + "output = []\n", + "baselines = [\n", + " ((\"2016-test\", \"**Winner (UH-PRHLT-primary)**\"), ((76.70, 0), (0, 0))),\n", + " ((\"2016-test\", \"**Baseline 1 (IR)**\"), ((74.75, 0), (0, 0))),\n", + " ((\"2016-test\", \"**Baseline 2 (random)**\"), ((46.98, 0), (0, 0))),\n", + " ((\"2017-test\", \"**Winner (SimBow-primary)**\"), ((47.22, 0), (0, 0))),\n", + " ((\"2017-test\", \"**Baseline 1 (IR)**\"), ((41.85, 0), (0, 0))),\n", + " ((\"2017-test\", \"**Baseline 2 (random)**\"), ((29.81, 0), (0, 0)))]\n", + "table_header = [\"Dataset | Strategy | MAP score | Elapsed time (sec)\", \":---|:---|:---|---:\"]\n", + "for row, ((dataset, technique), ((mean_map_score, mean_duration), (std_map_score, std_duration))) \\\n", + " in enumerate(sorted(chain(zip(args_list, results), baselines), key=lambda x: (x[0][0], -x[1][0][0]))):\n", + " if row % (len(strategies) + 3) == 0:\n", + " output.extend(chain([\"\\n\"], table_header))\n", + " map_score = \"%.02f ±%.02f\" % (mean_map_score, std_map_score)\n", + " duration = \"%.02f ±%.02f\" % (mean_duration, std_duration) if mean_duration else \"\"\n", + " output.append(\"%s|%s|%s|%s\" % (dataset, technique, map_score, duration))\n", + "\n", + "display(Markdown('\\n'.join(output)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "1. Grigori Sidorov et al. *Soft Similarity and Soft Cosine Measure: Similarity of Features in Vector Space Model*, 2014. ([link to PDF](http://www.scielo.org.mx/pdf/cys/v18n3/v18n3a7.pdf))\n", + "2. Delphine Charlet and Geraldine Damnati, SimBow at SemEval-2017 Task 3: Soft-Cosine Semantic Similarity between Questions for Community Question Answering, 2017. ([link to PDF](http://www.aclweb.org/anthology/S17-2051))\n", + "3. Thomas Mikolov et al. Efficient Estimation of Word Representations in Vector Space, 2013. ([link to PDF](https://arxiv.org/pdf/1301.3781.pdf))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.3" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/docs/notebooks/soft_cosine_tutorial.png b/docs/notebooks/soft_cosine_tutorial.png new file mode 100644 index 0000000000..6a6dede60e Binary files /dev/null and b/docs/notebooks/soft_cosine_tutorial.png differ diff --git a/gensim/matutils.py b/gensim/matutils.py index 570b025182..676bca9a63 100644 --- a/gensim/matutils.py +++ b/gensim/matutils.py @@ -9,6 +9,7 @@ from __future__ import with_statement +from itertools import chain import logging import math @@ -755,6 +756,77 @@ def cossim(vec1, vec2): return result +def softcossim(vec1, vec2, similarity_matrix): + """Get Soft Cosine Measure between two vectors given a term similarity matrix. + + Return Soft Cosine Measure between two sparse vectors given a sparse term similarity matrix + in the :class:`scipy.sparse.csc_matrix` format. The similarity is a number between <-1.0, 1.0>, + higher is more similar. + + Parameters + ---------- + vec1 : list of (int, float) + A query vector in the BoW format. + vec2 : list of (int, float) + A document vector in the BoW format. + similarity_matrix : {:class:`scipy.sparse.csc_matrix`, :class:`scipy.sparse.csr_matrix`} + A term similarity matrix, typically produced by + :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`. + + Returns + ------- + `similarity_matrix.dtype` + The Soft Cosine Measure between `vec1` and `vec2`. + + Raises + ------ + ValueError + When the term similarity matrix is in an unknown format. + + See Also + -------- + :meth:`gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix` + A term similarity matrix produced from term embeddings. + :class:`gensim.similarities.docsim.SoftCosineSimilarity` + A class for performing corpus-based similarity queries with Soft Cosine Measure. + + References + ---------- + Soft Cosine Measure was perhaps first defined by [sidorovetal14]_. + + .. [sidorovetal14] Grigori Sidorov et al., "Soft Similarity and Soft Cosine Measure: Similarity + of Features in Vector Space Model", 2014, http://www.cys.cic.ipn.mx/ojs/index.php/CyS/article/view/2043/1921. + + """ + if not isinstance(similarity_matrix, scipy.sparse.csc_matrix): + if isinstance(similarity_matrix, scipy.sparse.csr_matrix): + similarity_matrix = similarity_matrix.T + else: + raise ValueError('unknown similarity matrix format') + + if not vec1 or not vec2: + return 0.0 + + vec1 = dict(vec1) + vec2 = dict(vec2) + word_indices = sorted(set(chain(vec1, vec2))) + dtype = similarity_matrix.dtype + vec1 = np.array([vec1[i] if i in vec1 else 0 for i in word_indices], dtype=dtype) + vec2 = np.array([vec2[i] if i in vec2 else 0 for i in word_indices], dtype=dtype) + dense_matrix = similarity_matrix[[[i] for i in word_indices], word_indices].todense() + vec1len = vec1.T.dot(dense_matrix).dot(vec1)[0, 0] + vec2len = vec2.T.dot(dense_matrix).dot(vec2)[0, 0] + + assert \ + vec1len > 0.0 and vec2len > 0.0, \ + u"sparse documents must not contain any explicit zero entries and the similarity matrix S " \ + u"must satisfy x^T * S * x > 0 for any nonzero bag-of-words vector x." + + result = vec1.T.dot(dense_matrix).dot(vec2)[0, 0] + result /= math.sqrt(vec1len) * math.sqrt(vec2len) # rescale by vector lengths + return np.clip(result, -1.0, 1.0) + + def isbow(vec): """Checks if vector passed is in BoW format. diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 677b9ca002..c17505cf06 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -80,7 +80,7 @@ from gensim.corpora.dictionary import Dictionary from six import string_types, integer_types from six.moves import xrange, zip -from scipy import stats +from scipy import sparse, stats from gensim.utils import deprecated from gensim.models.utils_any2vec import _save_word2vec_format, _load_word2vec_format, _compute_ngrams @@ -191,8 +191,8 @@ def rank(self, entity1, entity2): class WordEmbeddingsKeyedVectors(BaseKeyedVectors): - """Class containing common methods for operations over word vectors. - """ + """Class containing common methods for operations over word vectors.""" + def __init__(self, vector_size): super(WordEmbeddingsKeyedVectors, self).__init__(vector_size=vector_size) self.vectors_norm = None @@ -432,6 +432,113 @@ def similar_by_vector(self, vector, topn=10, restrict_vocab=None): """ return self.most_similar(positive=[vector], topn=topn, restrict_vocab=restrict_vocab) + def similarity_matrix(self, dictionary, tfidf=None, threshold=0.0, exponent=2.0, nonzero_limit=100, dtype=REAL): + """Constructs a term similarity matrix for computing Soft Cosine Measure. + + Constructs a a sparse term similarity matrix in the :class:`scipy.sparse.csc_matrix` format for computing + Soft Cosine Measure between documents. + + Parameters + ---------- + dictionary : :class:`~gensim.corpora.dictionary.Dictionary` + A dictionary that specifies a mapping between words and the indices of rows and columns + of the resulting term similarity matrix. + tfidf : :class:`gensim.models.tfidfmodel.TfidfModel`, optional + A model that specifies the relative importance of the terms in the dictionary. The rows + of the term similarity matrix will be build in an increasing order of importance of terms, + or in the order of term identifiers if None. + threshold : float, optional + Only pairs of words whose embeddings are more similar than `threshold` are considered + when building the sparse term similarity matrix. + exponent : float, optional + The exponent applied to the similarity between two word embeddings when building the term similarity matrix. + nonzero_limit : int, optional + The maximum number of non-zero elements outside the diagonal in a single row or column + of the term similarity matrix. Setting `nonzero_limit` to a constant ensures that the + time complexity of computing the Soft Cosine Measure will be linear in the document + length rather than quadratic. + dtype : numpy.dtype, optional + Data-type of the term similarity matrix. + + Returns + ------- + :class:`scipy.sparse.csc_matrix` + Term similarity matrix. + + See Also + -------- + :func:`gensim.matutils.softcossim` + The Soft Cosine Measure. + :class:`gensim.similarities.docsim.SoftCosineSimilarity` + A class for performing corpus-based similarity queries with Soft Cosine Measure. + + + Notes + ----- + The constructed matrix corresponds to the matrix Mrel defined in section 2.1 of + `Delphine Charlet and Geraldine Damnati, "SimBow at SemEval-2017 Task 3: Soft-Cosine Semantic Similarity + between Questions for Community Question Answering", 2017 + `__. + + """ + logger.info("constructing a term similarity matrix") + matrix_order = len(dictionary) + matrix_nonzero = [1] * matrix_order + matrix = sparse.identity(matrix_order, dtype=dtype, format="dok") + num_skipped = 0 + # Decide the order of rows. + if tfidf is None: + word_indices = range(matrix_order) + else: + assert max(tfidf.idfs) < matrix_order + word_indices = [ + index for index, _ in sorted(tfidf.idfs.items(), key=lambda x: x[1], reverse=True) + ] + + # Traverse rows. + for row_number, w1_index in enumerate(word_indices): + if row_number % 1000 == 0: + logger.info( + "PROGRESS: at %.02f%% rows (%d / %d, %d skipped, %.06f%% density)", + 100.0 * (row_number + 1) / matrix_order, row_number + 1, matrix_order, + num_skipped, 100.0 * matrix.getnnz() / matrix_order**2) + w1 = dictionary[w1_index] + if w1 not in self.vocab: + num_skipped += 1 + continue # A word from the dictionary is not present in the word2vec model. + # Traverse upper triangle columns. + if matrix_order <= nonzero_limit + 1: # Traverse all columns. + columns = ( + (w2_index, self.similarity(w1, dictionary[w2_index])) + for w2_index in range(w1_index + 1, matrix_order) + if w1_index != w2_index and dictionary[w2_index] in self.vocab) + else: # Traverse only columns corresponding to the embeddings closest to w1. + num_nonzero = matrix_nonzero[w1_index] - 1 + columns = ( + (dictionary.token2id[w2], similarity) + for _, (w2, similarity) + in zip( + range(nonzero_limit - num_nonzero), + self.most_similar(positive=[w1], topn=nonzero_limit - num_nonzero) + ) + if w2 in dictionary.token2id + ) + columns = sorted(columns, key=lambda x: x[0]) + + for w2_index, similarity in columns: + # Ensure that we don't exceed `nonzero_limit` by mirroring the upper triangle. + if similarity > threshold and matrix_nonzero[w2_index] <= nonzero_limit: + element = similarity**exponent + matrix[w1_index, w2_index] = element + matrix_nonzero[w1_index] += 1 + matrix[w2_index, w1_index] = element + matrix_nonzero[w2_index] += 1 + logger.info( + "constructed a term similarity matrix with %0.6f %% nonzero elements", + 100.0 * matrix.getnnz() / matrix_order**2 + ) + return matrix.tocsc() + def wmdistance(self, document1, document2): """ Compute the Word Mover's Distance between two documents. When using this diff --git a/gensim/similarities/__init__.py b/gensim/similarities/__init__.py index 48915d89c9..52cbad43e7 100644 --- a/gensim/similarities/__init__.py +++ b/gensim/similarities/__init__.py @@ -3,4 +3,4 @@ """ # bring classes directly into package namespace, to save some typing -from .docsim import Similarity, MatrixSimilarity, SparseMatrixSimilarity, WmdSimilarity # noqa:F401 +from .docsim import Similarity, MatrixSimilarity, SparseMatrixSimilarity, SoftCosineSimilarity, WmdSimilarity # noqa:F401 diff --git a/gensim/similarities/docsim.py b/gensim/similarities/docsim.py index 6016e32f49..b315304771 100755 --- a/gensim/similarities/docsim.py +++ b/gensim/similarities/docsim.py @@ -563,6 +563,108 @@ def __str__(self): return "%s<%i docs, %i features>" % (self.__class__.__name__, len(self), self.index.shape[1]) +class SoftCosineSimilarity(interfaces.SimilarityABC): + """Document similarity (like MatrixSimilarity) that uses Soft Cosine Measure as a similarity measure.""" + + def __init__(self, corpus, similarity_matrix, num_best=None, chunksize=256): + """ + + Parameters + ---------- + corpus: iterable of list of (int, float) + A list of documents in the BoW format. + similarity_matrix : :class:`scipy.sparse.csc_matrix` + A term similarity matrix, typically produced by + :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`. + num_best : int, optional + The number of results to retrieve for a query, if None - return similarities with all elements from corpus. + chunksize: int, optional + Size of one corpus chunk. + + + See Also + -------- + :meth:`gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix` + A term similarity matrix produced from term embeddings. + :func:`gensim.matutils.softcossim` + The Soft Cosine Measure. + + Examples + -------- + >>> from gensim.corpora import Dictionary + >>> import gensim.downloader as api + >>> from gensim.models import Word2Vec + >>> from gensim.similarities import SoftCosineSimilarity + >>> from gensim.utils import simple_preprocess + >>> + >>> # Prepare the model + >>> corpus = api.load("text8") + >>> model = Word2Vec(corpus, workers=3, size=100) + >>> dictionary = Dictionary(corpus) + >>> bow_corpus = [dictionary.doc2bow(document) for document in corpus] + >>> similarity_matrix = model.wv.similarity_matrix(dictionary) + >>> index = SoftCosineSimilarity(bow_corpus, similarity_matrix, num_best=10) + >>> + >>> # Make a query. + >>> query = 'Yummy! Great view of the Bellagio Fountain show.' + >>> # calculate similarity between query and each doc from bow_corpus + >>> sims = index[dictionary.doc2bow(simple_preprocess(query))] + + See `Tutorial Notebook + `_ + for more examples. + + """ + self.corpus = corpus + self.similarity_matrix = similarity_matrix + self.num_best = num_best + self.chunksize = chunksize + + # Normalization of features is undesirable, since soft cosine similarity requires special + # normalization using the similarity matrix. Therefore, we would just be normalizing twice, + # increasing the numerical error. + self.normalize = False + + # index is simply an array from 0 to size of corpus. + self.index = numpy.arange(len(corpus)) + + def __len__(self): + return len(self.corpus) + + def get_similarities(self, query): + """ + **Do not use this function directly; use the self[query] syntax instead.** + """ + if isinstance(query, numpy.ndarray): + # Convert document indexes to actual documents. + query = [self.corpus[i] for i in query] + + if not query or not isinstance(query[0], list): + query = [query] + + n_queries = len(query) + result = [] + for qidx in range(n_queries): + # Compute similarity for each query. + qresult = [matutils.softcossim(document, query[qidx], self.similarity_matrix) + for document in self.corpus] + qresult = numpy.array(qresult) + + # Append single query result to list of all results. + result.append(qresult) + + if len(result) == 1: + # Only one query. + result = result[0] + else: + result = numpy.array(result) + + return result + + def __str__(self): + return "%s<%i docs, %i features>" % (self.__class__.__name__, len(self), self.similarity_matrix.shape[0]) + + class WmdSimilarity(interfaces.SimilarityABC): """ Document similarity (like MatrixSimilarity) that uses the negative of WMD @@ -605,7 +707,7 @@ def __init__(self, corpus, w2v_model, num_best=None, normalize_w2v_and_replace=T self.normalize = False # index is simply an array from 0 to size of corpus. - self.index = numpy.array(range(len(corpus))) + self.index = numpy.arange(len(corpus)) if normalize_w2v_and_replace: # Normalize vectors in word2vec class to length 1. @@ -622,7 +724,7 @@ def get_similarities(self, query): # Convert document indexes to actual documents. query = [self.corpus[i] for i in query] - if not isinstance(query[0], list): + if not query or not isinstance(query[0], list): query = [query] n_queries = len(query) diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 178dfcc91e..f013efa439 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -14,6 +14,7 @@ import numpy as np +from gensim.corpora import Dictionary from gensim.models import KeyedVectors as EuclideanKeyedVectors from gensim.test.utils import datapath @@ -26,6 +27,33 @@ def setUp(self): self.vectors = EuclideanKeyedVectors.load_word2vec_format( datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64) + def similarity_matrix(self): + """Test similarity_matrix returns expected results.""" + + corpus = [["government", "denied", "holiday"], ["holiday", "slowing", "hollingworth"]] + dictionary = Dictionary(corpus) + corpus = [dictionary.doc2bow(document) for document in corpus] + + # checking symmetry and the existence of ones on the diagonal + similarity_matrix = self.similarity_matrix(corpus, dictionary).todense() + self.assertTrue((similarity_matrix.T == similarity_matrix).all()) + self.assertTrue((np.diag(similarity_matrix) == similarity_matrix).all()) + + # checking that thresholding works as expected + similarity_matrix = self.similarity_matrix(corpus, dictionary, threshold=0.45).todense() + self.assertEquals(18, np.sum(similarity_matrix == 0)) + + # checking that exponent works as expected + similarity_matrix = self.similarity_matrix(corpus, dictionary, exponent=1.0).todense() + self.assertAlmostEqual(9.5788956, np.sum(similarity_matrix)) + + # checking that nonzero_limit works as expected + similarity_matrix = self.similarity_matrix(corpus, dictionary, nonzero_limit=4).todense() + self.assertEquals(4, np.sum(similarity_matrix == 0)) + + similarity_matrix = self.similarity_matrix(corpus, dictionary, nonzero_limit=3).todense() + self.assertEquals(20, np.sum(similarity_matrix == 0)) + def test_most_similar(self): """Test most_similar returns expected results.""" expected = [ diff --git a/gensim/test/test_similarities.py b/gensim/test/test_similarities.py index 58cbf2f734..d2d438db74 100644 --- a/gensim/test/test_similarities.py +++ b/gensim/test/test_similarities.py @@ -16,6 +16,7 @@ import numpy import scipy +from gensim.corpora import Dictionary from gensim.models import word2vec from gensim.models import doc2vec from gensim.models import KeyedVectors @@ -39,6 +40,10 @@ class _TestSimilarityABC(object): Base class for SparseMatrixSimilarity and MatrixSimilarity unit tests. """ + def factoryMethod(self): + """Creates a SimilarityABC instance.""" + return self.cls(corpus, num_features=len(dictionary)) + def testFull(self, num_best=None, shardsize=100): if self.cls == similarities.Similarity: index = self.cls(None, corpus, num_features=len(dictionary), shardsize=shardsize) @@ -103,6 +108,15 @@ def test_scipy2scipy_clipped(self): self.assertTrue(scipy.sparse.issparse(matrix_scipy_clipped)) self.assertTrue([matutils.scipy2sparse(x) for x in matrix_scipy_clipped], [expected] * 3) + def testEmptyQuery(self): + index = self.factoryMethod() + query = [] + try: + sims = index[query] + self.assertTrue(sims is not None) + except IndexError: + self.assertTrue(False) + def testChunking(self): if self.cls == similarities.Similarity: index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) @@ -155,12 +169,7 @@ def testPersistency(self): return fname = get_tmpfile('gensim_similarities.tst.pkl') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() index.save(fname) index2 = self.cls.load(fname) if self.cls == similarities.Similarity: @@ -180,12 +189,7 @@ def testPersistencyCompressed(self): return fname = get_tmpfile('gensim_similarities.tst.pkl.gz') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() index.save(fname) index2 = self.cls.load(fname) if self.cls == similarities.Similarity: @@ -205,12 +209,7 @@ def testLarge(self): return fname = get_tmpfile('gensim_similarities.tst.pkl') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -232,12 +231,7 @@ def testLargeCompressed(self): return fname = get_tmpfile('gensim_similarities.tst.pkl.gz') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -259,12 +253,7 @@ def testMmap(self): return fname = get_tmpfile('gensim_similarities.tst.pkl') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -287,12 +276,7 @@ def testMmapCompressed(self): return fname = get_tmpfile('gensim_similarities.tst.pkl.gz') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -310,6 +294,10 @@ def setUp(self): self.cls = similarities.WmdSimilarity self.w2v_model = Word2Vec(texts, min_count=1) + def factoryMethod(self): + # Override factoryMethod. + return self.cls(texts, self.w2v_model) + def testFull(self, num_best=None): # Override testFull. @@ -382,6 +370,80 @@ def testIter(self): self.assertTrue(numpy.alltrue(sims <= 1.0)) +class TestSoftCosineSimilarity(unittest.TestCase, _TestSimilarityABC): + def setUp(self): + self.cls = similarities.SoftCosineSimilarity + self.dictionary = Dictionary(texts) + self.corpus = [dictionary.doc2bow(document) for document in texts] + similarity_matrix = scipy.sparse.identity(12, format="lil") + similarity_matrix[dictionary.token2id["user"], dictionary.token2id["human"]] = 0.5 + similarity_matrix[dictionary.token2id["human"], dictionary.token2id["user"]] = 0.5 + self.similarity_matrix = similarity_matrix.tocsc() + + def factoryMethod(self): + # Override factoryMethod. + return self.cls(self.corpus, self.similarity_matrix) + + def testFull(self, num_best=None): + # Override testFull. + + index = self.cls(self.corpus, self.similarity_matrix, num_best=num_best) + query = self.dictionary.doc2bow(texts[0]) + sims = index[query] + + if num_best is not None: + # Sparse array. + for i, sim in sims: + self.assertTrue(numpy.alltrue(sim <= 1.0)) + self.assertTrue(numpy.alltrue(sim >= 0.0)) + else: + self.assertTrue(sims[0] == 1.0) # Similarity of a document with itself is 1.0. + self.assertTrue(numpy.alltrue(sims[1:] >= 0.0)) + self.assertTrue(numpy.alltrue(sims[1:] < 1.0)) + expected = 2.1889350195476758 + self.assertAlmostEqual(expected, numpy.sum(sims)) + + def testNonIncreasing(self): + """ Check that similarities are non-increasing when `num_best` is not `None`.""" + # NOTE: this could be implemented for other similarities as well (i.e. in _TestSimilarityABC). + + index = self.cls(self.corpus, self.similarity_matrix, num_best=5) + query = self.dictionary.doc2bow(texts[0]) + sims = index[query] + sims2 = numpy.asarray(sims)[:, 1] # Just the similarities themselves. + + # The difference of adjacent elements should be negative. + cond = sum(numpy.diff(sims2) < 0) == len(sims2) - 1 + self.assertTrue(cond) + + def testChunking(self): + # Override testChunking. + + index = self.cls(self.corpus, self.similarity_matrix) + query = [self.dictionary.doc2bow(document) for document in texts[:3]] + sims = index[query] + + for i in range(3): + self.assertTrue(numpy.alltrue(sims[i, i] == 1.0)) # Similarity of a document with itself is 1.0. + + # test the same thing but with num_best + index.num_best = 5 + sims = index[query] + for i, chunk in enumerate(sims): + expected = i + self.assertEquals(expected, chunk[0][0]) + expected = 1.0 + self.assertEquals(expected, chunk[0][1]) + + def testIter(self): + # Override testIter. + + index = self.cls(self.corpus, self.similarity_matrix) + for sims in index: + self.assertTrue(numpy.alltrue(sims >= 0.0)) + self.assertTrue(numpy.alltrue(sims <= 1.0)) + + class TestSparseMatrixSimilarity(unittest.TestCase, _TestSimilarityABC): def setUp(self): self.cls = similarities.SparseMatrixSimilarity @@ -419,6 +481,10 @@ class TestSimilarity(unittest.TestCase, _TestSimilarityABC): def setUp(self): self.cls = similarities.Similarity + def factoryMethod(self): + # Override factoryMethod. + return self.cls(None, corpus, num_features=len(dictionary), shardsize=5) + def testSharding(self): for num_best in [None, 0, 1, 9, 1000]: for shardsize in [1, 2, 9, 1000]: diff --git a/gensim/test/test_similarity_metrics.py b/gensim/test/test_similarity_metrics.py index a9420000fe..3e6c59b509 100644 --- a/gensim/test/test_similarity_metrics.py +++ b/gensim/test/test_similarity_metrics.py @@ -13,7 +13,7 @@ import unittest from gensim import matutils -from scipy.sparse import csr_matrix +from scipy.sparse import csr_matrix, csc_matrix import numpy as np import math from gensim.corpora.mmcorpus import MmCorpus @@ -81,7 +81,6 @@ def setUp(self): self.model = self.class_(common_corpus, id2word=common_dictionary, num_topics=2, passes=100) def test_inputs(self): - # checking empty inputs vec_1 = [] vec_2 = [] @@ -104,7 +103,6 @@ def test_inputs(self): self.assertEqual(expected, result) def test_distributions(self): - # checking different length bag of words as inputs vec_1 = [(2, 0.1), (3, 0.4), (4, 0.1), (5, 0.1), (1, 0.1), (7, 0.2)] vec_2 = [(1, 0.1), (3, 0.8), (4, 0.1)] @@ -175,7 +173,6 @@ def test_inputs(self): self.assertEqual(expected, result) def test_distributions(self): - # checking bag of words as inputs vec_1 = [(2, 0.1), (3, 0.4), (4, 0.1), (5, 0.1), (1, 0.1), (7, 0.2)] vec_2 = [(1, 0.1), (3, 0.8), (4, 0.1)] @@ -215,14 +212,12 @@ def test_distributions(self): class TestJaccard(unittest.TestCase): def test_inputs(self): - # all empty inputs will give a divide by zero exception vec_1 = [] vec_2 = [] self.assertRaises(ZeroDivisionError, matutils.jaccard, vec_1, vec_2) def test_distributions(self): - # checking bag of words as inputs vec_1 = [(2, 1), (3, 4), (4, 1), (5, 1), (1, 1), (7, 2)] vec_2 = [(1, 1), (3, 8), (4, 1)] @@ -245,6 +240,36 @@ def test_distributions(self): self.assertAlmostEqual(expected, result) +class TestSoftCosineSimilarity(unittest.TestCase): + def test_inputs(self): + # checking empty inputs + vec_1 = [] + vec_2 = [] + similarity_matrix = csc_matrix((0, 0)) + result = matutils.softcossim(vec_1, vec_2, similarity_matrix) + expected = 0.0 + self.assertEqual(expected, result) + + # checking CSR term similarity matrix format + similarity_matrix = csr_matrix((0, 0)) + result = matutils.softcossim(vec_1, vec_2, similarity_matrix) + expected = 0.0 + self.assertEqual(expected, result) + + # checking unknown term similarity matrix format + with self.assertRaises(ValueError): + matutils.softcossim(vec_1, vec_2, np.matrix([])) + + def test_distributions(self): + # checking bag of words as inputs + vec_1 = [(0, 1.0), (2, 1.0)] # hello world + vec_2 = [(1, 1.0), (2, 1.0)] # hi world + similarity_matrix = csc_matrix([[1, 0.5, 0], [0.5, 1, 0], [0, 0, 1]]) + result = matutils.softcossim(vec_1, vec_2, similarity_matrix) + expected = 0.75 + self.assertAlmostEqual(expected, result) + + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) unittest.main()