diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 99903753e..c9f33411d 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -1199,29 +1199,56 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val import org.apache.spark.sql.functions.{explode => sparkExplode} val trans = - df.withColumn(column.value.name, - sparkExplode(df(column.value.name))).as[Out](TypedExpressionEncoder[Out]) + df + .withColumn(column.value.name, sparkExplode(df(column.value.name))) + .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } - def explodeMap[A, B, TRep <: HList, OutMod <: HList, OutModValues <: HList, Out] + /** + * Explodes a single column at a time. It only compiles if the type of column supports this operation. + * + * @example + * + * {{{ + * case class X(i: Int, j: Map[Int, Int]) + * case class Y(i: Int, j: (Int, Int)) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y] + * }}} + * @param column the column we wish to explode + */ + def explodeMap[A, B, V[_, _], TRep <: HList, OutMod <: HList, OutModValues <: HList, Out] (column: Witness.Lt[Symbol]) (implicit - i0: TypedColumn.Exists[T, column.T, Map[A, B]], + i0: TypedColumn.Exists[T, column.T, V[A, B]], i1: TypedEncoder[A], i2: TypedEncoder[B], i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, Map[A,B], Tuple2[A,B], OutMod], + i4: Modifier.Aux[TRep, column.T, V[A,B], (A, B), OutMod], i5: Values.Aux[OutMod, OutModValues], i6: Tupler.Aux[OutModValues, Out], i7: TypedEncoder[Out] ): TypedDataset[Out] = { val df = dataset.toDF() - import org.apache.spark.sql.functions.{explode => sparkExplode} - + import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol} + + // preserve the original list of columns + val columns = df.columns.toSeq.map(sparkCol) + // select all columns, all original columns and [key, value] columns appeared after the map explode + // .withColumn(column.value.name, sparkExplode(df(column.value.name))) in this case would not work + // since the map explode produces two columns + val exploded = df.select(sparkCol("*"), sparkExplode(df(column.value.name))) val trans = - df.withColumn(column.value.name, - sparkExplode(df(column.value.name))).as[Out](TypedExpressionEncoder[Out]) + exploded + // map explode explodes it into [key, value] columns + // the only way to put it into a column is to create a struct + // TODO: handle org.apache.spark.sql.AnalysisException: Reference 'key / value' is ambiguous, could be: key / value, key / value + .withColumn(column.value.name, sparkStruct(exploded("key"), exploded("value"))) + // selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode + .select(columns: _*) + .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } diff --git a/dataset/src/test/scala/frameless/ExplodeTests.scala b/dataset/src/test/scala/frameless/ExplodeTests.scala index 90a8ec05f..205b49038 100644 --- a/dataset/src/test/scala/frameless/ExplodeTests.scala +++ b/dataset/src/test/scala/frameless/ExplodeTests.scala @@ -7,7 +7,6 @@ import org.scalacheck.Prop._ import scala.reflect.ClassTag - class ExplodeTests extends TypedDatasetSuite { test("simple explode test") { val ds = TypedDataset.create(Seq((1,Array(1,2)))) @@ -49,18 +48,33 @@ class ExplodeTests extends TypedDatasetSuite { check(forAll(prop[String] _)) } - test("explode on Maps") { - def prop[A: TypedEncoder: ClassTag](xs: List[X1[Map[A, A]]]): Prop = { + test("explode on maps") { + def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X1[Map[A, B]]]): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.explodeMap('a).collect().run().toVector - val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple2(t._1, t._2)).toVector + val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector framelessResults ?= scalaResults } - check(forAll(prop[Long] _)) - check(forAll(prop[Int] _)) - check(forAll(prop[String] _)) + check(forAll(prop[Long, String] _)) + check(forAll(prop[Int, Long] _)) + check(forAll(prop[String, Int] _)) + } + + test("explode on maps preserving other columns") { + def prop[K: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X2[K, Map[A, B]]]): Prop = { + val tds = TypedDataset.create(xs) + + val framelessResults = tds.explodeMap('b).collect().run().toVector + val scalaResults = xs.flatMap { x2 => x2.b.toList.map((x2.a, _)) }.toVector + + framelessResults ?= scalaResults + } + + check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[String, Int, Long] _)) + check(forAll(prop[Long, String, Int] _)) } }