From bcd6432f8ca6f57a60438c2f846aacf31ae9441b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 7 Aug 2021 22:51:53 +0200 Subject: [PATCH] Get nested field for struct column --- .../src/main/scala/frameless/RecordEncoder.scala | 2 +- dataset/src/main/scala/frameless/TypedColumn.scala | 13 +++++++++++++ dataset/src/test/scala/frameless/ColumnTests.scala | 13 +++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 0beaba5b..fa5bd0c6 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -69,7 +69,7 @@ object NewInstanceExprs { Literal.fromObject(()) +: tail.from(exprs) } - implicit def deriveNonUnit[K <: Symbol, V , T <: HList] + implicit def deriveNonUnit[K <: Symbol, V, T <: HList] (implicit notUnit: V =:!= Unit, tail: NewInstanceExprs[T] diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index c96550b6..c83aa550 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -30,6 +30,7 @@ sealed class TypedColumn[T, U](expr: Expression)( } override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn + override def lit[U1: TypedEncoder](c: U1): TypedColumn[T,U1] = flit(c) } @@ -828,6 +829,18 @@ abstract class AbstractTypedColumn[T, U] w1: With.Aux[TT2, W1, W2] ): ThisType[W2, Boolean] = typed(self.untyped.between(lowerBound.untyped, upperBound.untyped)) + + /** + * Returns a nested column matching the field `symbol`. + * + * @param V the type of the nested field + */ + def field[V](symbol: Witness.Lt[Symbol])(implicit + i0: TypedColumn.Exists[U, symbol.T, V], + i1: TypedEncoder[V] + ): ThisType[T, V] = + typed(self.untyped.getField(symbol.value.name)) + } diff --git a/dataset/src/test/scala/frameless/ColumnTests.scala b/dataset/src/test/scala/frameless/ColumnTests.scala index 58ff98da..ba64515f 100644 --- a/dataset/src/test/scala/frameless/ColumnTests.scala +++ b/dataset/src/test/scala/frameless/ColumnTests.scala @@ -425,4 +425,17 @@ class ColumnTests extends TypedDatasetSuite with Matchers { "ds.select(ds('_1).opt.map(x => x))" shouldNot typeCheck "ds.select(ds('_2).opt.map(x => x))" shouldNot typeCheck } + + test("field") { + val ds = TypedDataset.create((1, (2.3F, "a")) :: Nil) + val rs = ds.select(ds('_2).field('_2)).collect().run() + + rs shouldEqual Seq("a") + } + + test("field compiles only for valid field") { + val ds = TypedDataset.create((1, (2.3F, "a")) :: Nil) + + "ds.select(ds('_2).field('_3))" shouldNot typeCheck + } }