Skip to content

Commit

Permalink
Add ability to parse raw text into docvectors on-the-fly for impact i…
Browse files Browse the repository at this point in the history
…ndexes #2122 (#2148)
  • Loading branch information
AileenLin authored Aug 8, 2023
1 parent 6aabfe7 commit 9cdcf0e
Show file tree
Hide file tree
Showing 8 changed files with 572 additions and 51 deletions.
415 changes: 399 additions & 16 deletions src/main/java/io/anserini/search/SimpleImpactSearcher.java

Large diffs are not rendered by default.

23 changes: 19 additions & 4 deletions src/main/java/io/anserini/search/query/QueryEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected static long[] convertTokensToIds(BertFullTokenizer tokenizer, List<Str
return tokenIds;
}

protected String generateEncodedQuery(Map<String, Float> tokenWeightMap) {
public String generateEncodedQuery(Map<String, Float> tokenWeightMap) {
/*
* This function generates the encoded query.
*/
Expand All @@ -72,7 +72,23 @@ protected String generateEncodedQuery(Map<String, Float> tokenWeightMap) {
return String.join(" ", encodedQuery);
}

static Map<String, Float> getTokenWeightMap(long[] indexes, float[] computedWeights, DefaultVocabulary vocab) {
public Map<String, Integer> getEncodedQueryMap(Map<String, Float> tokenWeightMap) throws OrtException {
Map<String, Integer> encodedQuery = new HashMap<>();
for (Map.Entry<String, Float> entry : tokenWeightMap.entrySet()) {
String token = entry.getKey();
Float tokenWeight = entry.getValue();
int weightQuanted = Math.round(tokenWeight / weightRange * quantRange);
encodedQuery.put(token, weightQuanted);
}
return encodedQuery;
}

public Map<String, Integer> getEncodedQueryMap(String query) throws OrtException {
Map<String, Float> tokenWeightMap = getTokenWeightMap(query);
return getEncodedQueryMap(tokenWeightMap);
}

static protected Map<String, Float> getTokenWeightMap(long[] indexes, float[] computedWeights, DefaultVocabulary vocab) {
/*
* This function returns a map of token to its weight.
*/
Expand All @@ -87,6 +103,5 @@ static Map<String, Float> getTokenWeightMap(long[] indexes, float[] computedWeig
return tokenWeightMap;
}

public abstract Map<String, Float> getTokenWeightMap(String query) throws OrtException;

protected abstract Map<String, Float> getTokenWeightMap(String query) throws OrtException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public String encode(String query) throws OrtException {
}

@Override
public Map<String, Float> getTokenWeightMap(String query) throws OrtException {
protected Map<String, Float> getTokenWeightMap(String query) throws OrtException {
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public String encode(String query) throws OrtException {
}

@Override
public Map<String, Float> getTokenWeightMap(String query) throws OrtException {
protected Map<String, Float> getTokenWeightMap(String query) throws OrtException {
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private Map<String, Float> getTokenWeightMap(List<String> tokens, float[] comput
}

@Override
public Map<String, Float> getTokenWeightMap(String query) throws OrtException {
protected Map<String, Float> getTokenWeightMap(String query) throws OrtException {
List<String> queryTokens = new ArrayList<>();
queryTokens.add("[CLS]");
queryTokens.addAll(tokenizer.tokenize(query));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ public void testSearch1() throws Exception {

SimpleImpactSearcher.Result[] hits;

Map<String, Float> query = new HashMap<>();
query.put("##ing", 1.0f);
Map<String, Integer> query = new HashMap<>();
query.put("##ing", 1);

hits = searcher.search(query, 10);
assertEquals(1, hits.length);
assertEquals("2000001", hits[0].docid);
assertEquals(2, (int) hits[0].score);

query = new HashMap<>();
query.put("test", 1.0f);
query.put("test", 1);
hits = searcher.search(query, 10);
assertEquals(1, hits.length);
assertEquals("2000000", hits[0].docid);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ public void testSearch1() throws Exception {

SimpleImpactSearcher.Result[] hits;

Map<String, Float> query = new HashMap<>();
query.put("##ing", 1.0f);
Map<String, Integer> query = new HashMap<>();
query.put("##ing", 1);

hits = searcher.search(query, 10);
assertEquals(1, hits.length);
assertEquals("2000001", hits[0].docid);
assertEquals(2, (int) hits[0].score);

query = new HashMap<>();
query.put("test", 1.0f);
query.put("test", 1);
hits = searcher.search(query, 10);
assertEquals(1, hits.length);
assertEquals("2000000", hits[0].docid);
Expand Down
167 changes: 145 additions & 22 deletions src/test/java/io/anserini/search/SimpleImpactSearcherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import io.anserini.search.SimpleImpactSearcher.Result;


public class SimpleImpactSearcherTest extends IndexerTestBase {

private static Map<String, Float> EXPECTED_ENCODED_QUERY = new HashMap<>();
private static Map<String, Integer> EXPECTED_ENCODED_QUERY = new HashMap<>();

static {
EXPECTED_ENCODED_QUERY.put("here", 3.05345f);
EXPECTED_ENCODED_QUERY.put("a", 0.59636426f);
EXPECTED_ENCODED_QUERY.put("test", 2.9012794f);
EXPECTED_ENCODED_QUERY.put("here", 156);
EXPECTED_ENCODED_QUERY.put("a", 31);
EXPECTED_ENCODED_QUERY.put("test", 149);
}

@Test
Expand Down Expand Up @@ -109,14 +111,22 @@ public void testGetRaw() throws Exception {
public void testSearch1() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());

Map<String, Float> testQuery = new HashMap<>();
testQuery.put("test", 1.2f);
Map<String, Integer> testQuery = new HashMap<>();
testQuery.put("test", 1);

SimpleImpactSearcher.Result[] hits = searcher.search(testQuery, 10);
SimpleImpactSearcher.Result[] hits_string = searcher.search("test", 10);
assertEquals(hits_string.length, hits.length);
assertEquals(hits_string[0].docid, hits[0].docid);
assertEquals(hits_string[0].lucene_docid, hits[0].lucene_docid);
assertEquals(hits_string[0].score, hits[0].score, 10e-6);
assertEquals(hits_string[0].contents, hits[0].contents);
assertEquals(hits_string[0].raw, hits[0].raw);

assertEquals(1, hits.length);
assertEquals("doc3", hits[0].docid);
assertEquals(2, hits[0].lucene_docid);
assertEquals(1.2f, hits[0].score, 10e-6);
assertEquals(1.0f, hits[0].score, 10e-6);
assertEquals("here is a test", hits[0].contents);
assertEquals("{\"contents\": \"here is a test\"}", hits[0].raw);

Expand All @@ -135,16 +145,16 @@ public void testSearch1() throws Exception {
public void testSearch2() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());

Map<String, Float> testQuery = new HashMap<>();
testQuery.put("text", 1.2f);
Map<String, Integer> testQuery = new HashMap<>();
testQuery.put("text", 1);

SimpleImpactSearcher.Result[] results;

results = searcher.search(testQuery, 1);
assertEquals(1, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(2.4f, results[0].score, 10e-6);
assertEquals(2.0f, results[0].score, 10e-6);
assertEquals("here is some text here is some more text. city.", results[0].contents);
assertEquals("{\"contents\": \"here is some text here is some more text. city.\"}", results[0].raw);

Expand All @@ -154,31 +164,31 @@ public void testSearch2() throws Exception {
assertEquals(0, results[0].lucene_docid);
assertEquals("doc2", results[1].docid);
assertEquals(1, results[1].lucene_docid);
assertEquals(2.4f, results[0].score, 10e-6);
assertEquals(1.2f, results[1].score, 10e-6);
assertEquals(2.0f, results[0].score, 10e-6);
assertEquals(1.0f, results[1].score, 10e-6);

Map<String, Float> testQuery2 = new HashMap<>();
testQuery2.put("test", 0.125f);
Map<String, Integer> testQuery2 = new HashMap<>();
testQuery2.put("test", 1);

results = searcher.search(testQuery2);
assertEquals(1, results.length);
assertEquals("doc3", results[0].docid);
assertEquals(2, results[0].lucene_docid);
assertEquals(0.125f, results[0].score, 10e-6);
assertEquals(1.0f, results[0].score, 10e-6);

searcher.close();
}

@Test
public void testBatchSearch() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
Map<String, Float> testQuery1 = new HashMap<>();
testQuery1.put("tests", 0.1f);
testQuery1.put("test", 0.1f);
Map<String, Float> testQuery2 = new HashMap<>();
testQuery2.put("more", 1.5f);
Map<String, Integer> testQuery1 = new HashMap<>();
testQuery1.put("tests", 1);
testQuery1.put("test", 1);
Map<String, Integer> testQuery2 = new HashMap<>();
testQuery2.put("more", 3);

List<Map<String, Float>> queries = new ArrayList<>();
List<Map<String, Integer>> queries = new ArrayList<>();
queries.add(testQuery1);
queries.add(testQuery2);

Expand All @@ -205,14 +215,127 @@ public void testTotalNumDocuments() throws Exception {
assertEquals(3 ,searcher.get_total_num_docs());
}

@Test
public void testOnnxEncodedQuery() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
Map<String, Integer> testQuery1 = new HashMap<>();
testQuery1.put("text", 2);
String encodedQuery = searcher.encodeWithOnnx(testQuery1);
assertEquals("text text" ,encodedQuery);
}

@Test
public void testOnnxEncoder() throws Exception{
SimpleImpactSearcher searcher = new SimpleImpactSearcher();
searcher.set_onnx_query_encoder("SpladePlusPlusEnsembleDistil");

Map<String, Float> encoded_query = searcher.encode_with_onnx("here is a test");
Map<String, Integer> encoded_query = searcher.encodeWithOnnx("here is a test");
assertEquals(encoded_query.get("here"), EXPECTED_ENCODED_QUERY.get("here"), 2e-4);
assertEquals(encoded_query.get("a"), EXPECTED_ENCODED_QUERY.get("a"), 2e-4);
assertEquals(encoded_query.get("test"), EXPECTED_ENCODED_QUERY.get("test"), 2e-4);
}

@Test
public void testSearch3() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
searcher.set_rm3();
assertTrue(searcher.use_rm3());

Result[] results;

Map<String, Integer> testQuery1 = new HashMap<>();
testQuery1.put("text", 1);

results = searcher.search(testQuery1, 1);
assertEquals(1, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(1.0f, results[0].score, 10e-5);

Map<String, Integer> testQuery2 = new HashMap<>();
testQuery2.put("test", 1);

results = searcher.search(testQuery2);
assertEquals(1, results.length);
assertEquals("doc3", results[0].docid);
assertEquals(2, results[0].lucene_docid);
assertEquals(0.5f, results[0].score, 10e-5);

Map<String, Integer> testQuery3 = new HashMap<>();
testQuery3.put("more", 1);

results = searcher.search(testQuery3);

assertEquals(2, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(0.5f, results[0].score, 10e-5);
assertEquals("doc2", results[1].docid);
assertEquals(1, results[1].lucene_docid);
assertEquals(0.5f, results[1].score, 10e-5);

searcher.unset_rm3();
assertFalse(searcher.use_rm3());

results = searcher.search(testQuery1, 1);
assertEquals(1, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(2.0f, results[0].score, 10e-5);

searcher.close();
}

@Test
public void testSearch4() throws Exception {
// This adds Rocchio on top of "testSearch1"
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
searcher.set_rocchio();
assertTrue(searcher.use_rocchio());

Result[] results;

Map<String, Integer> testQuery1 = new HashMap<>();
testQuery1.put("text", 1);

results = searcher.search(testQuery1, 1);
assertEquals(1, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(2.0f, results[0].score, 10e-5);

Map<String, Integer> testQuery2 = new HashMap<>();
testQuery2.put("test", 1);

results = searcher.search(testQuery2);
assertEquals(1, results.length);
assertEquals("doc3", results[0].docid);
assertEquals(2, results[0].lucene_docid);
assertEquals(1.0f, results[0].score, 10e-5);

Map<String, Integer> testQuery3 = new HashMap<>();
testQuery3.put("more", 1);

results = searcher.search(testQuery3);
assertEquals(2, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(1.0f, results[0].score, 10e-5);
assertEquals("doc2", results[1].docid);
assertEquals(1, results[1].lucene_docid);
assertEquals(1.0f, results[1].score, 10e-5);


searcher.unset_rocchio();
assertFalse(searcher.use_rocchio());

results = searcher.search(testQuery1, 1);
assertEquals(1, results.length);
assertEquals("doc1", results[0].docid);
assertEquals(0, results[0].lucene_docid);
assertEquals(2.0f, results[0].score, 10e-5);

searcher.close();
}

}

0 comments on commit 9cdcf0e

Please sign in to comment.