Skip to content

Commit

Permalink
adding unique indices test for batch write (#1014)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhexuany authored Aug 17, 2019
1 parent 9551f7d commit e7d51f5
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,29 @@ class BaseDataSourceTest(val table: String,
TiSession.clearCache()
}

protected def tidbWrite(rows: List[Row],
schema: StructType,
param: Option[Map[String, String]] = None): Unit = {
protected def dropTable(tblName: String): Unit = {
jdbcUpdate(s"drop table if exists `$database`.`$tblName`")
}

protected def tidbWriteWithTable(rows: List[Row],
schema: StructType,
tblName: String,
param: Option[Map[String, String]] = None): Unit = {
val data: RDD[Row] = sc.makeRDD(rows)
val df = sqlContext.createDataFrame(data, schema)
df.write
.format("tidb")
.options(tidbOptions ++ param.getOrElse(Map.empty))
.option("database", database)
.option("table", table)
.option("table", tblName)
.mode("append")
.save()
}
protected def tidbWrite(rows: List[Row],
schema: StructType,
param: Option[Map[String, String]] = None): Unit = {
tidbWriteWithTable(rows, schema, table, param)
}

protected def jdbcWrite(rows: List[Row],
schema: StructType,
Expand All @@ -67,23 +77,31 @@ class BaseDataSourceTest(val table: String,
.save()
}

protected def testTiDBSelect(expectedAnswer: Seq[Row],
sortCol: String = "i",
selectCol: String = null): Unit = {
protected def testTiDBSelectWithTable(
expectedAnswer: Seq[Row],
sortCol: String = "i",
selectCol: String = null,
tableName: String
) = {
// check data source result & expected answer
var df = queryDatasourceTiDB(sortCol)
var df = queryDatasourceTiDBWithTable(sortCol, tableName)
if (selectCol != null) {
df = df.select(selectCol)
}
checkAnswer(df, expectedAnswer)

// check table scan
var df2 = queryDatasourceTableScan(sortCol)
var df2 = queryDatasourceTableScanWithTable(sortCol, tableName)
if (selectCol != null) {
df2 = df2.select(selectCol)
}
checkAnswer(df2, expectedAnswer)
}
protected def testTiDBSelect(expectedAnswer: Seq[Row],
sortCol: String = "i",
selectCol: String = null): Unit = {
testTiDBSelectWithTable(expectedAnswer, sortCol, selectCol, table)
}

protected def compareTiDBWriteFailureWithJDBC(
data: List[Row],
Expand Down Expand Up @@ -175,17 +193,26 @@ class BaseDataSourceTest(val table: String,

}

protected def compareTiDBSelectWithJDBC_V2(sortCol: String = "i"): Unit = {
val sql = s"select * from $dbtable order by $sortCol"
protected def compareTiDBSelectWithJDBCWithTable_V2(tblName: String,
sortCol: String = "i"): Unit = {
val sql = s"select * from `$database`.`$tblName` order by $sortCol"

// check jdbc result & data source result
val jdbcResult = queryTiDBViaJDBC(sql)
val df = queryDatasourceTiDB(sortCol)
val df = queryDatasourceTiDBWithTable(sortCol, tableName = tblName)
val tidbResult = seqRowToList(df.collect(), df.schema)

assert(
compSqlResult(sql, jdbcResult, tidbResult, checkLimit = false)
)
if (compResult(jdbcResult, tidbResult)) {
assert(true)
} else {
println(s"failed on $tblName")
println(tidbResult)
assert(false)
}
}

protected def compareTiDBSelectWithJDBC_V2(sortCol: String = "i"): Unit = {
compareTiDBSelectWithJDBCWithTable_V2(table, sortCol)
}

private def seqRowToList(rows: Seq[Row], schema: StructType): List[List[Any]] =
Expand All @@ -203,26 +230,33 @@ class BaseDataSourceTest(val table: String,
})
.toList

protected def queryDatasourceTableScan(sortCol: String): DataFrame = {
protected def queryDatasourceTableScanWithTable(sortCol: String, tblName: String): DataFrame = {
sqlContext.read
.format("tidb")
.options(tidbOptions)
.option("database", database)
.option("table", table)
.option("table", tblName)
.option(TiConfigConst.ALLOW_INDEX_READ, "false")
.load()
.sort(sortCol)
}

protected def queryDatasourceTiDB(sortCol: String): DataFrame =
protected def queryDatasourceTableScan(sortCol: String): DataFrame = {
queryDatasourceTableScanWithTable(sortCol, table)
}

protected def queryDatasourceTiDBWithTable(sortCol: String, tableName: String): DataFrame =
sqlContext.read
.format("tidb")
.options(tidbOptions)
.option("database", database)
.option("table", table)
.option("table", tableName)
.load()
.sort(sortCol)

protected def queryDatasourceTiDB(sortCol: String): DataFrame =
queryDatasourceTiDBWithTable(sortCol, table)

protected def testTiDBSelectFilter(filter: String, expectedAnswer: Seq[Row]): Unit = {
val loadedDf = sqlContext.read
.format("tidb")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.apache.spark.sql.insertion

import com.pingcap.tikv.meta.TiColumnInfo
import com.pingcap.tispark.datasource.BaseDataSourceTest
import com.pingcap.tispark.utils.TiUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.generator.DataType.ReflectedDataType
import org.apache.spark.sql.test.generator.Schema
import org.apache.spark.sql.test.generator.TestDataGenerator._

class BatchWriteUniqueIndexSuite
extends BaseDataSourceTest("batch_write_insertion_one_unique_index", "batch_write_test_index")
with EnumerationUniqueIndexDataTypeTestAction {
// TODO: support binary insertion.
override val dataTypes: List[ReflectedDataType] = integers ::: decimals ::: doubles ::: charCharset
override val unsignedDataTypes: List[ReflectedDataType] = integers ::: decimals ::: doubles
override val dataTypeTestDir = "batch-write-test-index"
override val database = "batch_write_test_index"
override val testDesc = "Test for single PK column and multiple unique index type"

override def beforeAll(): Unit = {
super.beforeAll()
tidbStmt.execute(s"drop database if exists $database")
tidbStmt.execute(s"create database $database")
}

private def tiRowToSparkRow(row: TiRow, tiColsInfos: java.util.List[TiColumnInfo]) = {
val sparkRow = new Array[Any](row.fieldCount())
for (i <- 0 until row.fieldCount()) {
val colTp = tiColsInfos.get(i).getType
val colVal = row.get(i, colTp)
sparkRow(i) = colVal
}
Row.fromSeq(sparkRow)
}

private def dropAndCreateTbl(schema: Schema): Unit = {
// drop table if exits
dropTable(schema.tableName)

// create table in tidb first
jdbcUpdate(schema.toString)
}

private def insertAndSelect(schema: Schema): Unit = {
val tblName = schema.tableName

val tiTblInfo = getTableInfo(database, tblName)
val tiColInfos = tiTblInfo.getColumns
// gen data
val rows =
generateRandomRows(schema, rowCount, r).map(row => tiRowToSparkRow(row, tiColInfos))
// insert data to tikv
tidbWriteWithTable(rows, TiUtil.getSchemaFromTable(tiTblInfo), tblName)
// select data from tikv and compare with tidb
compareTiDBSelectWithJDBCWithTable_V2(tblName = tblName, "col_bigint")
}

test("test unique indices cases") {
val schemas = genSchema(dataTypes, table)

schemas.foreach { schema =>
dropAndCreateTbl(schema)
}

schemas.foreach { schema =>
insertAndSelect(schema)
}
}

// this is only for
override def test(): Unit = {}

override def afterAll(): Unit =
try {
dropTable()
} finally {
super.afterAll()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package org.apache.spark.sql.insertion

import org.apache.commons.math3.util.Combinations
import org.apache.spark.sql.BaseTestGenerationSpec
import org.apache.spark.sql.test.generator.DataType.{getBaseType, DECIMAL, ReflectedDataType}
import org.apache.spark.sql.test.generator.TestDataGenerator.{getDecimal, getLength, isCharOrBinary, isNumeric, isStringType, isVarString, schemaGenerator}
import org.apache.spark.sql.test.generator._
import org.apache.spark.sql.types.MultiColumnDataTypeTestSpec

import scala.util.Random

trait EnumerationUniqueIndexDataTypeTestAction
extends MultiColumnDataTypeTestSpec
with BaseTestGenerationSpec {
private def genIndex(dataTypes: List[ReflectedDataType], r: Random): List[Index] = {
val size = dataTypes.length
// the first step is generate all possible keys
val keyList = scala.collection.mutable.ListBuffer.empty[Key]
for (i <- 1 until 3) {
val combination = new Combinations(size, i)
//(i, size)
val iterator = combination.iterator()
while (iterator.hasNext) {
val intArray = iterator.next()
val indexColumnList = scala.collection.mutable.ListBuffer.empty[IndexColumn]
// index may have multiple column
for (j <- 0 until intArray.length) {
// we add extra one to the column id since 1 is reserved to primary key
if (isStringType(dataTypes(intArray(j)))) {
indexColumnList += PrefixColumn(intArray(j) + 1, r.nextInt(4) + 2)
} else {
indexColumnList += DefaultColumn(intArray(j) + 1)
}
}

keyList += Key(indexColumnList.toList)
}
}

keyList.toList
}

def genLen(dataType: ReflectedDataType): String = {
val baseType = getBaseType(dataType)
val length = getLength(baseType)
dataType match {
case DECIMAL => s"$length,${getDecimal(baseType)}"
case _ if isVarString(dataType) => s"$length"
case _ if isCharOrBinary(dataType) => "10"
case _ => ""
}
}

// this only generate schema with one unique index
def genSchema(dataTypes: List[ReflectedDataType], tablePrefix: String): List[Schema] = {
val indices = genIndex(dataTypes, r)

val dataTypesWithDescription = dataTypes.map { dataType =>
val len = genLen(dataType)
if (isNumeric(dataType)) {
(dataType, len, "not null")
} else {
(dataType, len, "")
}
}

indices.zipWithIndex.map { index =>
schemaGenerator(
database,
tablePrefix + index._2,
r,
dataTypesWithDescription,
List(index._1)
)
}
}

private def toString(dataTypes: Seq[String]): String = dataTypes.hashCode().toString

override val rowCount = 10

override def getTableName(dataTypes: String*): String = s"test_${toString(dataTypes)}"

override def getTableNameWithDesc(desc: String, dataTypes: String*): String =
s"test_${desc}_${toString(dataTypes)}"

override def getIndexName(dataTypes: String*): String = s"idx_${toString(dataTypes)}"
}
Loading

0 comments on commit e7d51f5

Please sign in to comment.