Skip to content

Commit

Permalink
sql reader在构建数据库表名时根据具体的数据库类型构建 (#704)
Browse files Browse the repository at this point in the history
* fix sql invalid in postgres

* fix test

* fix tablename

* add test case

* add comment
  • Loading branch information
xxh2000 authored and wonderflow committed Aug 13, 2018
1 parent 2d7bf6f commit 9732d12
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
28 changes: 24 additions & 4 deletions reader/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ func (r *Reader) getValidData(connectStr, curDB, matchData, matchStr string,
continue
}

rawSql, err := getRawSqls(queryType, s)
rawSql, err := r.getRawSqls(queryType, s)
if err != nil {
return validData, sqls, err
}
Expand Down Expand Up @@ -1678,14 +1678,34 @@ func (r *Reader) getCheckAll(queryType int) (checkAll bool, err error) {

return true, nil
}
//根据数据库类型返回表名
func getWrappedTableName(dbtype string, table string) (tableName string, err error) {
switch dbtype {
case reader.ModeMySQL:
tableName = "`" + table + "`"
case reader.ModeMSSQL, reader.ModePostgreSQL:
tableName = "\"" + table + "\""
default:
err = fmt.Errorf("%v mode not support in sql reader", dbtype)
}
return
}

// 根据 queryType 获取表中所有记录或者表中所有数据的条数的sql语句
func getRawSqls(queryType int, table string) (sqls string, err error) {
func (r *Reader) getRawSqls(queryType int, table string) (sqls string, err error) {
switch queryType {
case TABLE:
sqls += "Select * From `" + table + "`;"
tableName, err := getWrappedTableName(r.dbtype, table)
if err != nil {
return "", err
}
sqls += "Select * From " + tableName + ";"
case COUNT:
sqls += "Select Count(*) From `" + table + "`;"
tableName, err := getWrappedTableName(r.dbtype, table)
if err != nil {
return "", err
}
sqls += "Select Count(*) From " + tableName + ";"
case DATABASE:
default:
return "", fmt.Errorf("%v queryType is not support get sql now", queryType)
Expand Down
45 changes: 42 additions & 3 deletions reader/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,24 @@ func Test_getCheckAll(t *testing.T) {
assert.EqualValues(t, test.expRes, checkHistory)
}
}
func Test_getWrappedTableName(t *testing.T) {
dbtype := reader.ModeMySQL
tname, err := getWrappedTableName(dbtype, "my_table")
expRes := "`my_table`"
assert.NoError(t, err)
assert.EqualValues(t, expRes, tname)

dbtype = reader.ModePostgreSQL
tname, err = getWrappedTableName(dbtype, "my_table")
expRes = "\"my_table\""
assert.NoError(t, err)
assert.EqualValues(t, expRes, tname)
}
func Test_getRawSQLs(t *testing.T) {
tests := []struct {
r := &Reader{
dbtype: reader.ModeMySQL,
}
mysqltests := []struct {
queryType int
expSQLs string
}{
Expand All @@ -1037,11 +1052,35 @@ func Test_getRawSQLs(t *testing.T) {
},
}

for _, test := range tests {
sqls, err := getRawSqls(test.queryType, "my_table")
for _, test := range mysqltests {
sqls, err := r.getRawSqls(test.queryType, "my_table")
assert.NoError(t, err)
assert.EqualValues(t, test.expSQLs, sqls)
}
r.dbtype = reader.ModePostgreSQL
pgtests := []struct {
queryType int
expSQLs string
}{
{
queryType: TABLE,
expSQLs: "Select * From \"my_table\";",
},
{
queryType: COUNT,
expSQLs: "Select Count(*) From \"my_table\";",
},
{
queryType: DATABASE,
expSQLs: "",
},
}
for _, test := range pgtests {
sqls, err := r.getRawSqls(test.queryType, "my_table")
assert.NoError(t, err)
assert.EqualValues(t, test.expSQLs, sqls)
}

}

func Test_getConnectStr(t *testing.T) {
Expand Down

0 comments on commit 9732d12

Please sign in to comment.