From 1bec59b8c563358fe94806066ddf485be439dc94 Mon Sep 17 00:00:00 2001 From: Grigory Pomadchin Date: Fri, 24 Feb 2023 20:03:23 -0500 Subject: [PATCH] Fix unbase64 failing tests --- .../functions/NonAggregateFunctionsTests.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index 9a9f892f..470d58e5 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -2,6 +2,8 @@ package frameless package functions import java.io.File +import java.util.Base64 +import java.nio.charset.StandardCharsets import frameless.functions.nonAggregate._ import org.apache.commons.io.FileUtils @@ -14,7 +16,6 @@ import scala.annotation.nowarn class NonAggregateFunctionsTests extends TypedDatasetSuite { val testTempFiles = "target/testoutput" - object NonNegativeGenerators { val doubleGen = for { s <- Gen.chooseNum(1, Int.MaxValue) @@ -37,6 +38,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen) } + private val base64Encoder = Base64.getEncoder + private def base64X1String(x1: X1[String]): X1[String] = { + def base64(str: String): String = base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8)) + x1.copy(a = base64(x1.a)) + } + override def afterAll(): Unit = { FileUtils.deleteDirectory(new File(testTempFiles)) super.afterAll() @@ -1610,13 +1617,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ def prop(values: List[X1[String]])(implicit encX1: Encoder[X1[String]]) = { - val cDS = session.createDataset(values) + val valuesBase64 = values.map(base64X1String) + val cDS = session.createDataset(valuesBase64) val resCompare = cDS .select(sparkFunctions.unbase64(cDS("a"))) .map(_.getAs[Array[Byte]](0)) .collect().toList - val typedDS = TypedDataset.create(values) + val typedDS = TypedDataset.create(valuesBase64) val res = typedDS .select(unbase64(typedDS('a))) .collect()