From 53a96bf5a5b5614ac7089ff5954fc9c2545c44a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 8 Aug 2021 16:15:21 +0200 Subject: [PATCH 1/9] Unit test about encoding Value class as dataframe row --- .../scala/frameless/RecordEncoderTests.scala | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 44b40418..bcd8752e 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -1,7 +1,9 @@ package frameless import org.apache.spark.sql.Row -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ + ObjectType, StringType, StructField, StructType +} import shapeless.{HList, LabelledGeneric} import shapeless.test.illTyped import org.scalatest.matchers.should.Matchers @@ -22,6 +24,8 @@ object RecordEncoderTests { case class C(b: B) } +class Name(val value: String) extends AnyVal + class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Unable to encode products made from units only") { illTyped("""TypedEncoder[UnitsOnly]""") @@ -76,4 +80,21 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { val ds = session.createDataset(rdd)(TypedExpressionEncoder[C]) ds.collect.head shouldBe obj } + + test("Scalar value class") { + val encoder = TypedEncoder[Name] + + encoder.jvmRepr shouldBe ObjectType(classOf[Name]) + + encoder.catalystRepr shouldBe StructType( + Seq(StructField("value", StringType, false))) + + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + TypedDataset + .createUnsafe[Name](Seq("Foo", "Bar").toDF)(encoder) + .collect().run() shouldBe Seq(new Name("Foo"), new Name("Bar")) + + } } From 3bf12ac0ed382e9acfd3c949f35258021aff2f94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 8 Aug 2021 16:29:48 +0200 Subject: [PATCH 2/9] Failing test for Value class as field --- .../scala/frameless/RecordEncoderTests.scala | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index bcd8752e..f35f791f 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -2,7 +2,7 @@ package frameless import org.apache.spark.sql.Row import org.apache.spark.sql.types.{ - ObjectType, StringType, StructField, StructType + IntegerType, ObjectType, StringType, StructField, StructType } import shapeless.{HList, LabelledGeneric} import shapeless.test.illTyped @@ -22,9 +22,11 @@ object RecordEncoderTests { case class A(x: Int) case class B(a: Seq[A]) case class C(b: B) -} -class Name(val value: String) extends AnyVal + class Name(val value: String) extends AnyVal + + case class Person(name: Name, age: Int) +} class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Unable to encode products made from units only") { @@ -82,6 +84,8 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { } test("Scalar value class") { + import RecordEncoderTests._ + val encoder = TypedEncoder[Name] encoder.jvmRepr shouldBe ObjectType(classOf[Name]) @@ -97,4 +101,16 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { .collect().run() shouldBe Seq(new Name("Foo"), new Name("Bar")) } + + test("Case class with value class field") { + import RecordEncoderTests._ + + val encoder = TypedEncoder[Person] + + encoder.jvmRepr shouldBe ObjectType(classOf[Person]) + + encoder.catalystRepr shouldBe StructType(Seq( + StructField("name", StringType, false), + StructField("age", IntegerType, false))) + } } From 34d631f9e7ea871e5653c276ed9f5df97fe0ec19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Mon, 9 Aug 2021 13:58:40 +0200 Subject: [PATCH 3/9] Failing test for Value class as optional field --- .../scala/frameless/RecordEncoderTests.scala | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index f35f791f..f788b80b 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -2,7 +2,7 @@ package frameless import org.apache.spark.sql.Row import org.apache.spark.sql.types.{ - IntegerType, ObjectType, StringType, StructField, StructType + IntegerType, LongType, ObjectType, StringType, StructField, StructType } import shapeless.{HList, LabelledGeneric} import shapeless.test.illTyped @@ -26,6 +26,8 @@ object RecordEncoderTests { class Name(val value: String) extends AnyVal case class Person(name: Name, age: Int) + + case class User(id: Long, name: Option[Name]) } class RecordEncoderTests extends TypedDatasetSuite with Matchers { @@ -113,4 +115,19 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { StructField("name", StringType, false), StructField("age", IntegerType, false))) } + + test("Case class with value class as optional field") { + import RecordEncoderTests._ + + // Encode as a Person field + val encoder = TypedEncoder[User] + + encoder.jvmRepr shouldBe ObjectType(classOf[User]) + + val expectedPersonStructType = StructType(Seq( + StructField("id", LongType, false), + StructField("name", StringType, true))) + + encoder.catalystRepr shouldBe expectedPersonStructType + } } From b40f1a54b890688aed6ef8b94460613980cd15dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 8 Aug 2021 21:10:04 +0200 Subject: [PATCH 4/9] Support Value class in record encoder --- .../main/scala/frameless/RecordEncoder.scala | 103 ++++++++++++++++-- .../main/scala/frameless/TypedColumn.scala | 4 +- .../frameless/TypedExpressionEncoder.scala | 21 ++-- dataset/src/test/resources/log4j.properties | 3 + .../scala/frameless/RecordEncoderTests.scala | 99 ++++++++++++++++- 5 files changed, 206 insertions(+), 24 deletions(-) diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index fa5bd0c6..2f293e29 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -2,11 +2,15 @@ package frameless import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{ + Invoke, NewInstance, UnwrapOption, WrapOption +} import org.apache.spark.sql.types._ + import shapeless._ import shapeless.labelled.FieldType import shapeless.ops.hlist.IsHCons +import shapeless.ops.record.Keys import scala.reflect.ClassTag @@ -25,24 +29,22 @@ object RecordEncoderFields { implicit def deriveRecordLast[K <: Symbol, H] (implicit key: Witness.Aux[K], - head: TypedEncoder[H] + head: RecordFieldEncoder[H] ): RecordEncoderFields[FieldType[K, H] :: HNil] = new RecordEncoderFields[FieldType[K, H] :: HNil] { - def value: List[RecordEncoderField] = RecordEncoderField(0, key.value.name, head) :: Nil + def value: List[RecordEncoderField] = fieldEncoder[K, H] :: Nil } implicit def deriveRecordCons[K <: Symbol, H, T <: HList] (implicit key: Witness.Aux[K], - head: TypedEncoder[H], + head: RecordFieldEncoder[H], tail: RecordEncoderFields[T] ): RecordEncoderFields[FieldType[K, H] :: T] = new RecordEncoderFields[FieldType[K, H] :: T] { - def value: List[RecordEncoderField] = { - val fieldName = key.value.name - val fieldEncoder = RecordEncoderField(0, fieldName, head) + def value: List[RecordEncoderField] = + fieldEncoder[K, H] :: tail.value.map(x => x.copy(ordinal = x.ordinal + 1)) + } - fieldEncoder :: tail.value.map(x => x.copy(ordinal = x.ordinal + 1)) - } - } + private def fieldEncoder[K <: Symbol, H](implicit key: Witness.Aux[K], e: RecordFieldEncoder[H]): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder) } /** @@ -156,6 +158,7 @@ class RecordEncoder[F, G <: HList, H <: HList] val createExpr = CreateNamedStruct(exprs) val nullExpr = Literal.create(null, createExpr.dataType) + If(IsNull(path), nullExpr, createExpr) } @@ -168,6 +171,86 @@ class RecordEncoder[F, G <: HList, H <: HList] val newExpr = NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) val nullExpr = Literal.create(null, jvmRepr) + If(IsNull(path), nullExpr, newExpr) } } + +final class RecordFieldEncoder[T]( + val encoder: TypedEncoder[T]) extends Serializable + +object RecordFieldEncoder extends RecordFieldEncoderLowPriority { + + /** + * @tparam F the value class + * @tparam G the single field of the value class + * @tparam H the single field of the value class (with guarantee it's not a `Unit` value) + * @tparam K the key type for the fields + * @tparam V the inner value type + */ + implicit def optionValueClass[F <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] + (implicit + i0: LabelledGeneric.Aux[F, G], + i1: DropUnitValues.Aux[G, H], + i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], + i3: Keys.Aux[H, KS], + i4: IsHCons.Aux[KS, K, HNil], + i5: TypedEncoder[V], + i6: ClassTag[F] + ): RecordFieldEncoder[Option[F]] = RecordFieldEncoder[Option[F]](new TypedEncoder[Option[F]] { + val nullable = true + + val jvmRepr = ObjectType(classOf[Option[F]]) + + @inline def catalystRepr: DataType = i5.catalystRepr + + val innerJvmRepr = ObjectType(i6.runtimeClass) + + def fromCatalyst(path: Expression): Expression = { + val javaValue = i5.fromCatalyst(path) + val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) + + WrapOption(value, innerJvmRepr) + } + + @inline def toCatalyst(path: Expression): Expression = { + val value = UnwrapOption(innerJvmRepr, path) + + val fieldName = i4.head(i3()).name + val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) + + i5.toCatalyst(javaValue) + } + }) + + /** + * @tparam F the value class + * @tparam G the single field of the value class + * @tparam H the single field of the value class (with guarantee it's not a `Unit` value) + * @tparam V the inner value type + */ + implicit def valueClass[F <: AnyVal, G <: ::[_, HNil], H <: ::[_, HNil], V] + (implicit + i0: LabelledGeneric.Aux[F, G], + i1: DropUnitValues.Aux[G, H], + i2: IsHCons.Aux[H, _ <: FieldType[_, V], HNil], + i3: TypedEncoder[V], + i4: ClassTag[F] + ): RecordFieldEncoder[F] = RecordFieldEncoder[F](new TypedEncoder[F] { + def nullable = i3.nullable + + def jvmRepr = i3.jvmRepr + + def catalystRepr: DataType = i3.catalystRepr + + def fromCatalyst(path: Expression): Expression = + i3.fromCatalyst(path) + + @inline def toCatalyst(path: Expression): Expression = + i3.toCatalyst(path) + }) +} + +private[frameless] sealed trait RecordFieldEncoderLowPriority { + implicit def apply[T](implicit e: TypedEncoder[T]): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e) +} diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index a518d8c6..9890d9fd 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -49,6 +49,7 @@ sealed class TypedAggregate[T, U](expr: Expression)( } override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] = c.typedAggregate + override def lit[U1: TypedEncoder](c: U1): TypedAggregate[T, U1] = litAggr(c) } @@ -835,7 +836,8 @@ abstract class AbstractTypedColumn[T, U] /** * Returns a nested column matching the field `symbol`. * - * @param V the type of the nested field + * @param symbol the field symbol + * @tparam V the type of the nested field */ def field[V](symbol: Witness.Lt[Symbol])(implicit i0: TypedColumn.Exists[U, symbol.T, V], diff --git a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala index c8fbf88d..5b78cd29 100644 --- a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala @@ -10,30 +10,33 @@ object TypedExpressionEncoder { /** In Spark, DataFrame has always schema of StructType * - * DataFrames of primitive types become records with a single field called "value" set in ExpressionEncoder. + * DataFrames of primitive types become records + * with a single field called "value" set in ExpressionEncoder. */ - def targetStructType[A](encoder: TypedEncoder[A]): StructType = { + def targetStructType[A](encoder: TypedEncoder[A]): StructType = encoder.catalystRepr match { case x: StructType => if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true))) else x + case dt => new StructType().add("value", dt, nullable = encoder.nullable) } - } - def apply[T: TypedEncoder]: Encoder[T] = { - val encoder = TypedEncoder[T] + def apply[T](implicit encoder: TypedEncoder[T]): Encoder[T] = { val in = BoundReference(0, encoder.jvmRepr, encoder.nullable) val (out, serializer) = encoder.toCatalyst(in) match { - case it @ If(_, _, _: CreateNamedStruct) => + case it @ If(_, _, _: CreateNamedStruct) => { val out = GetColumnByOrdinal(0, encoder.catalystRepr) - (out, it) - case other => + out -> it + } + + case other => { val out = GetColumnByOrdinal(0, encoder.catalystRepr) - (out, other) + out -> other + } } new ExpressionEncoder[T]( diff --git a/dataset/src/test/resources/log4j.properties b/dataset/src/test/resources/log4j.properties index 044f9440..727d0d89 100644 --- a/dataset/src/test/resources/log4j.properties +++ b/dataset/src/test/resources/log4j.properties @@ -144,3 +144,6 @@ log4j.logger.org.spark-project.jetty.util.thread.QueuedThreadPool=ERROR log4j.logger.org.spark-project.jetty.util.thread.Timeout=ERROR log4j.logger.org.spark-project.jetty=ERROR log4j.logger.Remoting=ERROR + +# To debug expressions: +#log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG \ No newline at end of file diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index f788b80b..384aeb3d 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -1,11 +1,13 @@ package frameless -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, functions => F} import org.apache.spark.sql.types.{ IntegerType, LongType, ObjectType, StringType, StructField, StructType } + import shapeless.{HList, LabelledGeneric} import shapeless.test.illTyped + import org.scalatest.matchers.should.Matchers case class UnitsOnly(a: Unit, b: Unit) @@ -23,7 +25,7 @@ object RecordEncoderTests { case class B(a: Seq[A]) case class C(b: B) - class Name(val value: String) extends AnyVal + class Name(val value: String) extends AnyVal with Serializable case class Person(name: Name, age: Int) @@ -32,7 +34,7 @@ object RecordEncoderTests { class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Unable to encode products made from units only") { - illTyped("""TypedEncoder[UnitsOnly]""") + illTyped("TypedEncoder[UnitsOnly]") } test("Dropping fields") { @@ -107,18 +109,64 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Case class with value class field") { import RecordEncoderTests._ + illTyped( + // As `Person` is not a Value class + "val _: RecordFieldEncoder[Person] = RecordFieldEncoder.valueClass") + + val fieldEncoder: RecordFieldEncoder[Name] = RecordFieldEncoder.valueClass + + fieldEncoder.encoder.catalystRepr shouldBe StringType + fieldEncoder.encoder.jvmRepr shouldBe ObjectType(classOf[String]) + + // Encode as a Person field val encoder = TypedEncoder[Person] encoder.jvmRepr shouldBe ObjectType(classOf[Person]) - encoder.catalystRepr shouldBe StructType(Seq( + val expectedPersonStructType = StructType(Seq( StructField("name", StringType, false), StructField("age", IntegerType, false))) + + encoder.catalystRepr shouldBe expectedPersonStructType + + val unsafeDs: TypedDataset[Person] = { + val rdd = sc.parallelize(Seq( + Row.fromTuple("Foo" -> 2), + Row.fromTuple("Bar" -> 3) + )) + val df = session.createDataFrame(rdd, expectedPersonStructType) + + TypedDataset.createUnsafe(df)(encoder) + } + + val expected = Seq( + Person(new Name("Foo"), 2), Person(new Name("Bar"), 3)) + + unsafeDs.collect.run() shouldBe expected + + // Safely created DS + val safeDs = TypedDataset.create(expected) + + safeDs.collect.run() shouldBe expected + + // TODO: withColumnReplaced } test("Case class with value class as optional field") { import RecordEncoderTests._ + illTyped( // As `Person` is not a Value class + """val _: RecordFieldEncoder[Option[Person]] = + RecordFieldEncoder.optionValueClass""") + + val fieldEncoder: RecordFieldEncoder[Option[Name]] = + RecordFieldEncoder.optionValueClass + + fieldEncoder.encoder.catalystRepr shouldBe StringType + + fieldEncoder.encoder. // !StringType + jvmRepr shouldBe ObjectType(classOf[Option[_]]) + // Encode as a Person field val encoder = TypedEncoder[User] @@ -129,5 +177,48 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { StructField("name", StringType, true))) encoder.catalystRepr shouldBe expectedPersonStructType + + val ds1: TypedDataset[User] = { + val rdd = sc.parallelize(Seq( + Row(1L, null), + Row(2L, "Foo") + )) + + val df = session.createDataFrame(rdd, expectedPersonStructType) + + TypedDataset.createUnsafe(df)(encoder) + } + + ds1.collect.run() shouldBe Seq( + User(1L, None), + User(2L, Some(new Name("Foo")))) + + val ds2: TypedDataset[User] = { + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + val df1 = Seq( + """{"id":3,"label":"unused"}""", + """{"id":4,"name":"Lorem"}""", + """{"id":5,"name":null}""" + ).toDF + + val df2 = df1.withColumn( + "jsonValue", + F.from_json(df1.col("value"), expectedPersonStructType)). + select("jsonValue.id", "jsonValue.name") + + TypedDataset.createUnsafe[User](df2) + } + + val expected = Seq( + User(3L, None), + User(4L, Some(new Name("Lorem"))), + User(5L, None)) + + ds2.collect.run() shouldBe expected + + // Safely created ds + TypedDataset.create(expected).collect.run() shouldBe expected } } From b74abbfba25946a0db8a8451707e897fbe3cc6db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 29 Aug 2021 17:49:32 +0200 Subject: [PATCH 5/9] Failing test in LitTest: ``` - support value class *** FAILED *** (101 milliseconds) [info] Array(Q(1,[0,1000000005,6d65726f4c]), Q(2,[0,1000000005,6d65726f4c])) was not equal to List(Q(1,Lorem), Q(2,Lorem)) (LitTests.scala:70) ``` --- .../src/test/scala/frameless/LitTests.scala | 33 ++++++++++++++++--- .../scala/frameless/RecordEncoderTests.scala | 4 ++- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index 4d9760c6..0c81f65f 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -2,10 +2,13 @@ package frameless import frameless.functions.lit -import org.scalacheck.Prop -import org.scalacheck.Prop._ +import org.scalatest.matchers.should.Matchers -class LitTests extends TypedDatasetSuite { +import org.scalacheck.{ Arbitrary, Gen, Prop }, Prop._ + +import RecordEncoderTests.Name + +class LitTests extends TypedDatasetSuite with Matchers { def prop[A: TypedEncoder](value: A): Prop = { val df: TypedDataset[Int] = TypedDataset.create(1 :: Nil) @@ -21,7 +24,6 @@ class LitTests extends TypedDatasetSuite { .run() .toVector - (localElems ?= Vector(value)) && (elems ?= Vector(value)) } @@ -45,10 +47,29 @@ class LitTests extends TypedDatasetSuite { check(prop[Food] _) + implicit def nameArb: Arbitrary[Name] = + Arbitrary(Gen.alphaStr.map(new Name(_))) + + check(prop[Name] _) + // doesn't work, object has to be serializable // check(prop[frameless.LocalDateTime] _) } + test("support value class") { + val initial = Seq( + Q(name = new Name("Foo"), id = 1), + Q(name = new Name("Bar"), id = 2)) + val ds = TypedDataset.create(initial) + + ds.collect.run() shouldBe initial + + val lorem = new Name("Lorem") + + ds.withColumnReplaced('name, lit(lorem)). + collect.run() shouldBe initial.map(_.copy(name = lorem)) + } + test("#205: comparing literals encoded using Injection") { import org.apache.spark.sql.catalyst.util.DateTimeUtils implicit val dateAsInt: Injection[java.sql.Date, Int] = @@ -58,8 +79,10 @@ class LitTests extends TypedDatasetSuite { val data = Vector(P(42, today)) val tds = TypedDataset.create(data) - tds.filter(tds('d) === today).collect().run() + tds.filter(tds('d) === today).collect.run().map(_.i) shouldBe Seq(42) } } final case class P(i: Int, d: java.sql.Date) + +final case class Q(id: Int, name: Name) diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 384aeb3d..5e4a905f 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -25,7 +25,9 @@ object RecordEncoderTests { case class B(a: Seq[A]) case class C(b: B) - class Name(val value: String) extends AnyVal with Serializable + class Name(val value: String) extends AnyVal with Serializable { + override def toString = value + } case class Person(name: Name, age: Int) From 0ab91eec504f87234718150d8de3e9fd74d405bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 29 Aug 2021 16:14:50 +0200 Subject: [PATCH 6/9] Prepare literal refactor, making sure first there is no regression --- .../frameless/functions/FramelessLit.scala | 58 +++++++++++++++++++ .../main/scala/frameless/functions/Lit.scala | 32 +++++----- .../scala/frameless/functions/package.scala | 25 ++++++-- .../src/test/scala/frameless/LitTests.scala | 10 +++- 4 files changed, 103 insertions(+), 22 deletions(-) create mode 100644 dataset/src/main/scala/frameless/functions/FramelessLit.scala diff --git a/dataset/src/main/scala/frameless/functions/FramelessLit.scala b/dataset/src/main/scala/frameless/functions/FramelessLit.scala new file mode 100644 index 00000000..5cd79891 --- /dev/null +++ b/dataset/src/main/scala/frameless/functions/FramelessLit.scala @@ -0,0 +1,58 @@ +package frameless.functions + +import frameless.TypedEncoder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NonSQLExpression} +import org.apache.spark.sql.types.DataType + +@deprecated("Use `Lit[A]`", "0.10.2") +case class FramelessLit[A](obj: A, encoder: TypedEncoder[A]) extends Expression with NonSQLExpression { + override def nullable: Boolean = encoder.nullable + override def toString: String = s"FramelessLit($obj)" + + def eval(input: InternalRow): Any = { + val ctx = new CodegenContext() + val eval = genCode(ctx) + + val codeBody = s""" + public scala.Function1 generate(Object[] references) { + return new FramelessLitEvalImpl(references); + } + + class FramelessLitEvalImpl extends scala.runtime.AbstractFunction1 { + private final Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} + + public FramelessLitEvalImpl(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } + + public java.lang.Object apply(java.lang.Object z) { + InternalRow ${ctx.INPUT_ROW} = (InternalRow) z; + ${eval.code} + return ${eval.isNull} ? ((Object)null) : ((Object)${eval.value}); + } + } + """ + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + + val (clazz, _) = CodeGenerator.compile(code) + val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] + + codegen(input) + } + + def dataType: DataType = encoder.catalystRepr + def children: Seq[Expression] = Nil + + override def genCode(ctx: CodegenContext): ExprCode = { + encoder.toCatalyst(new Literal(obj, encoder.jvmRepr)).genCode(ctx) + } + + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ??? +} diff --git a/dataset/src/main/scala/frameless/functions/Lit.scala b/dataset/src/main/scala/frameless/functions/Lit.scala index 75393ea2..e2ee5a55 100644 --- a/dataset/src/main/scala/frameless/functions/Lit.scala +++ b/dataset/src/main/scala/frameless/functions/Lit.scala @@ -1,30 +1,35 @@ package frameless.functions -import frameless.TypedEncoder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression} import org.apache.spark.sql.types.DataType -case class FramelessLit[A](obj: A, encoder: TypedEncoder[A]) extends Expression with NonSQLExpression { - override def nullable: Boolean = encoder.nullable - override def toString: String = s"FramelessLit($obj)" +case class Lit[T <: AnyVal] private[frameless] ( + val dataType: DataType, + val nullable: Boolean, + toCatalyst: CodegenContext => ExprCode, + show: () => String) + extends Expression + with NonSQLExpression { + override def toString: String = s"FramelessLit(${show()})" + @SuppressWarnings(Array("AsInstanceOf", "MethodReturningAny")) def eval(input: InternalRow): Any = { val ctx = new CodegenContext() val eval = genCode(ctx) val codeBody = s""" public scala.Function1 generate(Object[] references) { - return new FramelessLitEvalImpl(references); + return new LiteralEvalImpl(references); } - class FramelessLitEvalImpl extends scala.runtime.AbstractFunction1 { + class LiteralEvalImpl extends scala.runtime.AbstractFunction1 { private final Object[] references; ${ctx.declareMutableStates()} ${ctx.declareAddedFunctions()} - public FramelessLitEvalImpl(Object[] references) { + public LiteralEvalImpl(Object[] references) { this.references = references; ${ctx.initMutableStates()} } @@ -38,20 +43,19 @@ case class FramelessLit[A](obj: A, encoder: TypedEncoder[A]) extends Expression """ val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) + ) val (clazz, _) = CodeGenerator.compile(code) - val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] + val codegen = + clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] codegen(input) } - def dataType: DataType = encoder.catalystRepr def children: Seq[Expression] = Nil - override def genCode(ctx: CodegenContext): ExprCode = { - encoder.toCatalyst(new Literal(obj, encoder.jvmRepr)).genCode(ctx) - } + override def genCode(ctx: CodegenContext): ExprCode = toCatalyst(ctx) protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ??? } diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index f1e72a0e..b60c5008 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -3,7 +3,9 @@ package frameless import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Literal -package object functions extends Udf with UnaryFunctions { +package object functions + extends Udf with UnaryFunctions with LowPriorityFunctions { + object aggregate extends AggregateFunctions object nonAggregate extends NonAggregateFunctions @@ -14,22 +16,35 @@ package object functions extends Udf with UnaryFunctions { */ def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] = new TypedAggregate[T,A](lit(value).expr) +} +private[frameless] sealed trait LowPriorityFunctions { /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make * sure the injection instance is in scope. * * apache/spark + * + * @tparam A the literal value type + * @tparam T the row type */ - def lit[A: TypedEncoder, T](value: A): TypedColumn[T, A] = { - val encoder = TypedEncoder[A] + def lit[A, T](value: A)( + implicit encoder: TypedEncoder[A]): TypedColumn[T, A] = { if (ScalaReflection.isNativeType(encoder.jvmRepr) && encoder.catalystRepr == encoder.jvmRepr) { val expr = Literal(value, encoder.catalystRepr) new TypedColumn(expr) } else { - val expr = FramelessLit(value, encoder) - new TypedColumn(expr) + val expr = new Literal(value, encoder.jvmRepr) + + new TypedColumn[T, A]( + new functions.Lit( + dataType = encoder.catalystRepr, + nullable = encoder.nullable, + toCatalyst = encoder.toCatalyst(expr).genCode(_), + show = value.toString + ) + ) } } } diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index 0c81f65f..dc6b55cc 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -12,19 +12,23 @@ class LitTests extends TypedDatasetSuite with Matchers { def prop[A: TypedEncoder](value: A): Prop = { val df: TypedDataset[Int] = TypedDataset.create(1 :: Nil) + val l: TypedColumn[Int, A] = lit(value) + // filter forces whole codegen - val elems = df.deserialized.filter((_:Int) => true).select(lit(value)) + val elems = df.deserialized.filter((_:Int) => true).select(l) .collect() .run() .toVector // otherwise it uses local relation - val localElems = df.select(lit(value)) + val localElems = df.select(l) .collect() .run() .toVector - (localElems ?= Vector(value)) && (elems ?= Vector(value)) + val expected = Vector(value) + + (localElems ?= expected) && (elems ?= expected) } test("select(lit(...))") { From b43b0d8f974bcae3fb934a711ea9f15c2b928fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 29 Aug 2021 16:41:52 +0200 Subject: [PATCH 7/9] Introduce functions.litValue --- .../scala/frameless/functions/package.scala | 49 +++++++++++++++++-- .../src/test/scala/frameless/LitTests.scala | 11 ++--- .../scala/frameless/RecordEncoderTests.scala | 5 +- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index b60c5008..acb96d48 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -1,10 +1,16 @@ package frameless +import scala.reflect.ClassTag + +import shapeless._ +import shapeless.labelled.FieldType +import shapeless.ops.hlist.IsHCons +import shapeless.ops.record.Values + import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Literal -package object functions - extends Udf with UnaryFunctions with LowPriorityFunctions { +package object functions extends Udf with UnaryFunctions { object aggregate extends AggregateFunctions object nonAggregate extends NonAggregateFunctions @@ -16,9 +22,6 @@ package object functions */ def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] = new TypedAggregate[T,A](lit(value).expr) -} - -private[frameless] sealed trait LowPriorityFunctions { /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make * sure the injection instance is in scope. @@ -47,4 +50,40 @@ private[frameless] sealed trait LowPriorityFunctions { ) } } + + /** Creates a [[frameless.TypedColumn]] of literal value + * for a Value class `A`. + * + * @tparam A the value class + * @tparam T the row type + */ + def litValue[A <: AnyVal, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)( + implicit + i0: LabelledGeneric.Aux[A, G], + i1: DropUnitValues.Aux[G, H], + i2: IsHCons.Aux[H, _ <: FieldType[_, V], HNil], + i3: Values.Aux[H, VS], + i4: IsHCons.Aux[VS, V, HNil], + i5: TypedEncoder[V], + i6: ClassTag[A] + ): TypedColumn[T, A] = { + val expr = { + val field: H = i1(i0.to(value)) + val v: V = i4.head(i3(field)) + + new Literal(v, i5.jvmRepr) + } + + implicit val enc: TypedEncoder[A] = + RecordFieldEncoder.valueClass[A, G, H, V].encoder + + new TypedColumn[T, A]( + new Lit( + dataType = i5.catalystRepr, + nullable = i5.nullable, + toCatalyst = i5.toCatalyst(expr).genCode(_), + show = value.toString + ) + ) + } } diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index dc6b55cc..ba7d522d 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -1,10 +1,10 @@ package frameless -import frameless.functions.lit +import frameless.functions.{ lit, litValue } import org.scalatest.matchers.should.Matchers -import org.scalacheck.{ Arbitrary, Gen, Prop }, Prop._ +import org.scalacheck.Prop, Prop._ import RecordEncoderTests.Name @@ -51,11 +51,6 @@ class LitTests extends TypedDatasetSuite with Matchers { check(prop[Food] _) - implicit def nameArb: Arbitrary[Name] = - Arbitrary(Gen.alphaStr.map(new Name(_))) - - check(prop[Name] _) - // doesn't work, object has to be serializable // check(prop[frameless.LocalDateTime] _) } @@ -70,7 +65,7 @@ class LitTests extends TypedDatasetSuite with Matchers { val lorem = new Name("Lorem") - ds.withColumnReplaced('name, lit(lorem)). + ds.withColumnReplaced('name, litValue(lorem)). collect.run() shouldBe initial.map(_.copy(name = lorem)) } diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 5e4a905f..2a75b42e 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -151,7 +151,10 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { safeDs.collect.run() shouldBe expected - // TODO: withColumnReplaced + val lorem = new Name("Lorem") + + safeDs.withColumnReplaced('name, functions.litValue(lorem)). + collect.run() shouldBe expected.map(_.copy(name = lorem)) } test("Case class with value class as optional field") { From 4ae661a736ee3fa39af1beace0c090c91ec4af09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sun, 29 Aug 2021 22:33:01 +0200 Subject: [PATCH 8/9] IsValueClass evidence --- .../main/scala/frameless/IsValueClass.scala | 17 ++++++++++ .../main/scala/frameless/RecordEncoder.scala | 4 +-- .../main/scala/frameless/TypedColumn.scala | 2 +- .../main/scala/frameless/functions/Lit.scala | 7 ++--- .../scala/frameless/functions/package.scala | 10 +++--- dataset/src/test/resources/log4j.properties | 3 +- .../test/scala/frameless/BitwiseTests.scala | 2 +- .../scala/frameless/IsValueClassTests.scala | 31 +++++++++++++++++++ .../src/test/scala/frameless/LitTests.scala | 6 ++-- 9 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 dataset/src/main/scala/frameless/IsValueClass.scala create mode 100644 dataset/src/test/scala/frameless/IsValueClassTests.scala diff --git a/dataset/src/main/scala/frameless/IsValueClass.scala b/dataset/src/main/scala/frameless/IsValueClass.scala new file mode 100644 index 00000000..78605c13 --- /dev/null +++ b/dataset/src/main/scala/frameless/IsValueClass.scala @@ -0,0 +1,17 @@ +package frameless + +import shapeless._ +import shapeless.labelled.FieldType + +/** Evidence that `T` is a Value class */ +@annotation.implicitNotFound(msg = "${T} is not a Value class") +final class IsValueClass[T] private() {} + +object IsValueClass { + /** Provides an evidence `A` is a Value class */ + implicit def apply[A <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil]]( + implicit + i0: LabelledGeneric.Aux[A, G], + i1: DropUnitValues.Aux[G, H]): IsValueClass[A] = new IsValueClass[A] + +} diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 2f293e29..b51a51eb 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -188,7 +188,7 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam K the key type for the fields * @tparam V the inner value type */ - implicit def optionValueClass[F <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] + implicit def optionValueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] (implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], @@ -229,7 +229,7 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam H the single field of the value class (with guarantee it's not a `Unit` value) * @tparam V the inner value type */ - implicit def valueClass[F <: AnyVal, G <: ::[_, HNil], H <: ::[_, HNil], V] + implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_, HNil], V] (implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 9890d9fd..26d74dba 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -33,7 +33,7 @@ sealed class TypedColumn[T, U](expr: Expression)( override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn - override def lit[U1: TypedEncoder](c: U1): TypedColumn[T,U1] = flit(c) + override def lit[U1: TypedEncoder](c: U1): TypedColumn[T, U1] = flit(c) } /** Expression used in `agg`-like constructions. diff --git a/dataset/src/main/scala/frameless/functions/Lit.scala b/dataset/src/main/scala/frameless/functions/Lit.scala index e2ee5a55..fcffd30b 100644 --- a/dataset/src/main/scala/frameless/functions/Lit.scala +++ b/dataset/src/main/scala/frameless/functions/Lit.scala @@ -5,16 +5,15 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression} import org.apache.spark.sql.types.DataType -case class Lit[T <: AnyVal] private[frameless] ( - val dataType: DataType, - val nullable: Boolean, +private[frameless] case class Lit[T <: AnyVal]( + dataType: DataType, + nullable: Boolean, toCatalyst: CodegenContext => ExprCode, show: () => String) extends Expression with NonSQLExpression { override def toString: String = s"FramelessLit(${show()})" - @SuppressWarnings(Array("AsInstanceOf", "MethodReturningAny")) def eval(input: InternalRow): Any = { val ctx = new CodegenContext() val eval = genCode(ctx) diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index acb96d48..8ffd5665 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -20,8 +20,8 @@ package object functions extends Udf with UnaryFunctions { * * apache/spark */ - def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] = - new TypedAggregate[T,A](lit(value).expr) + def litAggr[A, T](value: A)(implicit i0: TypedEncoder[A], i1: Refute[IsValueClass[A]]): TypedAggregate[T, A] = + new TypedAggregate[T, A](lit(value).expr) /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make * sure the injection instance is in scope. @@ -41,7 +41,7 @@ package object functions extends Udf with UnaryFunctions { val expr = new Literal(value, encoder.jvmRepr) new TypedColumn[T, A]( - new functions.Lit( + Lit( dataType = encoder.catalystRepr, nullable = encoder.nullable, toCatalyst = encoder.toCatalyst(expr).genCode(_), @@ -57,7 +57,7 @@ package object functions extends Udf with UnaryFunctions { * @tparam A the value class * @tparam T the row type */ - def litValue[A <: AnyVal, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)( + def litValue[A : IsValueClass, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)( implicit i0: LabelledGeneric.Aux[A, G], i1: DropUnitValues.Aux[G, H], @@ -78,7 +78,7 @@ package object functions extends Udf with UnaryFunctions { RecordFieldEncoder.valueClass[A, G, H, V].encoder new TypedColumn[T, A]( - new Lit( + Lit( dataType = i5.catalystRepr, nullable = i5.nullable, toCatalyst = i5.toCatalyst(expr).genCode(_), diff --git a/dataset/src/test/resources/log4j.properties b/dataset/src/test/resources/log4j.properties index 727d0d89..d3d35c98 100644 --- a/dataset/src/test/resources/log4j.properties +++ b/dataset/src/test/resources/log4j.properties @@ -146,4 +146,5 @@ log4j.logger.org.spark-project.jetty=ERROR log4j.logger.Remoting=ERROR # To debug expressions: -#log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG \ No newline at end of file +#log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG + diff --git a/dataset/src/test/scala/frameless/BitwiseTests.scala b/dataset/src/test/scala/frameless/BitwiseTests.scala index b83662f6..f58c906a 100644 --- a/dataset/src/test/scala/frameless/BitwiseTests.scala +++ b/dataset/src/test/scala/frameless/BitwiseTests.scala @@ -4,7 +4,7 @@ import org.scalacheck.Prop import org.scalacheck.Prop._ import org.scalatest.matchers.should.Matchers -class BitwiseTests extends TypedDatasetSuite with Matchers{ +class BitwiseTests extends TypedDatasetSuite with Matchers { /** * providing instances with implementations for bitwise operations since in the tests diff --git a/dataset/src/test/scala/frameless/IsValueClassTests.scala b/dataset/src/test/scala/frameless/IsValueClassTests.scala new file mode 100644 index 00000000..379da451 --- /dev/null +++ b/dataset/src/test/scala/frameless/IsValueClassTests.scala @@ -0,0 +1,31 @@ +package frameless + +import shapeless.Refute +import shapeless.test.illTyped + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +final class IsValueClassTests extends AnyFunSuite with Matchers { + test("Case class is not Value class") { + illTyped("IsValueClass[P]") + illTyped("IsValueClass[Q]") + } + + test("Scala value type is not Value class (excluded)") { + illTyped("implicitly[IsValueClass[Double]]") + illTyped("implicitly[IsValueClass[Float]]") + illTyped("implicitly[IsValueClass[Long]]") + illTyped("implicitly[IsValueClass[Int]]") + illTyped("implicitly[IsValueClass[Char]]") + illTyped("implicitly[IsValueClass[Short]]") + illTyped("implicitly[IsValueClass[Byte]]") + illTyped("implicitly[IsValueClass[Unit]]") + illTyped("implicitly[IsValueClass[Boolean]]") + } + + test("Value class evidence") { + implicitly[IsValueClass[RecordEncoderTests.Name]] + illTyped("implicitly[Refute[IsValueClass[RecordEncoderTests.Name]]]") + } +} diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index ba7d522d..4bbd0782 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -1,6 +1,6 @@ package frameless -import frameless.functions.{ lit, litValue } +import frameless.functions.lit import org.scalatest.matchers.should.Matchers @@ -9,7 +9,7 @@ import org.scalacheck.Prop, Prop._ import RecordEncoderTests.Name class LitTests extends TypedDatasetSuite with Matchers { - def prop[A: TypedEncoder](value: A): Prop = { + def prop[A: TypedEncoder](value: A)(implicit i0: shapeless.Refute[IsValueClass[A]]): Prop = { val df: TypedDataset[Int] = TypedDataset.create(1 :: Nil) val l: TypedColumn[Int, A] = lit(value) @@ -65,7 +65,7 @@ class LitTests extends TypedDatasetSuite with Matchers { val lorem = new Name("Lorem") - ds.withColumnReplaced('name, litValue(lorem)). + ds.withColumnReplaced('name, functions.litValue(lorem)). collect.run() shouldBe initial.map(_.copy(name = lorem)) } From ab190189bcdb05edb445780226cb37288365e4fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Mon, 6 Sep 2021 15:24:58 +0200 Subject: [PATCH 9/9] Prepare 0.11 --- README.md | 21 +++---- build.sbt | 32 +++++++--- .../frameless/functions/FramelessLit.scala | 58 ------------------- 3 files changed, 36 insertions(+), 75 deletions(-) delete mode 100644 dataset/src/main/scala/frameless/functions/FramelessLit.scala diff --git a/README.md b/README.md index 8b0df089..294bc674 100644 --- a/README.md +++ b/README.md @@ -26,16 +26,17 @@ associated channels (e.g. GitHub, Discord) to be a safe and friendly environment The compatible versions of [Spark](http://spark.apache.org/) and [cats](https://github.com/typelevel/cats) are as follows: -| Frameless | Spark | Cats | Cats-Effect | Scala | -| --- | --- | --- | --- | --- | -| 0.4.0 | 2.2.0 | 1.0.0-IF | 0.4 | 2.11 -| 0.4.1 | 2.2.0 | 1.x | 0.8 | 2.11 -| 0.5.2 | 2.2.1 | 1.x | 0.8 | 2.11 -| 0.6.1 | 2.3.0 | 1.x | 0.8 | 2.11 -| 0.7.0 | 2.3.1 | 1.x | 1.x | 2.11 -| 0.8.0 | 2.4.0 | 1.x | 1.x | 2.11/2.12 -| 0.9.0 | 3.0.0 | 1.x | 1.x | 2.12 -| 0.10.1 | 3.1.0 | 2.x | 2.x | 2.12 +| Frameless | Spark | Cats | Cats-Effect | Scala +| --------- | ----- | -------- | ----------- | --- +| 0.4.0 | 2.2.0 | 1.0.0-IF | 0.4 | 2.11 +| 0.4.1 | 2.2.0 | 1.x | 0.8 | 2.11 +| 0.5.2 | 2.2.1 | 1.x | 0.8 | 2.11 +| 0.6.1 | 2.3.0 | 1.x | 0.8 | 2.11 +| 0.7.0 | 2.3.1 | 1.x | 1.x | 2.11 +| 0.8.0 | 2.4.0 | 1.x | 1.x | 2.11/2.12 +| 0.9.0 | 3.0.0 | 1.x | 1.x | 2.12 +| 0.10.1 | 3.1.0 | 2.x | 2.x | 2.12 +| 0.11.0 | 3.1.0 | 2.x | 2.x | 2.12 Versions 0.5.x and 0.6.x have identical features. The first is compatible with Spark 2.2.1 and the second with 2.3.0. diff --git a/build.sbt b/build.sbt index c37cf00e..dcf87cb7 100644 --- a/build.sbt +++ b/build.sbt @@ -50,13 +50,31 @@ lazy val cats = project lazy val dataset = project .settings(name := "frameless-dataset") - .settings(framelessSettings: _*) - .settings(framelessTypedDatasetREPL: _*) - .settings(publishSettings: _*) - .settings(libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-core" % sparkVersion % Provided, - "org.apache.spark" %% "spark-sql" % sparkVersion % Provided, - "net.ceedubs" %% "irrec-regex-gen" % irrecVersion % Test + .settings(framelessSettings) + .settings(framelessTypedDatasetREPL) + .settings(publishSettings) + .settings(Seq( + libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-core" % sparkVersion % Provided, + "org.apache.spark" %% "spark-sql" % sparkVersion % Provided, + "net.ceedubs" %% "irrec-regex-gen" % irrecVersion % Test + ), + mimaBinaryIssueFilters ++= { + import com.typesafe.tools.mima.core._ + + val imt = ProblemFilters.exclude[IncompatibleMethTypeProblem](_) + val mc = ProblemFilters.exclude[MissingClassProblem](_) + val dmm = ProblemFilters.exclude[DirectMissingMethodProblem](_) + + // TODO: Remove have version bump + Seq( + imt("frameless.RecordEncoderFields.deriveRecordCons"), + imt("frameless.RecordEncoderFields.deriveRecordLast"), + mc("frameless.functions.FramelessLit"), + mc(f"frameless.functions.FramelessLit$$"), + dmm("frameless.functions.package.litAggr") + ) + } )) .dependsOn(core % "test->test;compile->compile") diff --git a/dataset/src/main/scala/frameless/functions/FramelessLit.scala b/dataset/src/main/scala/frameless/functions/FramelessLit.scala deleted file mode 100644 index 5cd79891..00000000 --- a/dataset/src/main/scala/frameless/functions/FramelessLit.scala +++ /dev/null @@ -1,58 +0,0 @@ -package frameless.functions - -import frameless.TypedEncoder -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NonSQLExpression} -import org.apache.spark.sql.types.DataType - -@deprecated("Use `Lit[A]`", "0.10.2") -case class FramelessLit[A](obj: A, encoder: TypedEncoder[A]) extends Expression with NonSQLExpression { - override def nullable: Boolean = encoder.nullable - override def toString: String = s"FramelessLit($obj)" - - def eval(input: InternalRow): Any = { - val ctx = new CodegenContext() - val eval = genCode(ctx) - - val codeBody = s""" - public scala.Function1 generate(Object[] references) { - return new FramelessLitEvalImpl(references); - } - - class FramelessLitEvalImpl extends scala.runtime.AbstractFunction1 { - private final Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} - - public FramelessLitEvalImpl(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public java.lang.Object apply(java.lang.Object z) { - InternalRow ${ctx.INPUT_ROW} = (InternalRow) z; - ${eval.code} - return ${eval.isNull} ? ((Object)null) : ((Object)${eval.value}); - } - } - """ - - val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) - - val (clazz, _) = CodeGenerator.compile(code) - val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] - - codegen(input) - } - - def dataType: DataType = encoder.catalystRepr - def children: Seq[Expression] = Nil - - override def genCode(ctx: CodegenContext): ExprCode = { - encoder.toCatalyst(new Literal(obj, encoder.jvmRepr)).genCode(ctx) - } - - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ??? -}