diff --git a/README.md b/README.md index 2d70b61..9a5a42c 100644 --- a/README.md +++ b/README.md @@ -159,12 +159,12 @@ The syntax is case-insensitive. | Create index | `CREATE INDEX ...;` | | | Delete index | `DROP INDEX ...;` | | | Query | `SELECT ...;` | | -| DML | `INSERT / UPDATE / DELETE ...;` | | -| Partitioned DML | | Not supported yet | +| DML | `(INSERT\|UPDATE\|DELETE) ...;` | | +| Partitioned DML | `PARTITIONED (UPDATE\|DELETE) ...;` | | | Show Query Execution Plan | `EXPLAIN SELECT ...;` | | | Show DML Execution Plan | `EXPLAIN INSERT / UPDATE / DELETE ...;` | EXPERIMENTAL | | Show Query Execution Plan with Stats | `EXPLAIN ANALYZE SELECT ...;` | EXPERIMENTAL | -| Show DML Execution Plan with Stats | `EXPLAIN ANALYZE INSERT / UPDATE / DELETE ...;` | EXPERIMENTAL | +| Show DML Execution Plan with Stats | `EXPLAIN ANALYZE (INSERT\|UPDATE\|DELETE) ...;` | EXPERIMENTAL | | Start Read-Write Transaction | `BEGIN (RW);` | | | Commit Read-Write Transaction | `COMMIT;` | | | Rollback Read-Write Transaction | `ROLLBACK;` | | diff --git a/cli.go b/cli.go index ef49717..d0a99bf 100644 --- a/cli.go +++ b/cli.go @@ -392,11 +392,18 @@ func resultLine(result *Result, verbose bool) string { } if result.IsMutation { + var affectedRowsPrefix string + if result.AffectedRowsType == rowCountTypeLowerBound { + // For Partitioned DML the result's row count is lower bounded number, so we add "at least" to express ambiguity. + // See https://cloud.google.com/spanner/docs/reference/rpc/google.spanner.v1?hl=en#resultsetstats + affectedRowsPrefix = "at least " + } if verbose && timestamp != "" { - return fmt.Sprintf("Query OK, %d rows affected (%s)\ntimestamp: %s\n", result.AffectedRows, - result.Stats.ElapsedTime, timestamp) + return fmt.Sprintf("Query OK, %s%d rows affected (%s)\ntimestamp: %s\n", + affectedRowsPrefix, result.AffectedRows, result.Stats.ElapsedTime, timestamp) } - return fmt.Sprintf("Query OK, %d rows affected (%s)\n", result.AffectedRows, result.Stats.ElapsedTime) + return fmt.Sprintf("Query OK, %s%d rows affected (%s)\n", + affectedRowsPrefix, result.AffectedRows, result.Stats.ElapsedTime) } var set string diff --git a/integration_test.go b/integration_test.go index 939842b..ff37283 100644 --- a/integration_test.go +++ b/integration_test.go @@ -696,3 +696,37 @@ func TestTruncateTable(t *testing.T) { t.Errorf("TRUNCATE TABLE executed, but %d rows are remained", count) } } + +func TestPartitionedDML(t *testing.T) { + if skipIntegrateTest { + t.Skip("Integration tests skipped") + } + + ctx, cancel := context.WithTimeout(context.Background(), 180*time.Second) + defer cancel() + + session, tableId, tearDown := setup(t, ctx, []string{ + "INSERT INTO [[TABLE]] (id, active) VALUES (1, false)", + }) + defer tearDown() + + stmt, err := BuildStatement(fmt.Sprintf("PARTITIONED UPDATE %s SET active = true WHERE true", tableId)) + if err != nil { + t.Fatalf("invalid statement: %v", err) + } + + if _, err := stmt.Execute(session); err != nil { + t.Fatalf("execution failed: %v", err) + } + + selectStmt := spanner.NewStatement(fmt.Sprintf("SELECT active FROM %s", tableId)) + var got bool + if err := session.client.Single().Query(ctx, selectStmt).Do(func(r *spanner.Row) error { + return r.Column(0, &got) + }); err != nil { + t.Fatalf("query failed: %v", err) + } + if want := true; want != got { + t.Errorf("PARTITIONED UPDATE was executed, but rows were not updated") + } +} diff --git a/statement.go b/statement.go index 856e51f..2e89dd7 100644 --- a/statement.go +++ b/statement.go @@ -35,15 +35,26 @@ type Statement interface { Execute(session *Session) (*Result, error) } +// rowCountType is type of modified rows count by DML. +type rowCountType int + +const ( + // rowCountTypeExact is exact count type for DML result. + rowCountTypeExact rowCountType = iota + // rowCountTypeLowerBound is lower bound type for Partitioned DML result. + rowCountTypeLowerBound +) + type Result struct { - ColumnNames []string - Rows []Row - Predicates []string - AffectedRows int - Stats QueryStats - IsMutation bool - Timestamp time.Time - ForceVerbose bool + ColumnNames []string + Rows []Row + Predicates []string + AffectedRows int + AffectedRowsType rowCountType + Stats QueryStats + IsMutation bool + Timestamp time.Time + ForceVerbose bool } type Row struct { @@ -78,6 +89,11 @@ var ( // DML dmlRe = regexp.MustCompile(`(?is)^(INSERT|UPDATE|DELETE)\s+.+$`) + // Partitioned DML + // In fact, INSERT is not supported in a Partitioned DML, but accept it for showing better error message. + // https://cloud.google.com/spanner/docs/dml-partitioned#features_that_arent_supported + pdmlRe = regexp.MustCompile(`(?is)^PARTITIONED\s+((?:INSERT|UPDATE|DELETE)\s+.+$)`) + // Transaction beginRwRe = regexp.MustCompile(`(?is)^BEGIN(\s+RW)?$`) beginRoRe = regexp.MustCompile(`(?is)^BEGIN\s+RO(?:\s+([^\s]+))?$`) @@ -157,6 +173,9 @@ func BuildStatement(input string) (Statement, error) { return &ShowIndexStatement{Table: unquoteIdentifier(matched[1])}, nil case dmlRe.MatchString(input): return &DmlStatement{Dml: input}, nil + case pdmlRe.MatchString(input): + matched := pdmlRe.FindStringSubmatch(input) + return &PartitionedDmlStatement{Dml: matched[1]}, nil case beginRwRe.MatchString(input): return &BeginRwStatement{}, nil case beginRoRe.MatchString(input): @@ -768,6 +787,32 @@ func (s *DmlStatement) Execute(session *Session) (*Result, error) { return result, nil } +type PartitionedDmlStatement struct { + Dml string +} + +func (s *PartitionedDmlStatement) Execute(session *Session) (*Result, error) { + if session.InRwTxn() { + // PartitionedUpdate creates a new transaction and it could cause dead lock with the current running transaction. + return nil, errors.New(`Partitioned DML statement can not be run in a read-write transaction`) + } + if session.InRoTxn() { + // Just for user-friendly. + return nil, errors.New(`Partitioned DML statement can not be run in a read-only transaction`) + } + + stmt := spanner.NewStatement(s.Dml) + count, err := session.client.PartitionedUpdate(session.ctx, stmt) + if err != nil { + return nil, err + } + return &Result{ + IsMutation: true, + AffectedRows: int(count), + AffectedRowsType: rowCountTypeLowerBound, + }, nil +} + type ExplainDmlStatement struct { Dml string } diff --git a/statement_test.go b/statement_test.go index 70ee1af..dc23883 100644 --- a/statement_test.go +++ b/statement_test.go @@ -119,6 +119,16 @@ func TestBuildStatement(t *testing.T) { input: "DELETE FROM t1 WHERE id = 1", want: &DmlStatement{Dml: "DELETE FROM t1 WHERE id = 1"}, }, + { + desc: "PARTITIONED UPDATE statement", + input: "PARTITIONED UPDATE t1 SET name = hello WHERE id > 1", + want: &PartitionedDmlStatement{Dml: "UPDATE t1 SET name = hello WHERE id > 1"}, + }, + { + desc: "PARTITIONED DELETE statement", + input: "PARTITIONED DELETE FROM t1 WHERE id > 1", + want: &PartitionedDmlStatement{Dml: "DELETE FROM t1 WHERE id > 1"}, + }, { desc: "EXPLAIN INSERT statement", input: "EXPLAIN INSERT INTO t1 (id, name) VALUES (1, 'yuki')",