-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding unique indices test for batch write (#1014)
- Loading branch information
Showing
9 changed files
with
309 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
core/src/test/scala/org/apache/spark/sql/insertion/BatchWriteUniqueIndexSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package org.apache.spark.sql.insertion | ||
|
||
import com.pingcap.tikv.meta.TiColumnInfo | ||
import com.pingcap.tispark.datasource.BaseDataSourceTest | ||
import com.pingcap.tispark.utils.TiUtil | ||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.test.generator.DataType.ReflectedDataType | ||
import org.apache.spark.sql.test.generator.Schema | ||
import org.apache.spark.sql.test.generator.TestDataGenerator._ | ||
|
||
class BatchWriteUniqueIndexSuite | ||
extends BaseDataSourceTest("batch_write_insertion_one_unique_index", "batch_write_test_index") | ||
with EnumerationUniqueIndexDataTypeTestAction { | ||
// TODO: support binary insertion. | ||
override val dataTypes: List[ReflectedDataType] = integers ::: decimals ::: doubles ::: charCharset | ||
override val unsignedDataTypes: List[ReflectedDataType] = integers ::: decimals ::: doubles | ||
override val dataTypeTestDir = "batch-write-test-index" | ||
override val database = "batch_write_test_index" | ||
override val testDesc = "Test for single PK column and multiple unique index type" | ||
|
||
override def beforeAll(): Unit = { | ||
super.beforeAll() | ||
tidbStmt.execute(s"drop database if exists $database") | ||
tidbStmt.execute(s"create database $database") | ||
} | ||
|
||
private def tiRowToSparkRow(row: TiRow, tiColsInfos: java.util.List[TiColumnInfo]) = { | ||
val sparkRow = new Array[Any](row.fieldCount()) | ||
for (i <- 0 until row.fieldCount()) { | ||
val colTp = tiColsInfos.get(i).getType | ||
val colVal = row.get(i, colTp) | ||
sparkRow(i) = colVal | ||
} | ||
Row.fromSeq(sparkRow) | ||
} | ||
|
||
private def dropAndCreateTbl(schema: Schema): Unit = { | ||
// drop table if exits | ||
dropTable(schema.tableName) | ||
|
||
// create table in tidb first | ||
jdbcUpdate(schema.toString) | ||
} | ||
|
||
private def insertAndSelect(schema: Schema): Unit = { | ||
val tblName = schema.tableName | ||
|
||
val tiTblInfo = getTableInfo(database, tblName) | ||
val tiColInfos = tiTblInfo.getColumns | ||
// gen data | ||
val rows = | ||
generateRandomRows(schema, rowCount, r).map(row => tiRowToSparkRow(row, tiColInfos)) | ||
// insert data to tikv | ||
tidbWriteWithTable(rows, TiUtil.getSchemaFromTable(tiTblInfo), tblName) | ||
// select data from tikv and compare with tidb | ||
compareTiDBSelectWithJDBCWithTable_V2(tblName = tblName, "col_bigint") | ||
} | ||
|
||
test("test unique indices cases") { | ||
val schemas = genSchema(dataTypes, table) | ||
|
||
schemas.foreach { schema => | ||
dropAndCreateTbl(schema) | ||
} | ||
|
||
schemas.foreach { schema => | ||
insertAndSelect(schema) | ||
} | ||
} | ||
|
||
// this is only for | ||
override def test(): Unit = {} | ||
|
||
override def afterAll(): Unit = | ||
try { | ||
dropTable() | ||
} finally { | ||
super.afterAll() | ||
} | ||
} |
88 changes: 88 additions & 0 deletions
88
.../test/scala/org/apache/spark/sql/insertion/EnumerationUniqueIndexDataTypeTestAction.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package org.apache.spark.sql.insertion | ||
|
||
import org.apache.commons.math3.util.Combinations | ||
import org.apache.spark.sql.BaseTestGenerationSpec | ||
import org.apache.spark.sql.test.generator.DataType.{getBaseType, DECIMAL, ReflectedDataType} | ||
import org.apache.spark.sql.test.generator.TestDataGenerator.{getDecimal, getLength, isCharOrBinary, isNumeric, isStringType, isVarString, schemaGenerator} | ||
import org.apache.spark.sql.test.generator._ | ||
import org.apache.spark.sql.types.MultiColumnDataTypeTestSpec | ||
|
||
import scala.util.Random | ||
|
||
trait EnumerationUniqueIndexDataTypeTestAction | ||
extends MultiColumnDataTypeTestSpec | ||
with BaseTestGenerationSpec { | ||
private def genIndex(dataTypes: List[ReflectedDataType], r: Random): List[Index] = { | ||
val size = dataTypes.length | ||
// the first step is generate all possible keys | ||
val keyList = scala.collection.mutable.ListBuffer.empty[Key] | ||
for (i <- 1 until 3) { | ||
val combination = new Combinations(size, i) | ||
//(i, size) | ||
val iterator = combination.iterator() | ||
while (iterator.hasNext) { | ||
val intArray = iterator.next() | ||
val indexColumnList = scala.collection.mutable.ListBuffer.empty[IndexColumn] | ||
// index may have multiple column | ||
for (j <- 0 until intArray.length) { | ||
// we add extra one to the column id since 1 is reserved to primary key | ||
if (isStringType(dataTypes(intArray(j)))) { | ||
indexColumnList += PrefixColumn(intArray(j) + 1, r.nextInt(4) + 2) | ||
} else { | ||
indexColumnList += DefaultColumn(intArray(j) + 1) | ||
} | ||
} | ||
|
||
keyList += Key(indexColumnList.toList) | ||
} | ||
} | ||
|
||
keyList.toList | ||
} | ||
|
||
def genLen(dataType: ReflectedDataType): String = { | ||
val baseType = getBaseType(dataType) | ||
val length = getLength(baseType) | ||
dataType match { | ||
case DECIMAL => s"$length,${getDecimal(baseType)}" | ||
case _ if isVarString(dataType) => s"$length" | ||
case _ if isCharOrBinary(dataType) => "10" | ||
case _ => "" | ||
} | ||
} | ||
|
||
// this only generate schema with one unique index | ||
def genSchema(dataTypes: List[ReflectedDataType], tablePrefix: String): List[Schema] = { | ||
val indices = genIndex(dataTypes, r) | ||
|
||
val dataTypesWithDescription = dataTypes.map { dataType => | ||
val len = genLen(dataType) | ||
if (isNumeric(dataType)) { | ||
(dataType, len, "not null") | ||
} else { | ||
(dataType, len, "") | ||
} | ||
} | ||
|
||
indices.zipWithIndex.map { index => | ||
schemaGenerator( | ||
database, | ||
tablePrefix + index._2, | ||
r, | ||
dataTypesWithDescription, | ||
List(index._1) | ||
) | ||
} | ||
} | ||
|
||
private def toString(dataTypes: Seq[String]): String = dataTypes.hashCode().toString | ||
|
||
override val rowCount = 10 | ||
|
||
override def getTableName(dataTypes: String*): String = s"test_${toString(dataTypes)}" | ||
|
||
override def getTableNameWithDesc(desc: String, dataTypes: String*): String = | ||
s"test_${desc}_${toString(dataTypes)}" | ||
|
||
override def getIndexName(dataTypes: String*): String = s"idx_${toString(dataTypes)}" | ||
} |
Oops, something went wrong.