diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index b60c5008a..acb96d483 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -1,10 +1,16 @@ package frameless +import scala.reflect.ClassTag + +import shapeless._ +import shapeless.labelled.FieldType +import shapeless.ops.hlist.IsHCons +import shapeless.ops.record.Values + import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Literal -package object functions - extends Udf with UnaryFunctions with LowPriorityFunctions { +package object functions extends Udf with UnaryFunctions { object aggregate extends AggregateFunctions object nonAggregate extends NonAggregateFunctions @@ -16,9 +22,6 @@ package object functions */ def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] = new TypedAggregate[T,A](lit(value).expr) -} - -private[frameless] sealed trait LowPriorityFunctions { /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make * sure the injection instance is in scope. @@ -47,4 +50,40 @@ private[frameless] sealed trait LowPriorityFunctions { ) } } + + /** Creates a [[frameless.TypedColumn]] of literal value + * for a Value class `A`. + * + * @tparam A the value class + * @tparam T the row type + */ + def litValue[A <: AnyVal, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)( + implicit + i0: LabelledGeneric.Aux[A, G], + i1: DropUnitValues.Aux[G, H], + i2: IsHCons.Aux[H, _ <: FieldType[_, V], HNil], + i3: Values.Aux[H, VS], + i4: IsHCons.Aux[VS, V, HNil], + i5: TypedEncoder[V], + i6: ClassTag[A] + ): TypedColumn[T, A] = { + val expr = { + val field: H = i1(i0.to(value)) + val v: V = i4.head(i3(field)) + + new Literal(v, i5.jvmRepr) + } + + implicit val enc: TypedEncoder[A] = + RecordFieldEncoder.valueClass[A, G, H, V].encoder + + new TypedColumn[T, A]( + new Lit( + dataType = i5.catalystRepr, + nullable = i5.nullable, + toCatalyst = i5.toCatalyst(expr).genCode(_), + show = value.toString + ) + ) + } } diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index dc6b55cc1..ba7d522de 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -1,10 +1,10 @@ package frameless -import frameless.functions.lit +import frameless.functions.{ lit, litValue } import org.scalatest.matchers.should.Matchers -import org.scalacheck.{ Arbitrary, Gen, Prop }, Prop._ +import org.scalacheck.Prop, Prop._ import RecordEncoderTests.Name @@ -51,11 +51,6 @@ class LitTests extends TypedDatasetSuite with Matchers { check(prop[Food] _) - implicit def nameArb: Arbitrary[Name] = - Arbitrary(Gen.alphaStr.map(new Name(_))) - - check(prop[Name] _) - // doesn't work, object has to be serializable // check(prop[frameless.LocalDateTime] _) } @@ -70,7 +65,7 @@ class LitTests extends TypedDatasetSuite with Matchers { val lorem = new Name("Lorem") - ds.withColumnReplaced('name, lit(lorem)). + ds.withColumnReplaced('name, litValue(lorem)). collect.run() shouldBe initial.map(_.copy(name = lorem)) } diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 5e4a905f4..2a75b42e6 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -151,7 +151,10 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { safeDs.collect.run() shouldBe expected - // TODO: withColumnReplaced + val lorem = new Name("Lorem") + + safeDs.withColumnReplaced('name, functions.litValue(lorem)). + collect.run() shouldBe expected.map(_.copy(name = lorem)) } test("Case class with value class as optional field") {