diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index fe4bce84..845e52c9 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -1195,12 +1195,60 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val i6: Tupler.Aux[OutModValues, Out], i7: TypedEncoder[Out] ): TypedDataset[Out] = { - val df = dataset.toDF() import org.apache.spark.sql.functions.{explode => sparkExplode} + val df = dataset.toDF() + + val trans = + df + .withColumn(column.value.name, sparkExplode(df(column.value.name))) + .as[Out](TypedExpressionEncoder[Out]) + TypedDataset.create[Out](trans) + } + + /** + * Explodes a single column at a time. It only compiles if the type of column supports this operation. + * + * @example + * + * {{{ + * case class X(i: Int, j: Map[Int, Int]) + * case class Y(i: Int, j: (Int, Int)) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y] + * }}} + * @param column the column we wish to explode + */ + def explodeMap[A, B, V[_, _], TRep <: HList, OutMod <: HList, OutModValues <: HList, Out] + (column: Witness.Lt[Symbol]) + (implicit + i0: TypedColumn.Exists[T, column.T, V[A, B]], + i1: TypedEncoder[A], + i2: TypedEncoder[B], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A,B], (A, B), OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { + import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol} + val df = dataset.toDF() + // preserve the original list of columns + val columns = df.columns.toSeq.map(sparkCol) + // select all columns, all original columns and [key, value] columns appeared after the map explode + // .withColumn(column.value.name, sparkExplode(df(column.value.name))) in this case would not work + // since the map explode produces two columns + val exploded = df.select(sparkCol("*"), sparkExplode(df(column.value.name))) val trans = - df.withColumn(column.value.name, - sparkExplode(df(column.value.name))).as[Out](TypedExpressionEncoder[Out]) + exploded + // map explode explodes it into [key, value] columns + // the only way to put it into a column is to create a struct + // TODO: handle org.apache.spark.sql.AnalysisException: Reference 'key / value' is ambiguous, could be: key / value, key / value + .withColumn(column.value.name, sparkStruct(exploded("key"), exploded("value"))) + // selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode + .select(columns: _*) + .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } diff --git a/dataset/src/test/scala/frameless/ExplodeTests.scala b/dataset/src/test/scala/frameless/ExplodeTests.scala index 7e503158..205b4903 100644 --- a/dataset/src/test/scala/frameless/ExplodeTests.scala +++ b/dataset/src/test/scala/frameless/ExplodeTests.scala @@ -7,7 +7,6 @@ import org.scalacheck.Prop._ import scala.reflect.ClassTag - class ExplodeTests extends TypedDatasetSuite { test("simple explode test") { val ds = TypedDataset.create(Seq((1,Array(1,2)))) @@ -48,4 +47,34 @@ class ExplodeTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } + + test("explode on maps") { + def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X1[Map[A, B]]]): Prop = { + val tds = TypedDataset.create(xs) + + val framelessResults = tds.explodeMap('a).collect().run().toVector + val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector + + framelessResults ?= scalaResults + } + + check(forAll(prop[Long, String] _)) + check(forAll(prop[Int, Long] _)) + check(forAll(prop[String, Int] _)) + } + + test("explode on maps preserving other columns") { + def prop[K: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X2[K, Map[A, B]]]): Prop = { + val tds = TypedDataset.create(xs) + + val framelessResults = tds.explodeMap('b).collect().run().toVector + val scalaResults = xs.flatMap { x2 => x2.b.toList.map((x2.a, _)) }.toVector + + framelessResults ?= scalaResults + } + + check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[String, Int, Long] _)) + check(forAll(prop[Long, String, Int] _)) + } }