Skip to content

Commit

Permalink
server event listener (#685)
Browse files Browse the repository at this point in the history
  • Loading branch information
bheni authored Dec 16, 2021
1 parent f81f87e commit 71f612c
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 8 deletions.
34 changes: 29 additions & 5 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,26 @@ type Handler struct {
sm *SessionManager
readTimeout time.Duration
disableMultiStmts bool
sel ServerEventListener
}

// NewHandler creates a new Handler given a SQLe engine.
func NewHandler(e *sqle.Engine, sm *SessionManager, rt time.Duration, disableMultiStmts bool) *Handler {
func NewHandler(e *sqle.Engine, sm *SessionManager, rt time.Duration, disableMultiStmts bool, listener ServerEventListener) *Handler {
return &Handler{
e: e,
sm: sm,
readTimeout: rt,
disableMultiStmts: disableMultiStmts,
sel: listener,
}
}

// NewConnection reports that a new connection has been established.
func (h *Handler) NewConnection(c *mysql.Conn) {
if h.sel != nil {
h.sel.ClientConnected()
}

c.DisableClientMultiStatements = h.disableMultiStmts
logrus.WithField(sqle.ConnectionIdLogField, c.ConnectionID).WithField("DisableClientMultiStatements", c.DisableClientMultiStatements).Infof("NewConnection")
}
Expand All @@ -107,6 +113,7 @@ func (h *Handler) ComPrepare(c *mysql.Conn, query string) ([]*query.Field, error
}

func (h *Handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {

_, err := h.errorWrappedDoQuery(c, prepare.PrepareStmt, MultiStmtModeOff, prepare.BindVars, func(res *sqltypes.Result, more bool) error {
return callback(res)
})
Expand All @@ -119,6 +126,12 @@ func (h *Handler) ComResetConnection(c *mysql.Conn) {

// ConnectionClosed reports that a connection has been closed.
func (h *Handler) ConnectionClosed(c *mysql.Conn) {
defer func() {
if h.sel != nil {
h.sel.ClientDisconnected()
}
}()

ctx, _ := h.sm.NewContextWithQuery(c, "")
h.sm.CloseConn(c)

Expand Down Expand Up @@ -541,13 +554,24 @@ func (h *Handler) errorWrappedDoQuery(
bindings map[string]*query.BindVariable,
callback func(*sqltypes.Result, bool) error,
) (string, error) {
start := time.Now()
if h.sel != nil {
h.sel.QueryStarted()
}

remainder, err := h.doQuery(c, query, mode, bindings, callback)
err, _, ok := sql.CastSQLError(err)
if ok {
return remainder, nil
} else {
return remainder, err

var retErr error
if !ok {
retErr = err
}

if h.sel != nil {
h.sel.QueryCompleted(retErr == nil, time.Since(start))
}

return remainder, retErr
}

// Periodically polls the connection socket to determine if it is has been closed by the client, returning an error
Expand Down
107 changes: 107 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func TestHandlerOutput(t *testing.T) {
),
0,
false,
nil,
)
handler.NewConnection(dummyConn)

Expand Down Expand Up @@ -158,6 +159,7 @@ func TestHandlerComPrepare(t *testing.T) {
),
0,
false,
nil,
)
handler.NewConnection(dummyConn)

Expand Down Expand Up @@ -200,6 +202,106 @@ func TestHandlerComPrepare(t *testing.T) {
}
}

type TestListener struct {
Connections int
Queries int
Disconnects int
Successes int
Failures int
}

func (tl *TestListener) ClientConnected() {
tl.Connections++
}

func (tl *TestListener) ClientDisconnected() {
tl.Disconnects++
}

func (tl *TestListener) QueryStarted() {
tl.Queries++
}

func (tl *TestListener) QueryCompleted(success bool, duration time.Duration) {
if success {
tl.Successes++
} else {
tl.Failures++
}
}

func TestServerEventListener(t *testing.T) {
require := require.New(t)
e := setupMemDB(require)
listener := &TestListener{}
handler := NewHandler(
e,
NewSessionManager(
func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
},
opentracing.NoopTracer{},
func(db string) bool { return db == "test" },
e.MemoryManager,
e.ProcessList,
"foo",
),
0,
false,
listener,
)

cb := func(res *sqltypes.Result, more bool) error {
return nil
}

require.Equal(listener.Connections, 0)
require.Equal(listener.Disconnects, 0)
require.Equal(listener.Queries, 0)
require.Equal(listener.Successes, 0)
require.Equal(listener.Failures, 0)

conn1 := newConn(1)
handler.NewConnection(conn1)
require.Equal(listener.Connections, 1)
require.Equal(listener.Disconnects, 0)

err := handler.sm.SetDB(conn1, "test")
require.NoError(err)

err = handler.ComQuery(conn1, "SELECT 1", cb)
require.NoError(err)
require.Equal(listener.Queries, 1)
require.Equal(listener.Successes, 1)
require.Equal(listener.Failures, 0)

conn2 := newConn(2)
handler.NewConnection(conn2)
require.Equal(listener.Connections, 2)
require.Equal(listener.Disconnects, 0)

handler.ComInitDB(conn2, "test")
err = handler.ComQuery(conn2, "select 1", cb)
require.NoError(err)
require.Equal(listener.Queries, 2)
require.Equal(listener.Successes, 2)
require.Equal(listener.Failures, 0)

err = handler.ComQuery(conn1, "select bad_col from bad_table with illegal syntax", cb)
require.Error(err)
require.Equal(listener.Queries, 3)
require.Equal(listener.Successes, 2)
require.Equal(listener.Failures, 1)

handler.ConnectionClosed(conn1)
require.Equal(listener.Connections, 2)
require.Equal(listener.Disconnects, 1)

handler.ConnectionClosed(conn2)
require.Equal(listener.Connections, 2)
require.Equal(listener.Disconnects, 2)
}

func TestHandlerKill(t *testing.T) {
require := require.New(t)
e := setupMemDB(require)
Expand All @@ -218,6 +320,7 @@ func TestHandlerKill(t *testing.T) {
),
0,
false,
nil,
)

conn1 := newConn(1)
Expand Down Expand Up @@ -298,6 +401,7 @@ func TestHandlerTimeout(t *testing.T) {
"foo"),
1*time.Second,
false,
nil,
)

noTimeOutHandler := NewHandler(
Expand All @@ -309,6 +413,7 @@ func TestHandlerTimeout(t *testing.T) {
"foo"),
0,
false,
nil,
)
require.Equal(1*time.Second, timeOutHandler.readTimeout)
require.Equal(0*time.Second, noTimeOutHandler.readTimeout)
Expand Down Expand Up @@ -364,6 +469,7 @@ func TestOkClosedConnection(t *testing.T) {
),
0,
false,
nil,
)
c := newConn(1)
h.NewConnection(c)
Expand Down Expand Up @@ -517,6 +623,7 @@ func TestHandlerFoundRowsCapabilities(t *testing.T) {
),
0,
false,
nil,
)

