Skip to content

Commit

Permalink
fix: invalid sequece (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Aug 10, 2022
1 parent 66137ba commit b318070
Show file tree
Hide file tree
Showing 35 changed files with 598 additions and 1,030 deletions.
95 changes: 63 additions & 32 deletions pkg/driver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ func (conn *BackendConnection) WriteComStmtClose(statementID uint32) (err error)
}

// ReadQueryResult gets the result from the last written query.
func (conn *BackendConnection) ReadQueryResult(wantFields bool) (result *mysql.Result, more bool, warnings uint16, err error) {
func (conn *BackendConnection) ReadQueryResult(ctx context.Context, wantFields bool) (result *mysql.Result, more bool, warnings uint16, err error) {
// Get the result.
affectedRows, lastInsertID, colNumber, more, warnings, err := conn.ReadComQueryResponse()
if err != nil {
Expand Down Expand Up @@ -565,8 +565,57 @@ func (conn *BackendConnection) ReadQueryResult(wantFields bool) (result *mysql.R
}
}

result.Rows = mysql.NewRows(conn.Conn, result.Fields)
return
// read each row until EOF or OK packet.
for {
data, err := conn.ReadPacket()
if err != nil {
return nil, false, 0, err
}

if packet.IsEOFPacket(data) {
// Strip the partial Fields before returning.
if !wantFields {
result.Fields = nil
}
result.AffectedRows = uint64(len(result.Rows))

// The deprecated EOF packets change means that this is either an
// EOF packet or an OK packet with the EOF type code.
if conn.capabilities&constant.CapabilityClientDeprecateEOF == 0 {
warnings, more, err = packet.ParseEOFPacket(data)
if err != nil {
return nil, false, 0, err
}
} else {
var statusFlags uint16
_, _, statusFlags, warnings, err = packet.ParseOKPacket(data)
if err != nil {
return nil, false, 0, err
}
more = (statusFlags & constant.ServerMoreResultsExists) != 0
}
return result, more, warnings, nil

} else if packet.IsErrorPacket(data) {
// Error packet.
return nil, false, 0, packet.ParseErrorPacket(data)
}

//// Check we're not over the limit before we add more.
//if len(result.Rows) == maxrows {
// if err := conn.DrainResults(); err != nil {
// return nil, false, 0, err
// }
// return nil, false, 0, err2.NewSQLError(constant.ERMaxRowsExceeded, constant.SSUnknownSQLState, "Row count exceeded %d")
//}

// Regular row.
row, err := conn.ParseRow(ctx, data, result.Fields)
if err != nil {
return nil, false, 0, err
}
result.Rows = append(result.Rows, row)
}
}

func (conn *BackendConnection) ReadComQueryResponse() (affectedRows uint64, lastInsertID uint64, status int, more bool, warnings uint16, err error) {
Expand Down Expand Up @@ -775,24 +824,6 @@ func (conn *BackendConnection) ReadColumnDefinitionType(field *mysql.Field, inde
return nil
}

// DrainResults will read all packets for a result set and ignore them.
func (conn *BackendConnection) DrainResults() error {
for {
data, err := conn.ReadEphemeralPacket()
if err != nil {
return err2.NewSQLError(constant.CRServerLost, constant.SSUnknownSQLState, "%v", err)
}
if packet.IsEOFPacket(data) {
conn.RecycleReadPacket()
return nil
} else if packet.IsErrorPacket(data) {
defer conn.RecycleReadPacket()
return packet.ParseErrorPacket(data)
}
conn.RecycleReadPacket()
}
}

func (conn *BackendConnection) ReadColumnDefinitions() ([]*mysql.Field, error) {
result := make([]*mysql.Field, 0)
i := 0
Expand Down Expand Up @@ -847,15 +878,15 @@ func (conn *BackendConnection) Ping(ctx context.Context) (err error) {
//
// 2. if the server closes the connection when a command is in flight,
// ReadComQueryResponse will fail, and we'll return CRServerLost(2013).
func (conn *BackendConnection) Execute(query string, wantFields bool) (result *mysql.Result, err error) {
result, _, err = conn.ExecuteMulti(query, wantFields)
func (conn *BackendConnection) Execute(ctx context.Context, query string, wantFields bool) (result *mysql.Result, err error) {
result, _, err = conn.ExecuteMulti(ctx, query, wantFields)
return
}

// ExecuteMulti is for fetching multiple results from a multi-statement result.
// It returns an additional 'more' flag. If it is set, you must fetch the additional
// results using ReadQueryResult.
func (conn *BackendConnection) ExecuteMulti(query string, wantFields bool) (result *mysql.Result, more bool, err error) {
func (conn *BackendConnection) ExecuteMulti(ctx context.Context, query string, wantFields bool) (result *mysql.Result, more bool, err error) {
defer func() {
if err != nil {
if sqlerr, ok := err.(*err2.SQLError); ok {
Expand All @@ -869,7 +900,7 @@ func (conn *BackendConnection) ExecuteMulti(query string, wantFields bool) (resu
return nil, false, err
}

result, more, _, err = conn.ReadQueryResult(wantFields)
result, more, _, err = conn.ReadQueryResult(ctx, wantFields)
return
}

Expand All @@ -893,16 +924,16 @@ func (conn *BackendConnection) ExecuteWithWarningCount(ctx context.Context, quer
return nil, 0, err
}

result, _, warnings, err = conn.ReadQueryResult(wantFields)
result, _, warnings, err = conn.ReadQueryResult(ctx, wantFields)
return
}

func (conn *BackendConnection) PrepareExecuteArgs(query string, args []interface{}) (result *mysql.Result, warnings uint16, err error) {
func (conn *BackendConnection) PrepareExecuteArgs(ctx context.Context, query string, args []interface{}) (result *mysql.Result, warnings uint16, err error) {
stmt, err := conn.prepare(query)
if err != nil {
return nil, 0, err
}
return stmt.execArgs(args)
return stmt.execArgs(ctx, args)
}

func (conn *BackendConnection) PrepareQueryArgs(ctx context.Context, query string, args []interface{}) (Result *mysql.Result, warnings uint16, err error) {
Expand All @@ -914,23 +945,23 @@ func (conn *BackendConnection) PrepareQueryArgs(ctx context.Context, query strin
span.RecordError(err)
return nil, 0, err
}
return stmt.queryArgs(args)
return stmt.queryArgs(ctx, args)
}

func (conn *BackendConnection) PrepareExecute(query string, data []byte) (result *mysql.Result, warnings uint16, err error) {
func (conn *BackendConnection) PrepareExecute(ctx context.Context, query string, data []byte) (result *mysql.Result, warnings uint16, err error) {
stmt, err := conn.prepare(query)
if err != nil {
return nil, 0, err
}
return stmt.exec(data)
}

func (conn *BackendConnection) PrepareQuery(query string, data []byte) (Result *mysql.Result, warnings uint16, err error) {
func (conn *BackendConnection) PrepareQuery(ctx context.Context, query string, data []byte) (Result *mysql.Result, warnings uint16, err error) {
stmt, err := conn.prepare(query)
if err != nil {
return nil, 0, err
}
return stmt.query(data)
return stmt.query(ctx, data)
}

func (conn *BackendConnection) prepare(query string) (*BackendStatement, error) {
Expand Down
11 changes: 6 additions & 5 deletions pkg/driver/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package driver

import (
"context"
"database/sql/driver"
"encoding/binary"
"encoding/json"
Expand Down Expand Up @@ -358,7 +359,7 @@ func (stmt *BackendStatement) writeExecutePacket(args []interface{}) error {
return bc.WritePacket(data[4:])
}

func (stmt *BackendStatement) execArgs(args []interface{}) (*mysql.Result, uint16, error) {
func (stmt *BackendStatement) execArgs(ctx context.Context, args []interface{}) (*mysql.Result, uint16, error) {
nargs := make([]interface{}, len(args))
for i, arg := range args {
var err error
Expand Down Expand Up @@ -395,7 +396,7 @@ func (stmt *BackendStatement) execArgs(args []interface{}) (*mysql.Result, uint1
}, warnings, nil
}

func (stmt *BackendStatement) queryArgs(args []interface{}) (*mysql.Result, uint16, error) {
func (stmt *BackendStatement) queryArgs(ctx context.Context, args []interface{}) (*mysql.Result, uint16, error) {
nargs := make([]interface{}, len(args))
for i, arg := range args {
var err error
Expand All @@ -410,7 +411,7 @@ func (stmt *BackendStatement) queryArgs(args []interface{}) (*mysql.Result, uint
return nil, 0, err
}

result, _, warnings, err := stmt.conn.ReadQueryResult(true)
result, _, warnings, err := stmt.conn.ReadQueryResult(ctx, true)
return result, warnings, err
}

Expand Down Expand Up @@ -457,7 +458,7 @@ func (stmt *BackendStatement) exec(args []byte) (*mysql.Result, uint16, error) {
}, warnings, nil
}

func (stmt *BackendStatement) query(args []byte) (*mysql.Result, uint16, error) {
func (stmt *BackendStatement) query(ctx context.Context, args []byte) (*mysql.Result, uint16, error) {
args[1] = byte(stmt.id)
args[2] = byte(stmt.id >> 8)
args[3] = byte(stmt.id >> 16)
Expand All @@ -471,6 +472,6 @@ func (stmt *BackendStatement) query(args []byte) (*mysql.Result, uint16, error)
return nil, 0, err
}

result, _, warnings, err := stmt.conn.ReadQueryResult(true)
result, _, warnings, err := stmt.conn.ReadQueryResult(ctx, true)
return result, warnings, err
}
8 changes: 4 additions & 4 deletions pkg/dt/mysql_undo_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,20 @@ func (executor MysqlUndoExecutor) queryCurrentRecords(tx proto.Tx) (*schema.Tabl

if executor.sqlUndoLog.IsBinary {
selectSql := executor.buildCurrentRecordsForPrepareSql(tableMeta, pkName, pkValues)
dataTable, _, err := tx.ExecuteSql(context.Background(), selectSql, pkValues...)
dataTable, _, err := tx.ExecuteSqlDirectly(selectSql, pkValues...)
if err != nil {
return nil, err
}
dt := dataTable.(*mysql.Result)
return schema.BuildBinaryRecords(tableMeta, dt), nil
return schema.BuildTableRecords(tableMeta, dt), nil
} else {
selectSql := executor.buildCurrentRecordsForQuerySql(tableMeta, pkName, pkValues)
dataTable, _, err := tx.Query(context.Background(), selectSql)
dataTable, _, err := tx.QueryDirectly(selectSql)
if err != nil {
return nil, err
}
dt := dataTable.(*mysql.Result)
return schema.BuildTextRecords(tableMeta, dt), nil
return schema.BuildTableRecords(tableMeta, dt), nil
}
}

Expand Down
24 changes: 10 additions & 14 deletions pkg/dt/mysql_undo_log_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"time"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/driver"
"github.com/cectc/dbpack/pkg/dt/undolog"
"github.com/cectc/dbpack/pkg/log"
Expand Down Expand Up @@ -76,21 +77,15 @@ func (manager MysqlUndoLogManager) Undo(db proto.DB, xid string) ([]string, erro
return lockKeys, err
}

if result, _, err = tx.ExecuteSql(context.Background(), SelectUndoLogSql, xid); err != nil {
if result, _, err = tx.ExecuteSqlDirectly(SelectUndoLogSql, xid); err != nil {
return lockKeys, err
}

exists := false
undoLogs := make([]*undolog.SqlUndoLog, 0)
rlt := result.(*mysql.Result)
for {
row, err := rlt.Rows.Next()
if err != nil {
break
}

binaryRow := mysql.BinaryRow{Row: row}
values, err := binaryRow.Decode()
for _, row := range rlt.Rows {
values, err := row.Decode()
if err != nil {
break
}
Expand Down Expand Up @@ -133,7 +128,7 @@ func (manager MysqlUndoLogManager) Undo(db proto.DB, xid string) ([]string, erro
}

if exists {
_, _, err := tx.ExecuteSql(context.Background(), DeleteUndoLogByXIDSql, xid)
_, _, err := tx.ExecuteSqlDirectly(DeleteUndoLogByXIDSql, xid)
if err != nil {
if _, err := tx.Rollback(context.Background(), nil); err != nil {
return lockKeys, err
Expand All @@ -154,7 +149,7 @@ func (manager MysqlUndoLogManager) Undo(db proto.DB, xid string) ([]string, erro
}

func (manager MysqlUndoLogManager) DeleteUndoLogByID(db proto.DB, id int64) error {
result, _, err := db.ExecuteSql(context.Background(), DeleteUndoLogByIDSql, id)
result, _, err := db.ExecuteSqlDirectly(DeleteUndoLogByIDSql, id)
if err != nil {
return err
}
Expand All @@ -164,7 +159,7 @@ func (manager MysqlUndoLogManager) DeleteUndoLogByID(db proto.DB, id int64) erro
}

func (manager MysqlUndoLogManager) DeleteUndoLogByXID(db proto.DB, xid string) error {
result, _, err := db.ExecuteSql(context.Background(), DeleteUndoLogByXIDSql, xid)
result, _, err := db.ExecuteSqlDirectly(DeleteUndoLogByXIDSql, xid)
if err != nil {
return err
}
Expand All @@ -175,7 +170,7 @@ func (manager MysqlUndoLogManager) DeleteUndoLogByXID(db proto.DB, xid string) e

func (manager MysqlUndoLogManager) DeleteUndoLogByLogCreated(db proto.DB, logCreated time.Time, limitRows int) error {
// TODO pass ctx.
result, _, err := db.ExecuteSql(context.Background(), DeleteUndoLogByCreateSql, logCreated, limitRows)
result, _, err := db.ExecuteSqlDirectly(DeleteUndoLogByCreateSql, logCreated, limitRows)
if err != nil {
return err
}
Expand All @@ -200,7 +195,8 @@ func (manager MysqlUndoLogManager) insertUndoLog(conn proto.Connection, xid stri
undoLogContent []byte, state State) error {
bc := conn.(*driver.BackendConnection)
args := []interface{}{xid, branchID, rollbackCtx, undoLogContent, state}
_, _, err := bc.PrepareExecuteArgs(InsertUndoLogSql, args)
ctx := proto.WithCommandType(context.Background(), constant.ComQuery)
_, _, err := bc.PrepareExecuteArgs(ctx, InsertUndoLogSql, args)
return err
}

Expand Down
50 changes: 3 additions & 47 deletions pkg/dt/schema/table_records.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,12 @@ func BuildLockKey(lockKeyRecords *TableRecords) string {
return sb.String()
}

func BuildBinaryRecords(meta TableMeta, result *mysql.Result) *TableRecords {
func BuildTableRecords(meta TableMeta, result *mysql.Result) *TableRecords {
records := NewTableRecords(meta)
rs := make([]*Row, 0)

for {
row, err := result.Rows.Next()
if err != nil {
break
}

binaryRow := mysql.BinaryRow{Row: row}
values, err := binaryRow.Decode()
for _, row := range result.Rows {
values, err := row.Decode()
if err != nil {
break
}
Expand All @@ -115,45 +109,7 @@ func BuildBinaryRecords(meta TableMeta, result *mysql.Result) *TableRecords {
r := &Row{Fields: fields}
rs = append(rs, r)
}
if len(rs) == 0 {
return nil
}
records.Rows = rs
return records
}

func BuildTextRecords(meta TableMeta, result *mysql.Result) *TableRecords {
records := NewTableRecords(meta)
rs := make([]*Row, 0)

for {
row, err := result.Rows.Next()
if err != nil {
break
}

textRow := mysql.TextRow{Row: row}
values, err := textRow.Decode()
if err != nil {
break
}
fields := make([]*Field, 0, len(result.Fields))
for i, col := range result.Fields {
field := &Field{
Name: col.FiledName(),
Type: meta.AllColumns[col.FiledName()].DataType,
}
if values[i] != nil {
field.Value = values[i].Val
}
if strings.EqualFold(col.FiledName(), meta.GetPKName()) {
field.KeyType = PrimaryKey
}
fields = append(fields, field)
}
r := &Row{Fields: fields}
rs = append(rs, r)
}
if len(rs) == 0 {
return nil
}
Expand Down
Loading

0 comments on commit b318070

Please sign in to comment.