Skip to content

Commit

Permalink
Fix explodeMap implementation and add explodeMap tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Nov 8, 2021
1 parent 6e981c9 commit ecdd727
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
45 changes: 36 additions & 9 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1199,29 +1199,56 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
import org.apache.spark.sql.functions.{explode => sparkExplode}

val trans =
df.withColumn(column.value.name,
sparkExplode(df(column.value.name))).as[Out](TypedExpressionEncoder[Out])
df
.withColumn(column.value.name, sparkExplode(df(column.value.name)))
.as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](trans)
}

def explodeMap[A, B, TRep <: HList, OutMod <: HList, OutModValues <: HList, Out]
/**
* Explodes a single column at a time. It only compiles if the type of column supports this operation.
*
* @example
*
* {{{
* case class X(i: Int, j: Map[Int, Int])
* case class Y(i: Int, j: (Int, Int))
*
* val f: TypedDataset[X] = ???
* val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y]
* }}}
* @param column the column we wish to explode
*/
def explodeMap[A, B, V[_, _], TRep <: HList, OutMod <: HList, OutModValues <: HList, Out]
(column: Witness.Lt[Symbol])
(implicit
i0: TypedColumn.Exists[T, column.T, Map[A, B]],
i0: TypedColumn.Exists[T, column.T, V[A, B]],
i1: TypedEncoder[A],
i2: TypedEncoder[B],
i3: LabelledGeneric.Aux[T, TRep],
i4: Modifier.Aux[TRep, column.T, Map[A,B], Tuple2[A,B], OutMod],
i4: Modifier.Aux[TRep, column.T, V[A,B], (A, B), OutMod],
i5: Values.Aux[OutMod, OutModValues],
i6: Tupler.Aux[OutModValues, Out],
i7: TypedEncoder[Out]
): TypedDataset[Out] = {
val df = dataset.toDF()
import org.apache.spark.sql.functions.{explode => sparkExplode}

import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol}

// preserve the original list of columns
val columns = df.columns.toSeq.map(sparkCol)
// select all columns, all original columns and [key, value] columns appeared after the map explode
// .withColumn(column.value.name, sparkExplode(df(column.value.name))) in this case would not work
// since the map explode produces two columns
val exploded = df.select(sparkCol("*"), sparkExplode(df(column.value.name)))
val trans =
df.withColumn(column.value.name,
sparkExplode(df(column.value.name))).as[Out](TypedExpressionEncoder[Out])
exploded
// map explode explodes it into [key, value] columns
// the only way to put it into a column is to create a struct
// TODO: handle org.apache.spark.sql.AnalysisException: Reference 'key / value' is ambiguous, could be: key / value, key / value
.withColumn(column.value.name, sparkStruct(exploded("key"), exploded("value")))
// selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode
.select(columns: _*)
.as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](trans)
}

Expand Down
28 changes: 21 additions & 7 deletions dataset/src/test/scala/frameless/ExplodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import org.scalacheck.Prop._

import scala.reflect.ClassTag


class ExplodeTests extends TypedDatasetSuite {
test("simple explode test") {
val ds = TypedDataset.create(Seq((1,Array(1,2))))
Expand Down Expand Up @@ -49,18 +48,33 @@ class ExplodeTests extends TypedDatasetSuite {
check(forAll(prop[String] _))
}

test("explode on Maps") {
def prop[A: TypedEncoder: ClassTag](xs: List[X1[Map[A, A]]]): Prop = {
test("explode on maps") {
def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X1[Map[A, B]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.explodeMap('a).collect().run().toVector
val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple2(t._1, t._2)).toVector
val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector

framelessResults ?= scalaResults
}

check(forAll(prop[Long] _))
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
check(forAll(prop[Long, String] _))
check(forAll(prop[Int, Long] _))
check(forAll(prop[String, Int] _))
}

test("explode on maps preserving other columns") {
def prop[K: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X2[K, Map[A, B]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.explodeMap('b).collect().run().toVector
val scalaResults = xs.flatMap { x2 => x2.b.toList.map((x2.a, _)) }.toVector

framelessResults ?= scalaResults
}

check(forAll(prop[Int, Long, String] _))
check(forAll(prop[String, Int, Long] _))
check(forAll(prop[Long, String, Int] _))
}
}

0 comments on commit ecdd727

Please sign in to comment.