Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for kill statement #13371

Merged
merged 10 commits into from
Jul 10, 2023
14 changes: 6 additions & 8 deletions go/cmd/vtgateclienttest/services/callerid.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@ limitations under the License.
package services

import (
"context"
"encoding/json"
"fmt"
"strings"

"context"

"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/callerid"
"vitess.io/vitess/go/vt/vtgate/vtgateservice"

querypb "vitess.io/vitess/go/vt/proto/query"
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vtgate/vtgateservice"
)

// CallerIDPrefix is the prefix to send with queries so they go
Expand Down Expand Up @@ -77,11 +75,11 @@ func (c *callerIDClient) checkCallerID(ctx context.Context, received string) (bo
return true, fmt.Errorf("SUCCESS: callerid matches")
}

func (c *callerIDClient) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *callerIDClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we want to introduce a type to be sent in here, instead of growing the arguments across all interface implementations. Like we have the planning-context when planning and the vcursor when running. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep that refactor when we go about changing this definition again.

if ok, err := c.checkCallerID(ctx, sql); ok {
return session, nil, err
}
return c.fallbackClient.Execute(ctx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
}

func (c *callerIDClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand All @@ -93,9 +91,9 @@ func (c *callerIDClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Ses
return c.fallbackClient.ExecuteBatch(ctx, session, sqlList, bindVariablesList)
}

func (c *callerIDClient) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
func (c *callerIDClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
if ok, err := c.checkCallerID(ctx, sql); ok {
return session, err
}
return c.fallbackClient.StreamExecute(ctx, session, sql, bindVariables, callback)
return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback)
}
11 changes: 5 additions & 6 deletions go/cmd/vtgateclienttest/services/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ package services

