Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
Also added feature flag to ALL_FEATURE_FLAG_SETTINGS so it is
recognized.

Also fixed a couple of bugs in ApproximatePointRangeQuery that I caught
with testing.

Signed-off-by: Michael Froh <froh@amazon.com>
  • Loading branch information
msfroh committed Oct 15, 2024
1 parent 2bf0e80 commit 914801f
Show file tree
Hide file tree
Showing 12 changed files with 298 additions and 238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ public class FeatureFlags {
STAR_TREE_INDEX_SETTING,
APPLICATION_BASED_CONFIGURATION_TEMPLATES_SETTING,
READER_WRITER_SPLIT_EXPERIMENTAL_SETTING,
TERM_VERSION_PRECOMMIT_ENABLE_SETTING
TERM_VERSION_PRECOMMIT_ENABLE_SETTING,
APPROXIMATE_POINT_RANGE_QUERY_SETTING
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,9 @@ public Query rangeQuery(
name(),
pack(new long[] { l }).bytes,
pack(new long[] { u }).bytes,
new long[] { l }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}
new long[] { l }.length,
ApproximatePointRangeQuery.LONG_FORMAT
)
);
}
return query;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1086,13 +1086,9 @@ public Query rangeQuery(
field,
LongPoint.pack(new long[] { l }).bytes,
LongPoint.pack(new long[] { u }).bytes,
new long[] { l }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}
new long[] { l }.length,
ApproximatePointRangeQuery.LONG_FORMAT
)
);
}
return query;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.search.sort.FieldSortBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* Replaces match-all query with a less expensive query if possible.
Expand All @@ -28,6 +29,7 @@ public class ApproximateMatchAllQuery extends ApproximateQuery {

@Override
protected boolean canApproximate(SearchContext context) {
approximation = null;
if (context == null) {
return false;
}
Expand Down Expand Up @@ -66,7 +68,11 @@ public void visit(QueryVisitor visitor) {

@Override
public boolean equals(Object o) {
return sameClassAs(o);
if (sameClassAs(o)) {
ApproximateMatchAllQuery other = (ApproximateMatchAllQuery) o;
return Objects.equals(approximation, other.approximation);
}
return false;
}

@Override
Expand All @@ -79,6 +85,6 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (approximation == null) {
throw new IllegalStateException("rewrite called without setting context or query could not be approximated");
}
return approximation;
return approximation.rewrite(indexSearcher);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.search.approximate;

import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
Expand All @@ -24,41 +25,57 @@
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.IntsRef;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.function.Function;

/**
* An approximate-able version of {@link PointRangeQuery}. It creates an instance of {@link PointRangeQuery} but short-circuits the intersect logic
* after {@code size} is hit
*/
public abstract class ApproximatePointRangeQuery extends ApproximateQuery {
public class ApproximatePointRangeQuery extends ApproximateQuery {
public static final Function<byte[], String> LONG_FORMAT = new Function<byte[], String>() {
@Override
public String apply(byte[] bytes) {
return Long.toString(LongPoint.decodeDimension(bytes, 0));
}
};

private int size;

private SortOrder sortOrder;

public final PointRangeQuery pointRangeQuery;

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
this(field, lowerPoint, upperPoint, numDims, 10_000, null);
}
private final PointRangeQuery pointRangeQuery;

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size) {
this(field, lowerPoint, upperPoint, numDims, size, null);
public ApproximatePointRangeQuery(
String field,
byte[] lowerPoint,
byte[] upperPoint,
int numDims,
Function<byte[], String> valueToString
) {
this(field, lowerPoint, upperPoint, numDims, SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO, null, valueToString);
}

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size, SortOrder sortOrder) {
protected ApproximatePointRangeQuery(
String field,
byte[] lowerPoint,
byte[] upperPoint,
int numDims,
int size,
SortOrder sortOrder,
Function<byte[], String> valueToString
) {
this.size = size;
this.sortOrder = sortOrder;
this.pointRangeQuery = new PointRangeQuery(field, lowerPoint, upperPoint, numDims) {
@Override
protected String toString(int dimension, byte[] value) {
return super.toString(field);
return valueToString.apply(value);
}
};
}
Expand Down Expand Up @@ -435,17 +452,22 @@ public boolean canApproximate(SearchContext context) {
}
// size 0 could be set for caching
if (context.from() + context.size() == 0) {
this.setSize(10_000);
this.setSize(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO);
} else {
this.setSize(Math.max(context.from() + context.size(), context.trackTotalHitsUpTo()));
}
this.setSize(Math.max(context.from() + context.size(), context.trackTotalHitsUpTo()));
if (context.request() != null && context.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source());
if (primarySortField != null
&& primarySortField.missing() == null
&& primarySortField.getFieldName().equals(((RangeQueryBuilder) context.request().source().query()).fieldName())) {
if (primarySortField.order() == SortOrder.DESC) {
this.setSortOrder(SortOrder.DESC);
if (primarySortField != null) {
if (!primarySortField.fieldName().equals(pointRangeQuery.getField())) {
// Cannot sort on a different field.
return false;
}
if (primarySortField.missing() != null) {
// Cannot sort documents missing this field.
return false;
}
this.setSortOrder(primarySortField.order());
}
}
return true;
Expand All @@ -462,56 +484,16 @@ public final boolean equals(Object o) {
}

private boolean equalsTo(ApproximatePointRangeQuery other) {
return Objects.equals(pointRangeQuery.getField(), other.pointRangeQuery.getField())
&& pointRangeQuery.getNumDims() == other.pointRangeQuery.getNumDims()
&& pointRangeQuery.getBytesPerDim() == other.pointRangeQuery.getBytesPerDim()
&& Arrays.equals(pointRangeQuery.getLowerPoint(), other.pointRangeQuery.getLowerPoint())
&& Arrays.equals(pointRangeQuery.getUpperPoint(), other.pointRangeQuery.getUpperPoint());
return Objects.equals(pointRangeQuery, other.pointRangeQuery);
}

@Override
public final String toString(String field) {
final StringBuilder sb = new StringBuilder();
if (pointRangeQuery.getField().equals(field) == false) {
sb.append(pointRangeQuery.getField());
sb.append(':');
}

// print ourselves as "range per dimension"
for (int i = 0; i < pointRangeQuery.getNumDims(); i++) {
if (i > 0) {
sb.append(',');
}

int startOffset = pointRangeQuery.getBytesPerDim() * i;

sb.append('[');
sb.append(
toString(
i,
ArrayUtil.copyOfSubArray(pointRangeQuery.getLowerPoint(), startOffset, startOffset + pointRangeQuery.getBytesPerDim())
)
);
sb.append(" TO ");
sb.append(
toString(
i,
ArrayUtil.copyOfSubArray(pointRangeQuery.getUpperPoint(), startOffset, startOffset + pointRangeQuery.getBytesPerDim())
)
);
sb.append(']');
}
sb.append("Approximate(");
sb.append(pointRangeQuery.toString());
sb.append(")");

return sb.toString();
}

/**
* Returns a string of a single value in a human-readable format for debugging. This is used by
* {@link #toString()}.
*
* @param dimension dimension of the particular value
* @param value single value, never null
* @return human readable value for debugging
*/
protected abstract String toString(int dimension, byte[] value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.approximate.ApproximatePointRangeQuery;
import org.opensearch.search.approximate.ApproximateScoreQuery;
import org.opensearch.test.FeatureFlagSetter;
import org.opensearch.test.TestSearchContext;
import org.joda.time.DateTimeZone;
import org.junit.Before;

import java.io.IOException;
import java.time.ZoneOffset;
Expand All @@ -82,6 +85,13 @@ public class DateFieldTypeTests extends FieldTypeTestCase {

private static final long nowInMillis = 0;

@Override
@Before
public void setUp() throws Exception {
super.setUp();
FeatureFlagSetter.set(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY);
}

public void testIsFieldWithinRangeEmptyReader() throws IOException {
QueryRewriteContext context = new QueryRewriteContext(xContentRegistry(), writableRegistry(), null, () -> nowInMillis);
IndexReader reader = new MultiReader();
Expand Down Expand Up @@ -222,13 +232,9 @@ public void testTermQuery() {
"field",
pack(new long[] { instant }).bytes,
pack(new long[] { instant + 999 }).bytes,
new long[] { instant }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}
new long[] { instant }.length,
ApproximatePointRangeQuery.LONG_FORMAT
)
);
assumeThat(
"Using Approximate Range Query as default",
Expand Down Expand Up @@ -281,32 +287,22 @@ public void testRangeQuery() throws IOException {
String date2 = "2016-04-28T11:33:52";
long instant1 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date1)).toInstant().toEpochMilli();
long instant2 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date2)).toInstant().toEpochMilli() + 999;
Query expected = new ApproximateScoreQuery(
new IndexOrDocValuesQuery(
LongPoint.newRangeQuery("field", instant1, instant2),
SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2)
),
new ApproximatePointRangeQuery(
"field",
pack(new long[] { instant1 }).bytes,
pack(new long[] { instant2 }).bytes,
new long[] { instant1 }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}
Query expected = new ApproximatePointRangeQuery(
"field",
pack(new long[] { instant1 }).bytes,
pack(new long[] { instant2 }).bytes,
new long[] { instant1 }.length,
ApproximatePointRangeQuery.LONG_FORMAT
);
assumeThat(
"Using Approximate Range Query as default",
FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY),
is(true)
);
assertEquals(
expected,
ft.rangeQuery(date1, date2, true, true, null, null, null, context).rewrite(new IndexSearcher(new MultiReader()))
);
Query rangeQuery = ft.rangeQuery(date1, date2, true, true, null, null, null, context);
assertTrue(rangeQuery instanceof ApproximateScoreQuery);
((ApproximateScoreQuery) rangeQuery).setContext(new TestSearchContext(context));
assertEquals(expected, rangeQuery.rewrite(new IndexSearcher(new MultiReader())));

instant1 = nowInMillis;
instant2 = instant1 + 100;
Expand All @@ -320,13 +316,9 @@ protected String toString(int dimension, byte[] value) {
"field",
pack(new long[] { instant1 }).bytes,
pack(new long[] { instant2 }).bytes,
new long[] { instant1 }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}
new long[] { instant1 }.length,
ApproximatePointRangeQuery.LONG_FORMAT
)
)
);
assumeThat(
Expand Down Expand Up @@ -391,23 +383,19 @@ public void testRangeQueryWithIndexSort() {
long instant2 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date2)).toInstant().toEpochMilli() + 999;

Query dvQuery = SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2);
Query expected = new IndexSortSortedNumericDocValuesRangeQuery(
"field",
instant1,
instant2,
new ApproximateScoreQuery(
new IndexOrDocValuesQuery(LongPoint.newRangeQuery("field", instant1, instant2), dvQuery),
new ApproximatePointRangeQuery(
"field",
pack(new long[] { instant1 }).bytes,
pack(new long[] { instant2 }).bytes,
new long[] { instant1 }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}
Query expected = new ApproximateScoreQuery(
new IndexSortSortedNumericDocValuesRangeQuery(
"field",
instant1,
instant2,
new IndexOrDocValuesQuery(LongPoint.newRangeQuery("field", instant1, instant2), dvQuery)
),
new ApproximatePointRangeQuery(
"field",
pack(new long[] { instant1 }).bytes,
pack(new long[] { instant2 }).bytes,
new long[] { instant1 }.length,
ApproximatePointRangeQuery.LONG_FORMAT
)
);
assumeThat(
Expand Down
Loading

0 comments on commit 914801f

Please sign in to comment.