Skip to content

Commit

Permalink
Use CREATE TABLE to resolve INSERT columns
Browse files Browse the repository at this point in the history
  • Loading branch information
mpchadwick committed Nov 1, 2020
1 parent dd90c66 commit c621fe3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
type LineProcessor struct {
Config *Config
Provider ProviderInterface
nextTable string
currentTable sqlparser.Statement
}

var nextTable = ""
var currentTable []string

func NewLineProcessor(c *Config, p ProviderInterface) *LineProcessor {
return &LineProcessor{Config: c, Provider: p, nextTable: ""}
return &LineProcessor{Config: c, Provider: p}
}

func (p LineProcessor) ProcessLine(s string) string {
Expand All @@ -28,21 +29,25 @@ func (p LineProcessor) ProcessLine(s string) string {
}

func (p LineProcessor) findNextTable(s string) {
if len(p.nextTable) > 0 {
if len(nextTable) > 0 {
// TODO: Are we guaranteed this will delimit the end of the CREATE TABLE?
j := strings.Index(s, "/*!40101")
if j == 0 {
stmt, _ := sqlparser.Parse(p.nextTable)
p.currentTable = stmt
p.nextTable = ""
stmt, _ := sqlparser.Parse(nextTable)
currentTable = nil
createTable := stmt.(*sqlparser.CreateTable)
for _, col := range createTable.Columns {
currentTable = append(currentTable, col.Name)
}
nextTable = ""
} else {
p.nextTable += s
nextTable += s
}
}

k := strings.Index(s, "CREATE TABLE")
if k == 0 {
p.nextTable += s
nextTable += s
}
}

Expand All @@ -63,7 +68,7 @@ func (p LineProcessor) processInsert(s string) string {
rows := insert.Rows.(sqlparser.Values)
for _, vt := range rows {
for i, e := range vt {
column := insert.Columns[i].String()
column := currentTable[i]

result, dataType := p.Config.ProcessColumn(table, column)

Expand Down
5 changes: 5 additions & 0 deletions src/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func TestProcessLine(t *testing.T) {
t.Error("Got bob wanted no bob")
}

processor.ProcessLine("CREATE TABLE `admin_user` (")
processor.ProcessLine(" `user_id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'User ID'")
processor.ProcessLine(") ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8 COMMENT='Admin User Table'")
processor.ProcessLine("/*!40101 SET character_set_client = @saved_cs_client */;")

r3 := processor.ProcessLine("INSERT INTO `admin_user` (`user_id`) VALUES (1337);")
if !strings.Contains(r3, "1337") {
t.Error("Got no 1337 wanted 1337")
Expand Down

0 comments on commit c621fe3

Please sign in to comment.