Skip to content

Commit

Permalink
Introduce functions.litValue
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Aug 29, 2021
1 parent 657d0a3 commit 1f44ab9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
49 changes: 44 additions & 5 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package frameless

import scala.reflect.ClassTag

import shapeless._
import shapeless.labelled.FieldType
import shapeless.ops.hlist.IsHCons
import shapeless.ops.record.Values

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Literal

package object functions
extends Udf with UnaryFunctions with LowPriorityFunctions {
package object functions extends Udf with UnaryFunctions {

object aggregate extends AggregateFunctions
object nonAggregate extends NonAggregateFunctions
Expand All @@ -16,9 +22,6 @@ package object functions
*/
def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] =
new TypedAggregate[T,A](lit(value).expr)
}

private[frameless] sealed trait LowPriorityFunctions {

/** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make
* sure the injection instance is in scope.
Expand Down Expand Up @@ -47,4 +50,40 @@ private[frameless] sealed trait LowPriorityFunctions {
)
}
}

/** Creates a [[frameless.TypedColumn]] of literal value
* for a Value class `A`.
*
* @tparam A the value class
* @tparam T the row type
*/
def litValue[A <: AnyVal, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)(
implicit
i0: LabelledGeneric.Aux[A, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons.Aux[H, _ <: FieldType[_, V], HNil],
i3: Values.Aux[H, VS],
i4: IsHCons.Aux[VS, V, HNil],
i5: TypedEncoder[V],
i6: ClassTag[A]
): TypedColumn[T, A] = {
val expr = {
val field: H = i1(i0.to(value))
val v: V = i4.head(i3(field))

new Literal(v, i5.jvmRepr)
}

implicit val enc: TypedEncoder[A] =
RecordFieldEncoder.valueClass[A, G, H, V].encoder

new TypedColumn[T, A](
new Lit(
dataType = i5.catalystRepr,
nullable = i5.nullable,
toCatalyst = i5.toCatalyst(expr).genCode(_),
show = value.toString
)
)
}
}
11 changes: 3 additions & 8 deletions dataset/src/test/scala/frameless/LitTests.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package frameless

import frameless.functions.lit
import frameless.functions.{ lit, litValue }

import org.scalatest.matchers.should.Matchers

import org.scalacheck.{ Arbitrary, Gen, Prop }, Prop._
import org.scalacheck.Prop, Prop._

import RecordEncoderTests.Name

Expand Down Expand Up @@ -51,11 +51,6 @@ class LitTests extends TypedDatasetSuite with Matchers {

check(prop[Food] _)

implicit def nameArb: Arbitrary[Name] =
Arbitrary(Gen.alphaStr.map(new Name(_)))

check(prop[Name] _)

// doesn't work, object has to be serializable
// check(prop[frameless.LocalDateTime] _)
}
Expand All @@ -70,7 +65,7 @@ class LitTests extends TypedDatasetSuite with Matchers {

val lorem = new Name("Lorem")

ds.withColumnReplaced('name, lit(lorem)).
ds.withColumnReplaced('name, litValue(lorem)).
collect.run() shouldBe initial.map(_.copy(name = lorem))
}

Expand Down
5 changes: 4 additions & 1 deletion dataset/src/test/scala/frameless/RecordEncoderTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers {

safeDs.collect.run() shouldBe expected

// TODO: withColumnReplaced
val lorem = new Name("Lorem")

safeDs.withColumnReplaced('name, functions.litValue(lorem)).
collect.run() shouldBe expected.map(_.copy(name = lorem))
}

test("Case class with value class as optional field") {
Expand Down

0 comments on commit 1f44ab9

Please sign in to comment.