diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 65de29756a..4096bf3a89 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -508,9 +508,9 @@ func TestInsertIgnoreInto(t *testing.T, harness Harness) { } // todo: merge this into the above test when https://github.com/dolthub/dolt/issues/3836 is fixed -func TestInsertIgnoreIntoWithDuplicateUniqueKeyKeyless(t *testing.T, harness Harness) { +func TestIgnoreIntoWithDuplicateUniqueKeyKeyless(t *testing.T, harness Harness) { harness.Setup(setup.MydbData) - for _, script := range queries.InsertIgnoreIntoWithDuplicateUniqueKeyKeylessScripts { + for _, script := range queries.IgnoreWithDuplicateUniqueKeyKeylessScripts { TestScript(t, harness, script) } @@ -583,6 +583,17 @@ func TestUpdate(t *testing.T, harness Harness) { } } +func TestUpdateIgnore(t *testing.T, harness Harness) { + harness.Setup(setup.MydbData, setup.MytableData, setup.Mytable_del_idxData, setup.FloattableData, setup.NiltableData, setup.TypestableData, setup.Pk_tablesData, setup.OthertableData, setup.TabletestData) + for _, tt := range queries.UpdateIgnoreTests { + RunWriteQueryTest(t, harness, tt) + } + + for _, script := range queries.UpdateIgnoreScripts { + TestScript(t, harness, script) + } +} + func TestUpdateErrors(t *testing.T, harness Harness) { harness.Setup(setup.MydbData, setup.MytableData, setup.FloattableData, setup.TypestableData) for _, expectedFailure := range queries.GenericUpdateErrorTests { diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index c59888f714..4945fdac4b 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -387,8 +387,8 @@ func TestInsertIgnoreInto(t *testing.T) { enginetest.TestInsertIgnoreInto(t, enginetest.NewDefaultMemoryHarness()) } -func TestInsertIgnoreIntoWithDuplicateUniqueKeyKeyless(t *testing.T) { - enginetest.TestInsertIgnoreIntoWithDuplicateUniqueKeyKeyless(t, enginetest.NewDefaultMemoryHarness()) +func TestIgnoreIntoWithDuplicateUniqueKeyKeyless(t *testing.T) { + enginetest.TestIgnoreIntoWithDuplicateUniqueKeyKeyless(t, enginetest.NewDefaultMemoryHarness()) } func TestInsertIntoErrors(t *testing.T) { @@ -435,6 +435,10 @@ func TestUpdate(t *testing.T) { enginetest.TestUpdate(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) } +func TestUpdateIgnore(t *testing.T) { + enginetest.TestUpdateIgnore(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) +} + func TestUpdateErrors(t *testing.T) { enginetest.TestUpdateErrors(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) } diff --git a/enginetest/queries/insert_queries.go b/enginetest/queries/insert_queries.go index 3bd658433d..9931779761 100644 --- a/enginetest/queries/insert_queries.go +++ b/enginetest/queries/insert_queries.go @@ -1739,7 +1739,7 @@ var InsertIgnoreScripts = []ScriptTest{ }, } -var InsertIgnoreIntoWithDuplicateUniqueKeyKeylessScripts = []ScriptTest{ +var IgnoreWithDuplicateUniqueKeyKeylessScripts = []ScriptTest{ { Name: "Test that INSERT IGNORE INTO works with unique keys on a keyless table", SetUpScript: []string{ @@ -1777,6 +1777,84 @@ var InsertIgnoreIntoWithDuplicateUniqueKeyKeylessScripts = []ScriptTest{ }, }, }, + { + Name: "INSERT IGNORE INTO multiple violations of a unique secondary index", + SetUpScript: []string{ + "CREATE TABLE keyless(pk int, val int)", + "INSERT INTO keyless values (1, 1), (2, 2), (3, 3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT IGNORE INTO keyless VALUES (1, 2);", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "ALTER TABLE keyless ADD CONSTRAINT c UNIQUE(val)", + ExpectedErr: sql.ErrUniqueKeyViolation, + }, + { + Query: "DELETE FROM keyless where pk = 1 and val = 2", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "ALTER TABLE keyless ADD CONSTRAINT c UNIQUE(val)", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + Query: "INSERT IGNORE INTO keyless VALUES (1, 3)", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + }, + }, + { + Name: "UPDATE IGNORE keyless tables and secondary indexes", + SetUpScript: []string{ + "CREATE TABLE keyless(pk int, val int)", + "INSERT INTO keyless VALUES (1, 1), (2, 2), (3, 3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE IGNORE keyless SET val = 2 where pk = 1", + Expected: []sql.Row{{newUpdateResult(1, 1)}}, + }, + { + Query: "SELECT * FROM keyless ORDER BY pk", + Expected: []sql.Row{{1, 2}, {2, 2}, {3, 3}}, + }, + { + Query: "ALTER TABLE keyless ADD CONSTRAINT c UNIQUE(val)", + ExpectedErr: sql.ErrUniqueKeyViolation, + }, + { + Query: "UPDATE IGNORE keyless SET val = 1 where pk = 1", + Expected: []sql.Row{{newUpdateResult(1, 1)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + { + Query: "ALTER TABLE keyless ADD CONSTRAINT c UNIQUE(val)", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + Query: "UPDATE IGNORE keyless SET val = 3 where pk = 1", + Expected: []sql.Row{{newUpdateResult(1, 0)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + { + Query: "SELECT * FROM keyless ORDER BY pk", + Expected: []sql.Row{{1, 1}, {2, 2}, {3, 3}}, + }, + { + Query: "UPDATE IGNORE keyless SET val = val + 1 ORDER BY pk", + Expected: []sql.Row{{newUpdateResult(3, 1)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + { + Query: "SELECT * FROM keyless ORDER BY pk", + Expected: []sql.Row{{1, 1}, {2, 2}, {3, 4}}, + }, + }, + }, } var InsertBrokenScripts = []ScriptTest{ diff --git a/enginetest/queries/update_queries.go b/enginetest/queries/update_queries.go index f5d745458c..9b9dfe79fb 100644 --- a/enginetest/queries/update_queries.go +++ b/enginetest/queries/update_queries.go @@ -15,6 +15,8 @@ package queries import ( + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" ) @@ -532,6 +534,161 @@ var GenericUpdateErrorTests = []GenericErrorQueryTest{ }, } +var UpdateIgnoreTests = []WriteQueryTest{ + { + WriteQuery: "UPDATE IGNORE mytable SET i = 2 where i = 1", + ExpectedWriteResult: []sql.Row{{newUpdateResult(1, 0)}}, + SelectQuery: "SELECT * FROM mytable order by i", + ExpectedSelect: []sql.Row{ + sql.NewRow(1, "first row"), + sql.NewRow(2, "second row"), + sql.NewRow(3, "third row"), + }, + }, + { + WriteQuery: "UPDATE IGNORE mytable SET i = i+1 where i = 1", + ExpectedWriteResult: []sql.Row{{newUpdateResult(1, 0)}}, + SelectQuery: "SELECT * FROM mytable order by i", + ExpectedSelect: []sql.Row{ + sql.NewRow(1, "first row"), + sql.NewRow(2, "second row"), + sql.NewRow(3, "third row"), + }, + }, +} + +var UpdateIgnoreScripts = []ScriptTest{ + { + Name: "UPDATE IGNORE with primary keys and indexes", + SetUpScript: []string{ + "CREATE TABLE pkTable(pk int, val int, primary key(pk, val))", + "CREATE TABLE idxTable(pk int primary key, val int UNIQUE)", + "INSERT INTO pkTable VALUES (1, 1), (2, 2), (3, 3)", + "INSERT INTO idxTable VALUES (1, 1), (2, 2), (3, 3)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE IGNORE pkTable set pk = pk + 1, val = val + 1", + Expected: []sql.Row{{newUpdateResult(3, 1)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + { + Query: "SELECT * FROM pkTable order by pk", + Expected: []sql.Row{{1, 1}, {2, 2}, {4, 4}}, + }, + { + Query: "UPDATE IGNORE idxTable set val = val + 1", + Expected: []sql.Row{{newUpdateResult(3, 1)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + { + Query: "SELECT * FROM idxTable order by pk", + Expected: []sql.Row{{1, 1}, {2, 2}, {3, 4}}, + }, + { + Query: "UPDATE IGNORE pkTable set val = val + 1 where pk = 2", + Expected: []sql.Row{{newUpdateResult(1, 1)}}, + }, + { + Query: "SELECT * FROM pkTable order by pk", + Expected: []sql.Row{{1, 1}, {2, 3}, {4, 4}}, + }, + { + Query: "UPDATE IGNORE pkTable SET pk = NULL", + Expected: []sql.Row{{newUpdateResult(3, 3)}}, + ExpectedWarning: mysql.ERBadNullError, + }, + { + Query: "SELECT * FROM pkTable order by pk", + Expected: []sql.Row{{0, 1}, {0, 3}, {0, 4}}, + }, + { + Query: "UPDATE IGNORE pkTable SET val = NULL", + Expected: []sql.Row{{newUpdateResult(3, 1)}}, + }, + { + Query: "SELECT * FROM pkTable order by pk", + Expected: []sql.Row{{0, 0}, {0, 3}, {0, 4}}, + }, + { + Query: "UPDATE IGNORE idxTable set pk = pk + 1, val = val + 1", // two bad updates + Expected: []sql.Row{{newUpdateResult(3, 1)}}, + ExpectedWarning: mysql.ERDupEntry, + }, + { + Query: "SELECT * FROM idxTable order by pk", + Expected: []sql.Row{{1, 1}, {2, 2}, {4, 5}}, + }, + }, + }, + { + Name: "UPDATE IGNORE with type conversions", + SetUpScript: []string{ + "CREATE TABLE t1 (pk int primary key, v1 int, v2 int)", + "INSERT INTO t1 VALUES (1, 1, 1)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE IGNORE t1 SET v1 = 'dsddads'", + Expected: []sql.Row{{newUpdateResult(1, 1)}}, + ExpectedWarning: mysql.ERTruncatedWrongValueForField, + }, + { + Query: "SELECT * FROM t1", + Expected: []sql.Row{{1, 0, 1}}, + }, + { + Query: "UPDATE IGNORE t1 SET pk = 'dasda', v2 = 'dsddads'", + Expected: []sql.Row{{newUpdateResult(1, 1)}}, + ExpectedWarning: mysql.ERTruncatedWrongValueForField, + }, + { + Query: "SELECT * FROM t1", + Expected: []sql.Row{{0, 0, 0}}, + }, + }, + }, + { + Name: "UPDATE IGNORE with foreign keys", + SetUpScript: []string{ + "CREATE TABLE colors ( id INT NOT NULL, color VARCHAR(32) NOT NULL, PRIMARY KEY (id), INDEX color_index(color));", + "CREATE TABLE objects (id INT NOT NULL, name VARCHAR(64) NOT NULL,color VARCHAR(32), PRIMARY KEY(id),FOREIGN KEY (color) REFERENCES colors(color))", + "INSERT INTO colors (id,color) VALUES (1,'red'),(2,'green'),(3,'blue'),(4,'purple')", + "INSERT INTO objects (id,name,color) VALUES (1,'truck','red'),(2,'ball','green'),(3,'shoe','blue')", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE IGNORE objects SET color = 'orange' where id = 2", + Expected: []sql.Row{{newUpdateResult(1, 0)}}, + ExpectedWarning: mysql.ErNoReferencedRow2, + }, + { + Query: "SELECT * FROM objects ORDER BY id", + Expected: []sql.Row{{1, "truck", "red"}, {2, "ball", "green"}, {3, "shoe", "blue"}}, + }, + }, + }, + { + Name: "UPDATE IGNORE with check constraints", + SetUpScript: []string{ + "CREATE TABLE checksTable(pk int primary key)", + "ALTER TABLE checksTable ADD CONSTRAINT mycx CHECK (pk < 5)", + "INSERT INTO checksTable VALUES (1),(2),(3),(4)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE IGNORE checksTable SET pk = pk + 1 where pk = 4", + Expected: []sql.Row{{newUpdateResult(1, 0)}}, + ExpectedWarning: mysql.ERUnknownError, + }, + { + Query: "SELECT * from checksTable ORDER BY pk", + Expected: []sql.Row{{1}, {2}, {3}, {4}}, + }, + }, + }, +} + var UpdateErrorTests = []QueryErrorTest{ { Query: `UPDATE keyless INNER JOIN one_pk on keyless.c0 = one_pk.pk SET keyless.c0 = keyless.c0 + 1`, diff --git a/sql/errors.go b/sql/errors.go index fbf8216eff..8d0b6d193e 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -615,6 +615,10 @@ func CastSQLError(err error) (*mysql.SQLError, error, bool) { return CastSQLError(w.Cause) } + if wm, ok := err.(WrappedTypeConversionError); ok { + return CastSQLError(wm.Err) + } + switch { case ErrTableNotFound.Is(err): code = mysql.ERNoSuchTable @@ -710,14 +714,30 @@ func (w WrappedInsertError) Error() string { return w.Cause.Error() } -type ErrInsertIgnore struct { +// IgnorableError is used propagate information about an error that needs to be ignored and does not interfere with +// any update accumulators +type IgnorableError struct { OffendingRow Row } -func NewErrInsertIgnore(row Row) ErrInsertIgnore { - return ErrInsertIgnore{OffendingRow: row} +func NewIgnorableError(row Row) IgnorableError { + return IgnorableError{OffendingRow: row} +} + +func (e IgnorableError) Error() string { + return "An ignorable error should never be printed" +} + +type WrappedTypeConversionError struct { + OffendingVal interface{} + OffendingIdx int + Err error +} + +func NewWrappedTypeConversionError(offendingVal interface{}, idx int, err error) WrappedTypeConversionError { + return WrappedTypeConversionError{OffendingVal: offendingVal, OffendingIdx: idx, Err: err} } -func (e ErrInsertIgnore) Error() string { - return "Insert ignore error shoudl never be printed" +func (w WrappedTypeConversionError) Error() string { + return w.Err.Error() } diff --git a/sql/expression/set.go b/sql/expression/set.go index eb0644b339..c23fa40b7f 100644 --- a/sql/expression/set.go +++ b/sql/expression/set.go @@ -67,9 +67,9 @@ func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { // Fill in error with information if sql.ErrLengthBeyondLimit.Is(err) { - return nil, sql.ErrLengthBeyondLimit.New(val, getField.Name()) + return nil, sql.NewWrappedTypeConversionError(val, getField.fieldIndex, sql.ErrLengthBeyondLimit.New(val, getField.Name())) } - return nil, err + return nil, sql.NewWrappedTypeConversionError(val, getField.fieldIndex, err) } val = convertedVal } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index f727ff1983..800d48bd9e 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1909,7 +1909,9 @@ func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { } } - return plan.NewUpdate(node, updateExprs), nil + ignore := d.Ignore != "" + + return plan.NewUpdate(node, ignore, updateExprs), nil } func convertLoad(ctx *sql.Context, d *sqlparser.Load) (sql.Node, error) { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index b7db95d39c..832317ffe5 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1705,16 +1705,13 @@ CREATE TABLE t2 expression.NewLiteral("a", sql.LongText), &expression.DefaultColumn{}, }}), false, []string{"col1", "col2"}, []sql.Expression{}, false), - `UPDATE t1 SET col1 = ?, col2 = ? WHERE id = ?`: plan.NewUpdate( - plan.NewFilter( - expression.NewEquals(expression.NewUnresolvedColumn("id"), expression.NewBindVar("v3")), - plan.NewUnresolvedTable("t1", ""), - ), - []sql.Expression{ - expression.NewSetField(expression.NewUnresolvedColumn("col1"), expression.NewBindVar("v1")), - expression.NewSetField(expression.NewUnresolvedColumn("col2"), expression.NewBindVar("v2")), - }, - ), + `UPDATE t1 SET col1 = ?, col2 = ? WHERE id = ?`: plan.NewUpdate(plan.NewFilter( + expression.NewEquals(expression.NewUnresolvedColumn("id"), expression.NewBindVar("v3")), + plan.NewUnresolvedTable("t1", ""), + ), false, []sql.Expression{ + expression.NewSetField(expression.NewUnresolvedColumn("col1"), expression.NewBindVar("v1")), + expression.NewSetField(expression.NewUnresolvedColumn("col2"), expression.NewBindVar("v2")), + }), `REPLACE INTO t1 (col1, col2) VALUES ('a', 1)`: plan.NewInsertInto(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.LongText), expression.NewLiteral(int8(1), sql.Int8), @@ -3703,15 +3700,12 @@ var triggerFixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo", ""), plan.NewBeginEndBlock( plan.NewBlock([]sql.Node{ - plan.NewUpdate( - plan.NewFilter( - expression.NewEquals(expression.NewUnresolvedColumn("z"), expression.NewUnresolvedQualifiedColumn("new", "y")), - plan.NewUnresolvedTable("bar", ""), - ), - []sql.Expression{ - expression.NewSetField(expression.NewUnresolvedColumn("x"), expression.NewUnresolvedQualifiedColumn("old", "y")), - }, - ), + plan.NewUpdate(plan.NewFilter( + expression.NewEquals(expression.NewUnresolvedColumn("z"), expression.NewUnresolvedQualifiedColumn("new", "y")), + plan.NewUnresolvedTable("bar", ""), + ), false, []sql.Expression{ + expression.NewSetField(expression.NewUnresolvedColumn("x"), expression.NewUnresolvedQualifiedColumn("old", "y")), + }), plan.NewDeleteFrom( plan.NewFilter( expression.NewEquals(expression.NewUnresolvedColumn("a"), expression.NewUnresolvedQualifiedColumn("old", "b")), diff --git a/sql/plan/insert.go b/sql/plan/insert.go index b6a2e2dfeb..f9cfeebd4a 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -332,7 +332,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) } if err != nil { - return i.ignoreOrClose(ctx, row, err) + return nil, i.ignoreOrClose(ctx, row, err) } // Prune the row down to the size of the schema. It can be larger in the case of running with an outer scope, in which @@ -343,12 +343,12 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) err = i.validateNullability(ctx, i.schema, row) if err != nil { - return i.ignoreOrClose(ctx, row, err) + return nil, i.ignoreOrClose(ctx, row, err) } err = i.evaluateChecks(ctx, row) if err != nil { - return i.ignoreOrClose(ctx, row, err) + return nil, i.ignoreOrClose(ctx, row, err) } // Do any necessary type conversions to the target schema @@ -357,10 +357,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) converted, cErr := col.Type.Convert(row[idx]) // allows for better error handling if cErr != nil { if i.ignore { - row, err = i.convertDataAndWarn(ctx, row, idx, cErr) - if err != nil { - return nil, err - } + row = convertDataAndWarn(ctx, i.schema, row, idx, cErr) continue } else { // Fill in error with information @@ -409,7 +406,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) } else { if err := i.inserter.Insert(ctx, row); err != nil { if (!sql.ErrPrimaryKeyViolation.Is(err) && !sql.ErrUniqueKeyViolation.Is(err) && !sql.ErrDuplicateEntry.Is(err)) || len(i.updateExprs) == 0 { - return i.ignoreOrClose(ctx, row, err) + return nil, i.ignoreOrClose(ctx, row, err) } ue := err.(*errors.Error).Cause().(sql.UniqueKeyError) @@ -438,10 +435,7 @@ func (i *insertIter) handleOnDuplicateKeyUpdate(ctx *sql.Context, row, rowToUpda return nil, err } - val, err = i.convertDataAndWarn(ctx, row, idx, err) - if err != nil { - return nil, err - } + val = convertDataAndWarn(ctx, i.schema, row, idx, err) } else { return nil, err } @@ -453,12 +447,12 @@ func (i *insertIter) handleOnDuplicateKeyUpdate(ctx *sql.Context, row, rowToUpda // Should revaluate the check conditions. err = i.evaluateChecks(ctx, newRow) if err != nil { - return i.ignoreOrClose(ctx, newRow, err) + return nil, i.ignoreOrClose(ctx, newRow, err) } err = i.updater.Update(ctx, rowToUpdate, newRow) if err != nil { - return i.ignoreOrClose(ctx, newRow, err) + return nil, i.ignoreOrClose(ctx, newRow, err) } // In the case that we attempted an update, return a concatenated [old,new] row just like update. @@ -556,29 +550,23 @@ func (i *insertIter) updateLastInsertId(ctx *sql.Context, row sql.Row) { } } -func (i *insertIter) ignoreOrClose(ctx *sql.Context, row sql.Row, err error) (sql.Row, error) { - if i.ignore { - err = i.warnOnIgnorableError(ctx, row, err) - if err != nil { - return nil, err - } - return nil, nil - } else { - i.rowSource.Close(ctx) - i.rowSource = nil - return nil, sql.NewWrappedInsertError(row, err) +func (i *insertIter) ignoreOrClose(ctx *sql.Context, row sql.Row, err error) error { + if !i.ignore { + return sql.NewWrappedInsertError(row, err) } + + return warnOnIgnorableError(ctx, row, err) } -// convertDataAndWarn modifies a row with data conversion issues in INSERT IGNORE calls +// convertDataAndWarn modifies a row with data conversion issues in INSERT/UPDATE IGNORE calls // Per MySQL docs "Rows set to values that would cause data conversion errors are set to the closest valid values instead" // cc. https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sql-mode-strict -func (i *insertIter) convertDataAndWarn(ctx *sql.Context, row sql.Row, columnIdx int, err error) (sql.Row, error) { +func convertDataAndWarn(ctx *sql.Context, tableSchema sql.Schema, row sql.Row, columnIdx int, err error) sql.Row { if sql.ErrLengthBeyondLimit.Is(err) { - maxLength := i.schema[columnIdx].Type.(sql.StringType).MaxCharacterLength() + maxLength := tableSchema[columnIdx].Type.(sql.StringType).MaxCharacterLength() row[columnIdx] = row[columnIdx].(string)[:maxLength] // truncate string } else { - row[columnIdx] = i.schema[columnIdx].Type.Zero() + row[columnIdx] = tableSchema[columnIdx].Type.Zero() } sqlerr, _, _ := sql.CastSQLError(err) @@ -590,14 +578,10 @@ func (i *insertIter) convertDataAndWarn(ctx *sql.Context, row sql.Row, columnIdx Message: err.Error(), }) - return row, nil + return row } -func (i *insertIter) warnOnIgnorableError(ctx *sql.Context, row sql.Row, err error) error { - if !i.ignore { - return err - } - +func warnOnIgnorableError(ctx *sql.Context, row sql.Row, err error) error { // Check that this error is a part of the list of Ignorable Errors and create the relevant warning for _, ie := range IgnorableErrors { if ie.Is(err) { @@ -616,7 +600,7 @@ func (i *insertIter) warnOnIgnorableError(ctx *sql.Context, row sql.Row, err err } // Return the InsertIgnore err to ensure our accumulator doesn't count this row. - return sql.NewErrInsertIgnore(row) + return sql.NewIgnorableError(row) } } @@ -737,7 +721,7 @@ func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema, // In the case of an IGNORE we set the nil value to a default and add a warning if i.ignore { row[count] = col.Type.Zero() - _ = i.warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil + _ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil } else { return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name) } diff --git a/sql/plan/row_update_accumulator.go b/sql/plan/row_update_accumulator.go index 63a5d0b366..115aa00ef6 100644 --- a/sql/plan/row_update_accumulator.go +++ b/sql/plan/row_update_accumulator.go @@ -89,6 +89,12 @@ type accumulatorRowHandler interface { okResult() sql.OkResult } +// TODO: Extend this to UPDATE IGNORE JOIN +type updateIgnoreAccumulatorRowHandler interface { + accumulatorRowHandler + handleRowUpdateWithIgnore(row sql.Row, ignore bool) error +} + type insertRowHandler struct { rowsAffected int } @@ -185,6 +191,15 @@ func (u *updateRowHandler) handleRowUpdate(row sql.Row) error { return nil } +func (u *updateRowHandler) handleRowUpdateWithIgnore(row sql.Row, ignore bool) error { + if !ignore { + return u.handleRowUpdate(row) + } + + u.rowsMatched++ + return nil +} + func (u *updateRowHandler) okResult() sql.OkResult { affected := u.rowsAffected if u.clientFoundRowsCapability { @@ -312,14 +327,14 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { for { row, err := a.iter.Next(ctx) - _, isIg := err.(sql.ErrInsertIgnore) + igErr, isIg := err.(sql.IgnorableError) select { case <-ctx.Done(): return nil, ctx.Err() default: } if err == io.EOF { - res := a.updateRowHandler.okResult() + res := a.updateRowHandler.okResult() // TODO: Should add warnings here // TODO: The information flow here is pretty gnarly. We // set some session variables based on the result, and @@ -344,14 +359,19 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { return sql.NewRow(res), nil } else if isIg { - continue + if ui, ok := a.updateRowHandler.(updateIgnoreAccumulatorRowHandler); ok { + err = ui.handleRowUpdateWithIgnore(igErr.OffendingRow, true) + if err != nil { + return nil, err + } + } } else if err != nil { return nil, err - } - - err = a.updateRowHandler.handleRowUpdate(row) - if err != nil { - return nil, err + } else { + err = a.updateRowHandler.handleRowUpdate(row) + if err != nil { + return nil, err + } } } } diff --git a/sql/plan/set.go b/sql/plan/set.go index d7ebecd492..eeee7c2323 100644 --- a/sql/plan/set.go +++ b/sql/plan/set.go @@ -215,3 +215,20 @@ func (s *Set) DebugString() string { } return strings.Join(children, ", ") } + +// Applies the update expressions given to the row given, returning the new resultant row. +func applyUpdateExpressions(ctx *sql.Context, updateExprs []sql.Expression, row sql.Row) (sql.Row, error) { + var ok bool + prev := row + for _, updateExpr := range updateExprs { + val, err := updateExpr.Eval(ctx, prev) + if err != nil { + return nil, err + } + prev, ok = val.(sql.Row) + if !ok { + return nil, ErrUpdateUnexpectedSetResult.New(val) + } + } + return prev, nil +} diff --git a/sql/plan/table_editor.go b/sql/plan/table_editor.go index 39a469eb01..90dbe68379 100644 --- a/sql/plan/table_editor.go +++ b/sql/plan/table_editor.go @@ -65,7 +65,7 @@ func (s *tableEditorIter) Next(ctx *sql.Context) (sql.Row, error) { func (s *tableEditorIter) Close(ctx *sql.Context) error { err := s.errorEncountered - _, ok := err.(sql.ErrInsertIgnore) + _, ok := err.(sql.IgnorableError) if err != nil && !ok { err = s.editor.DiscardChanges(ctx, s.errorEncountered) diff --git a/sql/plan/update.go b/sql/plan/update.go index 32d433959d..347662cead 100644 --- a/sql/plan/update.go +++ b/sql/plan/update.go @@ -30,17 +30,21 @@ var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but ex type Update struct { UnaryNode Checks sql.CheckConstraints + Ignore bool } var _ sql.Databaseable = (*Update)(nil) // NewUpdate creates an Update node. -func NewUpdate(n sql.Node, updateExprs []sql.Expression) *Update { +func NewUpdate(n sql.Node, ignore bool, updateExprs []sql.Expression) *Update { return &Update{ UnaryNode: UnaryNode{NewUpdateSource( n, + ignore, updateExprs, - )}} + )}, + Ignore: ignore, + } } func GetUpdatable(node sql.Node) (sql.UpdatableTable, error) { @@ -144,6 +148,7 @@ type updateIter struct { updater sql.RowUpdater checks sql.CheckConstraints closed bool + ignore bool } func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) { @@ -154,7 +159,6 @@ func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) { oldRow, newRow := oldAndNewRow[:len(oldAndNewRow)/2], oldAndNewRow[len(oldAndNewRow)/2:] if equals, err := oldRow.Equals(newRow, u.schema); err == nil { - // TODO: we aren't enforcing other kinds of constraints here, like nullability if !equals { // apply check constraints for _, check := range u.checks { @@ -168,18 +172,18 @@ func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) { } if sql.IsFalse(res) { - return nil, sql.ErrCheckConstraintViolated.New(check.Name) + return nil, u.ignoreOrError(ctx, newRow, sql.ErrCheckConstraintViolated.New(check.Name)) } } - err := u.validateNullability(newRow, u.schema) + err := u.validateNullability(ctx, newRow, u.schema) if err != nil { - return nil, err + return nil, u.ignoreOrError(ctx, newRow, err) } err = u.updater.Update(ctx, oldRow, newRow) if err != nil { - return nil, err + return nil, u.ignoreOrError(ctx, newRow, err) } } } else { @@ -189,15 +193,23 @@ func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) { return oldAndNewRow, nil } -// Applies the update expressions given to the row given, returning the new resultant row. +// Applies the update expressions given to the row given, returning the new resultant row. In the case that ignore is +// provided and there is a type conversion error, this function sets the value to the zero value as per the MySQL standard. // TODO: a set of update expressions should probably be its own expression type with an Eval method that does this -func applyUpdateExpressions(ctx *sql.Context, updateExprs []sql.Expression, row sql.Row) (sql.Row, error) { +func applyUpdateExpressionsWithIgnore(ctx *sql.Context, updateExprs []sql.Expression, tableSchema sql.Schema, row sql.Row, ignore bool) (sql.Row, error) { var ok bool prev := row for _, updateExpr := range updateExprs { val, err := updateExpr.Eval(ctx, prev) if err != nil { - return nil, err + wtce, ok2 := err.(sql.WrappedTypeConversionError) + if !ok2 || !ignore { + return nil, err + } + + cpy := prev.Copy() + cpy[wtce.OffendingIdx] = wtce.OffendingVal // Needed for strings + val = convertDataAndWarn(ctx, tableSchema, cpy, wtce.OffendingIdx, wtce.Err) } prev, ok = val.(sql.Row) if !ok { @@ -207,12 +219,18 @@ func applyUpdateExpressions(ctx *sql.Context, updateExprs []sql.Expression, row return prev, nil } -func (u *updateIter) validateNullability(row sql.Row, schema sql.Schema) error { +func (u *updateIter) validateNullability(ctx *sql.Context, row sql.Row, schema sql.Schema) error { for idx := 0; idx < len(row); idx++ { col := schema[idx] if !col.Nullable && row[idx] == nil { // In the case of an IGNORE we set the nil value to a default and add a warning - return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name) + if u.ignore { + row[idx] = col.Type.Zero() + _ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil + } else { + return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name) + } + } } return nil @@ -229,18 +247,37 @@ func (u *updateIter) Close(ctx *sql.Context) error { return nil } +func (u *updateIter) ignoreOrError(ctx *sql.Context, row sql.Row, err error) error { + if !u.ignore { + return err + } + + return warnOnIgnorableError(ctx, row, err) +} + func newUpdateIter( childIter sql.RowIter, schema sql.Schema, updater sql.RowUpdater, checks sql.CheckConstraints, + ignore bool, ) sql.RowIter { - return NewTableEditorIter(updater, &updateIter{ - childIter: childIter, - updater: updater, - schema: schema, - checks: checks, - }) + if ignore { + return NewCheckpointingTableEditorIter(updater, &updateIter{ + childIter: childIter, + updater: updater, + schema: schema, + checks: checks, + ignore: true, + }) + } else { + return NewTableEditorIter(updater, &updateIter{ + childIter: childIter, + updater: updater, + schema: schema, + checks: checks, + }) + } } // RowIter implements the Node interface. @@ -256,7 +293,7 @@ func (u *Update) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { return nil, err } - return newUpdateIter(iter, updatable.Schema(), updater, u.Checks), nil + return newUpdateIter(iter, updatable.Schema(), updater, u.Checks, u.Ignore), nil } // WithChildren implements the Node interface. diff --git a/sql/plan/update_source.go b/sql/plan/update_source.go index 0098b800e5..a976f3b89b 100644 --- a/sql/plan/update_source.go +++ b/sql/plan/update_source.go @@ -25,13 +25,15 @@ import ( type UpdateSource struct { UnaryNode UpdateExprs []sql.Expression + Ignore bool } // NewUpdateSource returns a new UpdateSource from the node and expressions given. -func NewUpdateSource(node sql.Node, updateExprs []sql.Expression) *UpdateSource { +func NewUpdateSource(node sql.Node, ignore bool, updateExprs []sql.Expression) *UpdateSource { return &UpdateSource{ UnaryNode: UnaryNode{node}, UpdateExprs: updateExprs, + Ignore: ignore, } } @@ -45,7 +47,7 @@ func (u *UpdateSource) WithExpressions(newExprs ...sql.Expression) (sql.Node, er if len(newExprs) != len(u.UpdateExprs) { return nil, sql.ErrInvalidChildrenNumber.New(u, len(u.UpdateExprs), 1) } - return NewUpdateSource(u.Child, newExprs), nil + return NewUpdateSource(u.Child, u.Ignore, newExprs), nil } // Schema implements sql.Node. The schema of an update is a concatenation of the old and new rows. @@ -92,6 +94,7 @@ type updateSourceIter struct { childIter sql.RowIter updateExprs []sql.Expression tableSchema sql.Schema + ignore bool } func (u *updateSourceIter) Next(ctx *sql.Context) (sql.Row, error) { @@ -100,7 +103,7 @@ func (u *updateSourceIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - newRow, err := applyUpdateExpressions(ctx, u.updateExprs, oldRow) + newRow, err := applyUpdateExpressionsWithIgnore(ctx, u.updateExprs, u.tableSchema, oldRow, u.ignore) if err != nil { return nil, err } @@ -149,6 +152,7 @@ func (u *UpdateSource) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, erro childIter: rowIter, updateExprs: u.UpdateExprs, tableSchema: schema, + ignore: u.Ignore, }, nil } @@ -156,7 +160,7 @@ func (u *UpdateSource) WithChildren(children ...sql.Node) (sql.Node, error) { if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) } - return NewUpdateSource(children[0], u.UpdateExprs), nil + return NewUpdateSource(children[0], u.Ignore, u.UpdateExprs), nil } // CheckPrivileges implements the interface sql.Node. diff --git a/sql/plan/update_test.go b/sql/plan/update_test.go new file mode 100644 index 0000000000..32396145a2 --- /dev/null +++ b/sql/plan/update_test.go @@ -0,0 +1,103 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "testing" + "time" + + "github.com/dolthub/vitess/go/sqltypes" + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/memory" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" +) + +func TestUpdateIgnoreConversions(t *testing.T) { + ctx := sql.NewEmptyContext() + testCases := []struct { + name string + colType sql.Type + value interface{} + valueType sql.Type + expected interface{} + }{ + { + name: "inserting a string into a integer defaults to a 0", + colType: sql.Int64, + value: "dadasd", + valueType: sql.Text, + expected: int64(0), + }, + { + name: "string too long gets truncated", + colType: sql.MustCreateStringWithDefaults(sqltypes.VarChar, 2), + value: "dadsa", + valueType: sql.Text, + expected: "da", + }, + { + name: "inserting a string into a datetime results in 0 time", + colType: sql.Datetime, + value: "dadasd", + valueType: sql.Text, + expected: time.Unix(-62167219200, 0).UTC(), + }, + { + name: "inserting a negative into an unsigned int results in 0", + colType: sql.Uint64, + value: -1, + valueType: sql.Int8, + expected: uint64(0), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sch := sql.NewPrimaryKeySchema(sql.Schema{ + {Name: "c1", Source: "foo", Type: tc.colType, Nullable: true}, + }) + table := memory.NewTable("foo", sch, nil) + + err := table.Insert(ctx, sql.Row{nil}) + require.NoError(t, err) + + // Run the UPDATE IGNORE + sf := expression.NewSetField(expression.NewGetField(0, tc.colType, "c1", true), expression.NewLiteral(tc.value, tc.valueType)) + updatePlan := NewUpdate(NewResolvedTable(table, nil, nil), true, []sql.Expression{sf}) + + ri, err := updatePlan.RowIter(ctx, nil) + require.NoError(t, err) + + _, err = sql.RowIterToRows(ctx, sch.Schema, ri) + require.NoError(t, err) + + // Run a SELECT to see the updated data + selectPlan := NewProject([]sql.Expression{ + expression.NewGetField(0, tc.colType, "c1", true), + }, NewResolvedTable(table, nil, nil)) + + ri, err = selectPlan.RowIter(ctx, nil) + require.NoError(t, err) + + rows, err := sql.RowIterToRows(ctx, sch.Schema, ri) + require.NoError(t, err) + + require.Equal(t, 1, len(rows)) + require.Equal(t, tc.expected, rows[0][0]) + }) + } +}