Skip to content

Commit

Permalink
Fix unbase64 failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Feb 25, 2023
1 parent 207a118 commit 1bec59b
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package frameless
package functions

import java.io.File
import java.util.Base64
import java.nio.charset.StandardCharsets

import frameless.functions.nonAggregate._
import org.apache.commons.io.FileUtils
Expand All @@ -14,7 +16,6 @@ import scala.annotation.nowarn
class NonAggregateFunctionsTests extends TypedDatasetSuite {
val testTempFiles = "target/testoutput"


object NonNegativeGenerators {
val doubleGen = for {
s <- Gen.chooseNum(1, Int.MaxValue)
Expand All @@ -37,6 +38,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen)
}

private val base64Encoder = Base64.getEncoder
private def base64X1String(x1: X1[String]): X1[String] = {
def base64(str: String): String = base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8))
x1.copy(a = base64(x1.a))
}

override def afterAll(): Unit = {
FileUtils.deleteDirectory(new File(testTempFiles))
super.afterAll()
Expand Down Expand Up @@ -1610,13 +1617,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._

def prop(values: List[X1[String]])(implicit encX1: Encoder[X1[String]]) = {
val cDS = session.createDataset(values)
val valuesBase64 = values.map(base64X1String)
val cDS = session.createDataset(valuesBase64)
val resCompare = cDS
.select(sparkFunctions.unbase64(cDS("a")))
.map(_.getAs[Array[Byte]](0))
.collect().toList

val typedDS = TypedDataset.create(values)
val typedDS = TypedDataset.create(valuesBase64)
val res = typedDS
.select(unbase64(typedDS('a)))
.collect()
Expand Down

0 comments on commit 1bec59b

Please sign in to comment.