Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BloomFilter query rewrite (without pushdown optimization) #248

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Please see the following example in which Index Building Logic and Query Rewrite
| Partition | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;year PARTITION,<br>&nbsp;&nbsp;month PARTITION,<br>&nbsp;&nbsp;day PARTITION,<br>&nbsp;&nbsp;hour PARTITION<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;FIRST(year) AS year,<br>&nbsp;&nbsp;FIRST(month) AS month,<br>&nbsp;&nbsp;FIRST(day) AS day,<br>&nbsp;&nbsp;FIRST(hour) AS hour,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE year = 2023 AND month = 4<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE year = 2023 AND month = 4<br>)<br>WHERE year = 2023 AND month = 4 |
| ValueSet | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;elb_status_code VALUE_SET<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;COLLECT_SET(elb_status_code) AS elb_status_code,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE elb_status_code = 404<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE ARRAY_CONTAINS(elb_status_code, 404)<br>)<br>WHERE elb_status_code = 404 |
| MinMax | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;request_processing_time MIN_MAX<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;MIN(request_processing_time) AS request_processing_time_min,<br>&nbsp;&nbsp;MAX(request_processing_time) AS request_processing_time_max,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE request_processing_time = 100<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br> SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE request_processing_time_min <= 100<br>&nbsp;&nbsp;&nbsp;&nbsp;AND 100 <= request_processing_time_max<br>)<br>WHERE request_processing_time = 100 |
| BloomFilter | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;client_ip BLOOM_FILTER<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;BLOOM_FILTER_AGG(client_ip) AS client_ip,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE client_ip = '127.0.0.1'<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE BLOOM_FILTER_MIGHT_CONTAIN(client_ip, '127.0.0.1') = true<br>)<br>WHERE client_ip = '127.0.0.1' |
| BloomFilter | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;client_ip BLOOM_FILTER<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;BLOOM_FILTER_AGG(client_ip) AS client_ip,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE client_ip = '127.0.0.1'<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE BLOOM_FILTER_MIGHT_CONTAIN(client_ip, '127.0.0.1')<br>)<br>WHERE client_ip = '127.0.0.1' |

### Flint Index Refresh

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,23 @@ public void writeTo(OutputStream out) throws IOException {
* @param in input stream
* @return bloom filter
*/
public static BloomFilter readFrom(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);
public static BloomFilter readFrom(InputStream in) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to try-catch as Spark codegen doesn't allow checked exception

try {
DataInputStream dis = new DataInputStream(in);

// Check version compatibility
int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IllegalStateException("Unexpected Bloom filter version number (" + version + ")");
}

int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
// Read bloom filter content
int numHashFunctions = dis.readInt();
BitArray bits = BitArray.readFrom(dis);
return new ClassicBloomFilter(bits, numHashFunctions);
} catch (IOException e) {
throw new RuntimeException(e);
}
int numHashFunctions = dis.readInt();
BitArray bits = BitArray.readFrom(dis);
return new ClassicBloomFilter(bits, numHashFunctions);
}

private static int optimalNumOfHashFunctions(long n, long m) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import java.io.ByteArrayInputStream

import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._

/**
* Bloom filter function that returns the membership check result for values of `valueExpression`
* in the bloom filter represented by `bloomFilterExpression`.
*
* @param bloomFilterExpression
* binary expression that represents bloom filter data
* @param valueExpression
* Long value expression to be tested
*/
case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpression: Expression)
extends BinaryComparison {

override def nullable: Boolean = true

override def left: Expression = bloomFilterExpression

override def right: Expression = valueExpression

override def prettyName: String = "bloom_filter_might_contain"

override def dataType: DataType = BooleanType

override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN"

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
(BinaryType, LongType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"""
| Input to function $prettyName should be Binary expression followed by a Long value,
| but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].
| """.stripMargin)
}
}

override protected def withNewChildrenInternal(
newBloomFilterExpression: Expression,
newValueExpression: Expression): BloomFilterMightContain =
copy(bloomFilterExpression = newBloomFilterExpression, valueExpression = newValueExpression)

