diff --git a/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala b/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala index 9cf36ecbb7..03cbf0d69f 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala @@ -236,7 +236,7 @@ case class ColumnValueGenerator(dataType: ReflectedDataType, generatedRandomValues = if (generateUnique) { assert(n <= rangeSize, "random generator cannot generate unique value less than available") val set: mutable.Set[Any] = mutable.HashSet.empty[Any] - set += specialBound.map(TestDataGenerator.hash) + set ++= specialBound.map(TestDataGenerator.hash) (0L until n - specialBound.size).map { _ => randomUniqueValue(r, set) }.toList ++ specialBound @@ -245,9 +245,16 @@ case class ColumnValueGenerator(dataType: ReflectedDataType, randomValue(r) }.toList ++ specialBound } + + val expectedGeneratedRandomValuesLen = if (generateUnique) { + generatedRandomValues.toSet.size + } else { + generatedRandomValues.size + } + assert( - generatedRandomValues.size >= n, - s"Generate values size=$generatedRandomValues less than n=$n" + expectedGeneratedRandomValuesLen >= n, + s"Generate values size=$generatedRandomValues less than n=$n on datatype $dataType" ) curPos = 0 }