From f1495bc206f2c7025a235ba50c6041eb2cb62c96 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 31 Oct 2022 12:39:06 -0700 Subject: [PATCH 1/3] Casting errors during ComPrepare as SQLErrors so that error codes show up for clients. Added unit test and simplified interface for CastSQLError. --- enginetest/evaluation.go | 3 +-- server/handler.go | 12 ++++-------- server/handler_test.go | 25 ++++++++++++++++++++----- sql/errors.go | 12 ++++++++---- sql/errors_test.go | 4 ++-- sql/plan/insert.go | 4 ++-- 6 files changed, 37 insertions(+), 23 deletions(-) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index c1c88bd572..d577853383 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -623,8 +623,7 @@ func AssertErrWithCtx(t *testing.T, e *sqle.Engine, harness Harness, ctx *sql.Co } require.Error(t, err) if expectedErrKind != nil { - _, orig, _ := sql.CastSQLError(err) - require.True(t, expectedErrKind.Is(orig), "Expected error of type %s but got %s", expectedErrKind, err) + require.True(t, expectedErrKind.Is(err), "Expected error of type %s but got %s", expectedErrKind, err) } // If there are multiple error strings then we only match against the first if len(errStrs) >= 1 { diff --git a/server/handler.go b/server/handler.go index af7c039b91..daa592c1a4 100644 --- a/server/handler.go +++ b/server/handler.go @@ -112,6 +112,7 @@ func (h *Handler) ComPrepare(c *mysql.Conn, query string) ([]*query.Field, error analyzed, err = h.e.PrepareQuery(ctx, query) } if err != nil { + err := sql.CastSQLError(err) return nil, err } @@ -613,18 +614,13 @@ func (h *Handler) errorWrappedDoQuery( } remainder, err := h.doQuery(c, query, mode, bindings, callback) - err, _, ok := sql.CastSQLError(err) - - var retErr error - if !ok { - retErr = err - } + err = sql.CastSQLError(err) if h.sel != nil { - h.sel.QueryCompleted(retErr == nil, time.Since(start)) + h.sel.QueryCompleted(err == nil, time.Since(start)) } - return remainder, retErr + return remainder, err } // Periodically polls the connection socket to determine if it is has been closed by the client, returning an error diff --git a/server/handler_test.go b/server/handler_test.go index 1ecdf3de92..79cf884914 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -177,9 +177,10 @@ func TestHandlerComPrepare(t *testing.T) { handler.NewConnection(dummyConn) type testcase struct { - name string - statement string - expected []*query.Field + name string + statement string + expected []*query.Field + expectedErr *mysql.SQLError } for _, test := range []testcase{ @@ -205,12 +206,26 @@ func TestHandlerComPrepare(t *testing.T) { {Name: "c1", Type: query.Type_INT32, Charset: mysql.CharacterSetUtf8, ColumnLength: 11}, }, }, + { + name: "errors are cast to SQLError", + statement: "SELECT * from doesnotexist LIMIT ?", + expectedErr: mysql.NewSQLError(mysql.ERNoSuchTable, "", "table not found: %s", "doesnotexist"), + }, } { t.Run(test.name, func(t *testing.T) { handler.ComInitDB(dummyConn, "test") schema, err := handler.ComPrepare(dummyConn, test.statement) - require.NoError(t, err) - require.Equal(t, test.expected, schema) + if test.expectedErr == nil { + require.NoError(t, err) + require.Equal(t, test.expected, schema) + } else { + require.NotNil(t, err) + sqlErr, isSqlError := err.(*mysql.SQLError) + require.True(t, isSqlError) + require.Equal(t, test.expectedErr.Number(), sqlErr.Number()) + require.Equal(t, test.expectedErr.SQLState(), sqlErr.SQLState()) + require.Equal(t, test.expectedErr.Error(), sqlErr.Error()) + } }) } } diff --git a/sql/errors.go b/sql/errors.go index e1e61a3035..d51a48171a 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -650,12 +650,16 @@ var ( ErrNoTablesUsed = errors.NewKind("No tables used") ) -func CastSQLError(err error) (*mysql.SQLError, error, bool) { +// CastSQLError returns a *mysql.SQLError with the error code and in some cases, also a SQL state, populated for the +// specified error object. Using this method enables Vitess to return an error code, instead of just "unknown error". +// Many tools (e.g. ORMs, SQL workbenches) rely on this error metadata to work correctly. If the specified error is nil, +// nil will be returned. If the error is already of type *mysql.SQLError, the error will be returend as is. +func CastSQLError(err error) *mysql.SQLError { if err == nil { - return nil, nil, true + return nil } if mysqlErr, ok := err.(*mysql.SQLError); ok { - return mysqlErr, nil, false + return mysqlErr } var code int @@ -722,7 +726,7 @@ func CastSQLError(err error) (*mysql.SQLError, error, bool) { } // This uses the given error as a format string, so we have to escape any percentage signs else they'll show up as "%!(MISSING)" - return mysql.NewSQLError(code, sqlState, strings.Replace(err.Error(), `%`, `%%`, -1)), err, false // return the original error as well + return mysql.NewSQLError(code, sqlState, strings.Replace(err.Error(), `%`, `%%`, -1)) } type UniqueKeyError struct { diff --git a/sql/errors_test.go b/sql/errors_test.go index 5c156735a2..6cded7c684 100644 --- a/sql/errors_test.go +++ b/sql/errors_test.go @@ -25,8 +25,8 @@ func TestSQLErrorCast(t *testing.T) { for _, test := range tests { var nilErr *mysql.SQLError = nil t.Run(fmt.Sprintf("%v %v", test.err, test.code), func(t *testing.T) { - err, _, ok := CastSQLError(test.err) - if !ok { + err := CastSQLError(test.err) + if err != nil { require.Error(t, err) assert.Equal(t, err.Number(), test.code) } else { diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 379ad5c63c..9eb5d2a1d9 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -569,7 +569,7 @@ func convertDataAndWarn(ctx *sql.Context, tableSchema sql.Schema, row sql.Row, c row[columnIdx] = tableSchema[columnIdx].Type.Zero() } - sqlerr, _, _ := sql.CastSQLError(err) + sqlerr := sql.CastSQLError(err) // Add a warning instead ctx.Session.Warn(&sql.Warning{ @@ -585,7 +585,7 @@ func warnOnIgnorableError(ctx *sql.Context, row sql.Row, err error) error { // Check that this error is a part of the list of Ignorable Errors and create the relevant warning for _, ie := range IgnorableErrors { if ie.Is(err) { - sqlerr, _, _ := sql.CastSQLError(err) + sqlerr := sql.CastSQLError(err) // Add a warning instead ctx.Session.Warn(&sql.Warning{ From 7e25b65fc3eb5b5266f8a867862210a0a79bfdc4 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 31 Oct 2022 13:26:46 -0700 Subject: [PATCH 2/3] Adding a new UnwrapError method for tests to use. --- enginetest/evaluation.go | 1 + sql/errors.go | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index d577853383..370d03e708 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -623,6 +623,7 @@ func AssertErrWithCtx(t *testing.T, e *sqle.Engine, harness Harness, ctx *sql.Co } require.Error(t, err) if expectedErrKind != nil { + err = sql.UnwrapError(err) require.True(t, expectedErrKind.Is(err), "Expected error of type %s but got %s", expectedErrKind, err) } // If there are multiple error strings then we only match against the first diff --git a/sql/errors.go b/sql/errors.go index d51a48171a..e1062a3786 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -653,7 +653,7 @@ var ( // CastSQLError returns a *mysql.SQLError with the error code and in some cases, also a SQL state, populated for the // specified error object. Using this method enables Vitess to return an error code, instead of just "unknown error". // Many tools (e.g. ORMs, SQL workbenches) rely on this error metadata to work correctly. If the specified error is nil, -// nil will be returned. If the error is already of type *mysql.SQLError, the error will be returend as is. +// nil will be returned. If the error is already of type *mysql.SQLError, the error will be returned as is. func CastSQLError(err error) *mysql.SQLError { if err == nil { return nil @@ -729,6 +729,19 @@ func CastSQLError(err error) *mysql.SQLError { return mysql.NewSQLError(code, sqlState, strings.Replace(err.Error(), `%`, `%%`, -1)) } +// UnwrapError removes any wrapping errors (e.g. WrappedInsertError) around the specified error and +// returns the first non-wrapped error type. +func UnwrapError(err error) error { + switch wrappedError := err.(type) { + case WrappedInsertError: + return UnwrapError(wrappedError.Cause) + case WrappedTypeConversionError: + return UnwrapError(wrappedError.Err) + default: + return err + } +} + type UniqueKeyError struct { keyStr string IsPK bool From 1abc1c633ad245db58c6a9d61df691de2104ac42 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 31 Oct 2022 14:12:38 -0700 Subject: [PATCH 3/3] Only call CastSQLError if the error is not nil, to avoid Go's non-nil nil interface values. --- server/handler.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/handler.go b/server/handler.go index daa592c1a4..2ed2d7e863 100644 --- a/server/handler.go +++ b/server/handler.go @@ -614,7 +614,9 @@ func (h *Handler) errorWrappedDoQuery( } remainder, err := h.doQuery(c, query, mode, bindings, callback) - err = sql.CastSQLError(err) + if err != nil { + err = sql.CastSQLError(err) + } if h.sel != nil { h.sel.QueryCompleted(err == nil, time.Since(start))