tests := []struct {
Expand Down
17 changes: 14 additions & 3 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,29 @@
package server

import (
"time"

"github.com/dolthub/vitess/go/mysql"
"github.com/opentracing/opentracing-go"

sqle "github.com/dolthub/go-mysql-server"
)

type ServerEventListener interface {
ClientConnected()
ClientDisconnected()
QueryStarted()
QueryCompleted(success bool, duration time.Duration)
}

// NewDefaultServer creates a Server with the default session builder.
func NewDefaultServer(cfg Config, e *sqle.Engine) (*Server, error) {
return NewServer(cfg, e, DefaultSessionBuilder)
return NewServer(cfg, e, DefaultSessionBuilder, nil)
}

// NewServer creates a server with the given protocol, address, authentication
// details given a SQLe engine and a session builder.
func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder) (*Server, error) {
func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder, listener ServerEventListener) (*Server, error) {
var tracer opentracing.Tracer
if cfg.Tracer != nil {
tracer = cfg.Tracer
Expand Down Expand Up @@ -57,7 +66,9 @@ func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder) (*Server, error) {
e.ProcessList,
cfg.Address),
cfg.ConnReadTimeout,
cfg.DisableClientMultiStatements)
cfg.DisableClientMultiStatements,
listener,
)
a := cfg.Auth.Mysql()
l, err := NewListener(cfg.Protocol, cfg.Address, handler)
if err != nil {
Expand Down

0 comments on commit 71f612c

Please sign in to comment.