Skip to content

Commit

Permalink
Add utility asJoinColValue to allow join on Value class dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Mar 14, 2023
1 parent a27b04b commit 258702c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
32 changes: 27 additions & 5 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] =
F.delay(dataset.count())

/** Returns `TypedColumn` of type `A` given its name.
/** Returns `TypedColumn` of type `A` given its name (alias for `col`).
*
* {{{
* tf('id)
Expand Down Expand Up @@ -250,7 +250,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def col[A](x: Function1[T, A]): TypedColumn[T, A] =
macro TypedColumnMacroImpl.applyImpl[T, A]

/** Projects the entire TypedDataset[T] into a single column of type TypedColumn[T,T]
/** Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`.
* {{{
* ts: TypedDataset[Foo] = ...
* ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)]
Expand All @@ -261,12 +261,28 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
case StructType(_) =>
val allColumns: Array[Column] = dataset.columns.map(dataset.col)
org.apache.spark.sql.functions.struct(allColumns.toSeq: _*)

case _ =>
dataset.col(dataset.columns.head)
}

new TypedColumn[T,T](projectedColumn)
}

/** References the entire `TypedDataset[T]` as a single column
* of type `TypedColumn[T,T]` so it can be used in a join operation.
*
* {{{
* def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) =
* ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue)
* }}}
*/
def asJoinColValue(implicit i0: IsValueClass[T]): TypedColumn[T, T] = {
import _root_.frameless.syntax._

dataset.col("value").typedColumn
}

object colMany extends SingletonProductArgs {
def applyProduct[U <: HList, Out](columns: U)
(implicit
Expand Down Expand Up @@ -635,11 +651,13 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def joinInner[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean])
(implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = {
import FramelessInternals._

val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)])

TypedDataset.create[(T, U)](joinedDs)
}

Expand Down Expand Up @@ -1291,8 +1309,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
i7: TypedEncoder[Out]
): TypedDataset[Out] = {
val df = dataset.toDF()
val trans =
df.filter(df(column.value.name).isNotNull).as[Out](TypedExpressionEncoder[Out])
val trans = df.filter(df(column.value.name).isNotNull).
as[Out](TypedExpressionEncoder[Out])

TypedDataset.create[Out](trans)
}
}
Expand All @@ -1304,6 +1323,7 @@ object TypedDataset {
sqlContext: SparkSession
): TypedDataset[A] = {
val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])

TypedDataset.create[A](dataset)
}

Expand All @@ -1313,10 +1333,12 @@ object TypedDataset {
sqlContext: SparkSession
): TypedDataset[A] = {
val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])

TypedDataset.create[A](dataset)
}

def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = createUnsafe(dataset.toDF())
def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] =
createUnsafe(dataset.toDF())

/**
* Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]].
Expand Down
18 changes: 17 additions & 1 deletion dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
test("asCol with numeric operators") {
def prop(a: Seq[Long]) = {
val ds: TypedDataset[Long] = TypedDataset.create(a)
val (first,second) = (2L,5L)
val (first, second) = (2L, 5L)
val frameless: Seq[(Long, Long, Long)] =
ds.select(ds.asCol, ds.asCol+first, ds.asCol*second).collect().run()

Expand All @@ -402,6 +402,22 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
check(forAll(prop _))
}

test("reference Value class so can join on") {
import RecordEncoderTests.{ Name, Person }

val bar = new Name("bar")

val ds1: TypedDataset[Person] = TypedDataset.create(
Seq(Person(bar, 23), Person(new Name("foo"), 11)))

val ds2: TypedDataset[Name] =
TypedDataset.create(Seq(new Name("lorem"), bar))

val joined = ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue)

joined.collect().run() shouldEqual Seq(Person(bar, 23))
}

test("unary_!") {
val ds = TypedDataset.create((true, false) :: Nil)

Expand Down

0 comments on commit 258702c

Please sign in to comment.