override def eval(input: InternalRow): Any = {
val value = valueExpression.eval(input)
if (value == null) {
null
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why eval result is null? Should bloomFilter.test(null) return false?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following Spark SQL NULL semantics, NULL is ignored in BloomFilterAgg. So NULL is returned for bloom_filter_might_contain(clientip, NULL). Reference: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala#L100

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand, what's discussed here will happen only if WHERE clientip = NULL. We're concerned it's rewritten to bloom_filter_might_contain(clientip, NULL) which skips source file by mistake.

I did some test and found out that col = NULL will be optimized by Spark directly because it always returns empty result:

spark-sql> EXPLAIN SELECT `@timestamp`, request FROM ds_tables.http_logs WHERE clientip = null;
== Physical Plan ==
LocalTableScan <empty>, [@timestamp#5, request#7]

} else {
val bytes = bloomFilterExpression.eval(input).asInstanceOf[Array[Byte]]
val bloomFilter = ClassicBloomFilter.readFrom(new ByteArrayInputStream(bytes))
bloomFilter.mightContain(value.asInstanceOf[Long])
}
}

/**
* Generate expression code for Spark codegen execution. Sample result code:
* ```
* boolean filter_isNull_0 = true;
* boolean filter_value_0 = false;
* if (!right_isNull) {
* filter_isNull_0 = false;
* filter_value_0 =
* org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter.readFrom(
* new java.io.ByteArrayInputStream(left_value)
* ).mightContain(right_value);
* }
* ```
*/
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val bloomFilterEncoder = classOf[ClassicBloomFilter].getCanonicalName.stripSuffix("$")
val bf = s"$bloomFilterEncoder.readFrom(new java.io.ByteArrayInputStream(${leftGen.value}))"
val result = s"$bf.mightContain(${rightGen.value})"
val resultCode =
s"""
|if (!(${rightGen.isNull})) {
| ${leftGen.code}
| ${ev.isNull} = false;
| ${ev.value} = $result;
|}
""".stripMargin
ev.copy(code = code"""
${rightGen.code}
boolean ${ev.isNull} = true;
boolean ${ev.value} = false;
$resultCode""")
}
}

object BloomFilterMightContain {

/**
* Generate bloom filter might contain function given the bloom filter column and value.
*
* @param colName
* column name
* @param value
* value
* @return
* bloom filter might contain expression
*/
def bloom_filter_might_contain(colName: String, value: Any): Column = {
new Column(BloomFilterMightContain(col(colName).expr, lit(value).expr))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package org.opensearch.flint.spark.skipping.bloomfilter

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
import org.apache.spark.sql.functions.{col, xxhash64}

/**
Expand Down Expand Up @@ -37,7 +39,14 @@ case class BloomFilterSkippingStrategy(
) // TODO: use xxhash64() for now
}

override def rewritePredicate(predicate: Expression): Option[Expression] = None
override def rewritePredicate(predicate: Expression): Option[Expression] = {
val IndexColumn = IndexColumnExtractor(columnName)
predicate match {
case EqualTo(IndexColumn(indexCol), value: Literal) =>
Some(BloomFilterMightContain(indexCol.expr, xxhash64(new Column(value)).expr))
case _ => None
}
}

private def expectedNumItems: Int = {
params
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types.{BinaryType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class BloomFilterMightContainSuite extends FlintSuite {

test("checkInputDataTypes should succeed for valid input types") {
val binaryExpression = Literal(Array[Byte](1, 2, 3), BinaryType)
val longExpression = Literal(42L, LongType)

val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression)
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess)
}

test("checkInputDataTypes should succeed for valid input types with nulls") {
val binaryExpression = Literal.create(null, BinaryType)
val longExpression = Literal.create(null, LongType)

val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression)
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess)
}

test("checkInputDataTypes should fail for invalid input types") {
val stringExpression = Literal(UTF8String.fromString("invalid"), StringType)
val doubleExpression = Literal(3.14, DoubleType)

val bloomFilterMightContain = BloomFilterMightContain(stringExpression, doubleExpression)
val expectedErrorMsg =
s"""
| Input to function bloom_filter_might_contain should be Binary expression followed by a Long value,
| but it's [${stringExpression.dataType.catalogString}, ${doubleExpression.dataType.catalogString}].
| """.stripMargin

assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckFailure(expectedErrorMsg))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ import org.opensearch.flint.core.FlintVersion.current
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.bloom_filter_might_contain
import org.opensearch.index.query.QueryBuilders
import org.opensearch.index.reindex.DeleteByQueryRequest
import org.scalatest.matchers.{Matcher, MatchResult}
import org.scalatest.matchers.must.Matchers._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.functions.{col, isnull}
import org.apache.spark.sql.functions.{col, isnull, lit, xxhash64}
import org.apache.spark.sql.internal.SQLConf

class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {

Expand Down Expand Up @@ -390,7 +393,24 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
// Assert index data
flint.queryIndex(testIndex).collect() should have size 2

// TODO: Assert query rewrite result
// Assert query result and rewrite
def assertQueryRewrite(): Unit = {
val query = sql(s"SELECT name FROM $testTable WHERE age = 50")
checkAnswer(query, Row("Java"))
query.queryExecution.executedPlan should
useFlintSparkSkippingFileIndex(
hasIndexFilter(bloom_filter_might_contain("age", xxhash64(lit(50)))))
}

// Test expression with codegen enabled by default
assertQueryRewrite()

// Test expression evaluation with codegen disabled
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) {
assertQueryRewrite()
}
}

test("should rewrite applicable query with table name without database specified") {
Expand Down
Loading