Skip to content

Commit

Permalink
Allow multiple vectors with the same id in ANN (#1126)
Browse files Browse the repository at this point in the history
  • Loading branch information
tteofili committed Apr 29, 2020
1 parent 105ad9c commit dd84a5a
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 76 deletions.
82 changes: 44 additions & 38 deletions src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -131,7 +132,7 @@ public static void main(String[] args) throws Exception {

System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> wordVectors = IndexVectors.readGloVe(indexArgs.input);
Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);

Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Expand Down Expand Up @@ -159,39 +160,41 @@ public static void main(String[] args) throws Exception {
if (wordVectors.containsKey(word)) {
Set<String> truth = nearestVector(wordVectors, word, indexArgs.topN);
try {
float[] vector = wordVectors.get(word);
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
List<float[]> vectors = wordVectors.get(word);
for (float[] vector : vectors) {
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
sb.append(fv);
}
String fvString = sb.toString();
String fvString = sb.toString();

CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
if (indexArgs.msm > 0) {
simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
}
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
if (indexArgs.msm > 0) {
simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
}
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}

long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
time += System.currentTimeMillis() - start;
long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
time += System.currentTimeMillis() - start;

Set<String> observations = new HashSet<>();
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectors.FIELD_ID);
observations.add(wordValue);
Set<String> observations = new HashSet<>();
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectors.FIELD_ID);
observations.add(wordValue);
}
double intersection = Sets.intersection(truth, observations).size();
double localRecall = intersection / (double) truth.size();
recall += localRecall;
queryCount++;
}
double intersection = Sets.intersection(truth, observations).size();
double localRecall = intersection / (double) truth.size();
recall += localRecall;
queryCount++;
} catch (IOException e) {
System.err.println("search for '" + word + "' failed " + e.getLocalizedMessage());
}
Expand All @@ -218,18 +221,21 @@ public static void main(String[] args) throws Exception {
* @param topN the number of similar word vectors to output
* @return the {@code topN} similar words of the input word
*/
private static Set<String> nearestVector(Map<String, float[]> vectors, String word, int topN) {
private static Set<String> nearestVector(Map<String, List<float[]>> vectors, String word, int topN) {
Set<String> intermediate = new TreeSet<>();
float[] input = vectors.get(word);
List<float[]> inputs = vectors.get(word);
String separateToken = "__";
for (Map.Entry<String, float[]> entry : vectors.entrySet()) {
float sim = 0;
float[] value = entry.getValue();
for (int i = 0; i < value.length; i++) {
sim += value[i] * input[i];
for (Map.Entry<String, List<float[]>> entry : vectors.entrySet()) {
for (float[] value : entry.getValue()) {
for (float[] input : inputs) {
float sim = 0;
for (int i = 0; i < value.length; i++) {
sim += value[i] * input[i];
}
// store the words, sorted by decreasing distance using natural order (in the $dist__$word format)
intermediate.add((1 - sim) + separateToken + entry.getKey());
}
}
// store the words, sorted by decreasing distance using natural order (in the $dist__$word format)
intermediate.add((1 - sim) + separateToken + entry.getKey());
}
Set<String> result = new HashSet<>();
int i = 0;
Expand Down
27 changes: 15 additions & 12 deletions src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.nio.file.Path;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import static org.apache.lucene.search.BooleanClause.Occur.SHOULD;
Expand Down Expand Up @@ -140,32 +141,34 @@ public static void main(String[] args) throws Exception {
searcher.setSimilarity(new ClassicSimilarity());
}

Collection<String> vectors = new LinkedList<>();
Collection<String> vectorStrings = new LinkedList<>();
if (indexArgs.stored) {
TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
vectors.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR));
vectorStrings.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR));
}
} else {
System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> wordVectors = IndexVectors.readGloVe(indexArgs.input);
Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);

if (wordVectors.containsKey(indexArgs.word)) {
float[] vector = wordVectors.get(indexArgs.word);
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
List<float[]> vectors = wordVectors.get(indexArgs.word);
for (float[] vector : vectors) {
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
sb.append(fv);
String vectorString = sb.toString();
vectorStrings.add(vectorString);
}
String vectorString = sb.toString();
vectors.add(vectorString);
}
}

