From dd84a5a514700365d9aa4a1ea988107372515f33 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 29 Apr 2020 13:29:48 +0200 Subject: [PATCH] Allow multiple vectors with the same id in ANN (#1126) --- .../ann/ApproximateNearestNeighborEval.java | 82 ++++++++++--------- .../ann/ApproximateNearestNeighborSearch.java | 27 +++--- .../java/io/anserini/ann/IndexVectors.java | 54 ++++++------ .../SimpleNearestNeighborSearcherTest.java | 4 +- src/test/resources/mini-word-vectors.txt | 3 +- 5 files changed, 94 insertions(+), 76 deletions(-) diff --git a/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java b/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java index ccc72ff00a..96b500ef82 100644 --- a/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java +++ b/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java @@ -46,6 +46,7 @@ import java.util.Collection; import java.util.HashSet; import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; @@ -131,7 +132,7 @@ public static void main(String[] args) throws Exception { System.out.println(String.format("Loading model %s", indexArgs.input)); - Map wordVectors = IndexVectors.readGloVe(indexArgs.input); + Map> wordVectors = IndexVectors.readGloVe(indexArgs.input); Path indexDir = indexArgs.path; if (!Files.exists(indexDir)) { @@ -159,39 +160,41 @@ public static void main(String[] args) throws Exception { if (wordVectors.containsKey(word)) { Set truth = nearestVector(wordVectors, word, indexArgs.topN); try { - float[] vector = wordVectors.get(word); - StringBuilder sb = new StringBuilder(); - for (double fv : vector) { - if (sb.length() > 0) { - sb.append(' '); + List vectors = wordVectors.get(word); + for (float[] vector : vectors) { + StringBuilder sb = new StringBuilder(); + for (double fv : vector) { + if (sb.length() > 0) { + sb.append(' '); + } + sb.append(fv); } - sb.append(fv); - } - String fvString = sb.toString(); + String fvString = sb.toString(); - CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff); - if (indexArgs.msm > 0) { - simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm); - } - for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) { - simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token)); - } + CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff); + if (indexArgs.msm > 0) { + simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm); + } + for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) { + simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token)); + } - long start = System.currentTimeMillis(); - TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE); - searcher.search(simQuery, results); - time += System.currentTimeMillis() - start; + long start = System.currentTimeMillis(); + TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE); + searcher.search(simQuery, results); + time += System.currentTimeMillis() - start; - Set observations = new HashSet<>(); - for (ScoreDoc sd : results.topDocs().scoreDocs) { - Document document = reader.document(sd.doc); - String wordValue = document.get(IndexVectors.FIELD_ID); - observations.add(wordValue); + Set observations = new HashSet<>(); + for (ScoreDoc sd : results.topDocs().scoreDocs) { + Document document = reader.document(sd.doc); + String wordValue = document.get(IndexVectors.FIELD_ID); + observations.add(wordValue); + } + double intersection = Sets.intersection(truth, observations).size(); + double localRecall = intersection / (double) truth.size(); + recall += localRecall; + queryCount++; } - double intersection = Sets.intersection(truth, observations).size(); - double localRecall = intersection / (double) truth.size(); - recall += localRecall; - queryCount++; } catch (IOException e) { System.err.println("search for '" + word + "' failed " + e.getLocalizedMessage()); } @@ -218,18 +221,21 @@ public static void main(String[] args) throws Exception { * @param topN the number of similar word vectors to output * @return the {@code topN} similar words of the input word */ - private static Set nearestVector(Map vectors, String word, int topN) { + private static Set nearestVector(Map> vectors, String word, int topN) { Set intermediate = new TreeSet<>(); - float[] input = vectors.get(word); + List inputs = vectors.get(word); String separateToken = "__"; - for (Map.Entry entry : vectors.entrySet()) { - float sim = 0; - float[] value = entry.getValue(); - for (int i = 0; i < value.length; i++) { - sim += value[i] * input[i]; + for (Map.Entry> entry : vectors.entrySet()) { + for (float[] value : entry.getValue()) { + for (float[] input : inputs) { + float sim = 0; + for (int i = 0; i < value.length; i++) { + sim += value[i] * input[i]; + } + // store the words, sorted by decreasing distance using natural order (in the $dist__$word format) + intermediate.add((1 - sim) + separateToken + entry.getKey()); + } } - // store the words, sorted by decreasing distance using natural order (in the $dist__$word format) - intermediate.add((1 - sim) + separateToken + entry.getKey()); } Set result = new HashSet<>(); int i = 0; diff --git a/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java b/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java index e41377967e..1aa33cddde 100644 --- a/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java +++ b/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java @@ -43,6 +43,7 @@ import java.nio.file.Path; import java.util.Collection; import java.util.LinkedList; +import java.util.List; import java.util.Map; import static org.apache.lucene.search.BooleanClause.Occur.SHOULD; @@ -140,32 +141,34 @@ public static void main(String[] args) throws Exception { searcher.setSimilarity(new ClassicSimilarity()); } - Collection vectors = new LinkedList<>(); + Collection vectorStrings = new LinkedList<>(); if (indexArgs.stored) { TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { - vectors.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR)); + vectorStrings.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR)); } } else { System.out.println(String.format("Loading model %s", indexArgs.input)); - Map wordVectors = IndexVectors.readGloVe(indexArgs.input); + Map> wordVectors = IndexVectors.readGloVe(indexArgs.input); if (wordVectors.containsKey(indexArgs.word)) { - float[] vector = wordVectors.get(indexArgs.word); - StringBuilder sb = new StringBuilder(); - for (double fv : vector) { - if (sb.length() > 0) { - sb.append(' '); + List vectors = wordVectors.get(indexArgs.word); + for (float[] vector : vectors) { + StringBuilder sb = new StringBuilder(); + for (double fv : vector) { + if (sb.length() > 0) { + sb.append(' '); + } + sb.append(fv); } - sb.append(fv); + String vectorString = sb.toString(); + vectorStrings.add(vectorString); } - String vectorString = sb.toString(); - vectors.add(vectorString); } } - for (String vectorString : vectors) { + for (String vectorString : vectorStrings) { float msm = indexArgs.msm; float cutoff = indexArgs.cutoff; CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, cutoff); diff --git a/src/main/java/io/anserini/ann/IndexVectors.java b/src/main/java/io/anserini/ann/IndexVectors.java index 13834805a7..e46b6a4177 100644 --- a/src/main/java/io/anserini/ann/IndexVectors.java +++ b/src/main/java/io/anserini/ann/IndexVectors.java @@ -44,6 +44,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -116,7 +118,7 @@ public static void main(String[] args) throws Exception { final long start = System.nanoTime(); System.out.println(String.format("Loading model %s", indexArgs.input)); - Map vectors = readGloVe(indexArgs.input); + Map> vectors = readGloVe(indexArgs.input); Path indexDir = indexArgs.path; if (!Files.exists(indexDir)) { @@ -134,27 +136,27 @@ public static void main(String[] args) throws Exception { IndexWriter indexWriter = new IndexWriter(d, conf); final AtomicInteger cnt = new AtomicInteger(); - for (Map.Entry entry : vectors.entrySet()) { - Document doc = new Document(); - - doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES)); - float[] vector = entry.getValue(); - StringBuilder sb = new StringBuilder(); - for (double fv : vector) { - if (sb.length() > 0) { - sb.append(' '); + for (Map.Entry> entry : vectors.entrySet()) { + for (float[] vector: entry.getValue()) { + Document doc = new Document(); + doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES)); + StringBuilder sb = new StringBuilder(); + for (double fv : vector) { + if (sb.length() > 0) { + sb.append(' '); + } + sb.append(fv); } - sb.append(fv); - } - doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO)); - try { - indexWriter.addDocument(doc); - int cur = cnt.incrementAndGet(); - if (cur % 100000 == 0) { - System.out.println(String.format("%s docs added", cnt)); + doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO)); + try { + indexWriter.addDocument(doc); + int cur = cnt.incrementAndGet(); + if (cur % 100000 == 0) { + System.out.println(String.format("%s docs added", cnt)); + } + } catch (IOException e) { + System.err.println("Error while indexing: " + e.getLocalizedMessage()); } - } catch (IOException e) { - System.err.println("Error while indexing: " + e.getLocalizedMessage()); } } @@ -171,8 +173,8 @@ public static void main(String[] args) throws Exception { DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss"))); } - static Map readGloVe(File input) throws IOException { - Map vectors = new HashMap<>(); + static Map> readGloVe(File input) throws IOException { + Map> vectors = new HashMap<>(); for (String line : IOUtils.readLines(new FileReader(input))) { String[] s = line.split("\\s+"); if (s.length > 2) { @@ -188,7 +190,13 @@ static Map readGloVe(File input) throws IOException { for (int i = 0; i < vector.length; i++) { vector[i] = vector[i] / norm; } - vectors.put(key, vector); + if (vectors.containsKey(key)) { + List floats = new LinkedList<>(vectors.get(key)); + floats.add(vector); + vectors.put(key, floats); + } else { + vectors.put(key, List.of(vector)); + } } } return vectors; diff --git a/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java b/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java index 6494b18c66..d17f084cac 100644 --- a/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java +++ b/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java @@ -40,7 +40,7 @@ public void testSearchingLL() throws Exception { SimpleNearestNeighborSearcher simpleNearestNeighborSearcher = new SimpleNearestNeighborSearcher(idxPath, "lexlsh"); SimpleNearestNeighborSearcher.Result[] results = simpleNearestNeighborSearcher.search("text", 2); assertNotNull(results); - assertEquals(1, results.length); + assertEquals(2, results.length); } @Test @@ -62,6 +62,6 @@ public void testMultiSearchingLL() throws Exception { SimpleNearestNeighborSearcher.Result[][] results = simpleNearestNeighborSearcher.multisearch("text", 2, 2); assertNotNull(results); assertEquals(1, results.length); - assertEquals(1, results[0].length); + assertEquals(2, results[0].length); } } \ No newline at end of file diff --git a/src/test/resources/mini-word-vectors.txt b/src/test/resources/mini-word-vectors.txt index 0b13c3f14c..cec4c9739a 100644 --- a/src/test/resources/mini-word-vectors.txt +++ b/src/test/resources/mini-word-vectors.txt @@ -1,3 +1,4 @@ simple 0.3 0.2 0.2 0.9 foo 0.1 0.2 0.4 0.4 -text 0.2 0.2 0.1 0.9 \ No newline at end of file +text 0.2 0.2 0.1 0.9 +simple 0.3 0.2 0.1 0.9 \ No newline at end of file