diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 7f459684f..add2170b2 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -210,7 +210,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] = F.delay(dataset.count()) - /** Returns `TypedColumn` of type `A` given its name. + /** Returns `TypedColumn` of type `A` given its name (alias for `col`). * * {{{ * tf('id) @@ -250,7 +250,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val def col[A](x: Function1[T, A]): TypedColumn[T, A] = macro TypedColumnMacroImpl.applyImpl[T, A] - /** Projects the entire TypedDataset[T] into a single column of type TypedColumn[T,T] + /** Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`. * {{{ * ts: TypedDataset[Foo] = ... * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)] @@ -261,12 +261,28 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val case StructType(_) => val allColumns: Array[Column] = dataset.columns.map(dataset.col) org.apache.spark.sql.functions.struct(allColumns.toSeq: _*) + case _ => dataset.col(dataset.columns.head) } + new TypedColumn[T,T](projectedColumn) } + /** References the entire `TypedDataset[T]` as a single column + * of type `TypedColumn[T,T]` so it can be used in a join operation. + * + * {{{ + * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) = + * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue) + * }}} + */ + def asJoinColValue(implicit i0: IsValueClass[T]): TypedColumn[T, T] = { + import _root_.frameless.syntax._ + + dataset.col("value").typedColumn + } + object colMany extends SingletonProductArgs { def applyProduct[U <: HList, Out](columns: U) (implicit @@ -635,11 +651,13 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val def joinInner[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = { import FramelessInternals._ + val leftPlan = logicalPlan(dataset) val rightPlan = logicalPlan(other.dataset) val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE)) val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan) val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)]) + TypedDataset.create[(T, U)](joinedDs) } @@ -1291,8 +1309,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val i7: TypedEncoder[Out] ): TypedDataset[Out] = { val df = dataset.toDF() - val trans = - df.filter(df(column.value.name).isNotNull).as[Out](TypedExpressionEncoder[Out]) + val trans = df.filter(df(column.value.name).isNotNull). + as[Out](TypedExpressionEncoder[Out]) + TypedDataset.create[Out](trans) } } @@ -1304,6 +1323,7 @@ object TypedDataset { sqlContext: SparkSession ): TypedDataset[A] = { val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) + TypedDataset.create[A](dataset) } @@ -1313,10 +1333,12 @@ object TypedDataset { sqlContext: SparkSession ): TypedDataset[A] = { val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) + TypedDataset.create[A](dataset) } - def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = createUnsafe(dataset.toDF()) + def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = + createUnsafe(dataset.toDF()) /** * Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]]. diff --git a/dataset/src/test/scala/frameless/ColumnTests.scala b/dataset/src/test/scala/frameless/ColumnTests.scala index 3dea3e568..9f1af57a5 100644 --- a/dataset/src/test/scala/frameless/ColumnTests.scala +++ b/dataset/src/test/scala/frameless/ColumnTests.scala @@ -389,7 +389,7 @@ final class ColumnTests extends TypedDatasetSuite with Matchers { test("asCol with numeric operators") { def prop(a: Seq[Long]) = { val ds: TypedDataset[Long] = TypedDataset.create(a) - val (first,second) = (2L,5L) + val (first, second) = (2L, 5L) val frameless: Seq[(Long, Long, Long)] = ds.select(ds.asCol, ds.asCol+first, ds.asCol*second).collect().run() @@ -402,6 +402,22 @@ final class ColumnTests extends TypedDatasetSuite with Matchers { check(forAll(prop _)) } + test("reference Value class so can join on") { + import RecordEncoderTests.{ Name, Person } + + val bar = new Name("bar") + + val ds1: TypedDataset[Person] = TypedDataset.create( + Seq(Person(bar, 23), Person(new Name("foo"), 11))) + + val ds2: TypedDataset[Name] = + TypedDataset.create(Seq(new Name("lorem"), bar)) + + val joined = ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue) + + joined.collect().run() shouldEqual Seq(Person(bar, 23)) + } + test("unary_!") { val ds = TypedDataset.create((true, false) :: Nil)