Skip to content

Commit

Permalink
Cleanup PR
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Dec 12, 2021
1 parent 85158d6 commit d5a966b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 42 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scala/frameless/CatalystOrdered.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frameless
import scala.annotation.implicitNotFound
import shapeless.{Generic, HList, Lazy}
import shapeless.ops.hlist.LiftAll
import java.time.{Duration, Instant, Period}

/** Types that can be ordered/compared by Catalyst. */
@implicitNotFound("Cannot compare columns of type ${A}.")
Expand All @@ -23,6 +24,9 @@ object CatalystOrdered {
implicit val framelessSQLDateOrdered : CatalystOrdered[SQLDate] = of[SQLDate]
implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = of[SQLTimestamp]
implicit val framelessStringOrdered : CatalystOrdered[String] = of[String]
implicit val framelessInstantOrdered : CatalystOrdered[Instant] = of[Instant]
implicit val framelessDurationOrdered : CatalystOrdered[Duration] = of[Duration]
implicit val framelessPeriodOrdered : CatalystOrdered[Period] = of[Period]

implicit def injectionOrdered[A, B]
(implicit
Expand Down
51 changes: 14 additions & 37 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ object TypedEncoder {
)
}

/** java.time Encoders, Spark uses https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala for encoding / decoding. */
implicit val timeInstant: TypedEncoder[Instant] = new TypedEncoder[Instant] {
def nullable: Boolean = false

Expand All @@ -255,43 +256,19 @@ object TypedEncoder {
)
}

implicit val timeDuration: TypedEncoder[Duration] = new TypedEncoder[Duration] {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[Duration]
def catalystRepr: DataType = LongType

def toCatalyst(path: Expression): Expression =
Invoke(path, "toMillis", LongType)

def fromCatalyst(path: Expression): Expression =
StaticInvoke(
staticObject = classOf[Duration],
dataType = jvmRepr,
functionName = "ofMillis",
arguments = path :: Nil,
propagateNull = true
)
}

implicit val timePeriod: TypedEncoder[Period] = new TypedEncoder[Period] {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[Period]
def catalystRepr: DataType = IntegerType

def toCatalyst(path: Expression): Expression =
Invoke(path, "getDays", IntegerType)

def fromCatalyst(path: Expression): Expression =
StaticInvoke(
staticObject = classOf[Period],
dataType = jvmRepr,
functionName = "ofDays",
arguments = path :: Nil,
propagateNull = true
)
}
/**
* DayTimeIntervalType and YearMonthIntervalType in Spark 3.2.0.
* We maintain Spark 3.x cross compilation and handle Duration and Period as an injections to be compatible with Spark versions < 3.2
* See
* * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1031-L1047
* * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1075-L1087
*/
// DayTimeIntervalType
implicit val timeDurationInjection: Injection[Duration, Long] = Injection(_.toMillis, Duration.ofMillis)
// YearMonthIntervalType
implicit val timePeriodInjection: Injection[Period, Int] = Injection(_.getDays, Period.ofDays)
implicit val timePeriodEncoder: TypedEncoder[Period] = TypedEncoder.usingInjection
implicit val timeDurationEncoder: TypedEncoder[Duration] = TypedEncoder.usingInjection

implicit def arrayEncoder[T: ClassTag](
implicit i0: Lazy[RecordFieldEncoder[T]]): TypedEncoder[Array[T]] =
Expand Down
6 changes: 1 addition & 5 deletions dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package frameless

import java.time.Instant

import org.scalacheck.Prop._
import org.scalacheck.{Arbitrary, Gen, Prop}, Arbitrary.arbitrary
import org.scalatest.matchers.should.Matchers
Expand All @@ -14,10 +13,7 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
private implicit object OrderingImplicits {
implicit val sqlDateOrdering: Ordering[SQLDate] = Ordering.by(_.days)
implicit val sqlTimestmapOrdering: Ordering[SQLTimestamp] = Ordering.by(_.us)
implicit val arbInstant: Arbitrary[Instant] =
Arbitrary(Gen.choose[Instant](Instant.EPOCH, Instant.now()))
implicit val instantAsLongInjection: Injection[Instant, Long] =
Injection(_.getEpochSecond, Instant.ofEpochSecond)
implicit val arbInstant: Arbitrary[Instant] = Arbitrary(Gen.choose[Instant](Instant.EPOCH, Instant.now))
}

test("select('a < 'b, 'a <= 'b, 'a > 'b, 'a >= 'b)") {
Expand Down

0 comments on commit d5a966b

Please sign in to comment.