diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/BloomFilterFactory.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/BloomFilterFactory.java new file mode 100644 index 000000000..63dda18e9 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/BloomFilterFactory.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.field.bloomfilter; + +import java.io.InputStream; +import java.io.Serializable; +import java.util.Map; +import java.util.Optional; +import org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter; +import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter; + +/** + * Bloom filter factory that builds bloom filter based on algorithm parameters. + */ +public abstract class BloomFilterFactory implements Serializable { + + /** + * Bloom filter adaptive key and default value. + */ + public static final String BLOOM_FILTER_ADAPTIVE_KEY = "adaptive"; + public static final boolean DEFAULT_BLOOM_FILTER_ADAPTIVE = true; + + /** + * Expected number of unique items key and default value. + */ + public static final String CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY = "num_items"; + public static final int DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS = 10000; + + /** + * False positive probability (FPP) key and default value. + */ + public static final String CLASSIC_BLOOM_FILTER_FPP_KEY = "fpp"; + public static final double DEFAULT_CLASSIC_BLOOM_FILTER_FPP = 0.03; + + /** + * Number of candidate key and default value. + */ + public static final String ADAPTIVE_NUMBER_CANDIDATE_KEY = "num_candidates"; + public static final int DEFAULT_ADAPTIVE_NUMBER_CANDIDATE = 10; + + /** + * Bloom filter algorithm parameters. + */ + private final Map parameters; + + protected BloomFilterFactory(Map parameters) { + this.parameters = parameters; + } + + /** + * @return all parameters including the default ones. + */ + public abstract Map getParameters(); + + /** + * Create specific BloomFilter instance. + * + * @return BloomFilter instance + */ + public abstract BloomFilter create(); + + /** + * Create specific BloomFilter instance by deserialization. + * + * @param in input stream + * @return BloomFilter instance + */ + public abstract BloomFilter deserialize(InputStream in); + + /** + * Create specific BloomFilter factory given the parameters. + * + * @param parameters BloomFilter parameters + * @return BloomFilter factory instance + */ + public static BloomFilterFactory of(Map parameters) { + if (isAdaptiveEnabled(parameters)) { + return createAdaptiveBloomFilterFactory(parameters); + } else { + return createClassicBloomFilterFactory(parameters); + } + } + + private static BloomFilterFactory createAdaptiveBloomFilterFactory(Map parameters) { + return new BloomFilterFactory(parameters) { + @Override + public Map getParameters() { + return Map.of( + BLOOM_FILTER_ADAPTIVE_KEY, "true", + ADAPTIVE_NUMBER_CANDIDATE_KEY, Integer.toString(numCandidates()), + CLASSIC_BLOOM_FILTER_FPP_KEY, Double.toString(fpp())); + } + + @Override + public BloomFilter create() { + return new AdaptiveBloomFilter(numCandidates(), fpp()); + } + + @Override + public BloomFilter deserialize(InputStream in) { + return AdaptiveBloomFilter.readFrom(numCandidates(), in); + } + }; + } + + private static BloomFilterFactory createClassicBloomFilterFactory(Map parameters) { + return new BloomFilterFactory(parameters) { + @Override + public Map getParameters() { + return Map.of( + BLOOM_FILTER_ADAPTIVE_KEY, "false", + CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, Integer.toString(expectedNumItems()), + CLASSIC_BLOOM_FILTER_FPP_KEY, Double.toString(fpp())); + } + + @Override + public BloomFilter create() { + return new ClassicBloomFilter(expectedNumItems(), fpp()); + } + + @Override + public BloomFilter deserialize(InputStream in) { + return ClassicBloomFilter.readFrom(in); + } + }; + } + + private static boolean isAdaptiveEnabled(Map params) { + return Optional.ofNullable(params.get(BLOOM_FILTER_ADAPTIVE_KEY)) + .map(Boolean::parseBoolean) + .orElse(DEFAULT_BLOOM_FILTER_ADAPTIVE); + } + + protected int expectedNumItems() { + return Optional.ofNullable(parameters.get(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY)) + .map(Integer::parseInt) + .orElse(DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS); + } + + protected double fpp() { + return Optional.ofNullable(parameters.get(CLASSIC_BLOOM_FILTER_FPP_KEY)) + .map(Double::parseDouble) + .orElse(DEFAULT_CLASSIC_BLOOM_FILTER_FPP); + } + + protected int numCandidates() { + return Optional.ofNullable(parameters.get(ADAPTIVE_NUMBER_CANDIDATE_KEY)) + .map(Integer::parseInt) + .orElse(DEFAULT_ADAPTIVE_NUMBER_CANDIDATE); + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java new file mode 100644 index 000000000..42d8889ff --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java @@ -0,0 +1,230 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.field.bloomfilter.adaptive; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Objects; +import java.util.function.Function; +import org.opensearch.flint.core.field.bloomfilter.BloomFilter; +import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter; + +/** + * Adaptive BloomFilter implementation that generates a series of bloom filter candidate + * with different expected number of item (NDV) and at last choose the best one. + */ +public class AdaptiveBloomFilter implements BloomFilter { + + /** + * Initial expected number of items for the first candidate. + */ + public static final int INITIAL_EXPECTED_NUM_ITEMS = 1024; + + /** + * Total number of distinct items seen so far. + */ + private int cardinality = 0; + + /** + * BloomFilter candidates. + */ + private final BloomFilterCandidate[] candidates; + + /** + * Construct adaptive BloomFilter instance with the given algorithm parameters. + * + * @param numCandidates number of candidate + * @param fpp false positive probability + */ + public AdaptiveBloomFilter(int numCandidates, double fpp) { + this.candidates = initializeCandidates(numCandidates, + expectedNumItems -> new ClassicBloomFilter(expectedNumItems, fpp)); + } + + /** + * Construct adaptive BloomFilter instance from BloomFilter array deserialized from input stream. + * + * @param cardinality total number of distinct items + * @param candidates BloomFilter candidates + */ + AdaptiveBloomFilter(int cardinality, BloomFilter[] candidates) { + this.cardinality = cardinality; + Iterator it = Arrays.stream(candidates).iterator(); + this.candidates = initializeCandidates(candidates.length, expectedNumItems -> it.next()); + } + + /** + * Deserialize adaptive BloomFilter instance from input stream. + * + * @param numCandidates number of candidates + * @param in input stream of serialized adaptive BloomFilter instance + * @return adaptive BloomFilter instance + */ + public static BloomFilter readFrom(int numCandidates, InputStream in) { + try { + // Read total distinct counter + int cardinality = new DataInputStream(in).readInt(); + + // Read BloomFilter candidate array + BloomFilter[] candidates = new BloomFilter[numCandidates]; + for (int i = 0; i < numCandidates; i++) { + candidates[i] = ClassicBloomFilter.readFrom(in); + } + return new AdaptiveBloomFilter(cardinality, candidates); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize adaptive BloomFilter", e); + } + } + + /** + * @return best BloomFilter candidate which has expected number of item right above total distinct counter. + */ + public BloomFilterCandidate bestCandidate() { + return candidates[bestCandidateIndex()]; + } + + @Override + public long bitSize() { + return Arrays.stream(candidates) + .mapToLong(c -> c.bloomFilter.bitSize()) + .sum(); + } + + @Override + public boolean put(long item) { + // Only insert into candidate with larger expectedNumItems for efficiency + boolean bitChanged = false; + for (int i = bestCandidateIndex(); i < candidates.length; i++) { + bitChanged = candidates[i].bloomFilter.put(item); + } + + // Use the last candidate's put result which is the most accurate + if (bitChanged) { + cardinality++; + } + return bitChanged; + } + + @Override + public BloomFilter merge(BloomFilter other) { + AdaptiveBloomFilter otherBf = (AdaptiveBloomFilter) other; + cardinality += otherBf.cardinality; + + for (int i = bestCandidateIndex(); i < candidates.length; i++) { + candidates[i].bloomFilter.merge(otherBf.candidates[i].bloomFilter); + } + return this; + } + + @Override + public boolean mightContain(long item) { + // Use the last candidate which is the most accurate + return candidates[candidates.length - 1].bloomFilter.mightContain(item); + } + + @Override + public void writeTo(OutputStream out) throws IOException { + // Serialized cardinality counter first + new DataOutputStream(out).writeInt(cardinality); + + // Serialize classic BloomFilter array + for (BloomFilterCandidate candidate : candidates) { + candidate.bloomFilter.writeTo(out); + } + } + + private BloomFilterCandidate[] initializeCandidates(int numCandidates, + Function initializer) { + BloomFilterCandidate[] candidates = new BloomFilterCandidate[numCandidates]; + int ndv = INITIAL_EXPECTED_NUM_ITEMS; + + // Initialize candidate with NDV doubled in each iteration + for (int i = 0; i < numCandidates; i++, ndv *= 2) { + candidates[i] = new BloomFilterCandidate(ndv, initializer.apply(ndv)); + } + return candidates; + } + + private int bestCandidateIndex() { + int index = Arrays.binarySearch(candidates, new BloomFilterCandidate(cardinality, null)); + if (index < 0) { + index = -(index + 1); + } + + /* + * Now 'index' represents the position where the current cardinality should be inserted, + * indicating the best candidate to choose based on its expected number of distinct values. + * The last one is chosen if cardinality exceeds each candidate's expected number. + */ + return Math.min(index, candidates.length - 1); + } + + /** + * BloomFilter candidate that records expected number of items for each candidate. + */ + public static class BloomFilterCandidate implements Comparable { + /** + * Expected number of items associated with this candidate. + */ + private final int expectedNumItems; + + /** + * BloomFilter instance. + */ + private final BloomFilter bloomFilter; + + BloomFilterCandidate(int expectedNumItems, BloomFilter bloomFilter) { + this.expectedNumItems = expectedNumItems; + this.bloomFilter = bloomFilter; + } + + public int getExpectedNumItems() { + return expectedNumItems; + } + + public BloomFilter getBloomFilter() { + return bloomFilter; + } + + @Override + public int compareTo(BloomFilterCandidate other) { + return Integer.compare(expectedNumItems, other.expectedNumItems); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BloomFilterCandidate that = (BloomFilterCandidate) o; + return expectedNumItems == that.expectedNumItems && Objects.equals(bloomFilter, that.bloomFilter); + } + + @Override + public int hashCode() { + return Objects.hash(expectedNumItems, bloomFilter); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AdaptiveBloomFilter that = (AdaptiveBloomFilter) o; + return cardinality == that.cardinality && Arrays.equals(candidates, that.candidates); + } + + @Override + public int hashCode() { + int result = Objects.hash(cardinality); + result = 31 * result + Arrays.hashCode(candidates); + return result; + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/BloomFilterFactoryTest.java b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/BloomFilterFactoryTest.java new file mode 100644 index 000000000..40a310b01 --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/BloomFilterFactoryTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.field.bloomfilter; + +import static java.util.Collections.emptyMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.ADAPTIVE_NUMBER_CANDIDATE_KEY; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.BLOOM_FILTER_ADAPTIVE_KEY; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.CLASSIC_BLOOM_FILTER_FPP_KEY; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.DEFAULT_ADAPTIVE_NUMBER_CANDIDATE; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.DEFAULT_CLASSIC_BLOOM_FILTER_FPP; +import static org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory.DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Map; +import org.junit.Test; +import org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter; +import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter; + +public class BloomFilterFactoryTest { + + @Test + public void createAdaptiveBloomFilter() { + BloomFilterFactory factory = BloomFilterFactory.of(emptyMap()); + + assertInstanceOf(AdaptiveBloomFilter.class, factory.create()); + assertEquals( + Map.of( + BLOOM_FILTER_ADAPTIVE_KEY, "true", + ADAPTIVE_NUMBER_CANDIDATE_KEY, Integer.toString(DEFAULT_ADAPTIVE_NUMBER_CANDIDATE), + CLASSIC_BLOOM_FILTER_FPP_KEY, Double.toString(DEFAULT_CLASSIC_BLOOM_FILTER_FPP)), + factory.getParameters()); + } + + @Test + public void createAdaptiveBloomFilterWithParameters() { + Map parameters = + Map.of( + BLOOM_FILTER_ADAPTIVE_KEY, "true", + ADAPTIVE_NUMBER_CANDIDATE_KEY, "8", + CLASSIC_BLOOM_FILTER_FPP_KEY, "0.01"); + BloomFilterFactory factory = BloomFilterFactory.of(parameters); + + assertInstanceOf(AdaptiveBloomFilter.class, factory.create()); + assertEquals(parameters, factory.getParameters()); + } + + @Test + public void shouldCreateClassicBloomFilter() { + BloomFilterFactory factory = BloomFilterFactory.of(Map.of(BLOOM_FILTER_ADAPTIVE_KEY, "false")); + + assertEquals( + Map.of( + BLOOM_FILTER_ADAPTIVE_KEY, "false", + CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, Integer.toString(DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS), + CLASSIC_BLOOM_FILTER_FPP_KEY, Double.toString(DEFAULT_CLASSIC_BLOOM_FILTER_FPP)), + factory.getParameters()); + assertInstanceOf(ClassicBloomFilter.class, factory.create()); + } + + @Test + public void shouldCreateClassicBloomFilterWithParameters() { + Map parameters = + Map.of( + BLOOM_FILTER_ADAPTIVE_KEY, "false", + CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, "20000", + CLASSIC_BLOOM_FILTER_FPP_KEY, "0.02"); + BloomFilterFactory factory = BloomFilterFactory.of(parameters); + + assertInstanceOf(ClassicBloomFilter.class, factory.create()); + assertEquals(parameters, factory.getParameters()); + } + + @Test + public void deserializeAdaptiveBloomFilter() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + new AdaptiveBloomFilter( + DEFAULT_ADAPTIVE_NUMBER_CANDIDATE, + DEFAULT_CLASSIC_BLOOM_FILTER_FPP + ).writeTo(out); + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + + BloomFilterFactory factory = BloomFilterFactory.of(emptyMap()); + assertInstanceOf(AdaptiveBloomFilter.class, factory.deserialize(in)); + } + + @Test + public void deserializeClassicBloomFilter() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + new ClassicBloomFilter( + DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS, + DEFAULT_CLASSIC_BLOOM_FILTER_FPP + ).writeTo(out); + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + + BloomFilterFactory factory = BloomFilterFactory.of(Map.of(BLOOM_FILTER_ADAPTIVE_KEY, "false")); + assertInstanceOf(ClassicBloomFilter.class, factory.deserialize(in)); + } +} \ No newline at end of file diff --git a/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java new file mode 100644 index 000000000..ad5948572 --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.field.bloomfilter.adaptive; + +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter.readFrom; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import org.junit.Test; +import org.opensearch.flint.core.field.bloomfilter.BloomFilter; + +public class AdaptiveBloomFilterTest { + + private final int numCandidates = 5; + + private final double fpp = 0.03; + + private final AdaptiveBloomFilter bloomFilter = new AdaptiveBloomFilter(numCandidates, fpp); + + @Test + public void shouldChooseBestCandidateAdaptively() { + // Insert 500 items should choose 1st candidate + for (int i = 0; i < 500; i++) { + bloomFilter.put(i); + } + assertEquals(1024, bloomFilter.bestCandidate().getExpectedNumItems()); + + // Insert 1000 (total 1500) items should choose 2nd candidate + for (int i = 500; i < 1500; i++) { + bloomFilter.put(i); + } + assertEquals(2048, bloomFilter.bestCandidate().getExpectedNumItems()); + + // Insert 4000 (total 5500) items should choose 4th candidate + for (int i = 1500; i < 5500; i++) { + bloomFilter.put(i); + } + assertEquals(8192, bloomFilter.bestCandidate().getExpectedNumItems()); + } + + @Test + public void shouldChooseLastCandidateForLargeCardinality() { + // Insert items more than last candidate's NDV 16384 + for (int i = 0; i < 20000; i++) { + bloomFilter.put(i); + } + + // Ensure that the last candidate is chosen due to the large cardinality + assertEquals(16384, bloomFilter.bestCandidate().getExpectedNumItems()); + } + + @Test + public void shouldBeTheSameAfterWriteToAndReadFrom() throws IOException { + // Insert some items to verify each candidate below + for (int i = 0; i < 10000; i++) { + bloomFilter.put(i); + } + + // Serialize and deserialize and assert the equality + ByteArrayOutputStream out = new ByteArrayOutputStream(); + bloomFilter.writeTo(out); + InputStream in = new ByteArrayInputStream(out.toByteArray()); + BloomFilter newBloomFilter = readFrom(numCandidates, in); + assertEquals(bloomFilter, newBloomFilter); + } + + @Test + public void shouldMergeTwoFiltersCorrectly() { + AdaptiveBloomFilter bloomFilter2 = new AdaptiveBloomFilter(numCandidates, fpp); + + // Insert items into the first filter + for (int i = 0; i < 1000; i++) { + bloomFilter.put(i); + } + + // Insert different items into the second filter + for (int i = 1000; i < 2000; i++) { + bloomFilter2.put(i); + } + + // Merge the second filter into the first one + bloomFilter.merge(bloomFilter2); + + // Check if the merged filter contains items from both filters + for (int i = 0; i < 2000; i++) { + assertTrue(bloomFilter.mightContain(i)); + } + assertEquals(2048, bloomFilter.bestCandidate().getExpectedNumItems()); + } +} \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala index b40554335..4fb79d67d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala @@ -7,8 +7,8 @@ package org.opensearch.flint.spark.skipping.bloomfilter import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import org.opensearch.flint.core.field.bloomfilter.BloomFilter -import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter +import org.opensearch.flint.core.field.bloomfilter.{BloomFilter, BloomFilterFactory} +import org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression @@ -26,21 +26,18 @@ import org.apache.spark.sql.types.{BinaryType, DataType} * * @param child * child expression that generate Long values for creating a bloom filter - * @param expectedNumItems - * expected maximum unique number of items - * @param fpp - * false positive probability + * @param bloomFilterFactory + * BloomFilter factory */ case class BloomFilterAgg( child: Expression, - expectedNumItems: Int, - fpp: Double, + bloomFilterFactory: BloomFilterFactory, override val mutableAggBufferOffset: Int, override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[BloomFilter] { - def this(child: Expression, expectedNumItems: Int, fpp: Double) = { - this(child, expectedNumItems, fpp, 0, 0) + def this(child: Expression, bloomFilterFactory: BloomFilterFactory) = { + this(child, bloomFilterFactory, 0, 0) } override def nullable: Boolean = true @@ -50,7 +47,7 @@ case class BloomFilterAgg( override def children: Seq[Expression] = Seq(child) override def createAggregationBuffer(): BloomFilter = { - new ClassicBloomFilter(expectedNumItems, fpp) + bloomFilterFactory.create() } override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = { @@ -72,7 +69,14 @@ case class BloomFilterAgg( // There's no set bit in the Bloom filter and hence no not-null value is processed. return null } - serialize(buffer) + + // Serialize BloomFilter (best candidate if adaptive) as final result + buffer match { + case filter: AdaptiveBloomFilter => + serialize(filter.bestCandidate().getBloomFilter) + case _ => + serialize(buffer) + } } override def serialize(buffer: BloomFilter): Array[Byte] = { @@ -88,7 +92,7 @@ case class BloomFilterAgg( override def deserialize(bytes: Array[Byte]): BloomFilter = { val in = new ByteArrayInputStream(bytes) - val bloomFilter = ClassicBloomFilter.readFrom(in) + val bloomFilter = bloomFilterFactory.deserialize(in) in.close() bloomFilter } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala index 3e91e235b..791f7cc9f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala @@ -5,10 +5,12 @@ package org.opensearch.flint.spark.skipping.bloomfilter +import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter} + +import org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind} -import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS} import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal} @@ -24,17 +26,15 @@ case class BloomFilterSkippingStrategy( params: Map[String, String] = Map.empty) extends FlintSparkSkippingStrategy { - override val parameters: Map[String, String] = { - Map( - CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> expectedNumItems.toString, - CLASSIC_BLOOM_FILTER_FPP_KEY -> fpp.toString) - } + private val bloomFilterFactory: BloomFilterFactory = BloomFilterFactory.of(params.asJava) + + override val parameters: Map[String, String] = bloomFilterFactory.getParameters.asScala.toMap override def outputSchema(): Map[String, String] = Map(columnName -> "binary") override def getAggregators: Seq[Expression] = { Seq( - new BloomFilterAgg(xxhash64(col(columnName)).expr, expectedNumItems, fpp) + new BloomFilterAgg(xxhash64(col(columnName)).expr, bloomFilterFactory) .toAggregateExpression() ) // TODO: use xxhash64() for now } @@ -47,33 +47,4 @@ case class BloomFilterSkippingStrategy( case _ => None } } - - private def expectedNumItems: Int = { - params - .get(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY) - .map(_.toInt) - .getOrElse(DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS) - } - - private def fpp: Double = { - params - .get(CLASSIC_BLOOM_FILTER_FPP_KEY) - .map(_.toDouble) - .getOrElse(DEFAULT_CLASSIC_BLOOM_FILTER_FPP) - } -} - -object BloomFilterSkippingStrategy { - - /** - * Expected number of unique items key and default value. - */ - val CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY = "num_items" - val DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS = 10000 - - /** - * False positive probability (FPP) key and default value. - */ - val CLASSIC_BLOOM_FILTER_FPP_KEY = "fpp" - val DEFAULT_CLASSIC_BLOOM_FILTER_FPP = 0.03 } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala index c3db6fb1d..1a623e3e8 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala @@ -5,8 +5,8 @@ package org.opensearch.flint.spark.skipping.bloomfilter +import org.opensearch.flint.core.field.bloomfilter.BloomFilterFactory._ import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite} -import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS} import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -22,7 +22,8 @@ class BloomFilterSkippingStrategySuite test("parameters") { strategy.parameters shouldBe Map( - CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString, + BLOOM_FILTER_ADAPTIVE_KEY -> DEFAULT_BLOOM_FILTER_ADAPTIVE.toString, + ADAPTIVE_NUMBER_CANDIDATE_KEY -> DEFAULT_ADAPTIVE_NUMBER_CANDIDATE.toString, CLASSIC_BLOOM_FILTER_FPP_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_FPP.toString) } } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index b3cbe4cae..2b0907a5f 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -91,7 +91,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | { | "kind": "BLOOM_FILTER", | "parameters": { - | "num_items": "10000", + | "adaptive": "true", + | "num_candidates": "10", | "fpp": "0.03" | }, | "columnName": "name",