Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple-column PK tests #970

Merged
merged 14 commits into from
Aug 6, 2019
7 changes: 4 additions & 3 deletions .ci/integration_test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ def call(ghprbActualCommit, ghprbCommentBody, ghprbPullId, ghprbPullTitle, ghprb
sh """
cp -R /home/jenkins/git/tispark/. ./
git checkout -f ${ghprbActualCommit}
find core/src -name '*Suite*' > test
find core/src -name '*Suite*' | grep -v 'MultiColumnPKDataTypeSuite' > test
shuf test -o test2
mv test2 test
find core/src -name '*MultiColumnPKDataTypeSuite*' >> test
marsishandsome marked this conversation as resolved.
Show resolved Hide resolved
sed -i 's/core\\/src\\/test\\/scala\\///g' test
sed -i 's/\\//\\./g' test
sed -i 's/\\.scala//g' test
shuf test -o test2
mv test2 test
split test -n r/$PARALLEL_NUMBER test_unit_ -a 2 --numeric-suffixes=1
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class BaseDataSourceTest(val table: String,
override def beforeAll(): Unit = {
enableTidbConfigPropertiesInjectedToSpark = _enableTidbConfigPropertiesInjectedToSpark
super.beforeAllWithoutLoadData()
tidbStmt = tidbConn.createStatement()
initializeTimeZone()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,46 @@

package org.apache.spark.sql

import org.apache.spark.sql.test.generator.DataType.ReflectedDataType

trait BaseTestGenerationSpec {

protected val rowCount: Int

protected val preDescription: String = "Generating Data for "

protected var cols: List[ReflectedDataType] = List.empty[ReflectedDataType]

def getTableName(dataTypes: String*): String

def getTableNameWithDesc(desc: String, dataTypes: String*): String

def getColumnName(dataType: String): String = s"col_$dataType"

def getColumnNameByOffset(offset: Int): String = {
assert(
cols.size > offset,
"column length incorrect, maybe `cols` is not initialized correctly?"
)
val dataType = cols(offset)
val suffix = if (cols.count(_ == dataType) > 1) {
var cnt = 0
for (i <- 0 until offset) {
if (cols(i) == dataType) {
cnt += 1
}
}
s"$cnt"
} else {
""
}
s"${getColumnName(dataType.toString)}$suffix"
}

def getIndexName(dataTypes: String*): String =
s"idx_${dataTypes.map(getColumnName).mkString("_")}"

def getIndexNameByOffset(offsets: Int*): String =
s"idx_${offsets.map(getColumnNameByOffset).mkString("_")}"

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import scala.collection.mutable.ArrayBuffer

class BaseTiSparkTest extends QueryTest with SharedSQLContext {

protected var tidbStmt: Statement = _

private val defaultTestDatabases: Seq[String] = Seq("tispark_test")

protected var tableNames: Seq[String] = _
Expand Down Expand Up @@ -166,12 +164,6 @@ class BaseTiSparkTest extends QueryTest with SharedSQLContext {
loadTestData()
}

protected def initializeTimeZone(): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove initializeTimeZone ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I just thought you move to another file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in SharedSQLContext

tidbStmt = tidbConn.createStatement()
// Set default time zone to GMT-7
tidbStmt.execute(s"SET time_zone = '$timeZoneOffset'")
}

protected case class TestTables(dbName: String, tables: String*)

protected def refreshConnections(testTables: TestTables, isHiveEnabled: Boolean = false): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class CatalogTestSuite extends BaseTiSparkTest {
spark.sql("desc extended full_data_type_table").explain()
spark.sql("desc extended full_data_type_table").show(200, truncate = false)
spark.sql("desc formatted full_data_type_table").show(200, truncate = false)
refreshConnections(true)
setCurrentDatabase("default")
spark.sql("drop table if exists t")
spark.sql("create table t(a int)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ trait SharedSQLContext extends SparkFunSuite with Eventually with BeforeAndAfter

protected def ti: TiContext = SharedSQLContext.ti

protected def tidbStmt: Statement = SharedSQLContext.tidbStmt

protected def tidbConn: Connection = SharedSQLContext.tidbConn

protected def tidbOptions: Map[String, String] = SharedSQLContext.tidbOptions
Expand All @@ -67,6 +69,8 @@ trait SharedSQLContext extends SparkFunSuite with Eventually with BeforeAndAfter

protected def initStatistics(): Unit = SharedSQLContext.initStatistics()

protected def initializeTimeZone(): Unit = SharedSQLContext.initializeTimeZone()

protected def defaultTimeZone: TimeZone = SharedSQLContext.timeZone

protected def refreshConnections(): Unit = SharedSQLContext.refreshConnections(false)
Expand Down Expand Up @@ -232,6 +236,12 @@ object SharedSQLContext extends Logging {
logger.info("Analyzing table finished.")
}

protected def initializeTimeZone(): Unit = {
_statement = _tidbConnection.createStatement()
// Set default time zone to GMT-7
_statement.execute(s"SET time_zone = '$timeZoneOffset'")
}

protected def loadSQLFile(directory: String, file: String): Unit = {
val fullFileName = s"$directory/$file.sql"
try {
Expand All @@ -241,7 +251,7 @@ object SharedSQLContext extends Logging {
val queryString = source.mkString
source.close()
_tidbConnection.setCatalog("mysql")
_statement = _tidbConnection.createStatement()
initializeTimeZone()
_statement.execute(queryString)
logger.info(s"Load $fullFileName successfully.")
} catch {
Expand Down Expand Up @@ -273,7 +283,8 @@ object SharedSQLContext extends Logging {
jdbcUrl =
s"jdbc:mysql://address=(protocol=tcp)(host=$tidbAddr)(port=$tidbPort)/?user=$tidbUser&password=$tidbPassword" +
s"&useUnicode=true&characterEncoding=UTF-8&zeroDateTimeBehavior=convertToNull&useSSL=false" +
s"&rewriteBatchedStatements=true&autoReconnect=true&failOverReadOnly=false&maxReconnects=10"
s"&rewriteBatchedStatements=true&autoReconnect=true&failOverReadOnly=false&maxReconnects=10" +
s"&allowMultiQueries=true&serverTimezone=${timeZone.getDisplayName}"

_tidbConnection = TiDBUtils.createConnectionFactory(jdbcUrl)()
_statement = _tidbConnection.createStatement()
Expand Down
29 changes: 14 additions & 15 deletions core/src/test/scala/org/apache/spark/sql/test/generator/Data.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,20 @@ case class Data(schema: Schema, data: List[TiRow], directory: String) {
case null => null
case _: Boolean => value.toString
case _: Number => value.toString
case _: Array[Byte] =>
s"X\'${value
.asInstanceOf[Array[Byte]]
.map { b =>
String.format("%02x", new java.lang.Byte(b))
}
.mkString}\'"
case _: Array[Boolean] =>
s"b\'${value
.asInstanceOf[Array[Boolean]]
.map {
case true => "1"
case false => "0"
}
.mkString}\'"
case arr: Array[Byte] =>
s"X\'${arr.map { b =>
String.format("%02x", new java.lang.Byte(b))
}.mkString}\'"
case arr: Array[Boolean] =>
s"b\'${arr.map {
case true => "1"
case false => "0"
}.mkString}\'"
case ts: java.sql.Timestamp =>
// convert to Timestamp output with current TimeZone
val zonedDateTime = ts.toLocalDateTime.atZone(java.util.TimeZone.getDefault.toZoneId)
val milliseconds = zonedDateTime.toEpochSecond * 1000L + zonedDateTime.getNano / 1000000
s"\'${new java.sql.Timestamp(milliseconds)}\'"
case _ => s"\'$value\'"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ case class ColumnInfo(columnName: String,

val isPrimaryKey: Boolean = desc.contains("primary key")
val nullable: Boolean = !isPrimaryKey && !desc.contains("not null")
val unsigned: Boolean = desc.contains("unsigned")
val noDefault: Boolean = !desc.contains("default")

private val breakDown = desc.split(" ")
val unsigned: Boolean = breakDown.contains("unsigned")
val noDefault: Boolean = !breakDown.contains("default")
val isUnique: Boolean = breakDown.contains("unique")
val default: String = {
if (noDefault) {
null
} else {
val breakDown = desc.split(" ")
val idx = breakDown.indexOf("default")
assert(idx >= 0)
if (idx == breakDown.length - 1) {
Expand Down Expand Up @@ -76,11 +78,19 @@ case class ColumnInfo(columnName: String,
}

val generator: ValueGenerator =
ValueGenerator(dataType, len, decimal, nullable, unsigned, noDefault, default, isPrimaryKey)
ValueGenerator(
dataType,
len,
decimal,
nullable,
unsigned,
noDefault,
default,
isPrimaryKey,
isUnique
)

override def toString: String = {
"`" + columnName + "` " + s"${generator.toString}"
}
override def toString: String = s"`$columnName` ${generator.toString}"
}

case class IndexColumnInfo(column: String, length: Integer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ case class Schema(database: String,

assert(indexInfo.count(_.isPrimary) <= 1, "more than one primary key exist in schema")

private val pkIndexInfo = indexInfo.filter(_.isPrimary)
private val pkColumnName = if (pkIndexInfo.isEmpty) {
val pkIndexInfo = indexInfo.filter(_.isPrimary)
val pkColumnName: String = if (pkIndexInfo.isEmpty) {
""
} else if (pkIndexInfo.head.indexColumns.size == 1) {
pkIndexInfo.head.indexColumns.head.column
} else {
throw new IllegalArgumentException("Multi-column Primary key/Unique index not supported yet")
pkIndexInfo.head.indexColumns.map(_.column).mkString(",")
}

val columnInfo: List[ColumnInfo] = columnNames.map { col =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ object TestDataGenerator {
val dateAndDateTime: List[ReflectedDataType] = timestamps ::: dates ::: durations ::: years

val stringAndBinaries: List[ReflectedDataType] = strings ::: binaries
val charCharset: List[ReflectedDataType] = strings ::: texts
val binaryCharset: List[ReflectedDataType] = binaries ::: bytes
// TODO: support enum and set https://github.com/pingcap/tispark/issues/946
// val stringType: List[DataType] = texts ::: strings ::: binaries ::: enums ::: sets
val stringType: List[ReflectedDataType] = texts ::: strings ::: binaries ::: bytes
val stringType: List[ReflectedDataType] = charCharset ::: binaryCharset
val varString: List[ReflectedDataType] = List(VARCHAR, VARBINARY)

val unsignedType: List[ReflectedDataType] = numeric
Expand Down Expand Up @@ -100,6 +102,8 @@ object TestDataGenerator {
def isNumeric(dataType: ReflectedDataType): Boolean = numeric.contains(dataType)
def isStringType(dataType: ReflectedDataType): Boolean = stringType.contains(dataType)
def isVarString(dataType: ReflectedDataType): Boolean = varString.contains(dataType)
def isCharCharset(dataType: ReflectedDataType): Boolean = charCharset.contains(dataType)
def isBinaryCharset(dataType: ReflectedDataType): Boolean = binaryCharset.contains(dataType)
def isCharOrBinary(dataType: ReflectedDataType): Boolean = stringAndBinaries.contains(dataType)

def getLength(dataType: TiDataType): Long =
Expand Down Expand Up @@ -279,8 +283,11 @@ object TestDataGenerator {
}

def hash(value: Any): String = value match {
case null => "null"
case b: Array[Byte] => b.mkString("[", ",", "]")
case null => "null"
case b: Array[Byte] => b.mkString("[", ",", "]")
case t: java.sql.Timestamp =>
// timestamp was indexed as Integer when treated as unique key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on what situations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unique key, as commented

s"${t.getTime / 1000}"
case list: List[Any] =>
val ret = StringBuilder.newBuilder
ret ++= "("
Expand Down
Loading