Skip to content

Commit

Permalink
[SPARK-49630][SS] Add flatten option to process collection types with…
Browse files Browse the repository at this point in the history
… state data source reader

### What changes were proposed in this pull request?
Add flatten option to process collection types with state data source reader

### Why are the changes needed?
Changes are needed to process entries row-by-row in case we don't have enough memory to fit these collections inside a single row

### Does this PR introduce _any_ user-facing change?
Yes

Users can provide the following query option:
```
        val stateReaderDf = spark.read
          .format("statestore")
          .option(StateSourceOptions.PATH, <state_checkpoint_loc>)
          .option(StateSourceOptions.STATE_VAR_NAME, <state_var_name>)
          .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, <true | false>)
          .load()
```

### How was this patch tested?
Added unit tests

```
[info] Run completed in 1 minute, 10 seconds.
[info] Total number of tests run: 12
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 12, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#48110 from anishshri-db/task/SPARK-49630.

Authored-by: Anish Shrigondekar <anish.shrigondekar@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Sep 24, 2024
1 parent 982028e commit 73d6bd7
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,15 @@ case class StateSourceOptions(
readChangeFeed: Boolean,
fromSnapshotOptions: Option[FromSnapshotOptions],
readChangeFeedOptions: Option[ReadChangeFeedOptions],
stateVarName: Option[String]) {
stateVarName: Option[String],
flattenCollectionTypes: Boolean) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
s"stateVarName=${stateVarName.getOrElse("None")}"
s"stateVarName=${stateVarName.getOrElse("None")}, +" +
s"flattenCollectionTypes=$flattenCollectionTypes"
if (fromSnapshotOptions.isDefined) {
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
Expand All @@ -334,6 +336,7 @@ object StateSourceOptions extends DataSourceOptions {
val CHANGE_START_BATCH_ID = newOption("changeStartBatchId")
val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")
val STATE_VAR_NAME = newOption("stateVarName")
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand Down Expand Up @@ -374,6 +377,15 @@ object StateSourceOptions extends DataSourceOptions {
val stateVarName = Option(options.get(STATE_VAR_NAME))
.map(_.trim)

val flattenCollectionTypes = try {
Option(options.get(FLATTEN_COLLECTION_TYPES))
.map(_.toBoolean).getOrElse(true)
} catch {
case _: IllegalArgumentException =>
throw StateDataSourceErrors.invalidOptionValue(FLATTEN_COLLECTION_TYPES,
"Boolean value is expected")
}

val joinSide = try {
Option(options.get(JOIN_SIDE))
.map(JoinSideValues.withName).getOrElse(JoinSideValues.none)
Expand Down Expand Up @@ -477,7 +489,8 @@ object StateSourceOptions extends DataSourceOptions {

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName)
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName,
flattenCollectionTypes)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo}
Expand Down Expand Up @@ -75,9 +74,11 @@ abstract class StatePartitionReaderBase(
StructType(Array(StructField("__dummy__", NullType)))

protected val keySchema = {
if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) {
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
} else {
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
} else SchemaUtil.getCompositeKeySchema(schema)
}
}

protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
Expand All @@ -98,12 +99,8 @@ abstract class StatePartitionReaderBase(
false
}

val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined &&
stateVariableInfoOpt.get.stateVariableType == StateVariableType.ListState) {
true
} else {
false
}
val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt,
StateVariableType.ListState)

val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
Expand Down Expand Up @@ -149,7 +146,7 @@ abstract class StatePartitionReaderBase(

/**
* An implementation of [[StatePartitionReaderBase]] for the normal mode of State Data
* Source. It reads the the state at a particular batchId.
* Source. It reads the state at a particular batchId.
*/
class StatePartitionReader(
storeConf: StateStoreConf,
Expand Down Expand Up @@ -181,41 +178,17 @@ class StatePartitionReader(
override lazy val iter: Iterator[InternalRow] = {
val stateVarName = stateVariableInfoOpt
.map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) {
SchemaUtil.unifyMapStateRowPair(
store.iterator(stateVarName), keySchema, partition.partition)

if (stateVariableInfoOpt.isDefined) {
val stateVariableInfo = stateVariableInfoOpt.get
val stateVarType = stateVariableInfo.stateVariableType
SchemaUtil.processStateEntries(stateVarType, stateVarName, store,
keySchema, partition.partition, partition.sourceOptions)
} else {
store
.iterator(stateVarName)
.map { pair =>
stateVariableInfoOpt match {
case Some(stateVarInfo) =>
val stateVarType = stateVarInfo.stateVariableType

stateVarType match {
case StateVariableType.ValueState =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)

case StateVariableType.ListState =>
val key = pair.key
val result = store.valuesIterator(key, stateVarName)
var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
result.foreach { entry =>
unsafeRowArr = unsafeRowArr :+ entry.copy()
}
// convert the list of values to array type
val arrData = new GenericArrayData(unsafeRowArr.toArray)
SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData),
partition.partition)

case _ =>
throw new IllegalStateException(
s"Unsupported state variable type: $stateVarType")
}

case None =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
}
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
}
}
}
Expand Down
Loading

0 comments on commit 73d6bd7

Please sign in to comment.