Skip to content

Commit

Permalink
IsValueClass evidence
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Sep 6, 2021
1 parent 1f44ab9 commit 7b460a8
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 17 deletions.
17 changes: 17 additions & 0 deletions dataset/src/main/scala/frameless/IsValueClass.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package frameless

import shapeless._
import shapeless.labelled.FieldType

/** Evidence that `T` is a Value class */
@annotation.implicitNotFound(msg = "${T} is not a Value class")
final class IsValueClass[T] private() {}

object IsValueClass {
/** Provides an evidence `A` is a Value class */
implicit def apply[A <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil]](
implicit
i0: LabelledGeneric.Aux[A, G],
i1: DropUnitValues.Aux[G, H]): IsValueClass[A] = new IsValueClass[A]

}
4 changes: 2 additions & 2 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
* @tparam K the key type for the fields
* @tparam V the inner value type
*/
implicit def optionValueClass[F <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]]
implicit def optionValueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]]
(implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
Expand Down Expand Up @@ -229,7 +229,7 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
* @tparam H the single field of the value class (with guarantee it's not a `Unit` value)
* @tparam V the inner value type
*/
implicit def valueClass[F <: AnyVal, G <: ::[_, HNil], H <: ::[_, HNil], V]
implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_, HNil], V]
(implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ sealed class TypedColumn[T, U](expr: Expression)(

override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn

override def lit[U1: TypedEncoder](c: U1): TypedColumn[T,U1] = flit(c)
override def lit[U1: TypedEncoder](c: U1): TypedColumn[T, U1] = flit(c)
}

/** Expression used in `agg`-like constructions.
Expand Down
7 changes: 3 additions & 4 deletions dataset/src/main/scala/frameless/functions/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression}
import org.apache.spark.sql.types.DataType

case class Lit[T <: AnyVal] private[frameless] (
val dataType: DataType,
val nullable: Boolean,
private[frameless] case class Lit[T <: AnyVal](
dataType: DataType,
nullable: Boolean,
toCatalyst: CodegenContext => ExprCode,
show: () => String)
extends Expression
with NonSQLExpression {
override def toString: String = s"FramelessLit(${show()})"

@SuppressWarnings(Array("AsInstanceOf", "MethodReturningAny"))
def eval(input: InternalRow): Any = {
val ctx = new CodegenContext()
val eval = genCode(ctx)
Expand Down
10 changes: 5 additions & 5 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package object functions extends Udf with UnaryFunctions {
*
* apache/spark
*/
def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] =
new TypedAggregate[T,A](lit(value).expr)
def litAggr[A, T](value: A)(implicit i0: TypedEncoder[A], i1: Refute[IsValueClass[A]]): TypedAggregate[T, A] =
new TypedAggregate[T, A](lit(value).expr)

/** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make
* sure the injection instance is in scope.
Expand All @@ -41,7 +41,7 @@ package object functions extends Udf with UnaryFunctions {
val expr = new Literal(value, encoder.jvmRepr)

new TypedColumn[T, A](
new functions.Lit(
Lit(
dataType = encoder.catalystRepr,
nullable = encoder.nullable,
toCatalyst = encoder.toCatalyst(expr).genCode(_),
Expand All @@ -57,7 +57,7 @@ package object functions extends Udf with UnaryFunctions {
* @tparam A the value class
* @tparam T the row type
*/
def litValue[A <: AnyVal, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)(
def litValue[A : IsValueClass, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)(
implicit
i0: LabelledGeneric.Aux[A, G],
i1: DropUnitValues.Aux[G, H],
Expand All @@ -78,7 +78,7 @@ package object functions extends Udf with UnaryFunctions {
RecordFieldEncoder.valueClass[A, G, H, V].encoder

new TypedColumn[T, A](
new Lit(
Lit(
dataType = i5.catalystRepr,
nullable = i5.nullable,
toCatalyst = i5.toCatalyst(expr).genCode(_),
Expand Down
3 changes: 2 additions & 1 deletion dataset/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,5 @@ log4j.logger.org.spark-project.jetty=ERROR
log4j.logger.Remoting=ERROR

# To debug expressions:
#log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG
#log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG

2 changes: 1 addition & 1 deletion dataset/src/test/scala/frameless/BitwiseTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.scalacheck.Prop
import org.scalacheck.Prop._
import org.scalatest.matchers.should.Matchers

class BitwiseTests extends TypedDatasetSuite with Matchers{
class BitwiseTests extends TypedDatasetSuite with Matchers {

/**
* providing instances with implementations for bitwise operations since in the tests
Expand Down
31 changes: 31 additions & 0 deletions dataset/src/test/scala/frameless/IsValueClassTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package frameless

import shapeless.Refute
import shapeless.test.illTyped

import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

final class IsValueClassTests extends AnyFunSuite with Matchers {
test("Case class is not Value class") {
illTyped("IsValueClass[P]")
illTyped("IsValueClass[Q]")
}

test("Scala value type is not Value class (excluded)") {
illTyped("implicitly[IsValueClass[Double]]")
illTyped("implicitly[IsValueClass[Float]]")
illTyped("implicitly[IsValueClass[Long]]")
illTyped("implicitly[IsValueClass[Int]]")
illTyped("implicitly[IsValueClass[Char]]")
illTyped("implicitly[IsValueClass[Short]]")
illTyped("implicitly[IsValueClass[Byte]]")
illTyped("implicitly[IsValueClass[Unit]]")
illTyped("implicitly[IsValueClass[Boolean]]")
}

test("Value class evidence") {
implicitly[IsValueClass[RecordEncoderTests.Name]]
illTyped("implicitly[Refute[IsValueClass[RecordEncoderTests.Name]]]")
}
}
6 changes: 3 additions & 3 deletions dataset/src/test/scala/frameless/LitTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package frameless

import frameless.functions.{ lit, litValue }
import frameless.functions.lit

import org.scalatest.matchers.should.Matchers

Expand All @@ -9,7 +9,7 @@ import org.scalacheck.Prop, Prop._
import RecordEncoderTests.Name

class LitTests extends TypedDatasetSuite with Matchers {
def prop[A: TypedEncoder](value: A): Prop = {
def prop[A: TypedEncoder](value: A)(implicit i0: shapeless.Refute[IsValueClass[A]]): Prop = {
val df: TypedDataset[Int] = TypedDataset.create(1 :: Nil)

val l: TypedColumn[Int, A] = lit(value)
Expand Down Expand Up @@ -65,7 +65,7 @@ class LitTests extends TypedDatasetSuite with Matchers {

val lorem = new Name("Lorem")

ds.withColumnReplaced('name, litValue(lorem)).
ds.withColumnReplaced('name, functions.litValue(lorem)).
collect.run() shouldBe initial.map(_.copy(name = lorem))
}

Expand Down

0 comments on commit 7b460a8

Please sign in to comment.