import (
"bytes"
"context"
"fmt"
"reflect"
"sort"
"strings"

"context"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/callerid"
"vitess.io/vitess/go/vt/vtgate/vtgateservice"
Expand Down Expand Up @@ -98,7 +97,7 @@ func echoQueryResult(vals map[string]any) *sqltypes.Result {
return qr
}

func (c *echoClient) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *echoClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
if strings.HasPrefix(sql, EchoPrefix) {
return session, echoQueryResult(map[string]any{
"callerId": callerid.EffectiveCallerIDFromContext(ctx),
Expand All @@ -107,10 +106,10 @@ func (c *echoClient) Execute(ctx context.Context, session *vtgatepb.Session, sql
"session": session,
}), nil
}
return c.fallbackClient.Execute(ctx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
}

func (c *echoClient) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
func (c *echoClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
if strings.HasPrefix(sql, EchoPrefix) {
callback(echoQueryResult(map[string]any{
"callerId": callerid.EffectiveCallerIDFromContext(ctx),
Expand All @@ -120,7 +119,7 @@ func (c *echoClient) StreamExecute(ctx context.Context, session *vtgatepb.Sessio
}))
return session, nil
}
return c.fallbackClient.StreamExecute(ctx, session, sql, bindVariables, callback)
return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback)
}

func (c *echoClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand Down
11 changes: 5 additions & 6 deletions go/cmd/vtgateclienttest/services/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ limitations under the License.
package services

import (
"strings"

"context"
"strings"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -111,14 +110,14 @@ func trimmedRequestToError(received string) error {
}
}

func (c *errorClient) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *errorClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
if err := requestToPartialError(sql, session); err != nil {
return session, nil, err
}
if err := requestToError(sql); err != nil {
return session, nil, err
}
return c.fallbackClient.Execute(ctx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
}

func (c *errorClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand All @@ -133,11 +132,11 @@ func (c *errorClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Sessio
return c.fallbackClient.ExecuteBatch(ctx, session, sqlList, bindVariablesList)
}

func (c *errorClient) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
func (c *errorClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
if err := requestToError(sql); err != nil {
return session, err
}
return c.fallbackClient.StreamExecute(ctx, session, sql, bindVariables, callback)
return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback)
}

func (c *errorClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
Expand Down
8 changes: 4 additions & 4 deletions go/cmd/vtgateclienttest/services/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ func newFallbackClient(fallback vtgateservice.VTGateService) fallbackClient {
return fallbackClient{fallback: fallback}
}

func (c fallbackClient) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
return c.fallback.Execute(ctx, session, sql, bindVariables)
func (c fallbackClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
return c.fallback.Execute(ctx, mysqlCtx, session, sql, bindVariables)
}

func (c fallbackClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
return c.fallback.ExecuteBatch(ctx, session, sqlList, bindVariablesList)
}

func (c fallbackClient) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
return c.fallback.StreamExecute(ctx, session, sql, bindVariables, callback)
func (c fallbackClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
return c.fallback.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback)
}

func (c fallbackClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
Expand Down
7 changes: 4 additions & 3 deletions go/cmd/vtgateclienttest/services/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ limitations under the License.
package services

import (
"context"
"errors"
"fmt"

"context"
"vitess.io/vitess/go/vt/vtgate/vtgateservice"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/tb"
Expand All @@ -42,7 +43,7 @@ func newTerminalClient() *terminalClient {
return &terminalClient{}
}

func (c *terminalClient) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *terminalClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
if sql == "quit://" {
log.Fatal("Received quit:// query. Going down.")
}
Expand All @@ -58,7 +59,7 @@ func (c *terminalClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Ses
return session, nil, errTerminal
}

func (c *terminalClient) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
func (c *terminalClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
return session, errTerminal
}

Expand Down
55 changes: 51 additions & 4 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mysql

import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand All @@ -29,11 +30,9 @@ import (
"sync/atomic"
"time"

"vitess.io/vitess/go/bucketpool"
"vitess.io/vitess/go/mysql/collations"

"vitess.io/vitess/go/sqlescape"

"vitess.io/vitess/go/bucketpool"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -199,6 +198,15 @@ type Conn struct {
// enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets
// See: ConnParams.EnableQueryInfo
enableQueryInfo bool

// mu protects the fields below
mu sync.Mutex
// cancel keep the cancel function for the current executing query.
// this is used by `kill [query|connection] ID` command from other connection.
cancel context.CancelFunc
// this is used to mark the connection to be closed so that the command phase for the connection can be stopped and
// the connection gets closed.
closing bool
}

// splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query.
Expand Down Expand Up @@ -767,7 +775,7 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro

bytes, pos := c.startEphemeralPacketWithHeader(length)
data := &coder{data: bytes, pos: pos}
data.writeByte(headerType) //header - OK or EOF
data.writeByte(headerType) // header - OK or EOF
data.writeLenEncInt(packetOk.affectedRows)
data.writeLenEncInt(packetOk.lastInsertID)
data.writeUint16(packetOk.statusFlags)
Expand Down Expand Up @@ -896,6 +904,10 @@ func (c *Conn) handleNextCommand(handler Handler) bool {
if len(data) == 0 {
return false
}
// before continue to process the packet, check if the connection should be closed or not.
if c.IsMarkedForClose() {
return false
}

switch data[0] {
case ComQuit:
Expand Down Expand Up @@ -1632,3 +1644,38 @@ func (c *Conn) IsUnixSocket() bool {
func (c *Conn) GetRawConn() net.Conn {
return c.conn
}

// CancelCtx aborts an existing running query
func (c *Conn) CancelCtx() {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
c.cancel()
}
}

// UpdateCancelCtx updates the cancel function on the connection.
func (c *Conn) UpdateCancelCtx(cancel context.CancelFunc) {
c.mu.Lock()
defer c.mu.Unlock()
c.cancel = cancel
}

// MarkForClose marks the connection for close.
func (c *Conn) MarkForClose() {
c.mu.Lock()
defer c.mu.Unlock()
c.closing = true
}

// IsMarkedForClose return true if the connection should be closed.
func (c *Conn) IsMarkedForClose() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closing
}

// GetTestConn returns a conn for testing purpose only.
func GetTestConn() *Conn {
return newConn(testConn{})
}
83 changes: 83 additions & 0 deletions go/mysql/conn_fake.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mysql

import (
"fmt"
"net"
"time"
)

// testConn to be used for testing only as net.Conn interface implementation.
type testConn struct {
writeToPass []bool
pos int
queryPacket []byte
}

func (t testConn) Read(b []byte) (n int, err error) {
copy(b, t.queryPacket)
return len(b), nil
}

func (t testConn) Write(b []byte) (n int, err error) {
t.pos = t.pos + 1
if t.writeToPass[t.pos] {
return 0, nil
}
return 0, fmt.Errorf("error in writing to connection")
}

func (t testConn) Close() error {
return nil
}

func (t testConn) LocalAddr() net.Addr {
panic("implement me")
}

func (t testConn) RemoteAddr() net.Addr {
return mockAddress{s: "a"}
}

func (t testConn) SetDeadline(t1 time.Time) error {
panic("implement me")
}

func (t testConn) SetReadDeadline(t1 time.Time) error {
panic("implement me")
}

func (t testConn) SetWriteDeadline(t1 time.Time) error {
panic("implement me")
}

var _ net.Conn = (*testConn)(nil)

type mockAddress struct {
s string
}

func (m mockAddress) Network() string {
return m.s
}

func (m mockAddress) String() string {
return m.s
}

var _ net.Addr = (*mockAddress)(nil)
Loading