for (String vectorString : vectors) {
for (String vectorString : vectorStrings) {
float msm = indexArgs.msm;
float cutoff = indexArgs.cutoff;
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, cutoff);
Expand Down
54 changes: 31 additions & 23 deletions src/main/java/io/anserini/ann/IndexVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -116,7 +118,7 @@ public static void main(String[] args) throws Exception {
final long start = System.nanoTime();
System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> vectors = readGloVe(indexArgs.input);
Map<String, List<float[]>> vectors = readGloVe(indexArgs.input);

Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Expand All @@ -134,27 +136,27 @@ public static void main(String[] args) throws Exception {
IndexWriter indexWriter = new IndexWriter(d, conf);
final AtomicInteger cnt = new AtomicInteger();

for (Map.Entry<String, float[]> entry : vectors.entrySet()) {
Document doc = new Document();

doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
float[] vector = entry.getValue();
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
for (Map.Entry<String, List<float[]>> entry : vectors.entrySet()) {
for (float[] vector: entry.getValue()) {
Document doc = new Document();
doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
sb.append(fv);
}
doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO));
try {
indexWriter.addDocument(doc);
int cur = cnt.incrementAndGet();
if (cur % 100000 == 0) {
System.out.println(String.format("%s docs added", cnt));
doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO));
try {
indexWriter.addDocument(doc);
int cur = cnt.incrementAndGet();
if (cur % 100000 == 0) {
System.out.println(String.format("%s docs added", cnt));
}
} catch (IOException e) {
System.err.println("Error while indexing: " + e.getLocalizedMessage());
}
} catch (IOException e) {
System.err.println("Error while indexing: " + e.getLocalizedMessage());
}
}

Expand All @@ -171,8 +173,8 @@ public static void main(String[] args) throws Exception {
DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")));
}

static Map<String, float[]> readGloVe(File input) throws IOException {
Map<String, float[]> vectors = new HashMap<>();
static Map<String, List<float[]>> readGloVe(File input) throws IOException {
Map<String, List<float[]>> vectors = new HashMap<>();
for (String line : IOUtils.readLines(new FileReader(input))) {
String[] s = line.split("\\s+");
if (s.length > 2) {
Expand All @@ -188,7 +190,13 @@ static Map<String, float[]> readGloVe(File input) throws IOException {
for (int i = 0; i < vector.length; i++) {
vector[i] = vector[i] / norm;
}
vectors.put(key, vector);
if (vectors.containsKey(key)) {
List<float[]> floats = new LinkedList<>(vectors.get(key));
floats.add(vector);
vectors.put(key, floats);
} else {
vectors.put(key, List.of(vector));
}
}
}
return vectors;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void testSearchingLL() throws Exception {
SimpleNearestNeighborSearcher simpleNearestNeighborSearcher = new SimpleNearestNeighborSearcher(idxPath, "lexlsh");
SimpleNearestNeighborSearcher.Result[] results = simpleNearestNeighborSearcher.search("text", 2);
assertNotNull(results);
assertEquals(1, results.length);
assertEquals(2, results.length);
}

@Test
Expand All @@ -62,6 +62,6 @@ public void testMultiSearchingLL() throws Exception {
SimpleNearestNeighborSearcher.Result[][] results = simpleNearestNeighborSearcher.multisearch("text", 2, 2);
assertNotNull(results);
assertEquals(1, results.length);
assertEquals(1, results[0].length);
assertEquals(2, results[0].length);
}
}
3 changes: 2 additions & 1 deletion src/test/resources/mini-word-vectors.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
simple 0.3 0.2 0.2 0.9
foo 0.1 0.2 0.4 0.4
text 0.2 0.2 0.1 0.9
text 0.2 0.2 0.1 0.9
simple 0.3 0.2 0.1 0.9

0 comments on commit dd84a5a

Please sign in to comment.