From 3a93ac5c0c9b31341d302b10fec58c80a04ec292 Mon Sep 17 00:00:00 2001 From: Asdine El Hrychy Date: Sat, 17 Feb 2024 14:25:05 +0400 Subject: [PATCH] db: only strict schemas --- cmd/chai/commands/app.go | 1 - cmd/chai/commands/insert.go | 127 -- cmd/chai/dbutil/dump.go | 32 +- cmd/chai/dbutil/dump_test.go | 8 +- cmd/chai/dbutil/exec_test.go | 2 +- cmd/chai/dbutil/insert.go | 111 -- cmd/chai/dbutil/insert_test.go | 136 --- cmd/chai/dbutil/schema.go | 21 +- cmd/chai/doc/doc.go | 74 -- cmd/chai/doc/doc_test.go | 100 -- cmd/chai/doc/functions.go | 47 - cmd/chai/doc/tokens.go | 17 - cmd/chai/shell/command.go | 23 +- cmd/chai/shell/command_test.go | 8 +- cmd/chai/shell/shell.go | 5 - db.go | 30 +- db_test.go | 36 +- driver/driver.go | 62 +- driver/driver_test.go | 66 +- driver/example_test.go | 26 +- example_test.go | 32 +- go.mod | 1 - go.sum | 2 - internal/database/catalog.go | 147 +-- internal/database/catalog_test.go | 133 +-- internal/database/catalogstore/store.go | 53 +- internal/database/constraint.go | 251 +--- internal/database/constraint_test.go | 151 +-- internal/database/encoding.go | 253 ++-- internal/database/encoding_test.go | 125 +- internal/database/index.go | 2 +- internal/database/index_test.go | 7 +- internal/database/info.go | 125 +- internal/database/iteration.go | 70 +- internal/database/row.go | 51 +- internal/database/sequence.go | 28 +- internal/database/sequence_test.go | 2 +- internal/database/table.go | 68 +- internal/database/table_test.go | 138 ++- internal/encoding/array.go | 110 -- internal/encoding/array_test.go | 49 - internal/encoding/conversion.go | 75 -- internal/encoding/document.go | 115 -- internal/encoding/document_test.go | 90 -- internal/encoding/encoding.go | 120 +- internal/encoding/helpers_test.go | 297 +++-- internal/encoding/times.go | 14 +- internal/encoding/times_test.go | 21 +- internal/environment/env.go | 44 +- internal/errors/errors.go | 10 +- internal/expr/arithmeric.go | 19 +- internal/expr/arithmetic_test.go | 49 +- internal/expr/column.go | 27 + internal/expr/comparison.go | 98 +- internal/expr/comparison_test.go | 62 +- internal/expr/constraint.go | 21 +- internal/expr/expr.go | 52 +- internal/expr/expr_test.go | 20 +- internal/expr/functions/builtins.go | 191 +-- internal/expr/functions/definition.go | 26 +- internal/expr/functions/definition_test.go | 38 +- internal/expr/functions/math.go | 46 +- internal/expr/functions/object.go | 73 -- internal/expr/functions/scalar_definition.go | 9 +- .../expr/functions/scalar_definition_test.go | 16 +- internal/expr/functions/strings.go | 42 - .../functions/testdata/builtin_functions.sql | 8 +- .../functions/testdata/math_functions.sql | 138 +-- internal/expr/literal.go | 108 +- internal/expr/operator.go | 5 +- internal/expr/operator_test.go | 4 +- internal/expr/path.go | 82 -- internal/expr/path_test.go | 126 -- internal/expr/wildcard.go | 28 + internal/kv/session_test.go | 16 +- internal/object/array.go | 210 ---- internal/object/array_test.go | 69 -- internal/object/cast.go | 267 ----- internal/object/create.go | 355 ------ internal/object/create_test.go | 216 ---- internal/object/diff.go | 188 --- internal/object/diff_test.go | 157 --- internal/object/object.go | 530 --------- internal/object/object_test.go | 652 ---------- internal/object/path.go | 182 --- internal/object/path_test.go | 49 - internal/object/scan_test.go | 313 ----- internal/planner/index_selection.go | 247 ++-- internal/planner/optimizer.go | 248 ++-- internal/planner/optimizer_test.go | 350 ++---- internal/query/statement/alter.go | 16 +- internal/query/statement/alter_test.go | 6 +- internal/query/statement/create.go | 6 +- internal/query/statement/create_test.go | 23 +- internal/query/statement/delete.go | 17 +- internal/query/statement/delete_test.go | 22 +- internal/query/statement/drop_test.go | 2 +- internal/query/statement/explain.go | 7 + internal/query/statement/explain_test.go | 14 +- internal/query/statement/insert.go | 55 +- internal/query/statement/insert_test.go | 64 +- internal/query/statement/reindex_test.go | 4 +- internal/query/statement/select.go | 52 +- internal/query/statement/select_test.go | 223 ++-- internal/query/statement/statement.go | 22 + internal/query/statement/stream.go | 16 +- internal/query/statement/update.go | 43 +- internal/query/statement/update_test.go | 80 +- internal/row/diff.go | 98 ++ internal/row/diff_test.go | 74 ++ internal/row/format.go | 138 +++ internal/{object => row}/json.go | 30 +- internal/row/object_test.go | 474 ++++++++ internal/row/row.go | 468 ++++++++ internal/row/row_test.go | 124 ++ internal/{object => row}/scan.go | 297 ++--- internal/row/scan_test.go | 124 ++ internal/sql/parser/alter.go | 22 +- internal/sql/parser/alter_test.go | 42 +- internal/sql/parser/create.go | 188 +-- internal/sql/parser/create_test.go | 24 +- internal/sql/parser/delete.go | 4 +- internal/sql/parser/delete_test.go | 2 +- internal/sql/parser/drop.go | 10 +- internal/sql/parser/explain.go | 4 +- internal/sql/parser/expr.go | 214 +--- internal/sql/parser/expr_test.go | 175 +-- internal/sql/parser/insert.go | 124 +- internal/sql/parser/insert_test.go | 150 ++- internal/sql/parser/options.go | 17 - internal/sql/parser/order_by.go | 14 +- internal/sql/parser/parser.go | 49 +- internal/sql/parser/reindex.go | 2 +- internal/sql/parser/select.go | 4 +- internal/sql/parser/select_test.go | 204 ++-- internal/sql/parser/transaction.go | 6 +- internal/sql/parser/update.go | 44 +- internal/sql/parser/update_test.go | 57 +- internal/sql/scanner/scanner.go | 8 - internal/sql/scanner/scanner_test.go | 4 - internal/sql/scanner/token.go | 9 - .../{add_field.sql => add_column.sql} | 43 +- internal/sqltests/CREATE_INDEX/base.sql | 8 +- internal/sqltests/CREATE_INDEX/undeclared.sql | 11 +- internal/sqltests/CREATE_SEQUENCE/base.sql | 32 +- internal/sqltests/CREATE_TABLE/base.sql | 20 +- internal/sqltests/CREATE_TABLE/check.sql | 18 +- .../sqltests/CREATE_TABLE/constraints.sql | 26 +- internal/sqltests/CREATE_TABLE/default.sql | 19 - internal/sqltests/CREATE_TABLE/not_null.sql | 12 +- .../sqltests/CREATE_TABLE/primary_key.sql | 39 +- internal/sqltests/CREATE_TABLE/types.sql | 54 +- .../sqltests/CREATE_TABLE/types_document.sql | 29 - internal/sqltests/CREATE_TABLE/unique.sql | 39 +- internal/sqltests/INSERT/check.sql | 39 +- internal/sqltests/INSERT/document.sql | 62 - internal/sqltests/INSERT/insert_select.sql | 62 +- internal/sqltests/INSERT/misc.sql | 68 +- internal/sqltests/INSERT/not_null.sql | 2 +- internal/sqltests/INSERT/primary_key.sql | 34 - internal/sqltests/INSERT/types.sql | 275 ----- internal/sqltests/INSERT/values.sql | 112 +- internal/sqltests/SELECT/STRINGS/lower.sql | 61 +- internal/sqltests/SELECT/STRINGS/ltrim.sql | 31 +- internal/sqltests/SELECT/STRINGS/rtrim.sql | 31 +- internal/sqltests/SELECT/STRINGS/trim.sql | 31 +- internal/sqltests/SELECT/STRINGS/upper.sql | 58 +- internal/sqltests/SELECT/WHERE/comp.sql | 1046 +++-------------- internal/sqltests/SELECT/distinct.sql | 80 +- internal/sqltests/SELECT/len.sql | 155 --- internal/sqltests/SELECT/nullable.sql | 40 +- internal/sqltests/SELECT/objects/fields.sql | 78 -- internal/sqltests/SELECT/order_by.sql | 2 + .../sqltests/SELECT/order_by_desc_index.sql | 2 +- .../SELECT/order_by_desc_pk_composite.sql | 2 +- internal/sqltests/SELECT/pk.sql | 13 - .../sqltests/SELECT/projection_no_table.sql | 14 +- internal/sqltests/SELECT/projection_table.sql | 38 +- internal/sqltests/SELECT/union.sql | 6 +- internal/sqltests/UPDATE/check.sql | 32 - internal/sqltests/UPDATE/pk.sql | 24 +- internal/sqltests/expr/arithmetic.sql | 40 +- internal/sqltests/expr/cast.sql | 75 -- internal/sqltests/expr/literal.sql | 38 +- internal/sqltests/expr/objects.sql | 35 - internal/sqltests/planning/between.sql | 6 +- internal/sqltests/planning/merge.gosave | 146 --- internal/sqltests/planning/order_by.sql | 8 +- .../sqltests/planning/order_by_composite.sql | 16 +- internal/sqltests/planning/precalculate.sql | 2 +- internal/sqltests/planning/where.sql | 14 +- internal/sqltests/planning/where_pk.sql | 10 +- internal/sqltests/sql_test.go | 8 +- internal/stream/index/delete.go | 8 +- internal/stream/index/insert.go | 10 +- internal/stream/index/scan.go | 2 +- internal/stream/index/scan_test.go | 309 ++--- internal/stream/index/validate.go | 8 +- internal/stream/on_conflict.go | 7 +- internal/stream/operator_test.go | 143 +-- internal/stream/path/rename.go | 22 +- .../path/{unset_test.go => rename_test.go} | 38 +- internal/stream/path/set.go | 39 +- internal/stream/path/set_test.go | 37 +- internal/stream/path/unset.go | 70 -- internal/stream/range.go | 12 +- internal/stream/rows/emit.go | 25 +- internal/stream/rows/emit_test.go | 48 - internal/stream/rows/group_aggregate.go | 29 +- internal/stream/rows/group_aggregate_test.go | 45 +- internal/stream/rows/project.go | 41 +- internal/stream/rows/project_test.go | 22 +- internal/stream/rows/skip.go | 3 +- internal/stream/rows/take.go | 3 +- internal/stream/rows/temp_tree_sort.go | 89 +- internal/stream/rows/temp_tree_sort_test.go | 56 +- internal/stream/stream.go | 2 +- internal/stream/stream_test.go | 94 +- internal/stream/table/delete.go | 2 +- internal/stream/table/insert.go | 2 +- internal/stream/table/replace.go | 4 +- internal/stream/table/table_test.go | 99 +- internal/stream/table/validate.go | 19 +- internal/stream/union.go | 50 +- internal/testutil/expr.go | 91 +- internal/testutil/object.go | 193 --- internal/testutil/row.go | 182 +++ internal/testutil/stream.go | 155 ++- internal/tree/key.go | 18 +- internal/tree/tree.go | 184 +-- internal/tree/tree_test.go | 150 +-- internal/types/array.go | 106 -- internal/types/bigint.go | 312 +++++ internal/types/blob.go | 51 + internal/types/boolean.go | 65 +- internal/{object => types}/cast_test.go | 50 +- internal/types/comparable.go | 222 ---- internal/types/comparable_test.go | 115 +- internal/types/double.go | 126 +- internal/types/encoding.go | 88 ++ internal/{encoding => types}/encoding_test.go | 51 +- internal/types/integer.go | 205 +++- internal/types/null.go | 51 +- internal/types/numeric.go | 45 + internal/types/object.go | 134 --- internal/types/text.go | 99 ++ internal/types/timestamp.go | 49 + internal/types/types.go | 184 ++- internal/types/value.go | 246 +--- internal/types/value_test.go | 69 +- 250 files changed, 6721 insertions(+), 13733 deletions(-) delete mode 100644 cmd/chai/commands/insert.go delete mode 100644 cmd/chai/dbutil/insert.go delete mode 100644 cmd/chai/dbutil/insert_test.go delete mode 100644 cmd/chai/doc/doc.go delete mode 100644 cmd/chai/doc/doc_test.go delete mode 100644 cmd/chai/doc/functions.go delete mode 100644 cmd/chai/doc/tokens.go delete mode 100644 internal/encoding/array.go delete mode 100644 internal/encoding/array_test.go delete mode 100644 internal/encoding/conversion.go delete mode 100644 internal/encoding/document.go delete mode 100644 internal/encoding/document_test.go create mode 100644 internal/expr/column.go delete mode 100644 internal/expr/functions/object.go delete mode 100644 internal/expr/path.go delete mode 100644 internal/expr/path_test.go create mode 100644 internal/expr/wildcard.go delete mode 100644 internal/object/array.go delete mode 100644 internal/object/array_test.go delete mode 100644 internal/object/cast.go delete mode 100644 internal/object/create.go delete mode 100644 internal/object/create_test.go delete mode 100644 internal/object/diff.go delete mode 100644 internal/object/diff_test.go delete mode 100644 internal/object/object.go delete mode 100644 internal/object/object_test.go delete mode 100644 internal/object/path.go delete mode 100644 internal/object/path_test.go delete mode 100644 internal/object/scan_test.go create mode 100644 internal/row/diff.go create mode 100644 internal/row/diff_test.go create mode 100644 internal/row/format.go rename internal/{object => row}/json.go (68%) create mode 100644 internal/row/object_test.go create mode 100644 internal/row/row.go create mode 100644 internal/row/row_test.go rename internal/{object => row}/scan.go (52%) create mode 100644 internal/row/scan_test.go delete mode 100644 internal/sql/parser/options.go rename internal/sqltests/ALTER_TABLE/{add_field.sql => add_column.sql} (85%) delete mode 100644 internal/sqltests/CREATE_TABLE/types_document.sql delete mode 100644 internal/sqltests/INSERT/document.sql delete mode 100644 internal/sqltests/INSERT/types.sql delete mode 100644 internal/sqltests/SELECT/len.sql delete mode 100644 internal/sqltests/SELECT/objects/fields.sql delete mode 100644 internal/sqltests/SELECT/pk.sql delete mode 100644 internal/sqltests/expr/objects.sql delete mode 100644 internal/sqltests/planning/merge.gosave rename internal/stream/path/{unset_test.go => rename_test.go} (51%) delete mode 100644 internal/stream/path/unset.go delete mode 100644 internal/stream/rows/emit_test.go delete mode 100644 internal/testutil/object.go create mode 100644 internal/testutil/row.go delete mode 100644 internal/types/array.go create mode 100644 internal/types/bigint.go rename internal/{object => types}/cast_test.go (69%) delete mode 100644 internal/types/comparable.go create mode 100644 internal/types/encoding.go rename internal/{encoding => types}/encoding_test.go (70%) delete mode 100644 internal/types/object.go diff --git a/cmd/chai/commands/app.go b/cmd/chai/commands/app.go index 04dc9ceb1..a6178ec27 100644 --- a/cmd/chai/commands/app.go +++ b/cmd/chai/commands/app.go @@ -19,7 +19,6 @@ func NewApp() *cli.App { app.EnableBashCompletion = true app.Commands = []*cli.Command{ - NewInsertCommand(), NewVersionCommand(), NewDumpCommand(), NewRestoreCommand(), diff --git a/cmd/chai/commands/insert.go b/cmd/chai/commands/insert.go deleted file mode 100644 index 2d50e88f7..000000000 --- a/cmd/chai/commands/insert.go +++ /dev/null @@ -1,127 +0,0 @@ -package commands - -import ( - "context" - "os" - "strconv" - "strings" - "time" - - "github.com/cockroachdb/errors" - "github.com/urfave/cli/v2" - - "github.com/chaisql/chai" - "github.com/chaisql/chai/cmd/chai/dbutil" -) - -// NewInsertCommand returns a cli.Command for "chai insert". -func NewInsertCommand() *cli.Command { - return &cli.Command{ - Name: "insert", - Usage: "Insert objects from arguments or standard input", - UsageText: "chai insert [options] [json...]", - Description: ` -The insert command inserts objects into an existing table. - -Insert can take JSON objects as separate arguments: - -$ chai insert --db mydb -t foo '{"a": 1}' '{"a": 2}' - -It is also possible to pass an array of objects: - -$ chai insert --db mydb -t foo '[{"a": 1}, {"a": 2}]' - -Also you can use -a flag to create database automatically. -This example will create a database with name 'data_${current unix timestamp}' -It can be combined with --db to select an existing database but automatically create the table. - -$ chai insert -a '[{"a": 1}, {"a": 2}]' - -Insert can also insert a stream of objects or an array of objects from standard input: - -$ echo '{"a": 1} {"a": 2}' | chai insert --db mydb -t foo -$ echo '[{"a": 1},{"a": 2}]' | chai insert --db mydb -t foo -$ curl https://api.github.com/repos/chaidb/chai/issues | chai insert --db mydb -t foo`, - Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "db", - Usage: "path of the database", - Required: false, - }, - &cli.StringFlag{ - Name: "table", - Aliases: []string{"t"}, - Usage: "name of the table, it must already exist", - Required: false, - }, - &cli.BoolFlag{ - Name: "auto", - Aliases: []string{"a"}, - Usage: `automatically creates a database and a table whose name is equal to "data_" followed by the current unix timestamp.`, - Required: false, - Value: false, - }, - }, - Action: func(c *cli.Context) error { - dbPath := c.String("db") - table := c.String("table") - args := c.Args().Slice() - return runInsertCommand(c.Context, dbPath, table, c.Bool("auto"), args) - }, - } -} - -func runInsertCommand(ctx context.Context, dbPath, table string, auto bool, args []string) error { - generatedName := "data_" + strconv.FormatInt(time.Now().Unix(), 10) - createTable := false - if table == "" && auto { - table = generatedName - createTable = true - } - - if dbPath == "" && auto { - dbPath = generatedName - } - - db, err := dbutil.OpenDB(ctx, dbPath) - if err != nil { - return err - } - defer db.Close() - - err = insert(db, table, createTable, args...) - if err != nil { - if createTable { - _ = os.RemoveAll(dbPath) - } - - return err - } - - return nil -} - -func insert(db *chai.DB, table string, createTable bool, args ...string) error { - if createTable { - err := db.Exec("CREATE TABLE " + table) - if err != nil { - return err - } - } - - if dbutil.CanReadFromStandardInput() { - return dbutil.InsertJSON(db, table, os.Stdin) - } - - if len(args) == 0 { - return errors.New("no data to insert") - } - - for _, arg := range args { - if err := dbutil.InsertJSON(db, table, strings.NewReader(arg)); err != nil { - return err - } - } - - return nil -} diff --git a/cmd/chai/dbutil/dump.go b/cmd/chai/dbutil/dump.go index b8ee3d274..89494d052 100644 --- a/cmd/chai/dbutil/dump.go +++ b/cmd/chai/dbutil/dump.go @@ -3,6 +3,7 @@ package dbutil import ( "fmt" "io" + "strings" "github.com/chaisql/chai" "go.uber.org/multierr" @@ -57,14 +58,35 @@ func dumpTable(tx *chai.Tx, w io.Writer, query, tableName string) error { defer res.Close() // Inserts statements. - insert := fmt.Sprintf("INSERT INTO %s VALUES", tableName) return res.Iterate(func(r *chai.Row) error { - data, err := r.MarshalJSON() + cols, err := r.Columns() if err != nil { return err } - if _, err := fmt.Fprintf(w, "%s %s;\n", insert, string(data)); err != nil { + m := make(map[string]interface{}, len(cols)) + err = r.MapScan(m) + if err != nil { + return err + } + + var sb strings.Builder + + for i, c := range cols { + if i > 0 { + sb.WriteString(", ") + } + + v := m[c] + if v == nil { + sb.WriteString("NULL") + continue + } + + fmt.Fprintf(&sb, "%v", v) + } + + if _, err := fmt.Fprintf(w, "INSERT INTO %s VALUES (%s);\n", tableName, sb.String()); err != nil { return err } @@ -105,8 +127,8 @@ func dumpSchema(tx *chai.Tx, w io.Writer, query string, tableName string) error // Indexes statements. res, err := tx.Query(` SELECT sql FROM __chai_catalog WHERE - type = 'index' AND owner.table_name = ? OR - type = 'sequence' AND owner IS NULL + type = 'index' AND owner_table_name = ? OR + type = 'sequence' AND owner_table_name IS NULL `, tableName) if err != nil { return err diff --git a/cmd/chai/dbutil/dump_test.go b/cmd/chai/dbutil/dump_test.go index 0cdf0f6aa..8bf2b2744 100644 --- a/cmd/chai/dbutil/dump_test.go +++ b/cmd/chai/dbutil/dump_test.go @@ -53,7 +53,7 @@ func TestDump(t *testing.T) { writeToBuf("\n") } - q := fmt.Sprintf("CREATE TABLE %s (a INTEGER, b ANY, c ANY, ...);", table) + q := fmt.Sprintf("CREATE TABLE %s (a INTEGER, b INTEGER, c INTEGER);", table) err = db.Exec(q) assert.NoError(t, err) writeToBuf(q + "\n") @@ -68,17 +68,17 @@ func TestDump(t *testing.T) { assert.NoError(t, err) writeToBuf(q + "\n") - q = fmt.Sprintf(`INSERT INTO %s VALUES {"a": %d, "b": %d, "c": %d};`, table, 1, 2, 3) + q = fmt.Sprintf(`INSERT INTO %s VALUES (%d, %d, %d);`, table, 1, 2, 3) err = db.Exec(q) assert.NoError(t, err) writeToBuf(q + "\n") - q = fmt.Sprintf(`INSERT INTO %s VALUES {"a": %d, "b": %d, "c": %d};`, table, 2, 2, 2) + q = fmt.Sprintf(`INSERT INTO %s VALUES (%d, %d, %d);`, table, 2, 2, 2) err = db.Exec(q) assert.NoError(t, err) writeToBuf(q + "\n") - q = fmt.Sprintf(`INSERT INTO %s VALUES {"a": %d, "b": %d, "c": %d};`, table, 3, 2, 1) + q = fmt.Sprintf(`INSERT INTO %s VALUES (%d, %d, %d);`, table, 3, 2, 1) err = db.Exec(q) assert.NoError(t, err) writeToBuf(q + "\n") diff --git a/cmd/chai/dbutil/exec_test.go b/cmd/chai/dbutil/exec_test.go index cd567e10e..5dc54715a 100644 --- a/cmd/chai/dbutil/exec_test.go +++ b/cmd/chai/dbutil/exec_test.go @@ -18,7 +18,7 @@ func TestExecSQL(t *testing.T) { var got bytes.Buffer err = ExecSQL(context.Background(), db, strings.NewReader(` - CREATE TABLE test(a, ...); + CREATE TABLE test(a INT, b INT); CREATE INDEX idx_a ON test (a); INSERT INTO test (a, b) VALUES (1, 2), (2, 2), (3, 2); SELECT * FROM test; diff --git a/cmd/chai/dbutil/insert.go b/cmd/chai/dbutil/insert.go deleted file mode 100644 index 6ca08ab85..000000000 --- a/cmd/chai/dbutil/insert.go +++ /dev/null @@ -1,111 +0,0 @@ -package dbutil - -import ( - "bufio" - "encoding/json" - "fmt" - "io" - - "github.com/cockroachdb/errors" - - "github.com/chaisql/chai" - "github.com/chaisql/chai/internal/object" -) - -// InsertJSON reads json objects from r and inserts them into the selected table. -// The reader can be either a stream of json objects or an array of objects. -func InsertJSON(db *chai.DB, table string, r io.Reader) error { - tx, err := db.Begin(true) - if err != nil { - return err - } - defer tx.Rollback() - - q := fmt.Sprintf("INSERT INTO %s VALUES ?", table) - rd := bufio.NewReader(r) - - // read first non-white space byte to determine - // whether we are reading from a json stream or - // an array of json objects. - c, err := readByteIgnoreWhitespace(rd) - if err != nil { - return err - } - switch c { - case '{': // json stream - if err := rd.UnreadByte(); err != nil { - return err - } - - dec := json.NewDecoder(rd) - for { - var fb object.FieldBuffer - err := dec.Decode(&fb) - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return err - } - - if err := tx.Exec(q, &fb); err != nil { - return err - } - } - - case '[': // Array of json objects - if err := rd.UnreadByte(); err != nil { - return err - } - - dec := json.NewDecoder(rd) - _, err := dec.Token() - if err != nil { - return err - } - - for dec.More() { - var fb object.FieldBuffer - err := dec.Decode(&fb) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - if err := tx.Exec(q, &fb); err != nil { - return err - } - } - - t, err := dec.Token() - if err != nil { - return err - } - d, ok := t.(json.Delim) - if ok && d.String() != "]" { - return fmt.Errorf("found %q, but expected ']'", c) - } - - default: - return fmt.Errorf("found %q, but expected '{' or '['", c) - } - - return tx.Commit() -} - -func readByteIgnoreWhitespace(r *bufio.Reader) (byte, error) { - var c byte - var err error - - for { - c, err = r.ReadByte() - if err != nil { - return c, err - } - - if c != '\n' && c != '\r' && c != ' ' && c != '\t' { - break - } - } - - return c, nil -} diff --git a/cmd/chai/dbutil/insert_test.go b/cmd/chai/dbutil/insert_test.go deleted file mode 100644 index 63b77c6c7..000000000 --- a/cmd/chai/dbutil/insert_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package dbutil - -import ( - "bytes" - "strings" - "testing" - - "github.com/chaisql/chai" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/stretchr/testify/require" -) - -func TestInsertJSON(t *testing.T) { - tests := []struct { - name string - data string - want string - fails bool - }{ - {"Simple Json", `{"a": 1}`, `[{"a": 1}]`, false}, - {"JSON object", `{"a": {"b": [1, 2, 3]}}`, `[{"a": {"b": [1, 2, 3]}}]`, false}, - {"nested object", `{"a": {"b": [1, 2, 3]}}`, `[{"a": {"b": [1, 2, 3]}}]`, false}, - {"nested array multiple indexes", `{"a": {"b": [1, 2, [1, 2, {"c": "foo"}]]}}`, `[{"a": {"b": [1, 2, [1, 2, {"c": "foo"}]]}}]`, false}, - {"object in array", `{"a": [{"b":"foo"}, 2, 3]}`, `[{"a": [{"b":"foo"}, 2, 3]}]`, false}, - {"Non closed json array", `[{"foo":"bar"}`, ``, true}, - {"Non closed json stream", `{"foo":"bar"`, ``, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec(`CREATE TABLE foo`) - assert.NoError(t, err) - err = InsertJSON(db, "foo", strings.NewReader(tt.data)) - if tt.fails { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - res, err := db.Query("SELECT * FROM foo") - defer res.Close() - assert.NoError(t, err) - - var buf bytes.Buffer - err = res.MarshalJSONTo(&buf) - assert.NoError(t, err) - require.JSONEq(t, tt.want, buf.String()) - }) - } - - t.Run(`Json Array`, func(t *testing.T) { - const jsonArray = ` - [ - {"Name": "Ed", "Text": "Knock knock."}, - {"Name": "Sam", "Text": "Who's there?"}, - {"Name": "Ed", "Text": "Go fmt."}, - {"Name": "Sam", "Text": "Go fmt who?"}, - {"Name": "Ed", "Text": "Go fmt yourself!"} - ] -` - jsonStreamResult := []string{`{"Name": "Ed", "Text": "Knock knock."}`, - `{"Name": "Sam", "Text": "Who's there?"}`, `{"Name": "Ed", "Text": "Go fmt."}`, - `{"Name": "Sam", "Text": "Go fmt who?"}`, - `{"Name": "Ed", "Text": "Go fmt yourself!"}`} - - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec(`CREATE TABLE foo`) - assert.NoError(t, err) - err = InsertJSON(db, "foo", strings.NewReader(jsonArray)) - assert.NoError(t, err) - res, err := db.Query("SELECT * FROM foo") - defer res.Close() - assert.NoError(t, err) - - i := 0 - _ = res.Iterate(func(r *chai.Row) error { - data, err := r.MarshalJSON() - assert.NoError(t, err) - require.JSONEq(t, jsonStreamResult[i], string(data)) - i++ - return nil - }) - }) - - t.Run(`Json Stream`, func(t *testing.T) { - const jsonStream = ` - {"Name": "Ed", "Text": "Knock knock."} - {"Name": "Sam", "Text": "Who's there?"} - {"Name": "Ed", "Text": "Go fmt."} - {"Name": "Sam", "Text": "Go fmt who?"} - {"Name": "Ed", "Text": "Go fmt yourself!"} - ` - jsonStreamResult := []string{`{"Name": "Ed", "Text": "Knock knock."}`, - `{"Name": "Sam", "Text": "Who's there?"}`, `{"Name": "Ed", "Text": "Go fmt."}`, - `{"Name": "Sam", "Text": "Go fmt who?"}`, - `{"Name": "Ed", "Text": "Go fmt yourself!"}`} - - db, err := chai.Open(":memory:") - defer db.Close() - assert.NoError(t, err) - - err = db.Exec(`CREATE TABLE foo`) - assert.NoError(t, err) - - err = InsertJSON(db, "foo", strings.NewReader(jsonStream)) - assert.NoError(t, err) - - res, err := db.Query("SELECT * FROM foo") - defer res.Close() - assert.NoError(t, err) - - i := 0 - _ = res.Iterate(func(r *chai.Row) error { - data, err := r.MarshalJSON() - assert.NoError(t, err) - require.JSONEq(t, jsonStreamResult[i], string(data)) - i++ - return nil - }) - - wantCount := 0 - err = res.Iterate(func(r *chai.Row) error { - wantCount++ - return nil - }) - assert.NoError(t, err) - require.Equal(t, wantCount, i) - }) -} diff --git a/cmd/chai/dbutil/schema.go b/cmd/chai/dbutil/schema.go index 60f2ff3ef..31f2fe1c6 100644 --- a/cmd/chai/dbutil/schema.go +++ b/cmd/chai/dbutil/schema.go @@ -1,6 +1,8 @@ package dbutil import ( + "fmt" + "github.com/chaisql/chai" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" @@ -8,11 +10,24 @@ import ( func QueryTables(tx *chai.Tx, tables []string, fn func(name, query string) error) error { query := "SELECT name, sql FROM __chai_catalog WHERE type = 'table' AND name NOT LIKE '__chai_%'" + var args []any if len(tables) > 0 { - query += " AND name IN ?" + var arg string + + for i := range tables { + arg += "?" + + if i < len(tables)-1 { + arg += ", " + } + + args = append(args, tables[i]) + } + + query += fmt.Sprintf(" AND name IN (%s)", arg) } - res, err := tx.Query(query, tables) + res, err := tx.Query(query, args...) if err != nil { return err } @@ -33,7 +48,7 @@ func ListIndexes(db *chai.DB, tableName string) ([]string, error) { var listName []string q := "SELECT sql FROM __chai_catalog WHERE type = 'index'" if tableName != "" { - q += " AND owner.table_name = ?" + q += " AND owner_table_name = ?" } res, err := db.Query(q, tableName) if err != nil { diff --git a/cmd/chai/doc/doc.go b/cmd/chai/doc/doc.go deleted file mode 100644 index dc6e09231..000000000 --- a/cmd/chai/doc/doc.go +++ /dev/null @@ -1,74 +0,0 @@ -/* Package doc provides an API to access documentation for functions and tokens */ -package doc - -import ( - "fmt" - "strings" - - "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/sql/scanner" - "github.com/cockroachdb/errors" -) - -var ErrNotFound = errors.New("No documentation found") -var ErrInvalid = errors.New("Invalid documentation query") - -// DocString returns a string containing the documentation for a given expression. -// -// The expression is merely scanned and not parsed, looking for keywords tokens or IDENT -// tokens forming a function. In that last case, a function lookup is performed, yielding the -// documentation of that particular function. -func DocString(rawExpr string) (string, error) { - if rawExpr == "" { - return "", ErrInvalid - } - s := scanner.NewScanner(strings.NewReader(rawExpr)) - tok, _, _ := s.Scan() - if tok == scanner.ILLEGAL { - return "", ErrInvalid - } - if tok == scanner.IDENT { - s.Unscan() - return scanFuncDocString(s) - } - docstr, ok := tokenDocs[tok] - if ok { - return docstr, nil - } - return "", ErrNotFound -} - -func scanFuncDocString(s *scanner.Scanner) (string, error) { - tok1, _, lit1 := s.Scan() - if tok1 != scanner.IDENT { - return "", ErrInvalid - } - tok2, _, _ := s.Scan() - if tok2 != scanner.EOF && tok2 == scanner.DOT { - // tok1 is a package because tok2 is a "." - tok3, _, lit3 := s.Scan() - if tok3 != scanner.IDENT { - return "", ErrInvalid - } - return funcDocString(lit1, lit3) - } else { - // no package, it's a builtin function - return funcDocString("", lit1) - } -} - -func funcDocString(pkg string, name string) (string, error) { - table := functions.DefaultPackages() - f, err := table.GetFunc(pkg, name) - if err != nil { - return "", ErrNotFound - } - // Because we got the definition, we know that the package and function both exist. - p := packageDocs[pkg] - d := p[name] - if pkg != "" { - return fmt.Sprintf("%s.%s: %s", pkg, f.String(), d), nil - } else { - return fmt.Sprintf("%s: %s", f.String(), d), nil - } -} diff --git a/cmd/chai/doc/doc_test.go b/cmd/chai/doc/doc_test.go deleted file mode 100644 index b22ad52e6..000000000 --- a/cmd/chai/doc/doc_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package doc_test - -import ( - "fmt" - "regexp" - "strings" - "testing" - - "github.com/chaisql/chai/cmd/chai/doc" - "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/sql/scanner" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/stretchr/testify/require" -) - -func TestFunctions(t *testing.T) { - packages := functions.DefaultPackages() - for pkgname, pkg := range packages { - for fname, def := range pkg { - if pkgname == "" { - var isAlias = false - for pkgname2, pkg2 := range packages { - if pkgname2 != "" { - _, ok := pkg2[strings.ToLower(fname)] - if ok { - isAlias = true - } - } - } - if !isAlias { - t.Run(fmt.Sprintf("%s is documented and has all its arguments mentioned", fname), func(t *testing.T) { - str, err := doc.DocString(fname) - assert.NoError(t, err) - for i := 0; i < def.Arity(); i++ { - require.Contains(t, trimDocPromt(str), fmt.Sprintf("arg%d", i+1)) - } - }) - } - } else { - t.Run(fmt.Sprintf("%s.%s is documented and has all its arguments mentioned", pkgname, fname), func(t *testing.T) { - str, err := doc.DocString(fmt.Sprintf("%s.%s", pkgname, fname)) - assert.NoError(t, err) - if def.Arity() > 0 { - for i := 0; i < def.Arity(); i++ { - require.Contains(t, trimDocPromt(str), fmt.Sprintf("arg%d", i+1)) - } - } - }) - } - } - } -} - -// trimDocPrompt returns the description part of the doc string, ignoring the promt. -func trimDocPromt(str string) string { - // Matches the doc description, ignoring the "package.funcname:" part. - r := regexp.MustCompile("[^:]+:(.*)") - subs := r.FindStringSubmatch(str) - return subs[1] -} - -func TestTokens(t *testing.T) { - for _, tok := range scanner.AllKeywords() { - t.Run(fmt.Sprintf("%s is documented", tok.String()), func(t *testing.T) { - str, err := doc.DocString(tok.String()) - assert.NoError(t, err) - require.NotEqual(t, "", str) - if str == "TODO" { - t.Logf("warning, %s is not yet documented", tok.String()) - } else { - // if the token is documented, its description should contain its own name. - require.Contains(t, str, tok.String()) - } - }) - } -} - -func TestDocString(t *testing.T) { - t.Run("OK", func(t *testing.T) { - str, err := doc.DocString("BY") - assert.NoError(t, err) - require.NotEmpty(t, str) - require.NotEqual(t, "TODO", str) - }) - - t.Run("NOK illegal input", func(t *testing.T) { - _, err := doc.DocString("😀") - assert.ErrorIs(t, err, doc.ErrInvalid) - }) - - t.Run("NOK empty input", func(t *testing.T) { - _, err := doc.DocString("") - assert.ErrorIs(t, err, doc.ErrInvalid) - }) - - t.Run("NOK no doc found", func(t *testing.T) { - _, err := doc.DocString("foo.bar") - assert.ErrorIs(t, err, doc.ErrNotFound) - }) -} diff --git a/cmd/chai/doc/functions.go b/cmd/chai/doc/functions.go deleted file mode 100644 index 30a755146..000000000 --- a/cmd/chai/doc/functions.go +++ /dev/null @@ -1,47 +0,0 @@ -package doc - -type functionDocs map[string]string - -var packageDocs = map[string]functionDocs{ - "": builtinDocs, - "math": mathDocs, - "strings": stringsDocs, - "objects": objectsDocs, -} - -var builtinDocs = functionDocs{ - "pk": "The pk() function returns the primary key for the current row", - "count": "Returns a count of the number of times that arg1 is not NULL in a group. The count(*) function (with no arguments) returns the total number of rows in the group.", - "min": "Returns the minimum value of the arg1 expression in a group.", - "max": "Returns the maximum value of the arg1 expressein in a group.", - "sum": "The sum function returns the sum of all values taken by the arg1 expression in a group.", - "avg": "The avg function returns the average of all values taken by the arg1 expression in a group.", - "typeof": "The typeof function returns the type of arg1.", - "len": "The len function returns length of the arg1 expression if arg1 evals to string, array or object, either returns NULL.", - "coalesce": "The coalesce function returns the first non-null argument. NULL is returned if all arguments are null.", -} - -var mathDocs = functionDocs{ - "abs": "Returns the absolute value of arg1.", - "acos": "Returns the arcosine, in radiant, of arg1.", - "acosh": "Returns the inverse hyperbolic cosine of arg1.", - "asin": "Returns the arsine, in radiant, of arg1.", - "asinh": "Returns the inverse hyperbolic sine of arg1.", - "atan": "Returns the arctangent, in radians, of arg1.", - "atan2": "Returns the arctangent of arg1/arg2, using the signs of the two to determine the quadrant of the return value.", - "floor": "Returns the greatest integer value less than or equal to arg1.", - "random": "The random function returns a random number between math.MinInt64 and math.MaxInt64.", - "sqrt": "The sqrt function returns the square root of arg1.", -} - -var stringsDocs = functionDocs{ - "lower": "The lower function returns arg1 to lower-case if arg1 evals to string", - "upper": "The upper function returns arg1 to upper-case if arg1 evals to string", - "trim": "The trim function returns arg1 with leading and trailing characters removed. space by default or arg2", - "ltrim": "The ltrim function returns arg1 with leading characters removed. space by default or arg2", - "rtrim": "The rtrim function returns arg1 with trailing characters removed. space by default or arg2", -} - -var objectsDocs = functionDocs{ - "fields": "The fields function returns the top-level fields of arg1 if arg1 evals to object, otherwise it returns null. It returns an array of TEXT.", -} diff --git a/cmd/chai/doc/tokens.go b/cmd/chai/doc/tokens.go deleted file mode 100644 index 0f83b6c91..000000000 --- a/cmd/chai/doc/tokens.go +++ /dev/null @@ -1,17 +0,0 @@ -package doc - -import "github.com/chaisql/chai/internal/sql/scanner" - -var tokenDocs map[scanner.Token]string - -func init() { - tokenDocs = make(map[scanner.Token]string) - // let's make sure the doc doesn't suggest that a keyword doesn't exist because - // it has no doc defined. - for _, tok := range scanner.AllKeywords() { - tokenDocs[tok] = "TODO" - } - - tokenDocs[scanner.BY] = "See GROUP BY, ORDER BY" - tokenDocs[scanner.FROM] = "FROM [TABLE] selects rows in the table named [TABLE]" -} diff --git a/cmd/chai/shell/command.go b/cmd/chai/shell/command.go index aab2b5986..3831e8b9b 100644 --- a/cmd/chai/shell/command.go +++ b/cmd/chai/shell/command.go @@ -13,9 +13,8 @@ import ( "github.com/chaisql/chai" "github.com/chaisql/chai/cmd/chai/dbutil" - "github.com/chaisql/chai/cmd/chai/doc" errs "github.com/chaisql/chai/internal/errors" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" ) type command struct { @@ -60,12 +59,6 @@ var commands = []command{ DisplayName: ".dump", Description: "Dump database content or table content as SQL statements.", }, - { - Name: ".doc", - Options: "[function_name]", - DisplayName: ".doc", - Description: "Display inline documentation for a function", - }, { Name: ".save", Options: "[filename]", @@ -120,16 +113,6 @@ func runHelpCmd(out io.Writer) error { return nil } -// runDocCommand prints the docstring for a given function -func runDocCmd(expr string, out io.Writer) error { - doc, err := doc.DocString(expr) - if err != nil { - return err - } - fmt.Fprintf(out, "%s\n", doc) - return nil -} - // runTablesCmd displays all tables. func runTablesCmd(db *chai.DB, w io.Writer) error { res, err := db.Query("SELECT name FROM __chai_catalog WHERE type = 'table' AND name NOT LIKE '__chai_%'") @@ -233,9 +216,9 @@ func runImportCmd(db *chai.DB, fileType, path, table string) error { baseQ := fmt.Sprintf("INSERT INTO %s VALUES ", table) buf := make([][]string, csvBatchSize) - fbs := make([]*object.FieldBuffer, csvBatchSize) + fbs := make([]*row.ColumnBuffer, csvBatchSize) for i := range fbs { - fbs[i] = object.NewFieldBuffer() + fbs[i] = row.NewColumnBuffer() } args := make([]any, csvBatchSize) for i := range args { diff --git a/cmd/chai/shell/command_test.go b/cmd/chai/shell/command_test.go index 79af3ab7e..57266a516 100644 --- a/cmd/chai/shell/command_test.go +++ b/cmd/chai/shell/command_test.go @@ -39,7 +39,7 @@ func TestRunTablesCmd(t *testing.T) { defer db.Close() for _, tb := range test.tables { - err := db.Exec("CREATE TABLE " + tb) + err := db.Exec("CREATE TABLE " + tb + "(a INT)") assert.NoError(t, err) } @@ -71,10 +71,10 @@ func TestIndexesCmd(t *testing.T) { defer db.Close() err = db.Exec(` - CREATE TABLE foo(a, b); + CREATE TABLE foo(a INT, b INT); CREATE INDEX idx_foo_a ON foo (a); CREATE INDEX idx_foo_b ON foo (b); - CREATE TABLE bar(a, b); + CREATE TABLE bar(a INT, b INT); CREATE INDEX idx_bar_a_b ON bar (a, b); `) assert.NoError(t, err) @@ -101,7 +101,7 @@ func TestSaveCommand(t *testing.T) { defer db.Close() err = db.Exec(` - CREATE TABLE test (a DOUBLE, b, ...); + CREATE TABLE test (a DOUBLE, b INT); CREATE INDEX idx_a_b ON test (a, b); `) assert.NoError(t, err) diff --git a/cmd/chai/shell/shell.go b/cmd/chai/shell/shell.go index 4757c5438..066be4f7f 100644 --- a/cmd/chai/shell/shell.go +++ b/cmd/chai/shell/shell.go @@ -349,11 +349,6 @@ func (sh *Shell) runCommand(ctx context.Context, in string, out io.Writer) error } return runImportCmd(sh.db, cmd[1], cmd[2], cmd[3]) - case ".doc": - if len(cmd) != 2 { - return fmt.Errorf(getUsage(".doc")) - } - return runDocCmd(cmd[1], out) case ".restore": if len(cmd) != 2 { return fmt.Errorf(getUsage(".restore")) diff --git a/db.go b/db.go index 27db27b5b..d672d66e2 100644 --- a/db.go +++ b/db.go @@ -15,9 +15,9 @@ import ( "github.com/chaisql/chai/internal/database/catalogstore" "github.com/chaisql/chai/internal/environment" errs "github.com/chaisql/chai/internal/errors" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query" "github.com/chaisql/chai/internal/query/statement" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" @@ -447,13 +447,13 @@ type Row struct { func (r *Row) Clone() *Row { var rr Row - fb := object.NewFieldBuffer() - err := fb.Copy(r.row.Object()) + cb := row.NewColumnBuffer() + err := cb.Copy(r.row) if err != nil { panic(err) } var br database.BasicRow - br.ResetWith(r.row.TableName(), r.row.Key(), fb) + br.ResetWith(r.row.TableName(), r.row.Key(), cb) rr.row = &br return &rr @@ -473,37 +473,29 @@ func (r *Row) Columns() ([]string, error) { } func (r *Row) GetColumnType(column string) (string, error) { v, err := r.row.Get(column) - if errors.Is(err, types.ErrFieldNotFound) { - return "", errors.New("column not found") + if errors.Is(err, types.ErrColumnNotFound) { + return "", err } return v.Type().String(), err } func (r *Row) ScanColumn(column string, dest any) error { - return object.ScanField(r.row.Object(), column, dest) + return row.ScanColumn(r.row, column, dest) } func (r *Row) Scan(dest ...any) error { - return object.Scan(r.row.Object(), dest...) + return row.Scan(r.row, dest...) } func (r *Row) StructScan(dest any) error { - return object.StructScan(r.row.Object(), dest) + return row.StructScan(r.row, dest) } func (r *Row) MapScan(dest map[string]any) error { - return object.MapScan(r.row.Object(), dest) + return row.MapScan(r.row, dest) } func (r *Row) MarshalJSON() ([]byte, error) { - return r.row.Object().MarshalJSON() -} - -func (r *Row) Iterate(fn func(column string, value types.Value) error) error { - return r.row.Object().Iterate(fn) -} - -func (r *Row) Object() types.Object { - return r.row.Object() + return r.row.MarshalJSON() } diff --git a/db_test.go b/db_test.go index 29d39fc53..f9e44dff4 100644 --- a/db_test.go +++ b/db_test.go @@ -28,19 +28,19 @@ func ExampleTx() { } defer tx.Rollback() - err = tx.Exec("CREATE TABLE IF NOT EXISTS user") + err = tx.Exec("CREATE TABLE IF NOT EXISTS user (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)") if err != nil { - log.Fatal(err) + panic(err) } err = tx.Exec("INSERT INTO user (id, name, age) VALUES (?, ?, ?)", 10, "foo", 15) if err != nil { - log.Fatal(err) + panic(err) } r, err := tx.QueryRow("SELECT id, name, age FROM user WHERE name = ?", "foo") if err != nil { - panic(err) + panic(fmt.Sprintf("%+v", err)) } var u User @@ -67,7 +67,7 @@ func ExampleTx() { panic(err) } - // Output: {10 foo 15 { }} + // Output: {10 foo 15} // 10 foo 15 } @@ -80,9 +80,9 @@ func TestOpen(t *testing.T) { assert.NoError(t, err) err = db.Exec(` - CREATE TABLE tableA (a INTEGER UNIQUE NOT NULL, b (c (d DOUBLE PRIMARY KEY))); + CREATE TABLE tableA (a INTEGER UNIQUE NOT NULL, b DOUBLE PRIMARY KEY); CREATE TABLE tableB (a TEXT NOT NULL DEFAULT 'hello', PRIMARY KEY (a)); - CREATE TABLE tableC (a INTEGER, b BOOL); + CREATE TABLE tableC (a INTEGER, b INTEGER); CREATE INDEX tableC_a_b_idx ON tableC(a, b); CREATE SEQUENCE seqD INCREMENT BY 10 CYCLE MINVALUE 100 NO MAXVALUE START 500; @@ -105,16 +105,16 @@ func TestOpen(t *testing.T) { var count int want := []string{ - `{"name":"__chai_catalog", "namespace":1, "sql":"CREATE TABLE __chai_catalog (name TEXT NOT NULL, type TEXT NOT NULL, namespace INTEGER, sql TEXT, rowid_sequence_name TEXT, owner (table_name TEXT NOT NULL, paths ARRAY), CONSTRAINT __chai_catalog_pk PRIMARY KEY (name))", "type":"table"}`, - `{"name":"__chai_sequence", "sql":"CREATE TABLE __chai_sequence (name TEXT NOT NULL, seq INTEGER, CONSTRAINT __chai_sequence_pk PRIMARY KEY (name))", "namespace":2, "type":"table"}`, - `{"name":"__chai_store_seq", "owner":{"table_name":"__chai_catalog"}, "sql":"CREATE SEQUENCE __chai_store_seq MAXVALUE 9223372036837998591 START WITH 10 CACHE 0", "type":"sequence"}`, - `{"name":"seqD", "sql":"CREATE SEQUENCE seqD INCREMENT BY 10 MINVALUE 100 START WITH 500 CYCLE", "type":"sequence"}`, - `{"name":"tableA", "sql":"CREATE TABLE tableA (a INTEGER NOT NULL, b (c (d DOUBLE NOT NULL)), CONSTRAINT tableA_a_unique UNIQUE (a), CONSTRAINT tableA_pk PRIMARY KEY (b.c.d))", "namespace":10, "type":"table"}`, - `{"name":"tableA_a_idx", "owner":{"table_name":"tableA", "paths":["a"]}, "sql":"CREATE UNIQUE INDEX tableA_a_idx ON tableA (a)", "namespace":11, "type":"index"}`, - `{"name":"tableB", "sql":"CREATE TABLE tableB (a TEXT NOT NULL DEFAULT \"hello\", CONSTRAINT tableB_pk PRIMARY KEY (a))", "namespace":12, "type":"table"}`, - `{"name":"tableC", "rowid_sequence_name":"tableC_seq", "sql":"CREATE TABLE tableC (a INTEGER, b BOOLEAN)", "namespace":13, "type":"table"}`, - `{"name":"tableC_a_b_idx", "owner":{"table_name":"tableC"}, "sql":"CREATE INDEX tableC_a_b_idx ON tableC (a, b)", "namespace":14, "type":"index"}`, - `{"name":"tableC_seq", "owner":{"table_name":"tableC"}, "sql":"CREATE SEQUENCE tableC_seq CACHE 64", "type":"sequence"}`, + `{"name":"__chai_catalog", "namespace":1, "owner_table_columns":null, "owner_table_name":null, "rowid_sequence_name":null, "sql":"CREATE TABLE __chai_catalog (name TEXT NOT NULL, type TEXT NOT NULL, namespace BIGINT, sql TEXT, rowid_sequence_name TEXT, owner_table_name TEXT, owner_table_columns TEXT, CONSTRAINT __chai_catalog_pk PRIMARY KEY (name))", "type":"table"}`, + `{"name":"__chai_sequence", "namespace":2, "owner_table_columns":null, "owner_table_name":null, "rowid_sequence_name":null, "sql":"CREATE TABLE __chai_sequence (name TEXT NOT NULL, seq BIGINT, CONSTRAINT __chai_sequence_pk PRIMARY KEY (name))", "type":"table"}`, + `{"name":"__chai_store_seq", "namespace":null, "owner_table_columns":null, "owner_table_name":"__chai_catalog", "rowid_sequence_name":null, "sql":"CREATE SEQUENCE __chai_store_seq MAXVALUE 9223372036837998591 START WITH 10 CACHE 0", "type":"sequence"}`, + `{"name":"seqD", "namespace":null, "owner_table_columns":null, "owner_table_name":null, "rowid_sequence_name":null, "sql":"CREATE SEQUENCE seqD INCREMENT BY 10 MINVALUE 100 START WITH 500 CYCLE", "type":"sequence"}`, + `{"name":"tableA", "namespace":10, "owner_table_columns":null, "owner_table_name":null, "rowid_sequence_name":null, "sql":"CREATE TABLE tableA (a INTEGER NOT NULL, b DOUBLE NOT NULL, CONSTRAINT tableA_a_unique UNIQUE (a), CONSTRAINT tableA_pk PRIMARY KEY (b))", "type":"table"}`, + `{"name":"tableA_a_idx", "namespace":11, "owner_table_columns":"a", "owner_table_name":"tableA", "rowid_sequence_name":null, "sql":"CREATE UNIQUE INDEX tableA_a_idx ON tableA (a)", "type":"index"}`, + `{"name":"tableB", "namespace":12, "owner_table_columns":null, "owner_table_name":null, "rowid_sequence_name":null, "sql":"CREATE TABLE tableB (a TEXT NOT NULL DEFAULT \"hello\", CONSTRAINT tableB_pk PRIMARY KEY (a))", "type":"table"}`, + `{"name":"tableC", "namespace":13, "owner_table_columns":null, "owner_table_name":null, "rowid_sequence_name":"tableC_seq", "sql":"CREATE TABLE tableC (a INTEGER, b INTEGER)", "type":"table"}`, + `{"name":"tableC_a_b_idx", "namespace":14, "owner_table_columns":null, "owner_table_name":"tableC", "rowid_sequence_name":null, "sql":"CREATE INDEX tableC_a_b_idx ON tableC (a, b)", "type":"index"}`, + `{"name":"tableC_seq", "namespace":null, "owner_table_columns":null, "owner_table_name":"tableC", "rowid_sequence_name":null, "sql":"CREATE SEQUENCE tableC_seq CACHE 64", "type":"sequence"}`, } err = res1.Iterate(func(r *chai.Row) error { count++ @@ -149,7 +149,7 @@ func TestQueryRow(t *testing.T) { assert.NoError(t, err) err = tx.Exec(` - CREATE TABLE test; + CREATE TABLE test(a INTEGER PRIMARY KEY, b TEXT NOT NULL); INSERT INTO test (a, b) VALUES (1, 'foo'), (2, 'bar') `) assert.NoError(t, err) diff --git a/driver/driver.go b/driver/driver.go index 0fa151147..9798a5cb1 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -11,7 +11,7 @@ import ( "github.com/chaisql/chai" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -167,28 +167,6 @@ func (s stmt) Exec(args []driver.Value) (driver.Result, error) { return nil, errors.New("not implemented") } -// CheckNamedValue has the same behaviour as driver.DefaultParameterConverter, except that -// it allows types.Object to be passed as parameters. -// It implements the driver.NamedValueChecker interface. -func (s stmt) CheckNamedValue(nv *driver.NamedValue) error { - if _, ok := nv.Value.(types.Object); ok { - return nil - } - - if _, ok := nv.Value.(object.Scanner); ok { - return nil - } - - var err error - val, err := driver.DefaultParameterConverter.ConvertValue(nv.Value) - if err == nil { - nv.Value = val - return nil - } - - return nil -} - // ExecContext executes a query that doesn't return rows, such // as an INSERT or UPDATE. func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { @@ -258,12 +236,12 @@ var errStop = errors.New("stop") type recordStream struct { res *chai.Result cancelFn func() - c chan row + c chan recordRow wg sync.WaitGroup columns []string } -type row struct { +type recordRow struct { r *chai.Row err error } @@ -274,7 +252,7 @@ func newRecordStream(res *chai.Result) *recordStream { ds := recordStream{ res: res, cancelFn: cancel, - c: make(chan row), + c: make(chan recordRow), } ds.wg.Add(1) @@ -297,7 +275,7 @@ func (rs *recordStream) iterate(ctx context.Context) { select { case <-ctx.Done(): return errStop - case rs.c <- row{ + case rs.c <- recordRow{ r: r, }: @@ -314,7 +292,7 @@ func (rs *recordStream) iterate(ctx context.Context) { return } if err != nil { - rs.c <- row{ + rs.c <- recordRow{ err: err, } return @@ -334,7 +312,7 @@ func (rs *recordStream) Close() error { } func (rs *recordStream) Next(dest []driver.Value) error { - rs.c <- row{} + rs.c <- recordRow{} row, ok := <-rs.c if !ok { @@ -365,12 +343,18 @@ func (rs *recordStream) Next(dest []driver.Value) error { } dest[i] = b case types.TypeInteger.String(): - var ii int64 + var ii int32 err = row.r.ScanColumn(rs.columns[i], &ii) if err != nil { return err } dest[i] = ii + case types.TypeBigint.String(): + var bi int64 + err = row.r.ScanColumn(rs.columns[i], &bi) + if err != nil { + return err + } case types.TypeDouble.String(): var d float64 err = row.r.ScanColumn(rs.columns[i], &d) @@ -399,20 +383,6 @@ func (rs *recordStream) Next(dest []driver.Value) error { return err } dest[i] = b - case types.TypeArray.String(): - var a []any - err = row.r.ScanColumn(rs.columns[i], &a) - if err != nil { - return err - } - dest[i] = a - case types.TypeObject.String(): - m := make(map[string]any) - err = row.r.ScanColumn(rs.columns[i], &m) - if err != nil { - return err - } - dest[i] = m default: err = row.r.ScanColumn(rs.columns[i], dest[i]) if err != nil { @@ -433,12 +403,12 @@ func (v valueScanner) Scan(src any) error { return r.StructScan(v.dest) } - vv, err := object.NewValue(src) + vv, err := row.NewValue(src) if err != nil { return err } - return object.ScanValue(vv, v.dest) + return row.ScanValue(vv, v.dest) } // Scanner turns a variable into a sql.Scanner. diff --git a/driver/driver_test.go b/driver/driver_test.go index f0e025f5e..1e8463919 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -3,6 +3,7 @@ package driver import ( "context" "database/sql" + "fmt" "testing" "time" @@ -12,25 +13,23 @@ import ( type rowtest struct { A int - B []int - C struct{ Foo string } + B string + C bool } -type foo struct{ Foo string } - func TestDriver(t *testing.T) { db, err := sql.Open("chai", ":memory:") assert.NoError(t, err) defer db.Close() - res, err := db.Exec("CREATE TABLE test") + res, err := db.Exec("CREATE TABLE test(a INT, b TEXT, c BOOL)") assert.NoError(t, err) n, err := res.RowsAffected() assert.Error(t, err) require.EqualValues(t, 0, n) for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO test (a, b, c) VALUES (?, ?, ?)", i, []int{i + 1, i + 2, i + 3}, &foo{Foo: "bar"}) + _, err = db.Exec("INSERT INTO test (a, b, c) VALUES (?, ?, ?)", i, fmt.Sprintf("foo%d", i), i%2 == 0) assert.NoError(t, err) } @@ -40,11 +39,11 @@ func TestDriver(t *testing.T) { defer rows.Close() var count int - var dt rowtest + var rt rowtest for rows.Next() { - err = rows.Scan(Scanner(&dt)) + err = rows.Scan(Scanner(&rt)) assert.NoError(t, err) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt) count++ } @@ -59,31 +58,31 @@ func TestDriver(t *testing.T) { var count int var a int - var c foo + var c bool for rows.Next() { err = rows.Scan(&a, Scanner(&c)) assert.NoError(t, err) require.Equal(t, count, a) - require.Equal(t, foo{Foo: "bar"}, c) + require.Equal(t, count%2 == 0, c) count++ } assert.NoError(t, rows.Err()) require.Equal(t, 10, count) }) - t.Run("Multiple fields with ORDER BY", func(t *testing.T) { + t.Run("Multiple columns with ORDER BY", func(t *testing.T) { rows, err := db.Query("SELECT a, c FROM test ORDER BY a") assert.NoError(t, err) defer rows.Close() var count int var a int - var c foo + var c bool for rows.Next() { err = rows.Scan(&a, Scanner(&c)) assert.NoError(t, err) require.Equal(t, count, a) - require.Equal(t, foo{Foo: "bar"}, c) + require.Equal(t, count%2 == 0, c) count++ } assert.NoError(t, rows.Err()) @@ -96,11 +95,11 @@ func TestDriver(t *testing.T) { defer rows.Close() var count int - var dt rowtest + var rt rowtest for rows.Next() { - err = rows.Scan(Scanner(&dt)) + err = rows.Scan(Scanner(&rt)) assert.NoError(t, err) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt) count++ } assert.NoError(t, rows.Err()) @@ -113,18 +112,18 @@ func TestDriver(t *testing.T) { defer rows.Close() var count int - var dt rowtest + var rt rowtest for rows.Next() { - err = rows.Scan(Scanner(&dt)) + err = rows.Scan(Scanner(&rt)) assert.NoError(t, err) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, rt) count++ } assert.NoError(t, rows.Err()) require.Equal(t, 5, count) }) - t.Run("Multiple fields and wildcards", func(t *testing.T) { + t.Run("Multiple columns and wildcards", func(t *testing.T) { rows, err := db.Query("SELECT a, a, *, b, c, * FROM test") assert.NoError(t, err) defer rows.Close() @@ -132,17 +131,18 @@ func TestDriver(t *testing.T) { var count int var a int var aa int - var b []float32 - var c foo + var b string + var c bool var dt1, dt2 rowtest for rows.Next() { err = rows.Scan(&a, Scanner(&aa), Scanner(&dt1), Scanner(&b), Scanner(&c), Scanner(&dt2)) assert.NoError(t, err) require.Equal(t, count, a) - require.Equal(t, []float32{float32(count + 1), float32(count + 2), float32(count + 3)}, b) - require.Equal(t, foo{Foo: "bar"}, c) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt1) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt2) + require.Equal(t, fmt.Sprintf("foo%d", count), b) + + require.Equal(t, count%2 == 0, c) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt1) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt2) count++ } assert.NoError(t, rows.Err()) @@ -197,7 +197,7 @@ func TestDriver(t *testing.T) { for rows.Next() { err = rows.Scan(Scanner(&dt)) assert.NoError(t, err) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt) count++ } assert.NoError(t, rows.Err()) @@ -207,7 +207,7 @@ func TestDriver(t *testing.T) { t.Run("Multiple queries", func(t *testing.T) { rows, err := db.Query(` SELECT * FROM test;;; - INSERT INTO test (a, b, c) VALUES (10, [11, 12, 13], {foo: "bar"}); + INSERT INTO test (a, b, c) VALUES (10, 'foo10', true); SELECT * FROM test; `) assert.NoError(t, err) @@ -218,7 +218,7 @@ func TestDriver(t *testing.T) { for rows.Next() { err = rows.Scan(Scanner(&dt)) assert.NoError(t, err) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt) count++ } assert.NoError(t, rows.Err()) @@ -232,7 +232,7 @@ func TestDriver(t *testing.T) { rows, err := tx.Query(` SELECT * FROM test;;; - INSERT INTO test (a, b, c) VALUES (11, [12, 13, 14], {foo: "bar"}); + INSERT INTO test (a, b, c) VALUES (11, 'foo11', false); SELECT * FROM test; `) assert.NoError(t, err) @@ -243,7 +243,7 @@ func TestDriver(t *testing.T) { for rows.Next() { err = rows.Scan(Scanner(&dt)) assert.NoError(t, err) - require.Equal(t, rowtest{count, []int{count + 1, count + 2, count + 3}, foo{Foo: "bar"}}, dt) + require.Equal(t, rowtest{count, fmt.Sprintf("foo%d", count), count%2 == 0}, dt) count++ } assert.NoError(t, rows.Err()) @@ -270,7 +270,7 @@ func TestDriverWithTimeValues(t *testing.T) { defer db.Close() now := time.Now().UTC() - _, err = db.Exec("CREATE TABLE test; INSERT INTO test (a) VALUES (?)", now) + _, err = db.Exec("CREATE TABLE test(a TIMESTAMP); INSERT INTO test (a) VALUES (?)", now) assert.NoError(t, err) tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) diff --git a/driver/example_test.go b/driver/example_test.go index cfff4938d..5767d09d5 100644 --- a/driver/example_test.go +++ b/driver/example_test.go @@ -3,7 +3,6 @@ package driver_test import ( "database/sql" "fmt" - "log" "github.com/chaisql/chai/driver" ) @@ -17,33 +16,28 @@ type User struct { func Example() { db, err := sql.Open("chai", ":memory:") if err != nil { - log.Fatal(err) + panic(err) } defer db.Close() - _, err = db.Exec("CREATE TABLE IF NOT EXISTS user (name, ...)") + _, err = db.Exec("CREATE TABLE IF NOT EXISTS user (id INT, name TEXT, age INT)") if err != nil { - log.Fatal(err) + panic(err) } _, err = db.Exec("CREATE INDEX IF NOT EXISTS idx_user_name ON user (name)") if err != nil { - log.Fatal(err) + panic(err) } _, err = db.Exec("INSERT INTO user (id, name, age) VALUES (?, ?, ?)", 10, "foo", 15) if err != nil { - log.Fatal(err) + panic(err) } - _, err = db.Exec("INSERT INTO user VALUES ?, ?", &User{ID: 1, Name: "bar", Age: 100}, &User{ID: 2, Name: "baz"}) + rows, err := db.Query("SELECT * FROM user WHERE name = ?", "foo") if err != nil { - log.Fatal(err) - } - - rows, err := db.Query("SELECT * FROM user WHERE name = ?", "bar") - if err != nil { - log.Fatal(err) + panic(err) } defer rows.Close() @@ -51,15 +45,15 @@ func Example() { var u User err = rows.Scan(driver.Scanner(&u)) if err != nil { - log.Fatal(err) + panic(err) } fmt.Println(u) } err = rows.Err() if err != nil { - log.Fatal(err) + panic(err) } - // Output: {1 bar 100} + // Output: {10 foo 15} } diff --git a/example_test.go b/example_test.go index 1348a7004..716497d35 100644 --- a/example_test.go +++ b/example_test.go @@ -7,13 +7,9 @@ import ( ) type User struct { - ID int64 - Name string - Age uint32 - Address struct { - City string - ZipCode string - } + ID int64 + Name string + Age uint32 } func Example() { @@ -24,8 +20,8 @@ func Example() { } defer db.Close() - // Create a table. Chai tables are schemaless by default, you don't need to specify a schema. - err = db.Exec("CREATE TABLE user (name text, ...)") + // Create a table. + err = db.Exec("CREATE TABLE user (id int, name text, age int)") if err != nil { panic(err) } @@ -42,19 +38,7 @@ func Example() { panic(err) } - // Insert some data using object notation - err = db.Exec(`INSERT INTO user VALUES {id: 12, "name": "bar", age: ?, address: {city: "Lyon", zipcode: "69001"}}`, 16) - if err != nil { - panic(err) - } - - // Structs can be used to describe a object - err = db.Exec("INSERT INTO user VALUES ?, ?", &User{ID: 1, Name: "baz", Age: 100}, &User{ID: 2, Name: "bat"}) - if err != nil { - panic(err) - } - - // Query some objects + // Query some rows stream, err := db.Query("SELECT * FROM user WHERE id > ?", 1) if err != nil { panic(err) @@ -79,7 +63,5 @@ func Example() { } // Output: - // {10 foo 15 { }} - // {12 bar 16 {Lyon 69001}} - // {2 bat 0 { }} + // {10 foo 15} } diff --git a/go.mod b/go.mod index a61aa719e..771fdffb8 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/chaisql/chai go 1.21 require ( - github.com/buger/jsonparser v1.1.1 github.com/cockroachdb/errors v1.11.1 github.com/cockroachdb/pebble v1.0.0 github.com/golang-module/carbon/v2 v2.2.14 diff --git a/go.sum b/go.sum index 13a6ada6d..0913c0885 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ github.com/DataDog/zstd v1.5.5 h1:oWf5W7GtOLgp6bciQYDmhHHjdhYkALu6S/5Ni9ZgSvQ= github.com/DataDog/zstd v1.5.5/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= -github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/datadriven v1.0.3-0.20230801171734-e384cf455877 h1:1MLK4YpFtIEo3ZtMA5C795Wtv5VuUnrXX7mQG+aHg6o= diff --git a/internal/database/catalog.go b/internal/database/catalog.go index 95beeee5d..97bfcd2d5 100644 --- a/internal/database/catalog.go +++ b/internal/database/catalog.go @@ -7,8 +7,8 @@ import ( "strings" errs "github.com/chaisql/chai/internal/errors" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/pkg/atomic" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -169,7 +169,7 @@ func NewCatalogWriter(c *Catalog) *CatalogWriter { } func (c *CatalogWriter) Init(tx *Transaction) error { - // ensure the catalog schema is store in the catalog table + // ensure the catalog schema is stored in the catalog table err := c.ensureTableExists(tx, c.Catalog.CatalogTable.info) if err != nil { return err @@ -309,9 +309,9 @@ func (c *CatalogWriter) CreateIndex(tx *Transaction, info *IndexInfo) (*IndexInf return nil, err } - // check if the indexed fields exist - for _, p := range info.Paths { - fc := ti.GetFieldConstraintForPath(p) + // check if the indexed columns exist + for _, p := range info.Columns { + fc := ti.GetColumnConstraint(p) if fc == nil { return nil, errors.Errorf("field %q does not exist for table %q", p, ti.TableName) } @@ -345,8 +345,8 @@ func (c *CatalogWriter) DropIndex(tx *Transaction, name string) error { } // check if the index has been created by a table constraint - if len(info.Owner.Paths) > 0 { - return fmt.Errorf("cannot drop index %s because constraint on %s(%s) requires it", info.IndexName, info.Owner.TableName, info.Owner.Paths) + if len(info.Owner.Columns) > 0 { + return fmt.Errorf("cannot drop index %s because constraint on %s(%s) requires it", info.IndexName, info.Owner.TableName, info.Owner.Columns) } _, err = c.Cache.Delete(tx, RelationIndexType, name) @@ -366,8 +366,8 @@ func (c *CatalogWriter) dropIndex(tx *Transaction, info *IndexInfo) error { return c.CatalogTable.Delete(tx, info.IndexName) } -// AddFieldConstraint adds a field constraint to a table. -func (c *CatalogWriter) AddFieldConstraint(tx *Transaction, tableName string, fc *FieldConstraint, tcs TableConstraints) error { +// AddColumnConstraint adds a field constraint to a table. +func (c *CatalogWriter) AddColumnConstraint(tx *Transaction, tableName string, cc *ColumnConstraint, tcs TableConstraints) error { r, err := c.Cache.Get(RelationTableType, tableName) if err != nil { return err @@ -375,8 +375,8 @@ func (c *CatalogWriter) AddFieldConstraint(tx *Transaction, tableName string, fc ti := r.(*TableInfoRelation).Info clone := ti.Clone() - if fc != nil { - err = clone.AddFieldConstraint(fc) + if cc != nil { + err = clone.AddColumnConstraint(cc) if err != nil { return err } @@ -580,7 +580,7 @@ func (r *IndexInfoRelation) SetName(name string) { } func (r *IndexInfoRelation) GenerateBaseName() string { - return fmt.Sprintf("%s_%s_idx", r.Info.Owner.TableName, pathsToIndexName(r.Info.Paths)) + return fmt.Sprintf("%s_%s_idx", r.Info.Owner.TableName, columnsToIndexName(r.Info.Columns)) } func (r *IndexInfoRelation) Clone() Relation { @@ -589,18 +589,8 @@ func (r *IndexInfoRelation) Clone() Relation { return &clone } -func pathsToIndexName(paths []object.Path) string { - var s strings.Builder - - for i, p := range paths { - if i > 0 { - s.WriteRune('_') - } - - s.WriteString(p.String()) - } - - return s.String() +func columnsToIndexName(columns []string) string { + return strings.Join(columns, "_") } type catalogCache struct { @@ -723,7 +713,7 @@ func (c *catalogCache) Replace(tx *Transaction, o Relation) error { old, ok := m[o.Name()] if !ok { - return errors.WithStack(errs.NotFoundError{Name: o.Name()}) + return errs.NewNotFoundError(o.Name()) } m[o.Name()] = o @@ -740,7 +730,7 @@ func (c *catalogCache) Delete(tx *Transaction, tp, name string) (Relation, error o, ok := m[name] if !ok { - return nil, errors.WithStack(errs.NotFoundError{Name: name}) + return nil, errs.NewNotFoundError(name) } delete(m, name) @@ -757,7 +747,7 @@ func (c *catalogCache) Get(tp, name string) (Relation, error) { o, ok := m[name] if !ok { - return nil, errors.WithStack(&errs.NotFoundError{Name: name}) + return nil, errs.NewNotFoundError(name) } return o, nil @@ -800,58 +790,48 @@ func newCatalogStore() *CatalogStore { { Name: CatalogTableName + "_pk", PrimaryKey: true, - Paths: []object.Path{ - object.NewPath("name"), + Columns: []string{ + "name", }, }, }, - FieldConstraints: MustNewFieldConstraints( - &FieldConstraint{ + ColumnConstraints: MustNewColumnConstraints( + &ColumnConstraint{ Position: 0, - Field: "name", + Column: "name", Type: types.TypeText, IsNotNull: true, }, - &FieldConstraint{ + &ColumnConstraint{ Position: 1, - Field: "type", + Column: "type", Type: types.TypeText, IsNotNull: true, }, - &FieldConstraint{ + &ColumnConstraint{ Position: 2, - Field: "namespace", - Type: types.TypeInteger, + Column: "namespace", + Type: types.TypeBigint, }, - &FieldConstraint{ + &ColumnConstraint{ Position: 3, - Field: "sql", + Column: "sql", Type: types.TypeText, }, - &FieldConstraint{ + &ColumnConstraint{ Position: 4, - Field: "rowid_sequence_name", + Column: "rowid_sequence_name", Type: types.TypeText, }, - &FieldConstraint{ + &ColumnConstraint{ Position: 5, - Field: "owner", - Type: types.TypeObject, - AnonymousType: &AnonymousType{ - FieldConstraints: MustNewFieldConstraints( - &FieldConstraint{ - Position: 0, - Field: "table_name", - Type: types.TypeText, - IsNotNull: true, - }, - &FieldConstraint{ - Position: 1, - Field: "paths", - Type: types.TypeArray, - }, - ), - }, + Column: "owner_table_name", + Type: types.TypeText, + }, + &ColumnConstraint{ + Position: 6, + Column: "owner_table_columns", + Type: types.TypeText, // TODO: change to array }, ), } @@ -878,7 +858,7 @@ func (s *CatalogStore) Table(tx *Transaction) *Table { func (s *CatalogStore) Insert(tx *Transaction, r Relation) error { tb := s.Table(tx) - _, _, err := tb.Insert(relationToObject(r)) + _, _, err := tb.Insert(relationToRow(r)) if cerr, ok := err.(*ConstraintViolationError); ok && cerr.Constraint == "PRIMARY KEY" { return errors.WithStack(errs.AlreadyExistsError{Name: r.Name()}) } @@ -891,7 +871,7 @@ func (s *CatalogStore) Replace(tx *Transaction, name string, r Relation) error { tb := s.Table(tx) key := tree.NewKey(types.NewTextValue(name)) - _, err := tb.Replace(key, relationToObject(r)) + _, err := tb.Replace(key, relationToRow(r)) return err } @@ -903,24 +883,24 @@ func (s *CatalogStore) Delete(tx *Transaction, name string) error { return tb.Delete(key) } -func relationToObject(r Relation) types.Object { +func relationToRow(r Relation) row.Row { switch t := r.(type) { case *TableInfoRelation: - return tableInfoToObject(t.Info) + return tableInfoToRow(t.Info) case *IndexInfoRelation: - return indexInfoToObject(t.Info) + return indexInfoToRow(t.Info) case *Sequence: - return sequenceInfoToObject(t.Info) + return sequenceInfoToRow(t.Info) } panic(fmt.Sprintf("relationToObject: unknown type %q", r.Type())) } -func tableInfoToObject(ti *TableInfo) types.Object { - buf := object.NewFieldBuffer() +func tableInfoToRow(ti *TableInfo) row.Row { + buf := row.NewColumnBuffer() buf.Add("name", types.NewTextValue(ti.TableName)) buf.Add("type", types.NewTextValue(RelationTableType)) - buf.Add("namespace", types.NewIntegerValue(int64(ti.StoreNamespace))) + buf.Add("namespace", types.NewBigintValue(int64(ti.StoreNamespace))) buf.Add("sql", types.NewTextValue(ti.String())) if ti.RowidSequenceName != "" { buf.Add("rowid_sequence_name", types.NewTextValue(ti.RowidSequenceName)) @@ -929,40 +909,33 @@ func tableInfoToObject(ti *TableInfo) types.Object { return buf } -func indexInfoToObject(i *IndexInfo) types.Object { - buf := object.NewFieldBuffer() +func indexInfoToRow(i *IndexInfo) row.Row { + buf := row.NewColumnBuffer() buf.Add("name", types.NewTextValue(i.IndexName)) buf.Add("type", types.NewTextValue(RelationIndexType)) - buf.Add("namespace", types.NewIntegerValue(int64(i.StoreNamespace))) + buf.Add("namespace", types.NewBigintValue(int64(i.StoreNamespace))) buf.Add("sql", types.NewTextValue(i.String())) if i.Owner.TableName != "" { - buf.Add("owner", types.NewObjectValue(ownerToObject(&i.Owner))) + buf.Add("owner_table_name", types.NewTextValue(i.Owner.TableName)) + if len(i.Owner.Columns) > 0 { + buf.Add("owner_table_columns", types.NewTextValue(strings.Join(i.Owner.Columns, ","))) + } } return buf } -func sequenceInfoToObject(seq *SequenceInfo) types.Object { - buf := object.NewFieldBuffer() +func sequenceInfoToRow(seq *SequenceInfo) row.Row { + buf := row.NewColumnBuffer() buf.Add("name", types.NewTextValue(seq.Name)) buf.Add("type", types.NewTextValue(RelationSequenceType)) buf.Add("sql", types.NewTextValue(seq.String())) if seq.Owner.TableName != "" { - buf.Add("owner", types.NewObjectValue(ownerToObject(&seq.Owner))) - } - - return buf -} - -func ownerToObject(owner *Owner) types.Object { - buf := object.NewFieldBuffer().Add("table_name", types.NewTextValue(owner.TableName)) - if owner.Paths != nil { - vb := object.NewValueBuffer() - for _, p := range owner.Paths { - vb.Append(types.NewTextValue(p.String())) + buf.Add("owner_table_name", types.NewTextValue(seq.Owner.TableName)) + if len(seq.Owner.Columns) > 0 { + buf.Add("owner_table_columns", types.NewTextValue(strings.Join(seq.Owner.Columns, ","))) } - buf.Add("paths", types.NewArrayValue(vb)) } return buf diff --git a/internal/database/catalog_test.go b/internal/database/catalog_test.go index 2518fb392..1df837773 100644 --- a/internal/database/catalog_test.go +++ b/internal/database/catalog_test.go @@ -9,7 +9,6 @@ import ( "github.com/chaisql/chai/internal/database" errs "github.com/chaisql/chai/internal/errors" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/tree" @@ -41,7 +40,7 @@ func updateCatalog(t testing.TB, db *database.Database, fn func(tx *database.Tra // - GetTable // - DropTable // - RenameTable -// - AddFieldConstraint +// - AddColumnConstraint func TestCatalogTable(t *testing.T) { t.Run("Get", func(t *testing.T) { db := testutil.NewTestDB(t) @@ -58,7 +57,7 @@ func TestCatalogTable(t *testing.T) { // Getting a table that doesn't exist should fail. _, err = catalog.GetTable(tx, "unknown") if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "unknown"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("unknown")) } return nil @@ -81,13 +80,13 @@ func TestCatalogTable(t *testing.T) { // Getting a table that has been dropped should fail. _, err = catalog.GetTable(tx, "test") if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "test"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("test")) } // Dropping a table that doesn't exist should fail. err = catalog.DropTable(tx, "test") if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "test"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("test")) } return errDontCommit @@ -100,22 +99,22 @@ func TestCatalogTable(t *testing.T) { db := testutil.NewTestDB(t) ti := &database.TableInfo{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "name", Type: types.TypeText, IsNotNull: true}, - &database.FieldConstraint{Field: "age", Type: types.TypeInteger}, - &database.FieldConstraint{Field: "gender", Type: types.TypeText}, - &database.FieldConstraint{Field: "city", Type: types.TypeText}, + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{Column: "name", Type: types.TypeText, IsNotNull: true}, + &database.ColumnConstraint{Column: "age", Type: types.TypeInteger}, + &database.ColumnConstraint{Column: "gender", Type: types.TypeText}, + &database.ColumnConstraint{Column: "city", Type: types.TypeText}, ), TableConstraints: []*database.TableConstraint{ - {Paths: []object.Path{testutil.ParseObjectPath(t, "age")}, PrimaryKey: true}, + {Columns: []string{"age"}, PrimaryKey: true}, }} updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { err := catalog.CreateTable(tx, "foo", ti) assert.NoError(t, err) - _, err = catalog.CreateIndex(tx, &database.IndexInfo{Paths: []object.Path{testutil.ParseObjectPath(t, "gender")}, IndexName: "idx_gender", Owner: database.Owner{TableName: "foo"}}) + _, err = catalog.CreateIndex(tx, &database.IndexInfo{Columns: []string{"gender"}, IndexName: "idx_gender", Owner: database.Owner{TableName: "foo"}}) assert.NoError(t, err) - _, err = catalog.CreateIndex(tx, &database.IndexInfo{Paths: []object.Path{testutil.ParseObjectPath(t, "city")}, IndexName: "idx_city", Owner: database.Owner{TableName: "foo"}, Unique: true}) + _, err = catalog.CreateIndex(tx, &database.IndexInfo{Columns: []string{"city"}, IndexName: "idx_city", Owner: database.Owner{TableName: "foo"}, Unique: true}) assert.NoError(t, err) seq := database.SequenceInfo{ @@ -143,14 +142,14 @@ func TestCatalogTable(t *testing.T) { // Getting the old table should return an error. _, err = catalog.GetTable(tx, "foo") if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "foo"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("foo")) } tb, err := catalog.GetTable(tx, "zoo") assert.NoError(t, err) // The field constraints should be the same. - require.Equal(t, ti.FieldConstraints, tb.Info.FieldConstraints) + require.Equal(t, ti.ColumnConstraints, tb.Info.ColumnConstraints) // Check that the indexes have been updated as well. idxs := catalog.ListIndexes(tb.Info.TableName) @@ -169,7 +168,7 @@ func TestCatalogTable(t *testing.T) { // Renaming a non existing table should return an error err = catalog.RenameTable(tx, "foo", "") if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "foo"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("foo")) } return errDontCommit @@ -178,16 +177,16 @@ func TestCatalogTable(t *testing.T) { require.Equal(t, clone, db.Catalog()) }) - t.Run("Add field constraint", func(t *testing.T) { + t.Run("Add column constraint", func(t *testing.T) { db := testutil.NewTestDB(t) - ti := &database.TableInfo{FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "name", Type: types.TypeText, IsNotNull: true}, - &database.FieldConstraint{Field: "age", Type: types.TypeInteger}, - &database.FieldConstraint{Field: "gender", Type: types.TypeText}, - &database.FieldConstraint{Field: "city", Type: types.TypeText}, + ti := &database.TableInfo{ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{Column: "name", Type: types.TypeText, IsNotNull: true}, + &database.ColumnConstraint{Column: "age", Type: types.TypeInteger}, + &database.ColumnConstraint{Column: "gender", Type: types.TypeText}, + &database.ColumnConstraint{Column: "city", Type: types.TypeText}, ), TableConstraints: []*database.TableConstraint{ - {Paths: []object.Path{testutil.ParseObjectPath(t, "age")}, PrimaryKey: true}, + {Columns: []string{"age"}, PrimaryKey: true}, }} updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { @@ -199,37 +198,37 @@ func TestCatalogTable(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { // Add field constraint - fieldToAdd := database.FieldConstraint{ - Field: "last_name", Type: types.TypeText, + fieldToAdd := database.ColumnConstraint{ + Column: "last_name", Type: types.TypeText, } // Add table constraint var tcs database.TableConstraints tcs = append(tcs, &database.TableConstraint{ Check: expr.Constraint(testutil.ParseExpr(t, "last_name > first_name")), }) - err := catalog.AddFieldConstraint(tx, "foo", &fieldToAdd, tcs) + err := catalog.AddColumnConstraint(tx, "foo", &fieldToAdd, tcs) assert.NoError(t, err) tb, err := catalog.GetTable(tx, "foo") assert.NoError(t, err) // The field constraints should not be the same. - require.Contains(t, tb.Info.FieldConstraints.Ordered, &fieldToAdd) + require.Contains(t, tb.Info.ColumnConstraints.Ordered, &fieldToAdd) require.Equal(t, expr.Constraint(testutil.ParseExpr(t, "last_name > first_name")), tb.Info.TableConstraints[1].Check) // Renaming a non existing table should return an error - err = catalog.AddFieldConstraint(tx, "bar", &fieldToAdd, nil) + err = catalog.AddColumnConstraint(tx, "bar", &fieldToAdd, nil) if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "bar"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("bar")) } // Adding a existing field should return an error - err = catalog.AddFieldConstraint(tx, "foo", ti.FieldConstraints.Ordered[0], nil) + err = catalog.AddColumnConstraint(tx, "foo", ti.ColumnConstraints.Ordered[0], nil) assert.Error(t, err) // Adding a second primary key should return an error - err = catalog.AddFieldConstraint(tx, "foo", nil, database.TableConstraints{ - {Paths: []object.Path{testutil.ParseObjectPath(t, "age")}, PrimaryKey: true}, + err = catalog.AddColumnConstraint(tx, "foo", nil, database.TableConstraints{ + {Columns: []string{"age"}, PrimaryKey: true}, }) assert.Error(t, err) @@ -283,11 +282,11 @@ func TestCatalogCreateIndex(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { return catalog.CreateTable(tx, "test", &database.TableInfo{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "a", Type: types.TypeText}, + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{Column: "a", Type: types.TypeText}, ), TableConstraints: []*database.TableConstraint{ - {Paths: []object.Path{testutil.ParseObjectPath(t, "a")}, PrimaryKey: true}, + {Columns: []string{"a"}, PrimaryKey: true}, }, }) }) @@ -296,7 +295,7 @@ func TestCatalogCreateIndex(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { _, err := catalog.CreateIndex(tx, &database.IndexInfo{ - IndexName: "idx_a", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "a")}, + IndexName: "idx_a", Owner: database.Owner{TableName: "test"}, Columns: []string{"a"}, }) assert.NoError(t, err) idx, err := catalog.GetIndex(tx, "idx_a") @@ -314,20 +313,20 @@ func TestCatalogCreateIndex(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { return catalog.CreateTable(tx, "test", &database.TableInfo{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "foo", Type: types.TypeText}, + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{Column: "foo", Type: types.TypeText}, ), }) }) updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { _, err := catalog.CreateIndex(tx, &database.IndexInfo{ - IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "foo")}, + IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }) assert.NoError(t, err) _, err = catalog.CreateIndex(tx, &database.IndexInfo{ - IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "foo")}, + IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }) assert.ErrorIs(t, err, errs.AlreadyExistsError{Name: "idxFoo"}) return nil @@ -338,10 +337,10 @@ func TestCatalogCreateIndex(t *testing.T) { db := testutil.NewTestDB(t) updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { _, err := catalog.CreateIndex(tx, &database.IndexInfo{ - IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "foo")}, + IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }) if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "test"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("test")) } return nil @@ -353,36 +352,28 @@ func TestCatalogCreateIndex(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { return catalog.CreateTable(tx, "test", &database.TableInfo{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "foo", Type: types.TypeObject, AnonymousType: &database.AnonymousType{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: " bar ", Type: types.TypeObject, AnonymousType: &database.AnonymousType{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "c", Type: types.TypeText}, - ), - }}, - ), - }}, + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{Column: "foo", Type: types.TypeInteger}, ), }) }) updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { _, err := catalog.CreateIndex(tx, &database.IndexInfo{ - Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "foo.` bar `.c")}, + Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }) assert.NoError(t, err) - _, err = catalog.GetIndex(tx, "test_foo. bar .c_idx") + _, err = catalog.GetIndex(tx, "test_foo_idx") assert.NoError(t, err) // create another one _, err = catalog.CreateIndex(tx, &database.IndexInfo{ - Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "foo.` bar `.c")}, + Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }) assert.NoError(t, err) - _, err = catalog.GetIndex(tx, "test_foo. bar .c_idx1") + _, err = catalog.GetIndex(tx, "test_foo_idx1") assert.NoError(t, err) return nil }) @@ -395,18 +386,18 @@ func TestTxDropIndex(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { err := catalog.CreateTable(tx, "test", &database.TableInfo{ - FieldConstraints: database.MustNewFieldConstraints( - &database.FieldConstraint{Field: "foo", Type: types.TypeText}, - &database.FieldConstraint{Field: "bar", Type: types.TypeAny}, + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{Column: "foo", Type: types.TypeText}, + &database.ColumnConstraint{Column: "bar", Type: types.TypeBoolean}, ), }) assert.NoError(t, err) _, err = catalog.CreateIndex(tx, &database.IndexInfo{ - IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "foo")}, + IndexName: "idxFoo", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }) assert.NoError(t, err) _, err = catalog.CreateIndex(tx, &database.IndexInfo{ - IndexName: "idxBar", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{testutil.ParseObjectPath(t, "bar")}, + IndexName: "idxBar", Owner: database.Owner{TableName: "test"}, Columns: []string{"bar"}, }) assert.NoError(t, err) return nil @@ -438,7 +429,7 @@ func TestTxDropIndex(t *testing.T) { updateCatalog(t, db, func(tx *database.Transaction, catalog *database.CatalogWriter) error { err := catalog.DropIndex(tx, "idxFoo") - assert.ErrorIs(t, err, &errs.NotFoundError{Name: "idxFoo"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("idxFoo")) return nil }) }) @@ -450,8 +441,8 @@ func TestReadOnlyTables(t *testing.T) { defer db.Close() res, err := db.Query(` - CREATE TABLE foo (a int, b (c double unique)); - CREATE INDEX idx_foo_a ON foo(a); + CREATE TABLE foo (a int, b double unique, c text); + CREATE INDEX idx_foo_a ON foo(a, c); SELECT * FROM __chai_catalog `) assert.NoError(t, err) @@ -461,19 +452,19 @@ func TestReadOnlyTables(t *testing.T) { err = res.Iterate(func(r *chai.Row) error { switch i { case 0: - testutil.RequireJSONEq(t, r, `{"name":"__chai_catalog", "namespace":1, "sql":"CREATE TABLE __chai_catalog (name TEXT NOT NULL, type TEXT NOT NULL, namespace INTEGER, sql TEXT, rowid_sequence_name TEXT, owner (table_name TEXT NOT NULL, paths ARRAY), CONSTRAINT __chai_catalog_pk PRIMARY KEY (name))", "type":"table"}`) + testutil.RequireJSONEq(t, r, `{"name":"__chai_catalog", "namespace":1, "owner_table_name": null, "owner_table_columns": null, "rowid_sequence_name": null, "sql":"CREATE TABLE __chai_catalog (name TEXT NOT NULL, type TEXT NOT NULL, namespace BIGINT, sql TEXT, rowid_sequence_name TEXT, owner_table_name TEXT, owner_table_columns TEXT, CONSTRAINT __chai_catalog_pk PRIMARY KEY (name))", "type":"table"}`) case 1: - testutil.RequireJSONEq(t, r, `{"name":"__chai_sequence", "sql":"CREATE TABLE __chai_sequence (name TEXT NOT NULL, seq INTEGER, CONSTRAINT __chai_sequence_pk PRIMARY KEY (name))", "namespace":2, "type":"table"}`) + testutil.RequireJSONEq(t, r, `{"name":"__chai_sequence", "namespace":2, "owner_table_name": null, "owner_table_columns":null, "rowid_sequence_name": null, "sql":"CREATE TABLE __chai_sequence (name TEXT NOT NULL, seq BIGINT, CONSTRAINT __chai_sequence_pk PRIMARY KEY (name))", "type":"table"}`) case 2: - testutil.RequireJSONEq(t, r, `{"name":"__chai_store_seq", "owner":{"table_name":"__chai_catalog"}, "sql":"CREATE SEQUENCE __chai_store_seq MAXVALUE 9223372036837998591 START WITH 10 CACHE 0", "type":"sequence"}`) + testutil.RequireJSONEq(t, r, `{"name":"__chai_store_seq", "namespace":null, "owner_table_name": "__chai_catalog", "owner_table_columns":null, "rowid_sequence_name": null, "sql":"CREATE SEQUENCE __chai_store_seq MAXVALUE 9223372036837998591 START WITH 10 CACHE 0", "type":"sequence"}`) case 3: - testutil.RequireJSONEq(t, r, `{"name":"foo", "rowid_sequence_name":"foo_seq", "sql":"CREATE TABLE foo (a INTEGER, b (c DOUBLE), CONSTRAINT \"foo_b.c_unique\" UNIQUE (b.c))", "namespace":10, "type":"table"}`) + testutil.RequireJSONEq(t, r, `{"name":"foo", "namespace":10, "owner_table_name": null, "owner_table_columns":null, "rowid_sequence_name":"foo_seq", "sql":"CREATE TABLE foo (a INTEGER, b DOUBLE, c TEXT, CONSTRAINT foo_b_unique UNIQUE (b))", "namespace":10, "type":"table"}`) case 4: - testutil.RequireJSONEq(t, r, `{"name":"foo_b.c_idx", "owner":{"table_name":"foo", "paths":["b.c"]}, "sql":"CREATE UNIQUE INDEX `+"`foo_b.c_idx`"+` ON foo (b.c)", "namespace":11, "type":"index"}`) + testutil.RequireJSONEq(t, r, `{"name":"foo_b_idx", "namespace":11, "owner_table_name":"foo", "owner_table_columns": "b", "rowid_sequence_name": null, "sql":"CREATE UNIQUE INDEX foo_b_idx ON foo (b)", "type":"index"}`) case 5: - testutil.RequireJSONEq(t, r, `{"name":"foo_seq", "owner":{"table_name":"foo"}, "sql":"CREATE SEQUENCE foo_seq CACHE 64", "type":"sequence"}`) + testutil.RequireJSONEq(t, r, `{"name":"foo_seq", "namespace":null, "owner_table_name":"foo", "owner_table_columns":null, "rowid_sequence_name": null, "sql":"CREATE SEQUENCE foo_seq CACHE 64", "type":"sequence"}`) case 6: - testutil.RequireJSONEq(t, r, `{"name":"idx_foo_a", "sql":"CREATE INDEX idx_foo_a ON foo (a)", "namespace":12, "type":"index", "owner": {"table_name": "foo"}}`) + testutil.RequireJSONEq(t, r, `{"name":"idx_foo_a", "namespace":12, "owner_table_name":"foo", "owner_table_columns":null, "rowid_sequence_name": null, "sql":"CREATE INDEX idx_foo_a ON foo (a, c)", "type":"index", "owner_table_name":"foo"}`) default: t.Fatalf("count should be 6, got %d", i) } diff --git a/internal/database/catalogstore/store.go b/internal/database/catalogstore/store.go index bdbd78033..d978b31b1 100644 --- a/internal/database/catalogstore/store.go +++ b/internal/database/catalogstore/store.go @@ -62,7 +62,7 @@ func loadSequences(tx *database.Transaction, info []database.SequenceInfo) ([]da } v, err := r.Get("seq") - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return nil, err } @@ -139,7 +139,7 @@ func tableInfoFromRow(r database.Row) (*database.TableInfo, error) { ti.StoreNamespace = tree.Namespace(storeNamespace) v, err = r.Get("rowid_sequence_name") - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return nil, err } if err == nil && v.Type() != types.TypeNull { @@ -176,15 +176,11 @@ func indexInfoFromRow(r database.Row) (*database.IndexInfo, error) { i.StoreNamespace = tree.Namespace(storeNamespace) - v, err = r.Get("owner") - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + owner, err := ownerFromRow(r) + if err != nil { return nil, err } - if err == nil && v.Type() != types.TypeNull { - owner, err := ownerFromObject(types.AsObject(v)) - if err != nil { - return nil, err - } + if owner != nil { i.Owner = *owner } @@ -204,48 +200,37 @@ func sequenceInfoFromRow(r database.Row) (*database.SequenceInfo, error) { i := stmt.(*statement.CreateSequenceStmt).Info - v, err := r.Get("owner") - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { - return nil, errors.Wrap(err, "failed to get owner field") + owner, err := ownerFromRow(r) + if err != nil { + return nil, err } - if err == nil && v.Type() != types.TypeNull { - owner, err := ownerFromObject(types.AsObject(v)) - if err != nil { - return nil, errors.Wrap(err, "failed to get owner") - } + if owner != nil { i.Owner = *owner } return &i, nil } -func ownerFromObject(o types.Object) (*database.Owner, error) { +func ownerFromRow(r database.Row) (*database.Owner, error) { var owner database.Owner - v, err := o.GetByField("table_name") - if err != nil { + v, err := r.Get("owner_table_name") + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return nil, err } + if err != nil || v.Type() == types.TypeNull { + return nil, nil + } owner.TableName = types.AsString(v) - v, err = o.GetByField("paths") - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + v, err = r.Get("owner_table_columns") + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return nil, err } if err == nil && v.Type() != types.TypeNull { - err = types.AsArray(v).Iterate(func(i int, value types.Value) error { - pp, err := parser.ParsePath(types.AsString(value)) - if err != nil { - return err - } - - owner.Paths = append(owner.Paths, pp) - return nil - }) - if err != nil { - return nil, err - } + cols := types.AsString(v) + owner.Columns = strings.Split(cols, ",") } return &owner, nil diff --git a/internal/database/constraint.go b/internal/database/constraint.go index 48cdcba9b..bac70b9ea 100644 --- a/internal/database/constraint.go +++ b/internal/database/constraint.go @@ -4,41 +4,32 @@ import ( "fmt" "strings" - "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stringutil" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) -// FieldConstraint describes constraints on a particular field. -type FieldConstraint struct { - Position int - Field string - Type types.Type - IsNotNull bool - DefaultValue TableExpression - AnonymousType *AnonymousType +// ColumnConstraint describes constraints on a particular column. +type ColumnConstraint struct { + Position int + Column string + Type types.Type + IsNotNull bool + DefaultValue TableExpression } -func (f *FieldConstraint) IsEmpty() bool { - return f.Field == "" && f.Type.IsAny() && !f.IsNotNull && f.DefaultValue == nil +func (f *ColumnConstraint) IsEmpty() bool { + return f.Column == "" && f.Type.IsAny() && !f.IsNotNull && f.DefaultValue == nil } -func (f *FieldConstraint) String() string { +func (f *ColumnConstraint) String() string { var s strings.Builder - s.WriteString(f.Field) - if f.Type != types.TypeObject { - s.WriteString(" ") - s.WriteString(strings.ToUpper(f.Type.String())) - } else if f.AnonymousType != nil { - s.WriteString(" ") - s.WriteString(f.AnonymousType.String()) - } else { - s.WriteString(" OBJECT (...)") - } + s.WriteString(f.Column) + s.WriteString(" ") + s.WriteString(strings.ToUpper(f.Type.String())) if f.IsNotNull { s.WriteString(" NOT NULL") @@ -52,172 +43,77 @@ func (f *FieldConstraint) String() string { return s.String() } -// FieldConstraints is a list of field constraints. -type FieldConstraints struct { - Ordered []*FieldConstraint - ByField map[string]*FieldConstraint - AllowExtraFields bool +// ColumnConstraints is a list of column constraints. +type ColumnConstraints struct { + Ordered []*ColumnConstraint + ByColumn map[string]*ColumnConstraint } -func NewFieldConstraints(constraints ...*FieldConstraint) (FieldConstraints, error) { - var fc FieldConstraints +func NewColumnConstraints(constraints ...*ColumnConstraint) (ColumnConstraints, error) { + var fc ColumnConstraints for _, c := range constraints { if err := fc.Add(c); err != nil { - return FieldConstraints{}, err + return ColumnConstraints{}, err } } return fc, nil } -func MustNewFieldConstraints(constraints ...*FieldConstraint) FieldConstraints { - fc, err := NewFieldConstraints(constraints...) +func MustNewColumnConstraints(constraints ...*ColumnConstraint) ColumnConstraints { + fc, err := NewColumnConstraints(constraints...) if err != nil { panic(err) } return fc } -// Add a field constraint to the list. If another constraint exists for the same path +// Add a column constraint to the list. If another constraint exists for the same path // and they are equal, an error is returned. -func (f *FieldConstraints) Add(newFc *FieldConstraint) error { - if f.ByField == nil { - f.ByField = make(map[string]*FieldConstraint) +func (f *ColumnConstraints) Add(newCc *ColumnConstraint) error { + if f.ByColumn == nil { + f.ByColumn = make(map[string]*ColumnConstraint) } - if c, ok := f.ByField[newFc.Field]; ok { - return fmt.Errorf("conflicting constraints: %q and %q: %#v", c.String(), newFc.String(), f.ByField) + if c, ok := f.ByColumn[newCc.Column]; ok { + return fmt.Errorf("conflicting constraints: %q and %q: %#v", c.String(), newCc.String(), f.ByColumn) } // ensure default value type is compatible - if newFc.DefaultValue != nil && !newFc.Type.IsAny() { + if newCc.DefaultValue != nil { // first, try to evaluate the default value - v, err := newFc.DefaultValue.Eval(nil, nil) + v, err := newCc.DefaultValue.Eval(nil, nil) // if there is no error, check if the default value can be converted to the type of the constraint if err == nil { - _, err = object.CastAs(v, newFc.Type) + _, err = v.CastAs(newCc.Type) if err != nil { - return fmt.Errorf("default value %q cannot be converted to type %q", newFc.DefaultValue, newFc.Type) + return fmt.Errorf("default value %q cannot be converted to type %q", newCc.DefaultValue, newCc.Type) } } else { // if there is an error, we know we are using a function that returns an integer (NEXT VALUE FOR) // which is the only one compatible for the moment. // Integers can be converted to other integers, doubles, texts and bools. - switch newFc.Type { - case types.TypeInteger, types.TypeDouble, types.TypeText, types.TypeBoolean: + // TODO: rework + switch newCc.Type { + case types.TypeInteger, types.TypeBigint, types.TypeDouble, types.TypeText: default: - return fmt.Errorf("default value %q cannot be converted to type %q", newFc.DefaultValue, newFc.Type) + return fmt.Errorf("default value %q cannot be converted to type %q", newCc.DefaultValue, newCc.Type) } } } - newFc.Position = len(f.Ordered) - f.Ordered = append(f.Ordered, newFc) - f.ByField[newFc.Field] = newFc + newCc.Position = len(f.Ordered) + f.Ordered = append(f.Ordered, newCc) + f.ByColumn[newCc.Column] = newCc return nil } -// ConversionFunc is called when the type of a value is different than the expected type -// and the value needs to be converted. -type ConversionFunc func(v types.Value, path object.Path, targetType types.Type) (types.Value, error) - -// CastConversion is a ConversionFunc that casts the value to the target type. -func CastConversion(v types.Value, path object.Path, targetType types.Type) (types.Value, error) { - return object.CastAs(v, targetType) -} - -// ConvertValueAtPath converts the value using the field constraints that are applicable -// at the given path. -func (f FieldConstraints) ConvertValueAtPath(path object.Path, v types.Value, conversionFn ConversionFunc) (types.Value, error) { - switch v.Type() { - case types.TypeArray: - vb, err := f.convertArrayAtPath(path, types.AsArray(v), conversionFn) - return types.NewArrayValue(vb), err - case types.TypeObject: - fb, err := f.convertObjectAtPath(path, types.AsObject(v), conversionFn) - return types.NewObjectValue(fb), err - } - return f.convertScalarAtPath(path, v, conversionFn) -} - -// convert the value using field constraints type information. -// if there is a type constraint on a path, apply it. -// if a value is an integer and has no constraint, convert it to double. -// if a value is a timestamp and has no constraint, convert it to text. -func (f FieldConstraints) convertScalarAtPath(path object.Path, v types.Value, conversionFn ConversionFunc) (types.Value, error) { - fc := f.GetFieldConstraintForPath(path) - if fc != nil { - // check if the constraint enforces a particular type - // and if so convert the value to the new type. - if fc.Type != 0 { - newV, err := conversionFn(v, path, fc.Type) - if err != nil { - return v, err - } - - return newV, nil - } - } - - // no constraint have been found for this path. - // convert the value to the type that is stored in the index. - return encoding.ConvertAsIndexType(v, types.TypeAny) -} - -func (f FieldConstraints) GetFieldConstraintForPath(path object.Path) *FieldConstraint { - cur := f - for i := range path { - fc, ok := cur.ByField[path[i].FieldName] - if !ok { - break - } - - if i == len(path)-1 { - return fc - } - - if fc.AnonymousType == nil { - return nil - } - - cur = fc.AnonymousType.FieldConstraints - } - - return nil -} - -func (f FieldConstraints) convertObjectAtPath(path object.Path, d types.Object, conversionFn ConversionFunc) (*object.FieldBuffer, error) { - fb, ok := d.(*object.FieldBuffer) - if !ok { - fb = object.NewFieldBuffer() - err := fb.Copy(d) - if err != nil { - return nil, err - } - } - - err := fb.Apply(func(p object.Path, v types.Value) (types.Value, error) { - return f.convertScalarAtPath(append(path, p...), v, conversionFn) - }) - - return fb, err -} - -func (f FieldConstraints) convertArrayAtPath(path object.Path, a types.Array, conversionFn ConversionFunc) (*object.ValueBuffer, error) { - vb := object.NewValueBuffer() - err := vb.Copy(a) - if err != nil { - return nil, err - } - - err = vb.Apply(func(p object.Path, v types.Value) (types.Value, error) { - return f.convertScalarAtPath(append(path, p...), v, conversionFn) - }) - - return vb, err +func (f ColumnConstraints) GetColumnConstraint(column string) *ColumnConstraint { + return f.ByColumn[column] } type TableExpression interface { - Eval(tx *Transaction, o types.Object) (types.Value, error) + Eval(tx *Transaction, o row.Row) (types.Value, error) + Validate(info *TableInfo) error String() string } @@ -225,7 +121,7 @@ type TableExpression interface { // and not necessarily to a single field path. type TableConstraint struct { Name string - Paths object.Paths + Columns []string Check TableExpression Unique bool PrimaryKey bool @@ -245,11 +141,11 @@ func (t *TableConstraint) String() string { sb.WriteString(")") case t.PrimaryKey: sb.WriteString(" PRIMARY KEY (") - for i, pt := range t.Paths { + for i, c := range t.Columns { if i > 0 { sb.WriteString(", ") } - sb.WriteString(pt.String()) + sb.WriteString(c) if t.SortOrder.IsDesc(i) { sb.WriteString(" DESC") @@ -258,11 +154,11 @@ func (t *TableConstraint) String() string { sb.WriteString(")") case t.Unique: sb.WriteString(" UNIQUE (") - for i, pt := range t.Paths { + for i, c := range t.Columns { if i > 0 { sb.WriteString(", ") } - sb.WriteString(pt.String()) + sb.WriteString(c) if t.SortOrder.IsDesc(i) { sb.WriteString(" DESC") @@ -278,13 +174,13 @@ func (t *TableConstraint) String() string { type TableConstraints []*TableConstraint // ValidateRow checks all the table constraint for the given row. -func (t *TableConstraints) ValidateRow(tx *Transaction, r Row) error { +func (t *TableConstraints) ValidateRow(tx *Transaction, r row.Row) error { for _, tc := range *t { if tc.Check == nil { continue } - v, err := tc.Check.Eval(tx, r.Object()) + v, err := tc.Check.Eval(tx, r) if err != nil { return err } @@ -292,7 +188,7 @@ func (t *TableConstraints) ValidateRow(tx *Transaction, r Row) error { switch v.Type() { case types.TypeBoolean: ok = types.AsBool(v) - case types.TypeInteger: + case types.TypeInteger, types.TypeBigint: ok = types.AsInt64(v) != 0 case types.TypeDouble: ok = types.AsFloat64(v) != 0 @@ -308,53 +204,14 @@ func (t *TableConstraints) ValidateRow(tx *Transaction, r Row) error { return nil } -type AnonymousType struct { - FieldConstraints FieldConstraints -} - -func (an *AnonymousType) AddFieldConstraint(newFc *FieldConstraint) error { - if an.FieldConstraints.ByField == nil { - an.FieldConstraints.ByField = make(map[string]*FieldConstraint) - } - - return an.FieldConstraints.Add(newFc) -} - -func (an *AnonymousType) String() string { - var sb strings.Builder - - sb.WriteString("(") - - hasConstraints := false - for i, fc := range an.FieldConstraints.Ordered { - if i > 0 { - sb.WriteString(", ") - } - - sb.WriteString(fc.String()) - hasConstraints = true - } - - if an.FieldConstraints.AllowExtraFields { - if hasConstraints { - sb.WriteString(", ") - } - sb.WriteString("...") - } - - sb.WriteString(")") - - return sb.String() -} - type ConstraintViolationError struct { Constraint string - Paths []object.Path + Columns []string Key *tree.Key } func (c ConstraintViolationError) Error() string { - return fmt.Sprintf("%s constraint error: %s", c.Constraint, c.Paths) + return fmt.Sprintf("%s constraint error: %s", c.Constraint, c.Columns) } func IsConstraintViolationError(err error) bool { diff --git a/internal/database/constraint_test.go b/internal/database/constraint_test.go index 549de47d0..c8943edd0 100644 --- a/internal/database/constraint_test.go +++ b/internal/database/constraint_test.go @@ -1,101 +1,80 @@ package database_test import ( - "fmt" "testing" "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) -func TestFieldConstraintsAdd(t *testing.T) { +func TestColumnConstraintsAdd(t *testing.T) { tests := []struct { name string - got []*database.FieldConstraint - add database.FieldConstraint - want []*database.FieldConstraint + got []*database.ColumnConstraint + add database.ColumnConstraint + want []*database.ColumnConstraint fails bool }{ { "Same path", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "a", Type: types.TypeInteger}, + []*database.ColumnConstraint{{Column: "a", Type: types.TypeInteger}}, + database.ColumnConstraint{Column: "a", Type: types.TypeInteger}, nil, true, }, { "Different path", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "b", Type: types.TypeInteger}, - []*database.FieldConstraint{ - {Position: 0, Field: "a", Type: types.TypeInteger}, - {Position: 1, Field: "b", Type: types.TypeInteger}, + []*database.ColumnConstraint{{Column: "a", Type: types.TypeInteger}}, + database.ColumnConstraint{Column: "b", Type: types.TypeInteger}, + []*database.ColumnConstraint{ + {Position: 0, Column: "a", Type: types.TypeInteger}, + {Position: 1, Column: "b", Type: types.TypeInteger}, }, false, }, { "Default value conversion, typed constraint", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(testutil.DoubleValue(5))}, - []*database.FieldConstraint{ - {Position: 0, Field: "a", Type: types.TypeInteger}, - {Position: 1, Field: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(testutil.DoubleValue(5))}, + []*database.ColumnConstraint{{Column: "a", Type: types.TypeInteger}}, + database.ColumnConstraint{Column: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(testutil.DoubleValue(5))}, + []*database.ColumnConstraint{ + {Position: 0, Column: "a", Type: types.TypeInteger}, + {Position: 1, Column: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(testutil.DoubleValue(5))}, }, false, }, { "Default value conversion, typed constraint, NEXT VALUE FOR", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})}, - []*database.FieldConstraint{ - {Position: 0, Field: "a", Type: types.TypeInteger}, - {Position: 1, Field: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})}, + []*database.ColumnConstraint{{Column: "a", Type: types.TypeInteger}}, + database.ColumnConstraint{Column: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})}, + []*database.ColumnConstraint{ + {Position: 0, Column: "a", Type: types.TypeInteger}, + {Position: 1, Column: "b", Type: types.TypeInteger, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})}, }, false, }, { "Default value conversion, typed constraint, NEXT VALUE FOR with blob", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "b", Type: types.TypeBlob, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})}, + []*database.ColumnConstraint{{Column: "a", Type: types.TypeInteger}}, + database.ColumnConstraint{Column: "b", Type: types.TypeBlob, DefaultValue: expr.Constraint(expr.NextValueFor{SeqName: "seq"})}, nil, true, }, { "Default value conversion, typed constraint, incompatible value", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "b", Type: types.TypeDouble, DefaultValue: expr.Constraint(testutil.BoolValue(true))}, + []*database.ColumnConstraint{{Column: "a", Type: types.TypeInteger}}, + database.ColumnConstraint{Column: "b", Type: types.TypeDouble, DefaultValue: expr.Constraint(testutil.BoolValue(true))}, nil, true, }, - { - "Default value conversion, untyped constraint", - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - database.FieldConstraint{Field: "b", DefaultValue: expr.Constraint(testutil.IntegerValue(5))}, - []*database.FieldConstraint{ - {Position: 0, Field: "a", Type: types.TypeInteger}, - {Position: 1, Field: "b", DefaultValue: expr.Constraint(testutil.IntegerValue(5))}, - }, - false, - }, - { - "Default value on nested object column", - nil, - database.FieldConstraint{Field: "a.b", DefaultValue: expr.Constraint(testutil.IntegerValue(5))}, - []*database.FieldConstraint{ - {Field: "a.b", DefaultValue: expr.Constraint(testutil.IntegerValue(5))}, - }, - false, - }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - fcs := database.MustNewFieldConstraints(test.got...) + fcs := database.MustNewColumnConstraints(test.got...) err := fcs.Add(&test.add) if test.fails { assert.Error(t, err) @@ -106,81 +85,3 @@ func TestFieldConstraintsAdd(t *testing.T) { }) } } - -func TestFieldConstraintsConvert(t *testing.T) { - tests := []struct { - constraints []*database.FieldConstraint - path object.Path - in, want types.Value - fails bool - }{ - { - nil, - object.NewPath("a"), - types.NewIntegerValue(10), - types.NewDoubleValue(10), - false, - }, - { - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - object.NewPath("a"), - types.NewIntegerValue(10), - types.NewIntegerValue(10), - false, - }, - { - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - object.NewPath("a"), - types.NewDoubleValue(10.5), - types.NewIntegerValue(10), - false, - }, - { - []*database.FieldConstraint{{Field: "a", Type: types.TypeArray}}, - object.NewPath("a"), - types.NewArrayValue(testutil.MakeArray(t, `[10.5, 10.5]`)), - types.NewArrayValue(testutil.MakeArray(t, `[10.5, 10.5]`)), - false, - }, - { - []*database.FieldConstraint{{ - Field: "a", - Type: types.TypeObject, - AnonymousType: &database.AnonymousType{ - FieldConstraints: database.MustNewFieldConstraints(&database.FieldConstraint{ - Field: "b", - Type: types.TypeInteger, - })}}}, - object.NewPath("a"), - types.NewObjectValue(testutil.MakeObject(t, `{"b": 10.5, "c": 10.5}`)), - types.NewObjectValue(testutil.MakeObject(t, `{"b": 10, "c": 10.5}`)), - false, - }, - { - []*database.FieldConstraint{{Field: "a", Type: types.TypeInteger}}, - object.NewPath("a"), - types.NewTextValue("foo"), - types.NewTextValue("foo"), - true, - }, - { - []*database.FieldConstraint{{Field: "a", DefaultValue: expr.Constraint(testutil.IntegerValue(10))}}, - object.NewPath("a"), - types.NewTextValue("foo"), - types.NewTextValue("foo"), - false, - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s / %v to %v", test.path, test.in, test.want), func(t *testing.T) { - got, err := database.MustNewFieldConstraints(test.constraints...).ConvertValueAtPath(test.path, test.in, database.CastConversion) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - require.Equal(t, test.want, got) - } - }) - } -} diff --git a/internal/database/encoding.go b/internal/database/encoding.go index ed7864dce..46bcb31a3 100644 --- a/internal/database/encoding.go +++ b/internal/database/encoding.go @@ -1,89 +1,59 @@ package database import ( - "encoding/binary" - "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) -// EncodeObject validates a row against all the constraints of the table +// EncodeRow validates a row against all the constraints of the table // and encodes it. -func (t *TableInfo) EncodeObject(tx *Transaction, dst []byte, o types.Object) ([]byte, error) { - if ed, ok := o.(*encoding.EncodedObject); ok { - return ed.Encoded, nil +func (t *TableInfo) EncodeRow(tx *Transaction, dst []byte, r row.Row) ([]byte, error) { + if ed, ok := RowIsEncoded(r, &t.ColumnConstraints); ok { + return ed.encoded, nil } - return encodeObject(tx, dst, &t.FieldConstraints, o) + return encodeRow(tx, dst, &t.ColumnConstraints, r) } -func encodeObject(tx *Transaction, dst []byte, fcs *FieldConstraints, o types.Object) ([]byte, error) { - var err error - - // loop over all the defined field contraints in order. - for _, fc := range fcs.Ordered { +func encodeRow(tx *Transaction, dst []byte, ccs *ColumnConstraints, r row.Row) ([]byte, error) { + // loop over all the defined column contraints in order. + for _, cc := range ccs.Ordered { // get the column from the row - v, err := o.GetByField(fc.Field) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + v, err := r.Get(cc.Column) + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return nil, err } - // if the field is not found OR NULL, and the field has a default value, use the default value, + // if the column is not found OR NULL, and the column has a default value, use the default value, // otherwise return an error if v == nil { - if fc.DefaultValue != nil { - v, err = fc.DefaultValue.Eval(tx, o) + if cc.DefaultValue != nil { + v, err = cc.DefaultValue.Eval(tx, r) if err != nil { return nil, err } } } - // if the field is not found OR NULL, and the field is required, return an error - if fc.IsNotNull && (v == nil || v.Type() == types.TypeNull) { - return nil, &ConstraintViolationError{Constraint: "NOT NULL", Paths: []object.Path{object.NewPath(fc.Field)}} - } - if v == nil { v = types.NewNullValue() } - // ensure the value is of the correct type - if fc.Type != types.TypeAny { - v, err = object.CastAs(v, fc.Type) - if err != nil { - return nil, err - } - } else { - v, err = encoding.ConvertAsStoreType(v) - if err != nil { - return nil, err - } + // if the column is not found OR NULL, and the column is required, return an error + if cc.IsNotNull && v.Type() == types.TypeNull { + return nil, &ConstraintViolationError{Constraint: "NOT NULL", Columns: []string{cc.Column}} } - // Encode the value only. - if v.Type() == types.TypeObject { - // encode map length - mlen := len(fc.AnonymousType.FieldConstraints.Ordered) - if fc.AnonymousType.FieldConstraints.AllowExtraFields { - mlen += 1 - } - dst = encoding.EncodeArrayLength(dst, mlen) - dst, err = encodeObject(tx, dst, &fc.AnonymousType.FieldConstraints, types.AsObject(v)) - } else { - dst, err = encoding.EncodeValue(dst, v, false) - } + // ensure the value is of the correct type + v, err = v.CastAs(cc.Type) if err != nil { return nil, err } - } - // encode the extra fields, if any. - if fcs.AllowExtraFields { - dst, err = encodeExtraFields(dst, fcs, o) + dst, err = v.Encode(dst) if err != nil { return nil, err } @@ -92,154 +62,61 @@ func encodeObject(tx *Transaction, dst []byte, fcs *FieldConstraints, o types.Ob return dst, nil } -func encodeExtraFields(dst []byte, fcs *FieldConstraints, d types.Object) ([]byte, error) { - // count the number of extra fields - extraFields := 0 - err := d.Iterate(func(field string, value types.Value) error { - _, ok := fcs.ByField[field] - if ok { - return nil - } - extraFields++ - return nil - }) - if err != nil { - return nil, err - } - - // encode row length - dst = encoding.EncodeObjectLength(dst, extraFields) - if extraFields == 0 { - return dst, nil - } - - fields := make(map[string]struct{}, extraFields) - - err = d.Iterate(func(field string, value types.Value) error { - _, ok := fcs.ByField[field] - if ok { - return nil - } - - // ensure the field is not repeated - if _, ok := fields[field]; ok { - return errors.New("duplicate field " + field) - } - fields[field] = struct{}{} - - // encode the field name first - dst = encoding.EncodeText(dst, field) - - // then make sure the value is stored as the correct type - value, err = encoding.ConvertAsStoreType(value) - if err != nil { - return err - } - - dst, err = encoding.EncodeValue(dst, value, false) - return err - }) - if err != nil { - return nil, err - } - - return dst, nil -} - -type EncodedObject struct { - encoded []byte - fieldConstraints *FieldConstraints +type EncodedRow struct { + encoded []byte + columnConstraints *ColumnConstraints } -func NewEncodedObject(fcs *FieldConstraints, data []byte) *EncodedObject { - e := EncodedObject{ - fieldConstraints: fcs, - encoded: data, +func NewEncodedRow(ccs *ColumnConstraints, data []byte) *EncodedRow { + e := EncodedRow{ + columnConstraints: ccs, + encoded: data, } return &e } -func (e *EncodedObject) ResetWith(fcs *FieldConstraints, data []byte) { - e.fieldConstraints = fcs +func (e *EncodedRow) ResetWith(ccs *ColumnConstraints, data []byte) { + e.columnConstraints = ccs e.encoded = data } -func (e *EncodedObject) skipToExtra(b []byte) int { - l := len(e.fieldConstraints.Ordered) - - var n int - for i := 0; i < l; i++ { - nn := encoding.Skip(b[n:]) - n += nn +func (e *EncodedRow) decodeValue(fc *ColumnConstraint, b []byte) (types.Value, int, error) { + if b[0] == encoding.NullValue { + return types.NewNullValue(), 1, nil } - return n -} - -func (e *EncodedObject) decodeValue(fc *FieldConstraint, b []byte) (types.Value, int, error) { - c := b[0] - - if fc.Type == types.TypeObject && c == encoding.ArrayValue { - // skip array - after := encoding.SkipArray(b[1:]) - - // skip type - b = b[1:] - - // skip length - _, n := binary.Uvarint(b) - b = b[n:] - - return types.NewObjectValue(NewEncodedObject(&fc.AnonymousType.FieldConstraints, b)), after + 1, nil - } - - v, n := encoding.DecodeValue(b, fc.Type == types.TypeAny || fc.Type == types.TypeArray /* intAsDouble */) - - var err error - v, err = encoding.ConvertFromStoreTo(v, fc.Type) - if err != nil { - return nil, 0, err - } + v, n := fc.Type.Def().Decode(b) return v, n, nil } -// GetByField decodes the selected field from the buffer. -func (e *EncodedObject) GetByField(field string) (v types.Value, err error) { +// Get decodes the selected column from the buffer. +func (e *EncodedRow) Get(column string) (v types.Value, err error) { b := e.encoded - // get the field from the list of field constraints - fc, ok := e.fieldConstraints.ByField[field] - if ok { - // skip all fields before the selected field - for i := 0; i < fc.Position; i++ { - n := encoding.Skip(b) - b = b[n:] - } - - v, _, err = e.decodeValue(fc, b) - return + // get the column from the list of column constraints + cc, ok := e.columnConstraints.ByColumn[column] + if !ok { + return nil, errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) } - // if extra fields are not allowed, return an error - if !e.fieldConstraints.AllowExtraFields { - return nil, errors.Wrapf(types.ErrFieldNotFound, "field %q not found", field) + // skip all columns before the selected column + for i := 0; i < cc.Position; i++ { + n := encoding.Skip(b) + b = b[n:] } - // otherwise, decode the field from the extra fields - n := e.skipToExtra(b) - b = b[n:] - - return encoding.DecodeObject(b, true /* intAsDouble */).GetByField(field) + v, _, err = e.decodeValue(cc, b) + return } // Iterate decodes each columns one by one and passes them to fn // until the end of the row or until fn returns an error. -func (e *EncodedObject) Iterate(fn func(field string, value types.Value) error) error { +func (e *EncodedRow) Iterate(fn func(column string, value types.Value) error) error { b := e.encoded - for _, fc := range e.fieldConstraints.Ordered { + for _, fc := range e.columnConstraints.Ordered { v, n, err := e.decodeValue(fc, b) if err != nil { return err @@ -247,25 +124,35 @@ func (e *EncodedObject) Iterate(fn func(field string, value types.Value) error) b = b[n:] - if v.Type() == types.TypeNull { - continue - } - - err = fn(fc.Field, v) + err = fn(fc.Column, v) if err != nil { return err } } - if !e.fieldConstraints.AllowExtraFields { - return nil - } + return nil +} - return encoding.DecodeObject(b, true /* intAsDouble */).Iterate(func(field string, value types.Value) error { - return fn(field, value) - }) +func (e *EncodedRow) MarshalJSON() ([]byte, error) { + return row.MarshalJSON(e) } -func (e *EncodedObject) MarshalJSON() ([]byte, error) { - return object.MarshalJSON(e) +func RowIsEncoded(r row.Row, ccs *ColumnConstraints) (*EncodedRow, bool) { + br, ok := r.(*BasicRow) + if ok { + r = br.Row + } + ed, ok := r.(*EncodedRow) + if !ok { + return nil, false + } + + // if the pointers are the same, the column constraints are the same + // otherwise it means we created a copy of the constraints and probably + // altered them (ie. ALTER TABLE) + if ed.columnConstraints == ccs { + return ed, true + } + + return nil, false } diff --git a/internal/database/encoding_test.go b/internal/database/encoding_test.go index 77fcbf8a4..341cd3524 100644 --- a/internal/database/encoding_test.go +++ b/internal/database/encoding_test.go @@ -5,7 +5,7 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" @@ -14,139 +14,64 @@ import ( func TestEncoding(t *testing.T) { var ti database.TableInfo - err := ti.AddFieldConstraint(&database.FieldConstraint{ + err := ti.AddColumnConstraint(&database.ColumnConstraint{ Position: 0, - Field: "a", + Column: "a", Type: types.TypeInteger, }) require.NoError(t, err) - err = ti.AddFieldConstraint(&database.FieldConstraint{ + err = ti.AddColumnConstraint(&database.ColumnConstraint{ Position: 1, - Field: "b", + Column: "b", Type: types.TypeText, }) require.NoError(t, err) - err = ti.AddFieldConstraint(&database.FieldConstraint{ + err = ti.AddColumnConstraint(&database.ColumnConstraint{ Position: 2, - Field: "c", + Column: "c", Type: types.TypeDouble, IsNotNull: true, }) require.NoError(t, err) - err = ti.AddFieldConstraint(&database.FieldConstraint{ + err = ti.AddColumnConstraint(&database.ColumnConstraint{ Position: 3, - Field: "d", + Column: "d", Type: types.TypeDouble, DefaultValue: expr.Constraint(testutil.ParseExpr(t, `10`)), }) require.NoError(t, err) - err = ti.AddFieldConstraint(&database.FieldConstraint{ + err = ti.AddColumnConstraint(&database.ColumnConstraint{ Position: 4, - Field: "e", + Column: "e", Type: types.TypeDouble, }) require.NoError(t, err) - ti.FieldConstraints.AllowExtraFields = true - - doc := object.NewFromMap(map[string]any{ - "a": int64(1), - "b": "hello", - "c": float64(3.14), - "e": int64(100), - "f": int64(1000), - "g": float64(2000), - "array": []int{1, 2, 3}, - "doc": object.NewFromMap(map[string]int64{"a": 10}), + r := row.NewFromMap(map[string]any{ + "a": int64(1), + "b": "hello", + "c": float64(3.14), + "e": int64(100), }) var buf []byte - buf, err = ti.EncodeObject(nil, buf, doc) + buf, err = ti.EncodeRow(nil, buf, r) require.NoError(t, err) - d := database.NewEncodedObject(&ti.FieldConstraints, buf) + er := database.NewEncodedRow(&ti.ColumnConstraints, buf) require.NoError(t, err) - want := object.NewFromMap(map[string]any{ - "a": int64(1), - "b": "hello", - "c": float64(3.14), - "d": float64(10), - "e": float64(100), - "f": float64(1000), - "g": float64(2000), - "array": []float64{1, 2, 3}, - "doc": object.NewFromMap(map[string]float64{"a": 10}), + want := row.NewFromMap(map[string]any{ + "a": int64(1), + "b": "hello", + "c": float64(3.14), + "d": float64(10), + "e": float64(100), }) - testutil.RequireObjEqual(t, want, d) - - t.Run("with nested objects", func(t *testing.T) { - var ti database.TableInfo - - // a OBJECT(...) - err := ti.AddFieldConstraint(&database.FieldConstraint{ - Position: 0, - Field: "a", - Type: types.TypeObject, - AnonymousType: &database.AnonymousType{ - FieldConstraints: database.FieldConstraints{ - AllowExtraFields: true, - }, - }, - }) - require.NoError(t, err) - - // b OBJECT(d TEST) - var subfcs database.FieldConstraints - err = subfcs.Add(&database.FieldConstraint{ - Position: 0, - Field: "d", - Type: types.TypeText, - }) - subfcs.AllowExtraFields = true - require.NoError(t, err) - - err = ti.AddFieldConstraint(&database.FieldConstraint{ - Position: 1, - Field: "b", - Type: types.TypeObject, - AnonymousType: &database.AnonymousType{ - FieldConstraints: subfcs, - }, - }) - require.NoError(t, err) - - // c INT - err = ti.AddFieldConstraint(&database.FieldConstraint{ - Position: 2, - Field: "c", - Type: types.TypeInteger, - }) - require.NoError(t, err) - - doc := object.NewFromMap(map[string]any{ - "a": object.WithSortedFields(object.NewFromMap(map[string]any{"w": "hello", "x": int64(1)})), - "b": object.WithSortedFields(object.NewFromMap(map[string]any{"d": "bye", "e": int64(2)})), - "c": int64(100), - }) - - got, err := ti.EncodeObject(nil, nil, doc) - require.NoError(t, err) - - d := database.NewEncodedObject(&ti.FieldConstraints, got) - require.NoError(t, err) - - want := object.NewFromMap(map[string]any{ - "a": object.WithSortedFields(object.NewFromMap(map[string]any{"w": "hello", "x": float64(1)})), - "b": object.WithSortedFields(object.NewFromMap(map[string]any{"d": "bye", "e": float64(2)})), - "c": int64(100), - }) - - testutil.RequireObjEqual(t, want, d) - }) + testutil.RequireRowEqual(t, want, er) } diff --git a/internal/database/index.go b/internal/database/index.go index d76f7b98d..06cc5c84d 100644 --- a/internal/database/index.go +++ b/internal/database/index.go @@ -31,7 +31,7 @@ type Index struct { func NewIndex(tr *tree.Tree, opts IndexInfo) *Index { return &Index{ Tree: tr, - Arity: len(opts.Paths), + Arity: len(opts.Columns), } } diff --git a/internal/database/index_test.go b/internal/database/index_test.go index 47c695ab9..801b3fb1a 100644 --- a/internal/database/index_test.go +++ b/internal/database/index_test.go @@ -6,7 +6,6 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/kv" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/tree" @@ -33,11 +32,11 @@ func getIndex(t testing.TB, arity int) *database.Index { tr := tree.New(session, 10, 0) - var paths []object.Path + var columns []string for i := 0; i < arity; i++ { - paths = append(paths, object.NewPath(fmt.Sprintf("[%d]", i))) + columns = append(columns, fmt.Sprintf("[%d]", i)) } - idx := database.NewIndex(tr, database.IndexInfo{Paths: paths}) + idx := database.NewIndex(tr, database.IndexInfo{Columns: columns}) t.Cleanup(func() { session.Close() diff --git a/internal/database/info.go b/internal/database/info.go index c68a381b7..6a03024e9 100644 --- a/internal/database/info.go +++ b/internal/database/info.go @@ -3,10 +3,10 @@ package database import ( "fmt" "math" + "slices" "strconv" "strings" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/stringutil" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" @@ -24,35 +24,35 @@ type TableInfo struct { // Name of the rowid sequence if any. RowidSequenceName string - FieldConstraints FieldConstraints - TableConstraints TableConstraints + ColumnConstraints ColumnConstraints + TableConstraints TableConstraints PrimaryKey *PrimaryKey } -func (ti *TableInfo) AddFieldConstraint(newFc *FieldConstraint) error { - if ti.FieldConstraints.ByField == nil { - ti.FieldConstraints.ByField = make(map[string]*FieldConstraint) +func (ti *TableInfo) AddColumnConstraint(newCc *ColumnConstraint) error { + if ti.ColumnConstraints.ByColumn == nil { + ti.ColumnConstraints.ByColumn = make(map[string]*ColumnConstraint) } - return ti.FieldConstraints.Add(newFc) + return ti.ColumnConstraints.Add(newCc) } func (ti *TableInfo) AddTableConstraint(newTc *TableConstraint) error { // ensure the field paths exist - for _, p := range newTc.Paths { - if ti.GetFieldConstraintForPath(p) == nil { - return fmt.Errorf("field %q does not exist for table %q", p, ti.TableName) + for _, c := range newTc.Columns { + if ti.GetColumnConstraint(c) == nil { + return fmt.Errorf("column %q does not exist for table %q", c, ti.TableName) } } // ensure paths are not duplicated // i.e. PRIMARY KEY (a, a) is not allowed m := make(map[string]bool) - for _, p := range newTc.Paths { - ps := p.String() + for _, c := range newTc.Columns { + ps := c if _, ok := m[ps]; ok { - return fmt.Errorf("duplicate path %q for constraint", ps) + return fmt.Errorf("duplicate column %q for constraint", ps) } m[ps] = true } @@ -64,9 +64,9 @@ func (ti *TableInfo) AddTableConstraint(newTc *TableConstraint) error { return fmt.Errorf("multiple primary keys for table %q are not allowed", ti.TableName) } - // add NOT NULL constraint to paths - for _, p := range newTc.Paths { - fc := ti.GetFieldConstraintForPath(p) + // add NOT NULL constraint to columns + for _, p := range newTc.Columns { + fc := ti.GetColumnConstraint(p) fc.IsNotNull = true } @@ -94,14 +94,14 @@ func (ti *TableInfo) AddTableConstraint(newTc *TableConstraint) error { case newTc.Unique: // ensure there is only one unique constraint for the same paths for _, tc := range ti.TableConstraints { - if tc.Unique && tc.Paths.IsEqual(newTc.Paths) { - return errors.Errorf("duplicate UNIQUE table contraint on %q", newTc.Paths) + if tc.Unique && slices.Equal(tc.Columns, newTc.Columns) { + return errors.Errorf("duplicate UNIQUE table contraint on %q", newTc.Columns) } } // generate name if not provided if newTc.Name == "" { - newTc.Name = fmt.Sprintf("%s_%s_unique", ti.TableName, pathsToIndexName(newTc.Paths)) + newTc.Name = fmt.Sprintf("%s_%s_unique", ti.TableName, columnsToIndexName(newTc.Columns)) } default: return errors.New("invalid table constraint") @@ -120,6 +120,27 @@ func (ti *TableInfo) AddTableConstraint(newTc *TableConstraint) error { return nil } +// Validate ensures the constraints are valid. +func (ti *TableInfo) Validate() error { + // ensure the primary key is valid + if ti.PrimaryKey != nil { + if len(ti.PrimaryKey.Columns) != len(ti.PrimaryKey.Types) { + return errors.New("invalid primary key") + } + } + + // ensure the constraints are valid + for _, tc := range ti.TableConstraints { + if tc.Check != nil { + if err := tc.Check.Validate(ti); err != nil { + return err + } + } + } + + return nil +} + func (ti *TableInfo) BuildPrimaryKey() { var pk PrimaryKey @@ -128,11 +149,11 @@ func (ti *TableInfo) BuildPrimaryKey() { continue } - pk.Paths = tc.Paths + pk.Columns = tc.Columns pk.SortOrder = tc.SortOrder - for _, pp := range tc.Paths { - fc := ti.GetFieldConstraintForPath(pp) + for _, pp := range tc.Columns { + fc := ti.GetColumnConstraint(pp) if fc != nil { pk.Types = append(pk.Types, fc.Type) } else { @@ -152,8 +173,8 @@ func (ti *TableInfo) PrimaryKeySortOrder() tree.SortOrder { return ti.PrimaryKey.SortOrder } -func (ti *TableInfo) GetFieldConstraintForPath(p object.Path) *FieldConstraint { - return ti.FieldConstraints.GetFieldConstraintForPath(p) +func (ti *TableInfo) GetColumnConstraint(column string) *ColumnConstraint { + return ti.ColumnConstraints.GetColumnConstraint(column) } func (ti *TableInfo) EncodeKey(key *tree.Key) ([]byte, error) { @@ -169,62 +190,48 @@ func (ti *TableInfo) EncodeKey(key *tree.Key) ([]byte, error) { func (ti *TableInfo) String() string { var s strings.Builder - fmt.Fprintf(&s, "CREATE TABLE %s", stringutil.NormalizeIdentifier(ti.TableName, '`')) - if len(ti.FieldConstraints.Ordered) > 0 || len(ti.TableConstraints) > 0 || ti.FieldConstraints.AllowExtraFields { - s.WriteString(" (") - } + fmt.Fprintf(&s, "CREATE TABLE %s (", stringutil.NormalizeIdentifier(ti.TableName, '`')) - var hasConstraints bool - for i, fc := range ti.FieldConstraints.Ordered { + for i, fc := range ti.ColumnConstraints.Ordered { if i > 0 { s.WriteString(", ") } s.WriteString(fc.String()) - - hasConstraints = true } for i, tc := range ti.TableConstraints { - if i > 0 || hasConstraints { + if i == 0 && len(ti.ColumnConstraints.Ordered) > 0 { s.WriteString(", ") } - - s.WriteString(tc.String()) - hasConstraints = true - } - - if ti.FieldConstraints.AllowExtraFields { - if hasConstraints { + if i > 0 { s.WriteString(", ") } - s.WriteString("...") - hasConstraints = true - } - if hasConstraints { - s.WriteString(")") + s.WriteString(tc.String()) } + s.WriteString(")") + return s.String() } // Clone creates another tableInfo with the same values. func (ti *TableInfo) Clone() *TableInfo { cp := *ti - cp.FieldConstraints.Ordered = nil - cp.FieldConstraints.ByField = make(map[string]*FieldConstraint) + cp.ColumnConstraints.Ordered = nil + cp.ColumnConstraints.ByColumn = make(map[string]*ColumnConstraint) cp.TableConstraints = nil - cp.FieldConstraints.Ordered = append(cp.FieldConstraints.Ordered, ti.FieldConstraints.Ordered...) - for i := range ti.FieldConstraints.Ordered { - cp.FieldConstraints.ByField[ti.FieldConstraints.Ordered[i].Field] = ti.FieldConstraints.Ordered[i] + cp.ColumnConstraints.Ordered = append(cp.ColumnConstraints.Ordered, ti.ColumnConstraints.Ordered...) + for i := range ti.ColumnConstraints.Ordered { + cp.ColumnConstraints.ByColumn[ti.ColumnConstraints.Ordered[i].Column] = ti.ColumnConstraints.Ordered[i] } cp.TableConstraints = append(cp.TableConstraints, ti.TableConstraints...) return &cp } type PrimaryKey struct { - Paths object.Paths + Columns []string Types []types.Type SortOrder tree.SortOrder } @@ -234,7 +241,7 @@ type IndexInfo struct { // namespace of the store associated with the index. StoreNamespace tree.Namespace IndexName string - Paths []object.Path + Columns []string // Sort order of each indexed field. KeySortOrder tree.SortOrder @@ -259,13 +266,13 @@ func (idx *IndexInfo) String() string { fmt.Fprintf(&s, "INDEX %s ON %s (", stringutil.NormalizeIdentifier(idx.IndexName, '`'), stringutil.NormalizeIdentifier(idx.Owner.TableName, '`')) - for i, p := range idx.Paths { + for i, p := range idx.Columns { if i > 0 { s.WriteString(", ") } - // Path - s.WriteString(p.String()) + // Column + s.WriteString(p) if idx.KeySortOrder.IsDesc(i) { s.WriteString(" DESC") @@ -281,10 +288,8 @@ func (idx *IndexInfo) String() string { func (i IndexInfo) Clone() *IndexInfo { c := i - c.Paths = make([]object.Path, len(i.Paths)) - for i, p := range i.Paths { - c.Paths[i] = p.Clone() - } + c.Columns = make([]string, len(i.Columns)) + copy(c.Columns, i.Columns) return &c } @@ -348,5 +353,5 @@ func (s SequenceInfo) Clone() *SequenceInfo { // path must also be filled. type Owner struct { TableName string - Paths object.Paths + Columns []string } diff --git a/internal/database/iteration.go b/internal/database/iteration.go index c7ca8f264..946a6e605 100644 --- a/internal/database/iteration.go +++ b/internal/database/iteration.go @@ -1,10 +1,6 @@ package database import ( - "math" - - "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" ) @@ -17,29 +13,14 @@ type Range struct { Exact bool } -func (r *Range) ToTreeRange(constraints *FieldConstraints, paths []object.Path) (*tree.Range, error) { +func (r *Range) ToTreeRange(constraints *ColumnConstraints, columns []string) (*tree.Range, error) { var rng tree.Range - var err error if len(r.Min) > 0 { - for i := range r.Min { - r.Min[i], err = r.Convert(constraints, r.Min[i], paths[i], true) - if err != nil { - return nil, err - } - } - rng.Min = tree.NewKey(r.Min...) } if len(r.Max) > 0 { - for i := range r.Max { - r.Max[i], err = r.Convert(constraints, r.Max[i], paths[i], false) - if err != nil { - return nil, err - } - } - rng.Max = tree.NewKey(r.Max...) } @@ -60,55 +41,6 @@ func (r *Range) ToTreeRange(constraints *FieldConstraints, paths []object.Path) return &rng, nil } -func (r *Range) Convert(constraints *FieldConstraints, v types.Value, p object.Path, isMin bool) (types.Value, error) { - // ensure the operand satisfies all the constraints, index can work only on exact types. - // if a number is encountered, try to convert it to the right type if and only if the conversion - // is lossless. - // if a timestamp is encountered, ensure the field constraint is also a timestamp, otherwise convert it to text. - v, err := constraints.ConvertValueAtPath(p, v, func(v types.Value, path object.Path, targetType types.Type) (types.Value, error) { - if v.Type() == types.TypeDouble && targetType == types.TypeInteger { - f := types.AsFloat64(v) - if float64(int64(f)) == f { - return object.CastAsInteger(v) - } - - if r.Exact { - return v, nil - } - - // we want to convert a non rounded double to int in a way that preserves - // comparison logic with the index. ex: - // a > 1.1 -> a >= 2; exclusive -> false - // a >= 1.1 -> a >= 2; exclusive -> false - // a < 1.1 -> a < 2; exclusive -> true - // a <= 1.1 -> a < 2; exclusive -> true - // a BETWEEN 1.1 AND 2.2 -> a >= 2 AND a <= 3; exclusive -> false - - // First, we need to ceil the number. Ex: 1.1 -> 2 - v = types.NewIntegerValue(int64(math.Ceil(f))) - - // Next, we need to convert the boundaries - if isMin { - // (a > 1.1) or (a >= 1.1) must be transformed to (a >= 2) - r.Exclusive = false - } else { - // (a < 1.1) or (a <= 1.1) must be transformed to (a < 2) - // But there is an exception: if we are dealing with both min - // and max boundaries, we are operating a BETWEEN operation, - // meaning that we need to convert a BETWEEN 1.1 AND 2.2 to a >= 2 AND a <= 3, - // and thus have to set exclusive to false. - r.Exclusive = r.Min == nil || len(r.Min) == 0 - } - } else { - return encoding.ConvertAsIndexType(v, targetType) - } - - return v, nil - }) - - return v, err -} - func (r *Range) IsEqual(other *Range) bool { if r.Exact != other.Exact { return false diff --git a/internal/database/row.go b/internal/database/row.go index b9ab2439e..096297482 100644 --- a/internal/database/row.go +++ b/internal/database/row.go @@ -1,6 +1,7 @@ package database import ( + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" ) @@ -22,14 +23,12 @@ type Row interface { // Key returns the row key. Key() *tree.Key - - Object() types.Object } var _ Row = (*LazyRow)(nil) -// LazyRow holds an LazyRow key and lazily loads the LazyRow on demand when the Iterate or GetByField method is called. -// It implements the Row and the object.Keyer interfaces. +// LazyRow holds an LazyRow key and lazily loads the LazyRow on demand when the Iterate or Get method is called. +// It implements the Row and the row.Keyer interfaces. type LazyRow struct { key *tree.Key table *Table @@ -55,10 +54,6 @@ func (r *LazyRow) Iterate(fn func(name string, value types.Value) error) error { } func (r *LazyRow) Get(name string) (types.Value, error) { - return r.GetByField(name) -} - -func (r *LazyRow) GetByField(field string) (types.Value, error) { var err error if r.row == nil { r.row, err = r.table.GetRow(r.key) @@ -67,7 +62,7 @@ func (r *LazyRow) GetByField(field string) (types.Value, error) { } } - return r.row.Get(field) + return r.row.Get(name) } func (r *LazyRow) MarshalJSON() ([]byte, error) { @@ -90,48 +85,24 @@ func (r *LazyRow) TableName() string { return r.table.Info.TableName } -func (r *LazyRow) Object() types.Object { - if r.row == nil { - var err error - r.row, err = r.table.GetRow(r.key) - if err != nil { - panic(err) - } - } - - return r.row.Object() -} - var _ Row = (*BasicRow)(nil) type BasicRow struct { + row.Row tableName string key *tree.Key - obj types.Object } -func NewBasicRow(obj types.Object) *BasicRow { +func NewBasicRow(r row.Row) *BasicRow { return &BasicRow{ - obj: obj, + Row: r, } } -func (r *BasicRow) ResetWith(tableName string, key *tree.Key, obj types.Object) { +func (r *BasicRow) ResetWith(tableName string, key *tree.Key, rr row.Row) { r.tableName = tableName r.key = key - r.obj = obj -} - -func (r *BasicRow) Iterate(fn func(name string, value types.Value) error) error { - return r.obj.Iterate(fn) -} - -func (r *BasicRow) Get(name string) (types.Value, error) { - return r.obj.GetByField(name) -} - -func (r *BasicRow) MarshalJSON() ([]byte, error) { - return r.obj.(interface{ MarshalJSON() ([]byte, error) }).MarshalJSON() + r.Row = rr } func (r *BasicRow) Key() *tree.Key { @@ -142,10 +113,6 @@ func (r *BasicRow) TableName() string { return r.tableName } -func (r *BasicRow) Object() types.Object { - return r.obj -} - type RowIterator interface { // Iterate goes through all the rows of the table and calls the given function by passing each one of them. // If the given function returns an error, the iteration stops. diff --git a/internal/database/sequence.go b/internal/database/sequence.go index 7b1acc953..e4b3b3150 100644 --- a/internal/database/sequence.go +++ b/internal/database/sequence.go @@ -5,7 +5,7 @@ import ( "strings" errs "github.com/chaisql/chai/internal/errors" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -15,24 +15,24 @@ var sequenceTableInfo = func() *TableInfo { info := &TableInfo{ TableName: SequenceTableName, StoreNamespace: SequenceTableNamespace, - FieldConstraints: MustNewFieldConstraints( - &FieldConstraint{ + ColumnConstraints: MustNewColumnConstraints( + &ColumnConstraint{ Position: 0, - Field: "name", + Column: "name", Type: types.TypeText, IsNotNull: true, }, - &FieldConstraint{ + &ColumnConstraint{ Position: 1, - Field: "seq", - Type: types.TypeInteger, + Column: "seq", + Type: types.TypeBigint, }, ), TableConstraints: []*TableConstraint{ { Name: SequenceTableName + "_pk", - Paths: []object.Path{ - object.NewPath("name"), + Columns: []string{ + "name", }, PrimaryKey: true, }, @@ -85,7 +85,7 @@ func (s *Sequence) Init(tx *Transaction) error { return err } - _, _, err = tb.Insert(object.NewFieldBuffer().Add("name", types.NewTextValue(s.Info.Name))) + _, _, err = tb.Insert(row.NewColumnBuffer().Add("name", types.NewTextValue(s.Info.Name))) return err } @@ -180,9 +180,9 @@ func (s *Sequence) SetLease(tx *Transaction, name string, v int64) error { k := s.key() _, err = tb.Put(k, - object.NewFieldBuffer(). + row.NewColumnBuffer(). Add("name", types.NewTextValue(name)). - Add("seq", types.NewIntegerValue(v)), + Add("seq", types.NewBigintValue(v)), ) return err } @@ -216,9 +216,9 @@ func (s *Sequence) SetName(name string) { func (s *Sequence) GenerateBaseName() string { var sb strings.Builder sb.WriteString(s.Info.Owner.TableName) - if len(s.Info.Owner.Paths) > 0 { + if len(s.Info.Owner.Columns) > 0 { sb.WriteString("_") - sb.WriteString(s.Info.Owner.Paths.String()) + sb.WriteString(strings.Join(s.Info.Owner.Columns, "_")) } sb.WriteString("_seq") return sb.String() diff --git a/internal/database/sequence_test.go b/internal/database/sequence_test.go index ed20b2443..9a2e394d0 100644 --- a/internal/database/sequence_test.go +++ b/internal/database/sequence_test.go @@ -24,7 +24,7 @@ func getLease(t testing.TB, tx *database.Transaction, catalog *database.Catalog, } v, err := d.Get("seq") - if errors.Is(err, types.ErrFieldNotFound) { + if errors.Is(err, types.ErrColumnNotFound) { return nil, nil } if err != nil { diff --git a/internal/database/table.go b/internal/database/table.go index 9477b13a4..542eeb145 100644 --- a/internal/database/table.go +++ b/internal/database/table.go @@ -5,7 +5,7 @@ import ( "github.com/chaisql/chai/internal/engine" errs "github.com/chaisql/chai/internal/errors" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -28,20 +28,20 @@ func (t *Table) Truncate() error { // Insert the object into the table. // If a primary key has been specified during the table creation, the field is expected to be present -// in the given object. +// in the given row. // If no primary key has been selected, a monotonic autoincremented integer key will be generated. // It returns the inserted object alongside its key. -func (t *Table) Insert(o types.Object) (*tree.Key, Row, error) { +func (t *Table) Insert(r row.Row) (*tree.Key, Row, error) { if t.Info.ReadOnly { return nil, nil, errors.New("cannot write to read-only table") } - key, isRowid, err := t.generateKey(t.Info, o) + key, isRowid, err := t.generateKey(t.Info, r) if err != nil { return nil, nil, err } - o, enc, err := t.encodeObject(o) + r, enc, err := t.encodeRow(r) if err != nil { return nil, nil, err } @@ -58,7 +58,7 @@ func (t *Table) Insert(o types.Object) (*tree.Key, Row, error) { if errors.Is(err, engine.ErrKeyAlreadyExists) { return nil, nil, &ConstraintViolationError{ Constraint: "PRIMARY KEY", - Paths: t.Info.PrimaryKey.Paths, + Columns: t.Info.PrimaryKey.Columns, Key: key, } } @@ -68,24 +68,24 @@ func (t *Table) Insert(o types.Object) (*tree.Key, Row, error) { return key, &BasicRow{ tableName: t.Info.TableName, - obj: o, + Row: r, key: key, }, nil } -func (t *Table) encodeObject(o types.Object) (types.Object, []byte, error) { - ed, ok := o.(*EncodedObject) +func (t *Table) encodeRow(r row.Row) (row.Row, []byte, error) { + ed, ok := r.(*EncodedRow) // pointer comparison is enough here - if ok && ed.fieldConstraints == &t.Info.FieldConstraints { - return o, ed.encoded, nil + if ok && ed.columnConstraints == &t.Info.ColumnConstraints { + return r, ed.encoded, nil } - dst, err := t.Info.EncodeObject(t.Tx, nil, o) + dst, err := t.Info.EncodeRow(t.Tx, nil, r) if err != nil { return nil, nil, err } - return NewEncodedObject(&t.Info.FieldConstraints, dst), dst, nil + return NewEncodedRow(&t.Info.ColumnConstraints, dst), dst, nil } // Delete a object by key. @@ -96,7 +96,7 @@ func (t *Table) Delete(key *tree.Key) error { err := t.Tree.Delete(key) if errors.Is(err, engine.ErrKeyNotFound) { - return errors.WithStack(errs.NewNotFoundError(key.String())) + return errs.NewNotFoundError(key.String()) } return err @@ -104,7 +104,7 @@ func (t *Table) Delete(key *tree.Key) error { // Replace a row by key. // An error is returned if the key doesn't exist. -func (t *Table) Replace(key *tree.Key, o types.Object) (Row, error) { +func (t *Table) Replace(key *tree.Key, r row.Row) (Row, error) { if t.Info.ReadOnly { return nil, errors.New("cannot write to read-only table") } @@ -118,16 +118,16 @@ func (t *Table) Replace(key *tree.Key, o types.Object) (Row, error) { return nil, errors.Wrapf(errs.NewNotFoundError(key.String()), "can't replace key %q", key) } - return t.Put(key, o) + return t.Put(key, r) } // Put a row by key. If the key doesn't exist, it is created. -func (t *Table) Put(key *tree.Key, o types.Object) (Row, error) { +func (t *Table) Put(key *tree.Key, r row.Row) (Row, error) { if t.Info.ReadOnly { return nil, errors.New("cannot write to read-only table") } - o, enc, err := t.encodeObject(o) + r, enc, err := t.encodeRow(r) if err != nil { return nil, err } @@ -136,35 +136,35 @@ func (t *Table) Put(key *tree.Key, o types.Object) (Row, error) { err = t.Tree.Put(key, enc) return &BasicRow{ tableName: t.Info.TableName, - obj: o, + Row: r, key: key, }, err } func (t *Table) IterateOnRange(rng *Range, reverse bool, fn func(key *tree.Key, r Row) error) error { - var paths []object.Path + var columns []string pk := t.Info.PrimaryKey if pk != nil { - paths = pk.Paths + columns = pk.Columns } var r *tree.Range var err error if rng != nil { - r, err = rng.ToTreeRange(&t.Info.FieldConstraints, paths) + r, err = rng.ToTreeRange(&t.Info.ColumnConstraints, columns) if err != nil { return err } } - e := EncodedObject{ - fieldConstraints: &t.Info.FieldConstraints, + e := EncodedRow{ + columnConstraints: &t.Info.ColumnConstraints, } row := BasicRow{ tableName: t.Info.TableName, - obj: &e, + Row: &e, } return t.Tree.IterateOnRange(r, reverse, func(k *tree.Key, enc []byte) error { @@ -179,14 +179,14 @@ func (t *Table) GetRow(key *tree.Key) (Row, error) { enc, err := t.Tree.Get(key) if err != nil { if errors.Is(err, engine.ErrKeyNotFound) { - return nil, errors.WithStack(errs.NewNotFoundError(key.String())) + return nil, errs.NewNotFoundError(key.String()) } return nil, fmt.Errorf("failed to fetch row %q: %w", key, err) } return &BasicRow{ tableName: t.Info.TableName, - obj: NewEncodedObject(&t.Info.FieldConstraints, enc), + Row: NewEncodedRow(&t.Info.ColumnConstraints, enc), key: key, }, nil } @@ -198,13 +198,13 @@ func (t *Table) GetRow(key *tree.Key) (Row, error) { // if there are no primary key in the table, a default // key is generated, called the rowid. // It returns a boolean indicating whether the key is a rowid or not. -func (t *Table) generateKey(info *TableInfo, o types.Object) (*tree.Key, bool, error) { +func (t *Table) generateKey(info *TableInfo, r row.Row) (*tree.Key, bool, error) { if pk := t.Info.PrimaryKey; pk != nil { - vs := make([]types.Value, 0, len(pk.Paths)) - for _, p := range pk.Paths { - v, err := p.GetValueFromObject(o) - if errors.Is(err, types.ErrFieldNotFound) { - return nil, false, fmt.Errorf("missing primary key at path %q", p) + vs := make([]types.Value, 0, len(pk.Columns)) + for _, c := range pk.Columns { + v, err := r.Get(c) + if errors.Is(err, types.ErrColumnNotFound) { + return nil, false, fmt.Errorf("missing primary key at path %q", c) } if err != nil { return nil, false, err @@ -225,5 +225,5 @@ func (t *Table) generateKey(info *TableInfo, o types.Object) (*tree.Key, bool, e return nil, true, err } - return tree.NewKey(types.NewIntegerValue(rowid)), true, nil + return tree.NewKey(types.NewBigintValue(rowid)), true, nil } diff --git a/internal/database/table_test.go b/internal/database/table_test.go index 105d3a5f0..f6db7d9a0 100644 --- a/internal/database/table_test.go +++ b/internal/database/table_test.go @@ -7,8 +7,8 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/database/catalogstore" errs "github.com/chaisql/chai/internal/errors" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query/statement" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/tree" @@ -42,8 +42,21 @@ func newTestTable(t testing.TB) (*database.Table, func()) { _, tx, fn := testutil.NewTestTx(t) - ti := database.TableInfo{TableName: "test"} - ti.FieldConstraints.AllowExtraFields = true + ti := database.TableInfo{ + TableName: "test", + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{ + Position: 0, + Column: "a", + Type: types.TypeText, + }, + &database.ColumnConstraint{ + Position: 0, + Column: "b", + Type: types.TypeText, + }, + ), + } return createTable(t, tx, ti), fn } @@ -79,14 +92,14 @@ func createTableIfNotExists(t testing.TB, tx *database.Transaction, info databas return tb } -func newObject() *object.FieldBuffer { - return object.NewFieldBuffer(). - Add("fielda", types.NewTextValue("a")). - Add("fieldb", types.NewTextValue("b")) +func newRow() *row.ColumnBuffer { + return row.NewColumnBuffer(). + Add("a", types.NewTextValue("a")). + Add("b", types.NewTextValue("b")) } -// TestTableGetObject verifies GetObject behaviour. -func TestTableGetObject(t *testing.T) { +// TestTableGetRow verifies GetRow behaviour. +func TestTableGetRow(t *testing.T) { t.Run("Should fail if not found", func(t *testing.T) { tb, cleanup := newTestTable(t) defer cleanup() @@ -96,29 +109,26 @@ func TestTableGetObject(t *testing.T) { require.Nil(t, r) }) - t.Run("Should return the right object", func(t *testing.T) { + t.Run("Should return the right row", func(t *testing.T) { tb, cleanup := newTestTable(t) defer cleanup() - // create two objects, one with an additional field - doc1 := newObject() - vc := types.NewDoubleValue(40) - doc1.Add("fieldc", vc) - doc2 := newObject() + // create two rows + row1 := newRow() + row2 := newRow() + row2.Set("a", types.NewTextValue("c")) - key, _, err := tb.Insert(doc1) + key, _, err := tb.Insert(row1) assert.NoError(t, err) - _, _, err = tb.Insert(doc2) + _, _, err = tb.Insert(row2) assert.NoError(t, err) - // fetch doc1 and make sure it returns the right one + // fetch row1 and make sure it returns the right one res, err := tb.GetRow(key) assert.NoError(t, err) - fc, err := res.Get("fieldc") - assert.NoError(t, err) - ok, err := vc.EQ(fc) + v, err := res.Get("a") assert.NoError(t, err) - require.True(t, ok) + require.Equal(t, "a", types.AsString(v)) }) } @@ -131,19 +141,33 @@ func TestTableInsert(t *testing.T) { }) assert.NoError(t, err) - insertDoc := func(db *database.Database) (rawKey *tree.Key) { + insertRow := func(db *database.Database) (rawKey *tree.Key) { t.Helper() update(t, db, func(tx *database.Transaction) error { t.Helper() // create table if not exists - ti := database.TableInfo{TableName: "test"} - ti.FieldConstraints.AllowExtraFields = true + ti := database.TableInfo{ + TableName: "test", + ColumnConstraints: database.MustNewColumnConstraints( + &database.ColumnConstraint{ + Position: 0, + Column: "a", + Type: types.TypeText, + }, + &database.ColumnConstraint{ + Position: 0, + Column: "b", + Type: types.TypeText, + }, + ), + } + tb := createTableIfNotExists(t, tx, ti) - doc := newObject() - key, _, err := tb.Insert(doc) + r := newRow() + key, _, err := tb.Insert(r) assert.NoError(t, err) require.NotEmpty(t, key) rawKey = key @@ -152,7 +176,7 @@ func TestTableInsert(t *testing.T) { return } - key1 := insertDoc(db1) + key1 := insertRow(db1) err = db1.Close() assert.NoError(t, err) @@ -163,7 +187,7 @@ func TestTableInsert(t *testing.T) { }) assert.NoError(t, err) - key2 := insertDoc(db2) + key2 := insertRow(db2) vs, err := key1.Decode() assert.NoError(t, err) @@ -192,13 +216,13 @@ func TestTableDelete(t *testing.T) { defer cleanup() // create two objects, one with an additional field - doc1 := newObject() - doc1.Add("fieldc", types.NewIntegerValue(40)) - doc2 := newObject() + row1 := newRow() + row1.Add("fieldc", types.NewIntegerValue(40)) + row2 := newRow() - key1, _, err := tb.Insert(testutil.CloneObject(t, doc1)) + key1, _, err := tb.Insert(testutil.CloneRow(t, row1)) assert.NoError(t, err) - key2, _, err := tb.Insert(testutil.CloneObject(t, doc2)) + key2, _, err := tb.Insert(testutil.CloneRow(t, row2)) assert.NoError(t, err) // delete the object @@ -223,47 +247,47 @@ func TestTableReplace(t *testing.T) { tb, cleanup := newTestTable(t) defer cleanup() - _, err := tb.Replace(tree.NewKey(types.NewIntegerValue(10)), newObject()) + _, err := tb.Replace(tree.NewKey(types.NewIntegerValue(10)), newRow()) require.True(t, errs.IsNotFoundError(err)) }) - t.Run("Should replace the right object", func(t *testing.T) { + t.Run("Should replace the right row", func(t *testing.T) { tb, cleanup := newTestTable(t) defer cleanup() // create two different objects - doc1 := newObject() - doc2 := object.NewFieldBuffer(). - Add("fielda", types.NewTextValue("c")). - Add("fieldb", types.NewTextValue("d")) + row1 := newRow() + row2 := row.NewColumnBuffer(). + Add("a", types.NewTextValue("c")). + Add("b", types.NewTextValue("d")) - key1, _, err := tb.Insert(doc1) + key1, _, err := tb.Insert(row1) assert.NoError(t, err) - key2, _, err := tb.Insert(doc2) + key2, _, err := tb.Insert(row2) assert.NoError(t, err) // create a third object - doc3 := object.NewFieldBuffer(). - Add("fielda", types.NewTextValue("e")). - Add("fieldb", types.NewTextValue("f")) + doc3 := row.NewColumnBuffer(). + Add("a", types.NewTextValue("e")). + Add("b", types.NewTextValue("f")) - // replace doc1 with doc3 + // replace row1 with doc3 d3, err := tb.Replace(key1, doc3) assert.NoError(t, err) // make sure it replaced it correctly res, err := tb.GetRow(key1) assert.NoError(t, err) - f, err := res.Get("fielda") + f, err := res.Get("a") assert.NoError(t, err) require.Equal(t, "e", f.V().(string)) - testutil.RequireObjEqual(t, d3.Object(), res.Object()) + testutil.RequireRowEqual(t, d3, res) // make sure it didn't also replace the other one res, err = tb.GetRow(key2) assert.NoError(t, err) - f, err = res.Get("fielda") + f, err = res.Get("a") assert.NoError(t, err) require.Equal(t, "c", f.V().(string)) }) @@ -284,12 +308,12 @@ func TestTableTruncate(t *testing.T) { defer cleanup() // create two objects - doc1 := newObject() - doc2 := newObject() + row1 := newRow() + row2 := newRow() - _, _, err := tb.Insert(doc1) + _, _, err := tb.Insert(row1) assert.NoError(t, err) - _, _, err = tb.Insert(doc2) + _, _, err = tb.Insert(row2) assert.NoError(t, err) err = tb.Truncate() @@ -307,10 +331,10 @@ func TestTableTruncate(t *testing.T) { func BenchmarkTableInsert(b *testing.B) { for size := 1; size <= 10000; size *= 10 { b.Run(fmt.Sprintf("%.05d", size), func(b *testing.B) { - var fb object.FieldBuffer + var fb row.ColumnBuffer for i := int64(0); i < 10; i++ { - fb.Add(fmt.Sprintf("name-%d", i), types.NewIntegerValue(i)) + fb.Add(fmt.Sprintf("name-%d", i), types.NewBigintValue(i)) } b.ResetTimer() @@ -336,10 +360,10 @@ func BenchmarkTableScan(b *testing.B) { tb, cleanup := newTestTable(b) defer cleanup() - var fb object.FieldBuffer + var fb row.ColumnBuffer for i := int64(0); i < 10; i++ { - fb.Add(fmt.Sprintf("name-%d", i), types.NewIntegerValue(i)) + fb.Add(fmt.Sprintf("name-%d", i), types.NewBigintValue(i)) } for i := 0; i < size; i++ { diff --git a/internal/encoding/array.go b/internal/encoding/array.go deleted file mode 100644 index a1aced65e..000000000 --- a/internal/encoding/array.go +++ /dev/null @@ -1,110 +0,0 @@ -package encoding - -import ( - "encoding/binary" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/types" -) - -func EncodeArray(dst []byte, a types.Array) ([]byte, error) { - if a == nil { - dst = EncodeArrayLength(dst, 0) - return dst, nil - } - - l, err := object.ArrayLength(a) - if err != nil { - return nil, err - } - if l == 0 { - return append(dst, byte(ArrayValue), 0), nil - } - - dst = EncodeArrayLength(dst, l) - - err = a.Iterate(func(i int, value types.Value) error { - dst, err = EncodeValue(dst, value, false) - return err - }) - if err != nil { - return nil, err - } - - return dst, nil -} - -func EncodeArrayLength(dst []byte, l int) []byte { - // encode the length as a varint - buf := make([]byte, binary.MaxVarintLen64+1) - buf[0] = ArrayValue - n := binary.PutUvarint(buf[1:], uint64(l)) - return append(dst, buf[:n+1]...) -} - -func DecodeArray(b []byte, intAsDouble bool) types.Array { - return &EncodedArray{ - enc: b[1:], - intAsDouble: intAsDouble, - } -} - -// An EncodedArray implements the types.Array interface on top of an -// encoded representation of an array. -// It is useful for avoiding decoding the entire array when -// only a few values are needed. -type EncodedArray struct { - enc []byte - intAsDouble bool -} - -// Iterate goes through all the values of the array and calls the -// given function by passing each one of them. -// If the given function returns an error, the iteration stops. -func (e *EncodedArray) Iterate(fn func(i int, value types.Value) error) error { - l, n := binary.Uvarint(e.enc) - if l == 0 { - return nil - } - b := e.enc[n:] - - ll := int(l) - for i := 0; i < ll; i++ { - v, n := DecodeValue(b, e.intAsDouble) - b = b[n:] - - err := fn(i, v) - if err != nil { - return err - } - } - - return nil -} - -// GetByIndex returns a value by index of the array. -func (e *EncodedArray) GetByIndex(idx int) (v types.Value, err error) { - l, n := binary.Uvarint(e.enc) - if l == 0 { - return nil, types.ErrValueNotFound - } - b := e.enc[n:] - - ll := int(l) - for i := 0; i < ll; i++ { - if i == idx { - v, _ := DecodeValue(b, e.intAsDouble) - return v, nil - } - - n = Skip(b) - b = b[n:] - } - - err = types.ErrValueNotFound - return -} - -func (e *EncodedArray) MarshalJSON() ([]byte, error) { - return object.MarshalJSONArray(e) -} diff --git a/internal/encoding/array_test.go b/internal/encoding/array_test.go deleted file mode 100644 index 731c75a95..000000000 --- a/internal/encoding/array_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package encoding_test - -import ( - "encoding/binary" - "fmt" - "testing" - - "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func makeUvarint(n int) []byte { - var buf [10]byte - i := binary.PutUvarint(buf[:], uint64(n)) - return buf[:i] -} - -func TestEncodeDecodeArray(t *testing.T) { - tests := []struct { - input types.Array - want []byte - wantArray types.Array - }{ - {testutil.MakeArray(t, `[]`), []byte{byte(encoding.ArrayValue), makeUvarint(0)[0]}, testutil.MakeArray(t, `[]`)}, - {testutil.MakeArray(t, `[1]`), []byte{byte(encoding.ArrayValue), makeUvarint(1)[0], encoding.EncodeInt(nil, 1)[0]}, testutil.MakeArray(t, `[1.0]`)}, - {testutil.MakeArray(t, `[1, []]`), - []byte{ - byte(encoding.ArrayValue), - makeUvarint(2)[0], - encoding.EncodeInt(nil, 1)[0], - byte(encoding.ArrayValue), makeUvarint(0)[0], - }, - testutil.MakeArray(t, `[1.0, []]`), - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%d", test.input), func(t *testing.T) { - got, err := encoding.EncodeArray(nil, test.input) - require.NoError(t, err) - require.Equal(t, test.want, got) - - x := encoding.DecodeArray(got, true) - testutil.RequireArrayEqual(t, test.wantArray, x) - }) - } -} diff --git a/internal/encoding/conversion.go b/internal/encoding/conversion.go deleted file mode 100644 index a53e0949c..000000000 --- a/internal/encoding/conversion.go +++ /dev/null @@ -1,75 +0,0 @@ -package encoding - -import ( - "errors" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/types" -) - -// ConvertFromStoreTo ensures the value read from the store is the same -// as the column type. Most types are stored as is but certain types are -// converted prior to being stored. -// Example: when there is no constraint on the column, integers are stored -// as doubles. -// The given target type is the type of the column, and it is used to determine -// whether there exists a constraint on the column or not. -func ConvertFromStoreTo(src types.Value, target types.Type) (types.Value, error) { - if src.Type() == target { - return src, nil - } - - switch src.Type() { - case types.TypeInteger: - return convertIntegerFromStore(src, target) - default: - // if there is no constraint on the column, then the stored type - // is the same as the runtime type. - return src, nil - } -} - -func convertIntegerFromStore(src types.Value, target types.Type) (types.Value, error) { - switch target { - case types.TypeAny: - return types.NewDoubleValue(float64(types.AsInt64(src))), nil - case types.TypeTimestamp: - return types.NewTimestampValue(ConvertToTimestamp(types.AsInt64(src))), nil - } - - return nil, errors.New("cannot convert from store to " + target.String()) -} - -// ConvertAsStoreType converts the value to the type that is stored in the store -// when there is no constraint on the column. -func ConvertAsStoreType(src types.Value) (types.Value, error) { - switch src.Type() { - case types.TypeTimestamp: - // without a type constraint, timestamp values must - // always be stored as text to avoid mixed representations. - return object.CastAsText(src) - } - - return src, nil -} - -// ConvertAsIndexType converts the value to the type that is stored in the index -// as a key. -func ConvertAsIndexType(src types.Value, target types.Type) (types.Value, error) { - switch src.Type() { - case types.TypeInteger: - if target == types.TypeAny || target == types.TypeDouble { - return object.CastAsDouble(src) - } - return src, nil - case types.TypeTimestamp: - // without a type constraint, timestamp values must - // always be stored as text to avoid mixed representations. - if target == types.TypeAny { - return object.CastAsText(src) - } - return src, nil - } - - return src, nil -} diff --git a/internal/encoding/document.go b/internal/encoding/document.go deleted file mode 100644 index 94819ec10..000000000 --- a/internal/encoding/document.go +++ /dev/null @@ -1,115 +0,0 @@ -package encoding - -import ( - "encoding/binary" - "fmt" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/types" -) - -func EncodeObject(dst []byte, d types.Object) ([]byte, error) { - if d == nil { - dst = EncodeObjectLength(dst, 0) - return dst, nil - } - - l, err := object.Length(d) - if err != nil { - return nil, err - } - - // encode the length as a varint - dst = EncodeObjectLength(dst, l) - - fields := make(map[string]struct{}, l) - - err = d.Iterate(func(k string, v types.Value) error { - if _, ok := fields[k]; ok { - return fmt.Errorf("duplicate field %s", k) - } - fields[k] = struct{}{} - - dst = EncodeText(dst, k) - - dst, err = EncodeValue(dst, v, false) - return err - }) - if err != nil { - return nil, err - } - - return dst, nil -} - -func EncodeObjectLength(dst []byte, l int) []byte { - // encode the length as a varint - buf := make([]byte, binary.MaxVarintLen64+1) - buf[0] = ObjectValue - n := binary.PutUvarint(buf[1:], uint64(l)) - return append(dst, buf[:n+1]...) -} - -func DecodeObject(b []byte, intAsDouble bool) types.Object { - return &EncodedObject{ - Encoded: b[1:], - intAsDouble: intAsDouble, - } -} - -type EncodedObject struct { - Encoded []byte - intAsDouble bool -} - -func (e *EncodedObject) Iterate(fn func(k string, v types.Value) error) error { - l, n := binary.Uvarint(e.Encoded) - if l == 0 { - return nil - } - b := e.Encoded[n:] - - ll := int(l) - for i := 0; i < ll; i++ { - k, n := DecodeText(b) - b = b[n:] - - v, n := DecodeValue(b, e.intAsDouble) - b = b[n:] - - err := fn(k, v) - if err != nil { - return err - } - } - - return nil -} - -func (e *EncodedObject) GetByField(field string) (types.Value, error) { - l, n := binary.Uvarint(e.Encoded) - if l == 0 { - return nil, types.ErrFieldNotFound - } - b := e.Encoded[n:] - - ll := int(l) - for i := 0; i < ll; i++ { - k, n := DecodeText(b) - b = b[n:] - - if k == field { - v, _ := DecodeValue(b, e.intAsDouble) - return v, nil - } - - n = Skip(b) - b = b[n:] - } - - return nil, types.ErrFieldNotFound -} - -func (e *EncodedObject) MarshalJSON() ([]byte, error) { - return object.MarshalJSON(e) -} diff --git a/internal/encoding/document_test.go b/internal/encoding/document_test.go deleted file mode 100644 index e48895654..000000000 --- a/internal/encoding/document_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package encoding_test - -import ( - "fmt" - "testing" - - "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func makeByteSlice(b ...byte) []byte { - return b -} - -func mergeByteSlices(b ...[]byte) []byte { - var out []byte - for _, b := range b { - out = append(out, b...) - } - return out -} - -func TestEncodeDecodeObject(t *testing.T) { - tests := []struct { - input types.Object - want [][]byte - wantDoc types.Object - }{ - {testutil.MakeObject(t, `{}`), [][]byte{{byte(encoding.ObjectValue), makeUvarint(0)[0]}}, testutil.MakeObject(t, `{}`)}, - {testutil.MakeObject(t, `{"a": 1}`), [][]byte{ - makeByteSlice(byte(encoding.ObjectValue)), - makeUvarint(1), - encoding.EncodeText(nil, "a"), - encoding.EncodeInt(nil, 1), - }, testutil.MakeObject(t, `{"a": 1.0}`)}, - {testutil.MakeObject(t, `{"a": {"b": 1}, "c": 1}`), [][]byte{ - makeByteSlice(byte(encoding.ObjectValue)), - makeUvarint(2), - encoding.EncodeText(nil, "a"), makeByteSlice(byte(encoding.ObjectValue)), makeUvarint(1), encoding.EncodeText(nil, "b"), encoding.EncodeInt(nil, 1), - encoding.EncodeText(nil, "c"), encoding.EncodeInt(nil, 1), - }, - testutil.MakeObject(t, `{"a": {"b": 1.0}, "c": 1.0}`), - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.input), func(t *testing.T) { - got, err := encoding.EncodeObject(nil, test.input) - require.NoError(t, err) - - require.Equal(t, mergeByteSlices(test.want...), got) - - x := encoding.DecodeObject(got, true) - testutil.RequireObjEqual(t, test.wantDoc, x) - }) - } -} - -func TestObjectGetByField(t *testing.T) { - tests := []struct { - input types.Object - path object.Path - want types.Value - wantErr error - }{ - {testutil.MakeObject(t, `{}`), object.NewPath("a"), nil, types.ErrFieldNotFound}, - {testutil.MakeObject(t, `{"a": 1}`), object.NewPath("a"), types.NewDoubleValue(1), nil}, - {testutil.MakeObject(t, `{"a": 1}`), object.NewPath("b"), nil, types.ErrFieldNotFound}, - {testutil.MakeObject(t, `{"a": {"b": 1}}`), object.NewPath("a", "b"), types.NewDoubleValue(1), nil}, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.input), func(t *testing.T) { - got, err := encoding.EncodeObject(nil, test.input) - require.NoError(t, err) - - x := encoding.DecodeObject(got, true) - v, err := test.path.GetValueFromObject(x) - if test.wantErr != nil { - require.Equal(t, test.wantErr, err) - } else { - require.NoError(t, err) - require.Equal(t, test.want, v) - } - }) - } -} diff --git a/internal/encoding/encoding.go b/internal/encoding/encoding.go index 0f2aa42a0..45b96fddd 100644 --- a/internal/encoding/encoding.go +++ b/internal/encoding/encoding.go @@ -1,12 +1,5 @@ package encoding -import ( - "fmt" - "time" - - "github.com/chaisql/chai/internal/types" -) - func EncodeBoolean(dst []byte, x bool) []byte { if x { return append(dst, byte(TrueValue)) @@ -16,76 +9,13 @@ func EncodeBoolean(dst []byte, x bool) []byte { } func DecodeBoolean(b []byte) bool { - return b[0] == byte(TrueValue) + return b[0] == byte(TrueValue) || b[0] == byte(DESC_TrueValue) } func EncodeNull(dst []byte) []byte { return append(dst, byte(NullValue)) } -func EncodeValue(dst []byte, v types.Value, desc bool) ([]byte, error) { - newDst, err := encodeValueAsc(dst, v) - if err != nil { - return nil, err - } - - if desc { - newDst, _ = Desc(newDst, len(newDst)-len(dst)) - } - - return newDst, nil -} - -func encodeValueAsc(dst []byte, v types.Value) ([]byte, error) { - if v.V() == nil { - switch v.Type() { - case types.TypeNull: - return EncodeNull(dst), nil - case types.TypeBoolean: - return EncodeBoolean(dst, false), nil - case types.TypeInteger: - return EncodeInt(dst, 0), nil - case types.TypeDouble: - return EncodeFloat64(dst, 0), nil - case types.TypeTimestamp: - return EncodeTimestamp(dst, time.Time{}), nil - case types.TypeText: - return EncodeText(dst, ""), nil - case types.TypeBlob: - return EncodeBlob(dst, nil), nil - case types.TypeArray: - return EncodeArray(dst, nil) - case types.TypeObject: - return EncodeObject(dst, nil) - default: - panic(fmt.Sprintf("unsupported type %v", v.Type())) - } - } - - switch v.Type() { - case types.TypeNull: - return EncodeNull(dst), nil - case types.TypeBoolean: - return EncodeBoolean(dst, types.AsBool(v)), nil - case types.TypeInteger: - return EncodeInt(dst, types.AsInt64(v)), nil - case types.TypeDouble: - return EncodeFloat64(dst, types.AsFloat64(v)), nil - case types.TypeTimestamp: - return EncodeTimestamp(dst, types.AsTime(v)), nil - case types.TypeText: - return EncodeText(dst, types.AsString(v)), nil - case types.TypeBlob: - return EncodeBlob(dst, types.AsByteSlice(v)), nil - case types.TypeArray: - return EncodeArray(dst, types.AsArray(v)) - case types.TypeObject: - return EncodeObject(dst, types.AsObject(v)) - } - - return nil, fmt.Errorf("unsupported value type: %s", v.Type()) -} - // Desc changes the type of the encoded value to its descending counterpart. // It is meant to be used in combination with one of the Encode* functions. // @@ -99,51 +29,3 @@ func Desc(dst []byte, n int) ([]byte, int) { dst[len(dst)-n] = 255 - dst[len(dst)-n] return dst, n } - -func DecodeValue(b []byte, intAsDouble bool) (types.Value, int) { - t := b[0] - // deal with descending values - if t > 128 { - t = 255 - t - } - - if t >= IntSmallValue && t < Uint8Value { - x, n := DecodeInt(b) - if intAsDouble { - return types.NewDoubleValue(float64(x)), n - } - return types.NewIntegerValue(x), n - } - - switch t { - case NullValue: - return types.NewNullValue(), 1 - case FalseValue: - return types.NewBooleanValue(false), 1 - case TrueValue: - return types.NewBooleanValue(true), 1 - case Int8Value, Int16Value, Int32Value, Int64Value, Uint8Value, Uint16Value, Uint32Value, Uint64Value: - x, n := DecodeInt(b) - if intAsDouble { - return types.NewDoubleValue(float64(x)), n - } - return types.NewIntegerValue(x), n - case Float64Value: - x := DecodeFloat64(b[1:]) - return types.NewDoubleValue(x), 9 - case TextValue: - x, n := DecodeText(b) - return types.NewTextValue(x), n - case BlobValue: - x, n := DecodeBlob(b) - return types.NewBlobValue(x), n - case ArrayValue: - a := DecodeArray(b, intAsDouble) - return types.NewArrayValue(a), SkipArray(b[1:]) + 1 - case ObjectValue: - d := DecodeObject(b, intAsDouble) - return types.NewObjectValue(d), SkipObject(b[1:]) + 1 - } - - panic(fmt.Sprintf("unsupported value type: %d", t)) -} diff --git a/internal/encoding/helpers_test.go b/internal/encoding/helpers_test.go index 140bbe7ae..9de7133a8 100644 --- a/internal/encoding/helpers_test.go +++ b/internal/encoding/helpers_test.go @@ -1,6 +1,7 @@ package encoding_test import ( + "encoding/binary" "fmt" "math" "strings" @@ -8,9 +9,9 @@ import ( "github.com/chaisql/chai/internal/encoding" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/tree" + "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) @@ -19,130 +20,97 @@ func TestCompare(t *testing.T) { k1, k2 string cmp int }{ - // empty key - {`[]`, `[]`, 0}, - // null - {`[null]`, `[null]`, 0}, - {`[null]`, `[]`, 1}, - {`[]`, `[null]`, -1}, + {`(null)`, `(null)`, 0}, // booleans - {`[true]`, `[true]`, 0}, - {`[false]`, `[true]`, -1}, - {`[true]`, `[false]`, 1}, - {`[false]`, `[false]`, 0}, + {`(true)`, `(true)`, 0}, + {`(false)`, `(true)`, -1}, + {`(true)`, `(false)`, 1}, + {`(false)`, `(false)`, 0}, // ints - {`[1]`, `[1]`, 0}, - {`[1]`, `[2]`, -1}, - {`[2]`, `[1]`, 1}, - {`[1000000000]`, `[1]`, 33}, - {`[254]`, `[255]`, -1}, // 2x uint8 - {`[255]`, `[254]`, 1}, // 2x uint8 - {`[10000]`, `[10001]`, -1}, // 2x uint16 - {`[10001]`, `[10000]`, 1}, // 2x uint16 - {`[1000000]`, `[1000001]`, -1}, // 2x uint32 - {`[1000001]`, `[1000000]`, 1}, // 2x uint32 - {`[1000000000000000]`, `[1000000000000001]`, -1}, // 2x uint64 - {`[1000000000000001]`, `[1000000000000000]`, 1}, // 2x uint64 - {`[-126]`, `[-127]`, 1}, // 2x int8 - {`[-127]`, `[-126]`, -1}, // 2x int8 - {`[-10000]`, `[-10001]`, 1}, // 2x int16 - {`[-10001]`, `[-10000]`, -1}, // 2x int16 - {`[-1000000]`, `[-1000001]`, 1}, // 2x int32 - {`[-1000001]`, `[-1000000]`, -1}, // 2x int32 - {`[-1000000000000000]`, `[-1000000000000001]`, 1}, // 2x int64 - {`[-1000000000000001]`, `[-1000000000000000]`, -1}, // 2x int64 - {`[-1]`, `[1]`, -2}, // neg fixint < fixuint - {`[1]`, `[31]`, -30}, // neg fixint < fixuint - {`[-127]`, `[1]`, -34}, // int8 < fixuint - {`[-10000]`, `[1]`, -35}, // int16 < fixuint - {`[-1000000]`, `[1]`, -36}, // int32 < fixuint - {`[-1000000000000000]`, `[1]`, -37}, // int64 < fixuint - {`[-127]`, `[255]`, -65}, // int8 < uint8 - {`[-60000]`, `[255]`, -67}, // int16 < uint8 - {`[-1000000]`, `[255]`, -67}, // int32 < uint8 - {`[-1000000000000000]`, `[255]`, -68}, // int64 < uint8 + {`(1)`, `(1)`, 0}, + {`(1)`, `(2)`, -1}, + {`(2)`, `(1)`, 1}, + {`(1000000000)`, `(1)`, 33}, + {`(254)`, `(255)`, -1}, // 2x uint8 + {`(255)`, `(254)`, 1}, // 2x uint8 + {`(10000)`, `(10001)`, -1}, // 2x uint16 + {`(10001)`, `(10000)`, 1}, // 2x uint16 + {`(1000000)`, `(1000001)`, -1}, // 2x uint32 + {`(1000001)`, `(1000000)`, 1}, // 2x uint32 + {`(1000000000000000)`, `(1000000000000001)`, -1}, // 2x uint64 + {`(1000000000000001)`, `(1000000000000000)`, 1}, // 2x uint64 + {`(-126)`, `(-127)`, 1}, // 2x int8 + {`(-127)`, `(-126)`, -1}, // 2x int8 + {`(-10000)`, `(-10001)`, 1}, // 2x int16 + {`(-10001)`, `(-10000)`, -1}, // 2x int16 + {`(-1000000)`, `(-1000001)`, 1}, // 2x int32 + {`(-1000001)`, `(-1000000)`, -1}, // 2x int32 + {`(-1000000000000000)`, `(-1000000000000001)`, 1}, // 2x int64 + {`(-1000000000000001)`, `(-1000000000000000)`, -1}, // 2x int64 + {`(-1)`, `(1)`, -2}, // neg fixint < fixuint + {`(1)`, `(31)`, -30}, // neg fixint < fixuint + {`(-127)`, `(1)`, -34}, // int8 < fixuint + {`(-10000)`, `(1)`, -35}, // int16 < fixuint + {`(-1000000)`, `(1)`, -36}, // int32 < fixuint + {`(-1000000000000000)`, `(1)`, -37}, // int64 < fixuint + {`(-127)`, `(255)`, -65}, // int8 < uint8 + {`(-60000)`, `(255)`, -67}, // int16 < uint8 + {`(-1000000)`, `(255)`, -67}, // int32 < uint8 + {`(-1000000000000000)`, `(255)`, -68}, // int64 < uint8 // floats - {`[1.0]`, `[1.0]`, 0}, - {`[1.1]`, `[1.0]`, 1}, - {`[1.0]`, `[1.1]`, -1}, - {`[-1.0]`, `[-1.1]`, 1}, + {`(1.0)`, `(1.0)`, 0}, + {`(1.1)`, `(1.0)`, 1}, + {`(1.0)`, `(1.1)`, -1}, + {`(-1.0)`, `(-1.1)`, 1}, // doubles - {`[1e50]`, `[1e50]`, 0}, - {`[1e51]`, `[1e50]`, 1}, - {`[1e50]`, `[1e51]`, -1}, - {`[-1e50]`, `[-1e51]`, 1}, + {`(1e50)`, `(1e50)`, 0}, + {`(1e51)`, `(1e50)`, 1}, + {`(1e50)`, `(1e51)`, -1}, + {`(-1e50)`, `(-1e51)`, 1}, // floats and doubles - {`[1.0]`, `[1e50]`, -1}, - {`[1e50]`, `[1.0]`, 1}, + {`(1.0)`, `(1e50)`, -1}, + {`(1e50)`, `(1.0)`, 1}, // text - {`["a"]`, `["a"]`, 0}, - {`["b"]`, `["a"]`, 1}, - {`["a"]`, `["b"]`, -1}, - {`["a"]`, `["aa"]`, -1}, - {`["aaaa"]`, `["aab"]`, -1}, + {`("a")`, `("a")`, 0}, + {`("b")`, `("a")`, 1}, + {`("a")`, `("b")`, -1}, + {`("a")`, `("aa")`, -1}, + {`("aaaa")`, `("aab")`, -1}, // blob - {`["\xaa"]`, `["\xaa"]`, 0}, - {`["\xab"]`, `["\xaa"]`, 1}, - {`["\xaa"]`, `["\xab"]`, -1}, - {`["\xaa"]`, `["\xaaaa"]`, -1}, - - // arrays - {`[[]]`, `[[]]`, 0}, - {`[[1]]`, `[[]]`, 1}, - {`[[]]`, `[[1]]`, -1}, - {`[[1]]`, `[[1]]`, 0}, - {`[[1]]`, `[[1, 1]]`, -1}, - // array with more than 128 elements, i.e. 2 bytes for length - {`[[1, 2, 3]]`, `[[1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]`, -1}, - - // maps - {`[{"a": 2}]`, `[{"a": 2}]`, 0}, - {`[{"a": 1}]`, `[{"a": 2}]`, -1}, - {`[{"a": 2}]`, `[{"a": 1}]`, 1}, - {`[{"a": 1}]`, `[{"b": 1}]`, -1}, - {`[{"b": 1}]`, `[{"a": 1}]`, 1}, - {`[{"a": 1}]`, `[{"a": 1, "b": 1}]`, -1}, - {`[{"a": 1, "b": 1}]`, `[{"a": 1}]`, 1}, - {`[{"a": 1, "b": 1}]`, `[{"a": 1, "b": 1}]`, 0}, - {`[{"a": 1, "b": 1}]`, `[{"a": 1, "b": 2}]`, -1}, - {`[{"a": 1, "b": 2}]`, `[{"a": 1, "b": 1}]`, 1}, - {`[{"a": {"c": [1, 2]}}]`, `[{"a": {"c": [1, 2]}}]`, 0}, - {`[{"a": {"c": [1, 3]}}]`, `[{"a": {"c": [1, 2]}}]`, 1}, - {`[{"a": {"c": []}}]`, `[{"a": {"c": [1, 2]}}]`, -1}, + {`("\xaa")`, `("\xaa")`, 0}, + {`("\xab")`, `("\xaa")`, 1}, + {`("\xaa")`, `("\xab")`, -1}, + {`("\xaa")`, `("\xaaaa")`, -1}, // different types - {`[null]`, `[true]`, -4}, - {`[true]`, `[1]`, -43}, - {`[1]`, `[1.0]`, -41}, - {`[1.0]`, `["a"]`, -8}, - {`["a"]`, `["\x00"]`, -5}, - {`["\x00"]`, `[[]]`, -7}, - {`[[]]`, `[{}]`, -10}, + {`(null)`, `(true)`, -4}, + {`(true)`, `(1)`, -43}, + {`(1)`, `(1.0)`, -41}, + {`(1.0)`, `("a")`, -8}, + {`("a")`, `("\x00")`, -5}, // consecutive values - {`[1, 2, 3]`, `[1, 2, 3]`, 0}, - {`[1, 2, 3]`, `[1, 2, 4]`, -1}, - {`[1, 2, 3]`, `[1, 2, 3, 4]`, -1}, + {`(1, 2, 3)`, `(1, 2, 3)`, 0}, + {`(1, 2, 3)`, `(1, 2, 4)`, -1}, + {`(1, 2, 3)`, `(1, 2, 3, 4)`, -1}, // consecutive mixed values - {`[1, true, 3.4, []]`, `[1, true, 3.4, []]`, 0}, + {`(1, true, 3.4)`, `(1, true, 3.4)`, 0}, } for _, test := range tests { t.Run(fmt.Sprintf("Compare(%s, %s)", test.k1, test.k2), func(t *testing.T) { - v1, err := testutil.ParseExpr(t, test.k1).Eval(&environment.Environment{}) + a1, err := testutil.ParseExprList(t, test.k1).EvalAll(&environment.Environment{}) require.NoError(t, err) - a1 := v1.V().(*object.ValueBuffer).Values k1 := mustNewKey(t, 0, 0, a1...) - v2, err := testutil.ParseExpr(t, test.k2).Eval(&environment.Environment{}) + a2, err := testutil.ParseExprList(t, test.k2).EvalAll(&environment.Environment{}) require.NoError(t, err) - a2 := v2.V().(*object.ValueBuffer).Values k2 := mustNewKey(t, 0, 0, a2...) require.Equal(t, test.cmp, encoding.Compare(k1, k2)) @@ -172,21 +140,20 @@ func TestCompareOrder(t *testing.T) { order []bool }{ // null - {`[null]`, `[null]`, 0, []bool{false}}, - {`[null]`, `[null]`, 0, []bool{true}}, - {`[1]`, `[2]`, -1, []bool{false}}, - {`[1]`, `[2]`, 1, []bool{true}}, - {`[60]`, `[30]`, 2, []bool{false}}, - {`[60]`, `[30]`, -2, []bool{true}}, - {`[30, "hello"]`, `[30, "bye"]`, 1, []bool{true, false}}, - {`[30, "hello"]`, `[30, "bye"]`, -1, []bool{true, true}}, + {`(null)`, `(null)`, 0, []bool{false}}, + {`(null)`, `(null)`, 0, []bool{true}}, + {`(1)`, `(2)`, -1, []bool{false}}, + {`(1)`, `(2)`, 1, []bool{true}}, + {`(60)`, `(30)`, 2, []bool{false}}, + {`(60)`, `(30)`, -2, []bool{true}}, + {`(30, "hello")`, `(30, "bye")`, 1, []bool{true, false}}, + {`(30, "hello")`, `(30, "bye")`, -1, []bool{true, true}}, } for _, test := range tests { t.Run(fmt.Sprintf("CompareOrder(%s, %s)", test.k1, test.k2), func(t *testing.T) { - v1, err := testutil.ParseExpr(t, test.k1).Eval(&environment.Environment{}) + a1, err := testutil.ParseExprList(t, test.k1).EvalAll(&environment.Environment{}) require.NoError(t, err) - a1 := v1.V().(*object.ValueBuffer).Values order := tree.SortOrder(0) for i := range a1 { if test.order[i] { @@ -196,9 +163,8 @@ func TestCompareOrder(t *testing.T) { k1 := mustNewKey(t, 0, order, a1...) - v2, err := testutil.ParseExpr(t, test.k2).Eval(&environment.Environment{}) + a2, err := testutil.ParseExprList(t, test.k2).EvalAll(&environment.Environment{}) require.NoError(t, err) - a2 := v2.V().(*object.ValueBuffer).Values k2 := mustNewKey(t, 0, order, a2...) require.Equal(t, test.cmp, encoding.Compare(k1, k2)) @@ -230,71 +196,59 @@ func TestAbbreviatedKey(t *testing.T) { k string want uint64 }{ - // empty key - {`[]`, 0}, // namespace only - {`[1]`, 0b_0000000000000001_00000000_0000000000000000000000000000000000000000}, - {`[400]`, 0b_0000000110010000_00000000_0000000000000000000000000000000000000000}, - {`[1000000]`, 0b_1111111111111111_00000000_0000000000000000000000000000000000000000}, // > 1 << 16 + {`(1)`, 0b_0000000000000001_00000000_0000000000000000000000000000000000000000}, + {`(400)`, 0b_0000000110010000_00000000_0000000000000000000000000000000000000000}, + {`(1000000)`, 0b_1111111111111111_00000000_0000000000000000000000000000000000000000}, // > 1 << 16 // null - {`[1, null]`, 1<<48 | uint64(encoding.NullValue)<<40}, + {`(1, null)`, 1<<48 | uint64(encoding.NullValue)<<40}, // bool - {`[1, false]`, 1<<48 | uint64(encoding.FalseValue)<<40}, - {`[1, true]`, 1<<48 | uint64(encoding.TrueValue)<<40}, + {`(1, false)`, 1<<48 | uint64(encoding.FalseValue)<<40}, + {`(1, true)`, 1<<48 | uint64(encoding.TrueValue)<<40}, // int - {`[1, 1]`, 1<<48 | (uint64(encoding.IntSmallValue)+32+1)<<40}, // positive int -> small value - {`[1, -10]`, 1<<48 | (uint64(encoding.IntSmallValue)+32-10)<<40}, // negative int -> small value - {`[1, 31]`, 1<<48 | (uint64(encoding.IntSmallValue)+32+31)<<40}, // positive int -> small value - {`[1, 100]`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 100}, // positive int -> uint8 - {`[1, 128]`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 128}, // positive int -> uint8 - {`[1, 255]`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 255}, // positive int -> uint8 - {`[1, 256]`, 1<<48 | uint64(encoding.Uint16Value)<<40 | 256}, // positive int -> uint16 - {`[1, 999]`, 1<<48 | uint64(encoding.Uint16Value)<<40 | 999}, // positive int -> uint16 - {`[1, -5000000000]`, 1<<48 | uint64(encoding.Int64Value)<<40 | (uint64(i64)+math.MaxInt64+1)>>24}, // int64 - {`[1, -60000000]`, 1<<48 | uint64(encoding.Int32Value)<<40 | (uint64(i32) + math.MaxInt32 + 1)}, // int32 - {`[1, -10000]`, 1<<48 | uint64(encoding.Int16Value)<<40 | (uint64(i16) + math.MaxInt16 + 1)}, // int16 - {`[1, -127]`, 1<<48 | uint64(encoding.Int8Value)<<40 | (uint64(i8) + math.MaxInt8 + 1)}, // int8 - {`[1, 255]`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 255}, // uint8 - {`[1, 50000]`, 1<<48 | uint64(encoding.Uint16Value)<<40 | 50000}, // uint16 - {`[1, 500000]`, 1<<48 | uint64(encoding.Uint32Value)<<40 | 500000}, // uint32 - {`[1, 5000000000]`, 1<<48 | uint64(encoding.Uint64Value)<<40 | 5000000000>>24}, // uint64 + {`(1, 1)`, 1<<48 | (uint64(encoding.IntSmallValue)+32+1)<<40}, // positive int -> small value + {`(1, -10)`, 1<<48 | (uint64(encoding.IntSmallValue)+32-10)<<40}, // negative int -> small value + {`(1, 31)`, 1<<48 | (uint64(encoding.IntSmallValue)+32+31)<<40}, // positive int -> small value + {`(1, 100)`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 100}, // positive int -> uint8 + {`(1, 128)`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 128}, // positive int -> uint8 + {`(1, 255)`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 255}, // positive int -> uint8 + {`(1, 256)`, 1<<48 | uint64(encoding.Uint16Value)<<40 | 256}, // positive int -> uint16 + {`(1, 999)`, 1<<48 | uint64(encoding.Uint16Value)<<40 | 999}, // positive int -> uint16 + {`(1, -5000000000)`, 1<<48 | uint64(encoding.Int64Value)<<40 | (uint64(i64)+math.MaxInt64+1)>>24}, // int64 + {`(1, -60000000)`, 1<<48 | uint64(encoding.Int32Value)<<40 | (uint64(i32) + math.MaxInt32 + 1)}, // int32 + {`(1, -10000)`, 1<<48 | uint64(encoding.Int16Value)<<40 | (uint64(i16) + math.MaxInt16 + 1)}, // int16 + {`(1, -127)`, 1<<48 | uint64(encoding.Int8Value)<<40 | (uint64(i8) + math.MaxInt8 + 1)}, // int8 + {`(1, 255)`, 1<<48 | uint64(encoding.Uint8Value)<<40 | 255}, // uint8 + {`(1, 50000)`, 1<<48 | uint64(encoding.Uint16Value)<<40 | 50000}, // uint16 + {`(1, 500000)`, 1<<48 | uint64(encoding.Uint32Value)<<40 | 500000}, // uint32 + {`(1, 5000000000)`, 1<<48 | uint64(encoding.Uint64Value)<<40 | 5000000000>>24}, // uint64 // float / double - {`[1, 1.0]`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(1)^(1<<63))>>24}, - {`[1, -1.0]`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(-1)^(1<<64-1))>>24}, - {`[1, 1e50]`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(1e50)^1<<63)>>24}, - {`[1, -1e50]`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(-1e50)^(1<<64-1))>>24}, + {`(1, 1.0)`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(1)^(1<<63))>>24}, + {`(1, -1.0)`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(-1)^(1<<64-1))>>24}, + {`(1, 1e50)`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(1e50)^1<<63)>>24}, + {`(1, -1e50)`, 1<<48 | uint64(encoding.Float64Value)<<40 | uint64(math.Float64bits(-1e50)^(1<<64-1))>>24}, // text - {`[1, "abc"]`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16}, - {`[1, "abcdefghijkl"]`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16 | uint64('d')<<8 | uint64('e')}, - {`[1, "abcdefghijkl` + strings.Repeat("m", 100) + `"]`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16 | uint64('d')<<8 | uint64('e')}, - {`[1, "abcdefghijkl` + strings.Repeat("m", 10000) + `"]`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16 | uint64('d')<<8 | uint64('e')}, + {`(1, "abc")`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16}, + {`(1, "abcdefghijkl")`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16 | uint64('d')<<8 | uint64('e')}, + {`(1, "abcdefghijkl` + strings.Repeat("m", 100) + `")`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16 | uint64('d')<<8 | uint64('e')}, + {`(1, "abcdefghijkl` + strings.Repeat("m", 10000) + `")`, 1<<48 | uint64(encoding.TextValue)<<40 | uint64('a')<<32 | uint64('b')<<24 | uint64('c')<<16 | uint64('d')<<8 | uint64('e')}, // blob - {`[1, "\xab"]`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32}, - {`[1, "\xabcdefabcdef"]`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32 | uint64(0xcd)<<24 | uint64(0xef)<<16 | uint64(0xab)<<8 | uint64(0xcd)}, - {`[1, "\xabcdefabcdef` + strings.Repeat("c", 100) + `"]`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32 | uint64(0xcd)<<24 | uint64(0xef)<<16 | uint64(0xab)<<8 | uint64(0xcd)}, - {`[1, "\xabcdefabcdef` + strings.Repeat("c", 1000) + `"]`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32 | uint64(0xcd)<<24 | uint64(0xef)<<16 | uint64(0xab)<<8 | uint64(0xcd)}, - - // array - {`[1, []]`, 1<<48 | uint64(encoding.ArrayValue)<<40}, - {`[1, [1, 1]]`, 1<<48 | uint64(encoding.ArrayValue)<<40 | (uint64(encoding.IntSmallValue)+32+1)<<32}, - {`[1, [[]]]`, 1<<48 | uint64(encoding.ArrayValue)<<40 | uint64(encoding.ArrayValue)<<32}, - // doc - {`[1, {}]`, 1<<48 | uint64(encoding.ObjectValue)<<40}, - {`[1, {a: 1}]`, 1<<48 | uint64(encoding.ObjectValue)<<40 | uint64(encoding.TextValue)<<32 | uint64('a')<<24}, - {`[1, {a: 2}]`, 1<<48 | uint64(encoding.ObjectValue)<<40 | uint64(encoding.TextValue)<<32 | uint64('a')<<24}, + {`(1, "\xab")`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32}, + {`(1, "\xabcdefabcdef")`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32 | uint64(0xcd)<<24 | uint64(0xef)<<16 | uint64(0xab)<<8 | uint64(0xcd)}, + {`(1, "\xabcdefabcdef` + strings.Repeat("c", 100) + `")`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32 | uint64(0xcd)<<24 | uint64(0xef)<<16 | uint64(0xab)<<8 | uint64(0xcd)}, + {`(1, "\xabcdefabcdef` + strings.Repeat("c", 1000) + `")`, 1<<48 | uint64(encoding.BlobValue)<<40 | uint64(0xab)<<32 | uint64(0xcd)<<24 | uint64(0xef)<<16 | uint64(0xab)<<8 | uint64(0xcd)}, } for _, test := range tests { t.Run(fmt.Sprintf("AbbreviatedKey(%s)", test.k), func(t *testing.T) { - v, err := testutil.ParseExpr(t, test.k).Eval(&environment.Environment{}) + a, err := testutil.ParseExprList(t, test.k).EvalAll(&environment.Environment{}) require.NoError(t, err) - a := v.V().(*object.ValueBuffer).Values k := mustNewKey(t, 0, 0, a...) require.Equal(t, test.want, encoding.AbbreviatedKey(k)) @@ -306,21 +260,36 @@ func TestSeparator(t *testing.T) { tests := []struct { k1, k2 string }{ - {`[1, 1]`, `[1, 2]`}, - {`[1, 1]`, `[1, 3]`}, + {`(1, 1)`, `(1, 2)`}, + {`(1, 1)`, `(1, 3)`}, } for _, test := range tests { t.Run(fmt.Sprintf("Separator(%v, %v)", test.k1, test.k2), func(t *testing.T) { - v1, err := testutil.ParseExpr(t, test.k1).Eval(&environment.Environment{}) + v1, err := testutil.ParseExprList(t, test.k1).EvalAll(&environment.Environment{}) require.NoError(t, err) - v2, err := testutil.ParseExpr(t, test.k2).Eval(&environment.Environment{}) + v2, err := testutil.ParseExprList(t, test.k2).EvalAll(&environment.Environment{}) require.NoError(t, err) - k1 := mustNewKey(t, 0, 0, v1.V().(*object.ValueBuffer).Values...) - k2 := mustNewKey(t, 0, 0, v2.V().(*object.ValueBuffer).Values...) + k1 := mustNewKey(t, 0, 0, v1...) + k2 := mustNewKey(t, 0, 0, v2...) sep := encoding.Separator(nil, k1, k2) require.LessOrEqual(t, encoding.Compare(k1, sep), 0) require.Less(t, encoding.Compare(sep, k2), 0) }) } } + +func makeUvarint(n int) []byte { + var buf [10]byte + i := binary.PutUvarint(buf[:], uint64(n)) + return buf[:i] +} + +func mustNewKey(t testing.TB, namespace tree.Namespace, order tree.SortOrder, values ...types.Value) []byte { + k := tree.NewKey(values...) + + b, err := k.Encode(namespace, order) + require.NoError(t, err) + + return b +} diff --git a/internal/encoding/times.go b/internal/encoding/times.go index 94fd869ee..d28e50776 100644 --- a/internal/encoding/times.go +++ b/internal/encoding/times.go @@ -6,27 +6,27 @@ import ( ) var ( - epoch = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixMicro() - maxTime = math.MaxInt64 - epoch - minTime = math.MinInt64 + epoch + Epoch = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixMicro() + MaxTime = math.MaxInt64 - Epoch + MinTime = math.MinInt64 + Epoch ) func EncodeTimestamp(dst []byte, t time.Time) []byte { x := t.UnixMicro() - if x > maxTime || x < minTime { + if x > MaxTime || x < MinTime { panic("timestamp out of range") } - diff := x - epoch + diff := x - Epoch return EncodeInt(dst, diff) } func DecodeTimestamp(b []byte) (time.Time, int) { x, n := DecodeInt(b) - return time.UnixMicro(epoch + x).UTC(), n + return time.UnixMicro(Epoch + x).UTC(), n } func ConvertToTimestamp(x int64) time.Time { - return time.UnixMicro(epoch + x).UTC() + return time.UnixMicro(Epoch + x).UTC() } diff --git a/internal/encoding/times_test.go b/internal/encoding/times_test.go index 6bed75b1d..09b6b11c4 100644 --- a/internal/encoding/times_test.go +++ b/internal/encoding/times_test.go @@ -1,10 +1,11 @@ -package encoding +package encoding_test import ( "math" "testing" "time" + "github.com/chaisql/chai/internal/encoding" "github.com/stretchr/testify/require" ) @@ -19,51 +20,51 @@ func TestEncodeTimestamp(t *testing.T) { "epoch", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - EncodeInt(nil, 0), + encoding.EncodeInt(nil, 0), }, { "nanosecond-precision-loss", time.Date(2000, 1, 1, 0, 0, 0, 1, time.UTC), time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - EncodeInt(nil, 0), + encoding.EncodeInt(nil, 0), }, { "microsecond-precision", time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC), time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC), - EncodeInt(nil, 1), + encoding.EncodeInt(nil, 1), }, { "minute", time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC), time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC), - EncodeInt(nil, 60_000_000), + encoding.EncodeInt(nil, 60_000_000), }, { "negative-minute", time.Date(1999, 12, 31, 23, 59, 0, 0, time.UTC), time.Date(1999, 12, 31, 23, 59, 0, 0, time.UTC), - EncodeInt(nil, -60_000_000), + encoding.EncodeInt(nil, -60_000_000), }, { "max-date", time.Date(294_217, 1, 10, 4, 0, 54, 775_807_000, time.UTC), time.Date(294_217, 1, 10, 4, 0, 54, 775_807_000, time.UTC), - EncodeInt(nil, math.MaxInt64-epoch-epoch), + encoding.EncodeInt(nil, math.MaxInt64-encoding.Epoch-encoding.Epoch), }, { "min-date", time.Date(-290_278, 12, 22, 19, 59, 05, 224_192_000, time.UTC), time.Date(-290_278, 12, 22, 19, 59, 05, 224_192_000, time.UTC), - EncodeInt(nil, math.MinInt64), + encoding.EncodeInt(nil, math.MinInt64), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - enc := EncodeTimestamp(nil, test.t) + enc := encoding.EncodeTimestamp(nil, test.t) require.Equal(t, test.enc, enc) - ts, _ := DecodeTimestamp(enc) + ts, _ := encoding.DecodeTimestamp(enc) require.Equal(t, test.dec, ts) }) } diff --git a/internal/environment/env.go b/internal/environment/env.go index 9d5006979..2df882122 100644 --- a/internal/environment/env.go +++ b/internal/environment/env.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/types" ) @@ -21,13 +21,11 @@ type Param struct { // the expression is evaluated. type Environment struct { Params []Param - Vars *object.FieldBuffer - Row database.Row + Vars *row.ColumnBuffer + Row row.Row DB *database.Database Tx *database.Transaction - baseRow database.BasicRow - Outer *Environment } @@ -48,30 +46,30 @@ func (e *Environment) SetOuter(env *Environment) { e.Outer = env } -func (e *Environment) Get(path object.Path) (v types.Value, ok bool) { +func (e *Environment) Get(column string) (v types.Value, ok bool) { if e.Vars != nil { - v, err := path.GetValueFromObject(e.Vars) + v, err := e.Vars.Get(column) if err == nil { return v, true } } if e.Outer != nil { - return e.Outer.Get(path) + return e.Outer.Get(column) } return types.NewNullValue(), false } -func (e *Environment) Set(path object.Path, v types.Value) { +func (e *Environment) Set(column string, v types.Value) { if e.Vars == nil { - e.Vars = object.NewFieldBuffer() + e.Vars = row.NewColumnBuffer() } - e.Vars.Set(path, v) + e.Vars.Set(column, v) } -func (e *Environment) GetRow() (database.Row, bool) { +func (e *Environment) GetRow() (row.Row, bool) { if e.Row != nil { return e.Row, true } @@ -83,13 +81,21 @@ func (e *Environment) GetRow() (database.Row, bool) { return nil, false } -func (e *Environment) SetRow(r database.Row) { - e.Row = r +func (e *Environment) GetDatabaseRow() (database.Row, bool) { + if e.Row != nil { + r, ok := e.Row.(database.Row) + return r, ok + } + + if e.Outer != nil { + return e.Outer.GetDatabaseRow() + } + + return nil, false } -func (e *Environment) SetRowFromObject(o types.Object) { - e.baseRow.ResetWith("", nil, o) - e.Row = &e.baseRow +func (e *Environment) SetRow(r row.Row) { + e.Row = r } func (e *Environment) SetParams(params []Param) { @@ -105,7 +111,7 @@ func (e *Environment) GetParamByName(name string) (v types.Value, err error) { for _, nv := range e.Params { if nv.Name == name { - return object.NewValue(nv.Value) + return row.NewValue(nv.Value) } } @@ -124,7 +130,7 @@ func (e *Environment) GetParamByIndex(pos int) (types.Value, error) { return nil, fmt.Errorf("cannot find param number %d", pos) } - return object.NewValue(e.Params[idx].Value) + return row.NewValue(e.Params[idx].Value) } func (e *Environment) GetTx() *database.Transaction { diff --git a/internal/errors/errors.go b/internal/errors/errors.go index a1e3606a5..a0150636a 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -31,23 +31,23 @@ func IsAlreadyExistsError(err error) bool { // NotFoundError is returned when the requested table, index or sequence // doesn't exist. type NotFoundError struct { - Name string + name string } func NewRowNotFoundError() error { - return NewNotFoundError("row") + return errors.WithStack(NewNotFoundError("row")) } func NewNotFoundError(name string) error { - return &NotFoundError{Name: name} + return errors.WithStack(&NotFoundError{name: name}) } func (a NotFoundError) Error() string { - if a.Name == "row" { + if a.name == "row" { return "row not found" } - return fmt.Sprintf("%q not found", a.Name) + return fmt.Sprintf("%q not found", a.name) } func IsNotFoundError(err error) bool { diff --git a/internal/expr/arithmeric.go b/internal/expr/arithmeric.go index 0742420d4..362bf2ab1 100644 --- a/internal/expr/arithmeric.go +++ b/internal/expr/arithmeric.go @@ -40,12 +40,25 @@ func (op *arithmeticOperator) Eval(env *environment.Environment) (types.Value, e return a.Div(b) case scanner.MOD: return a.Mod(b) + } + + ia, ok := a.(types.Integral) + if !ok { + return NullLiteral, nil + } + + _, ok = b.(types.Integral) + if !ok { + return NullLiteral, nil + } + + switch op.simpleOperator.Tok { case scanner.BITWISEAND: - return a.BitwiseAnd(b) + return ia.BitwiseAnd(b) case scanner.BITWISEOR: - return a.BitwiseOr(b) + return ia.BitwiseOr(b) case scanner.BITWISEXOR: - return a.BitwiseXor(b) + return ia.BitwiseXor(b) } panic("unknown arithmetic token") diff --git a/internal/expr/arithmetic_test.go b/internal/expr/arithmetic_test.go index 8996881ba..ff95db8f9 100644 --- a/internal/expr/arithmetic_test.go +++ b/internal/expr/arithmetic_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" @@ -26,12 +25,10 @@ func TestValueAdd(t *testing.T) { {"integer(120)+integer(120)", types.NewIntegerValue(120), types.NewIntegerValue(120), types.NewIntegerValue(240), false}, {"integer(120)+float64(120)", types.NewIntegerValue(120), types.NewDoubleValue(120), types.NewDoubleValue(240), false}, {"integer(120)+float64(120.1)", types.NewIntegerValue(120), types.NewDoubleValue(120.1), types.NewDoubleValue(240.1), false}, - {"int64(max)+integer(10)", types.NewIntegerValue(math.MaxInt64), types.NewIntegerValue(10), types.NewDoubleValue(math.MaxInt64 + 10), false}, - {"int64(min)+integer(-10)", types.NewIntegerValue(math.MinInt64), types.NewIntegerValue(-10), types.NewDoubleValue(math.MinInt64 - 10), false}, + {"int64(max)+integer(10)", types.NewBigintValue(math.MaxInt64), types.NewIntegerValue(10), nil, true}, + {"int64(min)+integer(-10)", types.NewBigintValue(math.MinInt64), types.NewIntegerValue(-10), nil, true}, {"integer(120)+text('120')", types.NewIntegerValue(120), types.NewTextValue("120"), types.NewNullValue(), false}, {"text('120')+text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object+object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array+array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -64,12 +61,10 @@ func TestValueSub(t *testing.T) { {"int16(250)-int16(220)", types.NewIntegerValue(250), types.NewIntegerValue(220), types.NewIntegerValue(30), false}, {"integer(120)-float64(620)", types.NewIntegerValue(120), types.NewDoubleValue(620), types.NewDoubleValue(-500), false}, {"integer(120)-float64(120.1)", types.NewIntegerValue(120), types.NewDoubleValue(120.1), types.NewDoubleValue(-0.09999999999999432), false}, - {"int64(min)-integer(10)", types.NewIntegerValue(math.MinInt64), types.NewIntegerValue(10), types.NewDoubleValue(math.MinInt64 - 10), false}, - {"int64(max)-integer(-10)", types.NewIntegerValue(math.MaxInt64), types.NewIntegerValue(-10), types.NewDoubleValue(math.MaxInt64 + 10), false}, + {"int64(min)-integer(10)", types.NewBigintValue(math.MinInt64), types.NewIntegerValue(10), nil, true}, + {"int64(max)-integer(-10)", types.NewBigintValue(math.MaxInt64), types.NewIntegerValue(-10), nil, true}, {"integer(120)-text('120')", types.NewIntegerValue(120), types.NewTextValue("120"), types.NewNullValue(), false}, {"text('120')-text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object-object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array-array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -101,11 +96,9 @@ func TestValueMult(t *testing.T) { {"integer(10)*integer(10)", types.NewIntegerValue(10), types.NewIntegerValue(10), types.NewIntegerValue(100), false}, {"integer(10)*integer(80)", types.NewIntegerValue(10), types.NewIntegerValue(80), types.NewIntegerValue(800), false}, {"integer(10)*float64(80)", types.NewIntegerValue(10), types.NewDoubleValue(80), types.NewDoubleValue(800), false}, - {"int64(max)*int64(max)", types.NewIntegerValue(math.MaxInt64), types.NewIntegerValue(math.MaxInt64), types.NewDoubleValue(math.MaxInt64 * math.MaxInt64), false}, + {"int64(max)*int64(max)", types.NewBigintValue(math.MaxInt64), types.NewBigintValue(math.MaxInt64), nil, true}, {"integer(120)*text('120')", types.NewIntegerValue(120), types.NewTextValue("120"), types.NewNullValue(), false}, {"text('120')*text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object*object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array*array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -133,16 +126,14 @@ func TestValueDiv(t *testing.T) { {"null/integer(10)", types.NewNullValue(), types.NewIntegerValue(10), types.NewNullValue(), false}, {"bool(true)/bool(true)", types.NewBooleanValue(true), types.NewBooleanValue(true), types.NewNullValue(), false}, {"bool(true)/bool(false)", types.NewBooleanValue(true), types.NewBooleanValue(false), types.NewNullValue(), false}, - {"integer(10)/integer(0)", types.NewIntegerValue(10), types.NewIntegerValue(0), types.NewNullValue(), false}, + {"integer(10)/integer(0)", types.NewIntegerValue(10), types.NewIntegerValue(0), types.NewNullValue(), true}, {"integer(10)/float64(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewNullValue(), false}, {"integer(10)/integer(10)", types.NewIntegerValue(10), types.NewIntegerValue(10), types.NewIntegerValue(1), false}, {"integer(10)/integer(8)", types.NewIntegerValue(10), types.NewIntegerValue(8), types.NewIntegerValue(1), false}, {"integer(10)/float64(8)", types.NewIntegerValue(10), types.NewDoubleValue(8), types.NewDoubleValue(1.25), false}, - {"int64(maxint)/float64(maxint)", types.NewIntegerValue(math.MaxInt64), types.NewDoubleValue(math.MaxInt64), types.NewDoubleValue(1), false}, + {"int64(maxint)/float64(maxint)", types.NewBigintValue(math.MaxInt64), types.NewDoubleValue(math.MaxInt64), types.NewDoubleValue(1), false}, {"integer(120)/text('120')", types.NewIntegerValue(120), types.NewTextValue("120"), types.NewNullValue(), false}, {"text('120')/text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object/object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array/array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -175,13 +166,11 @@ func TestValueMod(t *testing.T) { {"integer(10)%integer(10)", types.NewIntegerValue(10), types.NewIntegerValue(10), types.NewIntegerValue(0), false}, {"integer(10)%integer(8)", types.NewIntegerValue(10), types.NewIntegerValue(8), types.NewIntegerValue(2), false}, {"integer(10)%float64(8)", types.NewIntegerValue(10), types.NewDoubleValue(8), types.NewDoubleValue(2), false}, - {"int64(maxint)%float64(maxint)", types.NewIntegerValue(math.MaxInt64), types.NewDoubleValue(math.MaxInt64), types.NewDoubleValue(0), false}, + {"int64(maxint)%float64(maxint)", types.NewBigintValue(math.MaxInt64), types.NewDoubleValue(math.MaxInt64), types.NewDoubleValue(0), false}, {"double(> maxint)%int64(100)", types.NewDoubleValue(math.MaxInt64 + 1000), types.NewIntegerValue(100), types.NewDoubleValue(8), false}, {"int64(100)%float64(> maxint)", types.NewIntegerValue(100), types.NewDoubleValue(math.MaxInt64 + 1000), types.NewDoubleValue(100), false}, {"integer(120)%text('120')", types.NewIntegerValue(120), types.NewTextValue("120"), types.NewNullValue(), false}, {"text('120')%text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object%object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array%array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -210,14 +199,12 @@ func TestValueBitwiseAnd(t *testing.T) { {"bool(true)&bool(true)", types.NewBooleanValue(true), types.NewBooleanValue(true), types.NewNullValue(), false}, {"bool(true)&bool(false)", types.NewBooleanValue(true), types.NewBooleanValue(false), types.NewNullValue(), false}, {"integer(10)&integer(0)", types.NewIntegerValue(10), types.NewIntegerValue(0), types.NewIntegerValue(0), false}, - {"double(10.5)&float64(3.2)", types.NewDoubleValue(10.5), types.NewDoubleValue(3.2), types.NewIntegerValue(2), false}, - {"integer(10)&float64(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewIntegerValue(0), false}, + {"double(10.5)&double(3.2)", types.NewDoubleValue(10.5), types.NewDoubleValue(3.2), types.NewNullValue(), false}, + {"integer(10)&double(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewNullValue(), false}, {"integer(10)&integer(10)", types.NewIntegerValue(10), types.NewIntegerValue(10), types.NewIntegerValue(10), false}, {"integer(10)&integer(8)", types.NewIntegerValue(10), types.NewIntegerValue(8), types.NewIntegerValue(8), false}, - {"integer(10)&float64(8)", types.NewIntegerValue(10), types.NewDoubleValue(8), types.NewIntegerValue(8), false}, + {"integer(10)&double(8)", types.NewIntegerValue(10), types.NewDoubleValue(8), types.NewNullValue(), false}, {"text('120')&text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object&object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array&array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -246,13 +233,11 @@ func TestValueBitwiseOr(t *testing.T) { {"bool(true)|bool(true)", types.NewBooleanValue(true), types.NewBooleanValue(true), types.NewNullValue(), false}, {"bool(true)|bool(false)", types.NewBooleanValue(true), types.NewBooleanValue(false), types.NewNullValue(), false}, {"integer(10)|integer(0)", types.NewIntegerValue(10), types.NewIntegerValue(0), types.NewIntegerValue(10), false}, - {"double(10.5)|float64(3.2)", types.NewDoubleValue(10.5), types.NewDoubleValue(3.2), types.NewIntegerValue(11), false}, - {"integer(10)|float64(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewIntegerValue(10), false}, + {"double(10.5)|double(3.2)", types.NewDoubleValue(10.5), types.NewDoubleValue(3.2), types.NewNullValue(), false}, + {"integer(10)|double(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewNullValue(), false}, {"integer(10)|integer(10)", types.NewIntegerValue(10), types.NewIntegerValue(10), types.NewIntegerValue(10), false}, - {"integer(10)|float64(8)", types.NewIntegerValue(10), types.NewDoubleValue(8), types.NewIntegerValue(10), false}, + {"integer(10)|double(8)", types.NewIntegerValue(10), types.NewDoubleValue(8), types.NewNullValue(), false}, {"text('120')|text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object|object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array|array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { @@ -281,12 +266,10 @@ func TestValueBitwiseXor(t *testing.T) { {"bool(true)^bool(true)", types.NewBooleanValue(true), types.NewBooleanValue(true), types.NewNullValue(), false}, {"bool(true)^bool(false)", types.NewBooleanValue(true), types.NewBooleanValue(false), types.NewNullValue(), false}, {"integer(10)^integer(0)", types.NewIntegerValue(10), types.NewIntegerValue(0), types.NewIntegerValue(10), false}, - {"double(10.5)^double(3.2)", types.NewDoubleValue(10.5), types.NewDoubleValue(3.2), types.NewIntegerValue(9), false}, - {"integer(10)^double(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewIntegerValue(10), false}, + {"double(10.5)^double(3.2)", types.NewDoubleValue(10.5), types.NewDoubleValue(3.2), types.NewNullValue(), false}, + {"integer(10)^double(0)", types.NewIntegerValue(10), types.NewDoubleValue(0), types.NewNullValue(), false}, {"integer(10)^integer(10)", types.NewIntegerValue(10), types.NewIntegerValue(10), types.NewIntegerValue(0), false}, {"text('120')^text('120')", types.NewTextValue("120"), types.NewTextValue("120"), types.NewNullValue(), false}, - {"object^object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), types.NewNullValue(), false}, - {"array^array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), types.NewNullValue(), false}, } for _, test := range tests { diff --git a/internal/expr/column.go b/internal/expr/column.go new file mode 100644 index 000000000..04529a1a5 --- /dev/null +++ b/internal/expr/column.go @@ -0,0 +1,27 @@ +package expr + +import ( + "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/types" + "github.com/cockroachdb/errors" +) + +type Column string + +func (c Column) String() string { + return string(c) +} + +func (c Column) Eval(env *environment.Environment) (types.Value, error) { + r, ok := env.GetRow() + if !ok { + return NullLiteral, errors.New("no table specified") + } + + v, err := r.Get(string(c)) + if err != nil { + return NullLiteral, err + } + + return v, nil +} diff --git a/internal/expr/comparison.go b/internal/expr/comparison.go index 530960d16..ccbcdfbdb 100644 --- a/internal/expr/comparison.go +++ b/internal/expr/comparison.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/types" ) @@ -143,25 +142,69 @@ func IsComparisonOperator(op Operator) bool { } type InOperator struct { - *simpleOperator + a Expr + b Expr + op scanner.Token } // In creates an expression that evaluates to the result of a IN b. -func In(a, b Expr) Expr { - return &InOperator{&simpleOperator{a, b, scanner.IN}} +func In(a Expr, b Expr) Expr { + return &InOperator{a, b, scanner.IN} +} + +func (op InOperator) Precedence() int { + return op.op.Precedence() +} + +func (op InOperator) LeftHand() Expr { + return op.a +} + +func (op InOperator) RightHand() Expr { + return op.b +} + +func (op *InOperator) SetLeftHandExpr(a Expr) { + op.a = a +} + +func (op *InOperator) SetRightHandExpr(b Expr) {} + +func (op *InOperator) Token() scanner.Token { + return op.op +} + +func (op *InOperator) String() string { + return fmt.Sprintf("%v IN %v", op.a, op.b) } func (op *InOperator) Eval(env *environment.Environment) (types.Value, error) { - return op.simpleOperator.eval(env, func(a, b types.Value) (types.Value, error) { - if a.Type() == types.TypeNull || b.Type() == types.TypeNull { - return NullLiteral, nil - } + a, err := op.validateLeftExpression(op.a) + if err != nil { + return NullLiteral, err + } + + b, err := op.validateRightExpression(op.b) + if err != nil { + return NullLiteral, err + } + + va, err := a.Eval(env) + if err != nil { + return NullLiteral, err + } - if b.Type() != types.TypeArray { - return FalseLiteral, nil + if va.Type() == types.TypeNull { + return NullLiteral, nil + } + + for _, bb := range b { + v, err := bb.Eval(env) + if err != nil { + return NullLiteral, err } - ok, err := object.ArrayContains(types.AsArray(b), a) + ok, err := va.EQ(v) if err != nil { return NullLiteral, err } @@ -169,8 +212,33 @@ func (op *InOperator) Eval(env *environment.Environment) (types.Value, error) { if ok { return TrueLiteral, nil } - return FalseLiteral, nil - }) + } + + return FalseLiteral, nil +} + +func (op *InOperator) validateLeftExpression(a Expr) (Expr, error) { + switch t := a.(type) { + case Parentheses: + return op.validateLeftExpression(t.E) + case Column: + return a, nil + case LiteralValue: + return a, nil + } + + return nil, fmt.Errorf("invalid left expression for IN operator: %v", a) +} + +func (op *InOperator) validateRightExpression(b Expr) (LiteralExprList, error) { + switch t := b.(type) { + case Parentheses: + return LiteralExprList{b.(Parentheses).E}, nil + case LiteralExprList: + return t, nil + } + + return nil, fmt.Errorf("invalid right expression for IN operator: %v", b) } type NotInOperator struct { @@ -178,8 +246,8 @@ type NotInOperator struct { } // NotIn creates an expression that evaluates to the result of a NOT IN b. -func NotIn(a, b Expr) Expr { - return &NotInOperator{InOperator{&simpleOperator{a, b, scanner.NIN}}} +func NotIn(a Expr, b Expr) Expr { + return &NotInOperator{InOperator{a, b, scanner.NIN}} } func (op *NotInOperator) Eval(env *environment.Environment) (types.Value, error) { diff --git a/internal/expr/comparison_test.go b/internal/expr/comparison_test.go index a7c9e8e1b..eef722109 100644 --- a/internal/expr/comparison_test.go +++ b/internal/expr/comparison_test.go @@ -3,11 +3,17 @@ package expr_test import ( "testing" + "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/types" ) +var envWithRow = environment.New(func() database.Row { + return database.NewBasicRow(row.NewColumnBuffer().Add("a", types.NewIntegerValue(1))) +}()) + func TestComparisonExpr(t *testing.T) { tests := []struct { expr string @@ -16,27 +22,27 @@ func TestComparisonExpr(t *testing.T) { }{ {"1 = a", types.NewBooleanValue(true), false}, {"1 = NULL", nullLiteral, false}, - {"1 = notFound", nullLiteral, false}, + {"1 = notFound", nullLiteral, true}, {"1 != a", types.NewBooleanValue(false), false}, {"1 != NULL", nullLiteral, false}, - {"1 != notFound", nullLiteral, false}, + {"1 != notFound", nullLiteral, true}, {"1 > a", types.NewBooleanValue(false), false}, {"1 > NULL", nullLiteral, false}, - {"1 > notFound", nullLiteral, false}, + {"1 > notFound", nullLiteral, true}, {"1 >= a", types.NewBooleanValue(true), false}, {"1 >= NULL", nullLiteral, false}, - {"1 >= notFound", nullLiteral, false}, + {"1 >= notFound", nullLiteral, true}, {"1 < a", types.NewBooleanValue(false), false}, {"1 < NULL", nullLiteral, false}, - {"1 < notFound", nullLiteral, false}, + {"1 < notFound", nullLiteral, true}, {"1 <= a", types.NewBooleanValue(true), false}, {"1 <= NULL", nullLiteral, false}, - {"1 <= notFound", nullLiteral, false}, + {"1 <= notFound", nullLiteral, true}, } for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } @@ -47,21 +53,18 @@ func TestComparisonINExpr(t *testing.T) { res types.Value fails bool }{ - {"1 IN []", types.NewBooleanValue(false), false}, - {"1 IN [1, 2, 3]", types.NewBooleanValue(true), false}, - {"2 IN [2.1, 2.2, 2.0]", types.NewBooleanValue(true), false}, - {"1 IN [2, 3]", types.NewBooleanValue(false), false}, - {"[1] IN [1, 2, 3]", types.NewBooleanValue(false), false}, - {"[1] IN [[1], [2], [3]]", types.NewBooleanValue(true), false}, - {"1 IN {}", types.NewBooleanValue(false), false}, - {"[1, 2] IN 1", types.NewBooleanValue(false), false}, - {"1 IN NULL", nullLiteral, false}, - {"NULL IN [1, 2, NULL]", nullLiteral, false}, + {"1 IN (2)", types.NewBooleanValue(false), false}, + {"1 IN (1, 2, 3)", types.NewBooleanValue(true), false}, + {"2 IN (2.1, 2.2, 2.0)", types.NewBooleanValue(true), false}, + {"1 IN (2, 3)", types.NewBooleanValue(false), false}, + {"(1) IN (1, 2, 3)", types.NewBooleanValue(true), false}, + {"(1) IN (1), (2), (3)", types.NewBooleanValue(true), false}, + {"NULL IN (1, 2, NULL)", nullLiteral, false}, } for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } @@ -72,20 +75,15 @@ func TestComparisonNOTINExpr(t *testing.T) { res types.Value fails bool }{ - {"1 NOT IN []", types.NewBooleanValue(true), false}, - {"1 NOT IN [1, 2, 3]", types.NewBooleanValue(false), false}, - {"1 NOT IN [2, 3]", types.NewBooleanValue(true), false}, - {"[1] NOT IN [1, 2, 3]", types.NewBooleanValue(true), false}, - {"[1] NOT IN [[1], [2], [3]]", types.NewBooleanValue(false), false}, - {"1 NOT IN {}", types.NewBooleanValue(true), false}, - {"[1, 2] NOT IN 1", types.NewBooleanValue(true), false}, - {"1 NOT IN NULL", nullLiteral, false}, - {"NULL NOT IN [1, 2, NULL]", nullLiteral, false}, + {"1 NOT IN (1, 2, 3)", types.NewBooleanValue(false), false}, + {"1 NOT IN (2, 3)", types.NewBooleanValue(true), false}, + {"(1) NOT IN (1, 2, 3)", types.NewBooleanValue(false), false}, + {"NULL NOT IN (1, 2, NULL)", nullLiteral, false}, } for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } @@ -105,7 +103,7 @@ func TestComparisonISExpr(t *testing.T) { for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } @@ -125,7 +123,7 @@ func TestComparisonISNOTExpr(t *testing.T) { for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } @@ -142,7 +140,7 @@ func TestComparisonExprNoObject(t *testing.T) { {"1 >= a", nullLiteral, true}, {"1 < a", nullLiteral, true}, {"1 <= a", nullLiteral, true}, - {"1 IN [a]", nullLiteral, true}, + {"1 IN (a)", nullLiteral, true}, {"1 IS a", nullLiteral, true}, {"1 IS NOT a", nullLiteral, true}, } @@ -179,7 +177,7 @@ func TestComparisonBetweenExpr(t *testing.T) { for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } diff --git a/internal/expr/constraint.go b/internal/expr/constraint.go index 0fedd4d5a..1ce0afa61 100644 --- a/internal/expr/constraint.go +++ b/internal/expr/constraint.go @@ -3,6 +3,7 @@ package expr import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -17,10 +18,10 @@ func Constraint(e Expr) *ConstraintExpr { } } -func (t *ConstraintExpr) Eval(tx *database.Transaction, o types.Object) (types.Value, error) { +func (t *ConstraintExpr) Eval(tx *database.Transaction, r row.Row) (types.Value, error) { var env environment.Environment env.Tx = tx - env.SetRowFromObject(o) + env.SetRow(r) if t.Expr == nil { return NullLiteral, errors.New("missing expression") @@ -29,6 +30,22 @@ func (t *ConstraintExpr) Eval(tx *database.Transaction, o types.Object) (types.V return t.Expr.Eval(&env) } +func (t *ConstraintExpr) Validate(info *database.TableInfo) (err error) { + Walk(t.Expr, func(e Expr) bool { + switch e := e.(type) { + case Column: + if info.GetColumnConstraint(string(e)) == nil { + err = errors.Newf("column %q does not exist", e) + return false + } + } + + return true + }) + + return err +} + func (t *ConstraintExpr) String() string { return t.Expr.String() } diff --git a/internal/expr/expr.go b/internal/expr/expr.go index 493151f4b..d34da3538 100644 --- a/internal/expr/expr.go +++ b/internal/expr/expr.go @@ -152,18 +152,6 @@ func Walk(e Expr, fn func(Expr) bool) bool { return false } } - case LiteralExprList: - for _, e := range t { - if !Walk(e, fn) { - return false - } - } - case *KVPairs: - for _, e := range t.Pairs { - if !Walk(e.V, fn) { - return false - } - } } return true @@ -191,7 +179,7 @@ func (n NextValueFor) Eval(env *environment.Environment) (types.Value, error) { return NullLiteral, err } - return types.NewIntegerValue(i), nil + return types.NewBigintValue(i), nil } // IsEqual compares this expression with the other expression and returns @@ -212,3 +200,41 @@ func (n NextValueFor) IsEqual(other Expr) bool { func (n NextValueFor) String() string { return fmt.Sprintf("NEXT VALUE FOR %s", n.SeqName) } + +// // Type returns the expected type of the expression without evaluating it. +// // Query parameters are not allowed and will return an error. +// func Type(e Expr, info *database.TableInfo) (types.Type, error) { +// switch e := e.(type) { +// case Column: +// cc := info.GetColumnConstraint(string(e)) +// if cc == nil { +// return types.TypeNull, fmt.Errorf("column %q does not exist", e) +// } +// return cc.Type, nil +// case *NamedExpr: +// return Type(e.Expr, info) +// case Operator: +// l, err := Type(e.LeftHand(), info) +// if err != nil { +// return 0, err +// } +// r, err := Type(e.RightHand(), info) +// if err != nil { +// return 0, err +// } + +// // when types are different, determine if they are compatible +// // depending on the operator +// if l != r { +// if IsArithmeticOperator(e) { + +// } else if IsComparisonOperator(e) && l.IsComparableWith(r) { +// return types.TypeBoolean, nil +// } else { +// return 0, fmt.Errorf("mismatched types: %v and %v", l, r) +// } +// } +// } + +// return types.TypeNull, fmt.Errorf("unexpected expression type: %T", e) +// } diff --git a/internal/expr/expr_test.go b/internal/expr/expr_test.go index 4daf5a164..114875af8 100644 --- a/internal/expr/expr_test.go +++ b/internal/expr/expr_test.go @@ -5,25 +5,12 @@ import ( "strings" "testing" - "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) -var row database.Row = func() database.Row { - return database.NewBasicRow(object.NewFromJSON([]byte(`{ - "a": 1, - "b": {"foo bar": [1, 2]}, - "c": [1, {"foo": "bar"}, [1, 2]] - }`))) -}() - -var envWithDoc = environment.New(row) - var nullLiteral = types.NewNullValue() func TestString(t *testing.T) { @@ -31,11 +18,8 @@ func TestString(t *testing.T) { `10.4`, "true", "500", - `foo.bar[1]`, + `foo`, `"hello"`, - `[1, 2, "foo"]`, - `{a: "foo", b: 10}`, - "pk()", "CAST(10 AS integer)", } @@ -57,7 +41,7 @@ func TestString(t *testing.T) { } for _, op := range operators { - want := fmt.Sprintf("10.4 %s foo.bar[1]", op) + want := fmt.Sprintf("10.4 %s foo", op) testFn(want, want) } } diff --git a/internal/expr/functions/builtins.go b/internal/expr/functions/builtins.go index dc15e90ae..5e7c973df 100644 --- a/internal/expr/functions/builtins.go +++ b/internal/expr/functions/builtins.go @@ -2,10 +2,10 @@ package functions import ( "fmt" + "strings" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -18,13 +18,6 @@ var builtinFunctions = Definitions{ return &TypeOf{Expr: args[0]}, nil }, }, - "pk": &definition{ - name: "pk", - arity: 0, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &PK{}, nil - }, - }, "count": &definition{ name: "count", arity: 1, @@ -82,29 +75,52 @@ var builtinFunctions = Definitions{ }, }, - // strings alias - "lower": stringsFunctions["lower"], - "upper": stringsFunctions["upper"], - "trim": stringsFunctions["trim"], - "ltrim": stringsFunctions["ltrim"], - "rtrim": stringsFunctions["rtrim"], - - // math alias - "floor": mathFunctions["floor"], - "abs": mathFunctions["abs"], - "acos": mathFunctions["acos"], - "acosh": mathFunctions["acosh"], - "asin": mathFunctions["asin"], - "asinh": mathFunctions["asinh"], - "atan": mathFunctions["atan"], - "atan2": mathFunctions["atan2"], - "random": mathFunctions["random"], - "sqrt": mathFunctions["sqrt"], -} + "lower": &definition{ + name: "lower", + arity: 1, + constructorFn: func(args ...expr.Expr) (expr.Function, error) { + return &Lower{Expr: args[0]}, nil + }, + }, + "upper": &definition{ + name: "upper", + arity: 1, + constructorFn: func(args ...expr.Expr) (expr.Function, error) { + return &Upper{Expr: args[0]}, nil + }, + }, + "trim": &definition{ + name: "trim", + arity: variadicArity, + constructorFn: func(args ...expr.Expr) (expr.Function, error) { + return &Trim{Expr: args, TrimFunc: strings.Trim, Name: "TRIM"}, nil + }, + }, + "ltrim": &definition{ + name: "ltrim", + arity: variadicArity, + constructorFn: func(args ...expr.Expr) (expr.Function, error) { + return &Trim{Expr: args, TrimFunc: strings.TrimLeft, Name: "LTRIM"}, nil + }, + }, + "rtrim": &definition{ + name: "rtrim", + arity: variadicArity, + constructorFn: func(args ...expr.Expr) (expr.Function, error) { + return &Trim{Expr: args, TrimFunc: strings.TrimRight, Name: "RTRIM"}, nil + }, + }, -// BuiltinDefinitions returns a map of builtin functions. -func BuiltinDefinitions() Definitions { - return builtinFunctions + "floor": floor, + "abs": abs, + "acos": acos, + "acosh": acosh, + "asin": asin, + "asinh": asinh, + "atan": atan, + "atan2": atan2, + "random": random, + "sqrt": sqrt, } type TypeOf struct { @@ -141,66 +157,6 @@ func (t *TypeOf) String() string { return fmt.Sprintf("typeof(%v)", t.Expr) } -// PK represents the pk() function. -// It returns the primary key of the current object. -type PK struct{} - -// Eval returns the primary key of the current object. -func (k *PK) Eval(env *environment.Environment) (types.Value, error) { - row, ok := env.GetRow() - if !ok { - return expr.NullLiteral, nil - } - - key := row.Key() - if key == nil { - return expr.NullLiteral, nil - } - - vs, err := key.Decode() - if err != nil { - return expr.NullLiteral, err - } - - info, err := env.GetTx().Catalog.GetTableInfo(row.TableName()) - if err != nil { - return nil, err - } - - pk := info.PrimaryKey - if pk != nil { - for i, tp := range pk.Types { - if !tp.IsAny() { - vs[i], err = object.CastAs(vs[i], tp) - if err != nil { - return nil, err - } - } - } - } - - vb := object.NewValueBuffer() - - for _, v := range vs { - vb.Append(v) - } - - return types.NewArrayValue(vb), nil -} - -func (*PK) Params() []expr.Expr { return nil } - -// IsEqual compares this expression with the other expression and returns -// true if they are equal. -func (k *PK) IsEqual(other expr.Expr) bool { - _, ok := other.(*PK) - return ok -} - -func (k *PK) String() string { - return "pk()" -} - var _ expr.AggregatorBuilder = (*Count)(nil) // Count is the COUNT aggregator function. It counts the number of objects @@ -270,7 +226,7 @@ func (c *CountAggregator) Aggregate(env *environment.Environment) error { } v, err := c.Fn.Expr.Eval(env) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return err } if v.Type() != types.TypeNull { @@ -282,7 +238,7 @@ func (c *CountAggregator) Aggregate(env *environment.Environment) error { // Eval returns the result of the aggregation as an integer. func (c *CountAggregator) Eval(_ *environment.Environment) (types.Value, error) { - return types.NewIntegerValue(c.Count), nil + return types.NewBigintValue(c.Count), nil } func (c *CountAggregator) String() string { @@ -344,18 +300,18 @@ type MinAggregator struct { // then if the type is equal their value is compared. Numbers are considered of the same type. func (m *MinAggregator) Aggregate(env *environment.Environment) error { v, err := m.Fn.Expr.Eval(env) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return err } if v.Type() == types.TypeNull { return nil } - // clone the value to avoid it being reused during next aggregation - v, err = object.CloneValue(v) - if err != nil { - return err - } + // // clone the value to avoid it being reused during next aggregation + // v, err = row.CloneValue(v) + // if err != nil { + // return err + // } if m.Min == nil { m.Min = v @@ -448,19 +404,13 @@ type MaxAggregator struct { // then if the type is equal their value is compared. Numbers are considered of the same type. func (m *MaxAggregator) Aggregate(env *environment.Environment) error { v, err := m.Fn.Expr.Eval(env) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return err } if v.Type() == types.TypeNull { return nil } - // clone the value to avoid it being reused during next aggregation - v, err = object.CloneValue(v) - if err != nil { - return err - } - if m.Max == nil { m.Max = v return nil @@ -555,17 +505,18 @@ type SumAggregator struct { // If any of the value is a double, the returned result will be a double. func (s *SumAggregator) Aggregate(env *environment.Environment) error { v, err := s.Fn.Expr.Eval(env) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return err } - if v.Type() != types.TypeInteger && v.Type() != types.TypeDouble { + if !v.Type().IsNumber() { return nil } if s.SumF != nil { - if v.Type() == types.TypeInteger { + switch v.Type() { + case types.TypeInteger, types.TypeBigint: *s.SumF += float64(types.AsInt64(v)) - } else { + default: *s.SumF += float64(types.AsFloat64(v)) } @@ -598,7 +549,7 @@ func (s *SumAggregator) Eval(_ *environment.Environment) (types.Value, error) { return types.NewDoubleValue(*s.SumF), nil } if s.SumI != nil { - return types.NewIntegerValue(*s.SumI), nil + return types.NewBigintValue(*s.SumI), nil } return types.NewNullValue(), nil @@ -663,12 +614,12 @@ type AvgAggregator struct { // Aggregate stores the average value of all non-NULL numeric values in the group. func (s *AvgAggregator) Aggregate(env *environment.Environment) error { v, err := s.Fn.Expr.Eval(env) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return err } switch v.Type() { - case types.TypeInteger: + case types.TypeInteger, types.TypeBigint: s.Avg += float64(types.AsInt64(v)) case types.TypeDouble: s.Avg += types.AsFloat64(v) @@ -694,7 +645,7 @@ func (s *AvgAggregator) String() string { } // Len represents the len() function. -// It returns the length of string, array or object. +// It returns the length of string, array or row. // For other types len() returns NULL. type Len struct { Expr expr.Expr @@ -710,23 +661,11 @@ func (s *Len) Eval(env *environment.Environment) (types.Value, error) { switch val.Type() { case types.TypeText: length = len(types.AsString(val)) - case types.TypeArray: - arrayLen, err := object.ArrayLength(types.AsArray(val)) - if err != nil { - return nil, err - } - length = arrayLen - case types.TypeObject: - docLen, err := object.Length(types.AsObject(val)) - if err != nil { - return nil, err - } - length = docLen default: return types.NewNullValue(), nil } - return types.NewIntegerValue(int64(length)), nil + return types.NewBigintValue(int64(length)), nil } // IsEqual compares this expression with the other expression and returns diff --git a/internal/expr/functions/definition.go b/internal/expr/functions/definition.go index 3f2d091a4..b4d03e27f 100644 --- a/internal/expr/functions/definition.go +++ b/internal/expr/functions/definition.go @@ -15,36 +15,16 @@ type Definition interface { Name() string String() string Function(...expr.Expr) (expr.Function, error) - Arity() int } // Definitions table holds a map of definition, indexed by their names. type Definitions map[string]Definition -// Packages represent a table of SQL functions grouped by their packages -type Packages map[string]Definitions - -func DefaultPackages() Packages { - return Packages{ - "": BuiltinDefinitions(), - "math": MathFunctions(), - "strings": StringsDefinitions(), - "objects": ObjectsDefinitions(), - } -} - // GetFunc return a function definition by its package and name. -func (t Packages) GetFunc(pkg string, fname string) (Definition, error) { - fs, ok := t[pkg] - if !ok { - return nil, fmt.Errorf("no such package: %q", fname) - } - def, ok := fs[strings.ToLower(fname)] +func GetFunc(fname string) (Definition, error) { + def, ok := builtinFunctions[strings.ToLower(fname)] if !ok { - if pkg == "" { - return nil, fmt.Errorf("no such function: %q", fname) - } - return nil, fmt.Errorf("no such function: %q.%q", pkg, fname) + return nil, fmt.Errorf("no such function: %q", fname) } return def, nil } diff --git a/internal/expr/functions/definition_test.go b/internal/expr/functions/definition_test.go index 8a6a27d5e..74b0d7d3a 100644 --- a/internal/expr/functions/definition_test.go +++ b/internal/expr/functions/definition_test.go @@ -5,14 +5,12 @@ import ( "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil/assert" "github.com/stretchr/testify/require" ) func TestDefinitions(t *testing.T) { - packages := functions.DefaultPackages() - def, err := packages.GetFunc("", "count") + def, err := functions.GetFunc("count") assert.NoError(t, err) t.Run("String()", func(t *testing.T) { @@ -20,40 +18,8 @@ func TestDefinitions(t *testing.T) { }) t.Run("Function()", func(t *testing.T) { - fexpr, err := def.Function(expr.Path(object.NewPath("a"))) + fexpr, err := def.Function(expr.Column("a")) assert.NoError(t, err) require.NotNil(t, fexpr) }) - - t.Run("Arity()", func(t *testing.T) { - require.Equal(t, 1, def.Arity()) - }) -} - -func TestPackages(t *testing.T) { - table := functions.DefaultPackages() - - t.Run("OK GetFunc()", func(t *testing.T) { - def, err := table.GetFunc("math", "floor") - assert.NoError(t, err) - require.Equal(t, "floor", def.Name()) - def, err = table.GetFunc("", "count") - assert.NoError(t, err) - require.Equal(t, "count", def.Name()) - }) - - t.Run("NOK GetFunc() missing func", func(t *testing.T) { - def, err := table.GetFunc("math", "foobar") - assert.Error(t, err) - require.Nil(t, def) - def, err = table.GetFunc("", "foobar") - assert.Error(t, err) - require.Nil(t, def) - }) - - t.Run("NOK GetFunc() missing package", func(t *testing.T) { - def, err := table.GetFunc("foobar", "foobar") - assert.Error(t, err) - require.Nil(t, def) - }) } diff --git a/internal/expr/functions/math.go b/internal/expr/functions/math.go index cf18f8b82..64cf3f890 100644 --- a/internal/expr/functions/math.go +++ b/internal/expr/functions/math.go @@ -5,28 +5,9 @@ import ( "math" "math/rand" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/types" ) -// MathFunctions returns all math package functions. -func MathFunctions() Definitions { - return mathFunctions -} - -var mathFunctions = Definitions{ - "floor": floor, - "abs": abs, - "acos": acos, - "acosh": acosh, - "asin": asin, - "asinh": asinh, - "atan": atan, - "atan2": atan2, - "random": random, - "sqrt": sqrt, -} - var floor = &ScalarDefinition{ name: "floor", arity: 1, @@ -34,7 +15,7 @@ var floor = &ScalarDefinition{ switch args[0].Type() { case types.TypeDouble: return types.NewDoubleValue(math.Floor(types.AsFloat64(args[0]))), nil - case types.TypeInteger: + case types.TypeInteger, types.TypeBigint: return args[0], nil default: return nil, fmt.Errorf("floor(arg1) expects arg1 to be a number") @@ -49,13 +30,16 @@ var abs = &ScalarDefinition{ if args[0].Type() == types.TypeNull { return types.NewNullValue(), nil } - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil { return nil, err } res := math.Abs(types.AsFloat64(v)) if args[0].Type() == types.TypeInteger { - return object.CastAs(types.NewDoubleValue(res), types.TypeInteger) + return types.NewDoubleValue(res).CastAs(types.TypeInteger) + } + if args[0].Type() == types.TypeBigint { + return types.NewDoubleValue(res).CastAs(types.TypeBigint) } return types.NewDoubleValue(res), nil }, @@ -68,7 +52,7 @@ var acos = &ScalarDefinition{ if args[0].Type() == types.TypeNull { return types.NewNullValue(), nil } - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil { return nil, err } @@ -88,7 +72,7 @@ var acosh = &ScalarDefinition{ if args[0].Type() == types.TypeNull { return types.NewNullValue(), nil } - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil { return nil, err } @@ -108,7 +92,7 @@ var asin = &ScalarDefinition{ if args[0].Type() == types.TypeNull { return types.NewNullValue(), nil } - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil { return nil, err } @@ -125,7 +109,7 @@ var asinh = &ScalarDefinition{ name: "asinh", arity: 1, callFn: func(args ...types.Value) (types.Value, error) { - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil || v.Type() == types.TypeNull { return v, err } @@ -139,7 +123,7 @@ var atan = &ScalarDefinition{ name: "atan", arity: 1, callFn: func(args ...types.Value) (types.Value, error) { - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil || v.Type() == types.TypeNull { return v, err } @@ -153,12 +137,12 @@ var atan2 = &ScalarDefinition{ name: "atan2", arity: 2, callFn: func(args ...types.Value) (types.Value, error) { - vA, err := object.CastAs(args[0], types.TypeDouble) + vA, err := args[0].CastAs(types.TypeDouble) if err != nil || vA.Type() == types.TypeNull { return vA, err } vvA := types.AsFloat64(vA) - vB, err := object.CastAs(args[1], types.TypeDouble) + vB, err := args[1].CastAs(types.TypeDouble) if err != nil || vB.Type() == types.TypeNull { return vB, err } @@ -173,7 +157,7 @@ var random = &ScalarDefinition{ arity: 0, callFn: func(args ...types.Value) (types.Value, error) { randomNum := rand.Int63() - return types.NewIntegerValue(randomNum), nil + return types.NewBigintValue(randomNum), nil }, } @@ -184,7 +168,7 @@ var sqrt = &ScalarDefinition{ if args[0].Type() != types.TypeDouble && args[0].Type() != types.TypeInteger { return types.NewNullValue(), nil } - v, err := object.CastAs(args[0], types.TypeDouble) + v, err := args[0].CastAs(types.TypeDouble) if err != nil { return nil, err } diff --git a/internal/expr/functions/object.go b/internal/expr/functions/object.go deleted file mode 100644 index 34cfe59bc..000000000 --- a/internal/expr/functions/object.go +++ /dev/null @@ -1,73 +0,0 @@ -package functions - -import ( - "fmt" - - "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/types" -) - -var objectsFunctions = Definitions{ - "fields": &definition{ - name: "fields", - arity: 1, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &ObjectFields{Expr: args[0]}, nil - }, - }, -} - -func ObjectsDefinitions() Definitions { - return objectsFunctions -} - -// ObjectFields implements the objects.fields function -// which returns the list of top-level fields of an object. -// If the argument is not an object, it returns null. -type ObjectFields struct { - Expr expr.Expr -} - -func (s *ObjectFields) Eval(env *environment.Environment) (types.Value, error) { - val, err := s.Expr.Eval(env) - if err != nil { - return nil, err - } - - if val.Type() != types.TypeObject { - return types.NewNullValue(), nil - } - - obj := types.AsObject(val) - var fields []string - err = obj.Iterate(func(k string, _ types.Value) error { - fields = append(fields, k) - return nil - }) - if err != nil { - return nil, err - } - - return types.NewArrayValue(object.NewArrayFromSlice(fields)), nil -} - -func (s *ObjectFields) IsEqual(other expr.Expr) bool { - if other == nil { - return false - } - - o, ok := other.(*ObjectFields) - if !ok { - return false - } - - return expr.Equal(s.Expr, o.Expr) -} - -func (s *ObjectFields) Params() []expr.Expr { return []expr.Expr{s.Expr} } - -func (s *ObjectFields) String() string { - return fmt.Sprintf("objects.fields(%v)", s.Expr) -} diff --git a/internal/expr/functions/scalar_definition.go b/internal/expr/functions/scalar_definition.go index 729756b3c..d8afa2685 100644 --- a/internal/expr/functions/scalar_definition.go +++ b/internal/expr/functions/scalar_definition.go @@ -12,7 +12,7 @@ import ( // A ScalarDefinition is the definition type for functions which operates on scalar values in contrast to other SQL functions // such as the SUM aggregator which operates on expressions instead. // -// This difference allows to simply define them with a CallFn function that takes multiple object.Value and +// This difference allows to simply define them with a CallFn function that takes multiple row.Value and // return another types.Value, rather than having to manually evaluate expressions (see Definition). type ScalarDefinition struct { name string @@ -49,11 +49,6 @@ func (fd *ScalarDefinition) Function(args ...expr.Expr) (expr.Function, error) { }, nil } -// Arity returns the arity of the defined function. -func (fd *ScalarDefinition) Arity() int { - return fd.arity -} - // A ScalarFunction is a function which operates on scalar values in contrast to other SQL functions // such as the SUM aggregator wich operates on expressions instead. type ScalarFunction struct { @@ -61,7 +56,7 @@ type ScalarFunction struct { params []expr.Expr } -// Eval returns a object.Value based on the given environment and the underlying function +// Eval returns a row.Value based on the given environment and the underlying function // definition. func (sf *ScalarFunction) Eval(env *environment.Environment) (types.Value, error) { args, err := sf.evalParams(env) diff --git a/internal/expr/functions/scalar_definition_test.go b/internal/expr/functions/scalar_definition_test.go index 01ec3e21a..926729381 100644 --- a/internal/expr/functions/scalar_definition_test.go +++ b/internal/expr/functions/scalar_definition_test.go @@ -7,7 +7,7 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" @@ -18,9 +18,9 @@ func TestScalarFunctionDef(t *testing.T) { "foo", 3, func(args ...types.Value) (types.Value, error) { - arg1 := args[0].V().(int64) - arg2 := args[1].V().(int64) - arg3 := args[2].V().(int64) + arg1 := args[0].V().(int32) + arg2 := args[1].V().(int32) + arg3 := args[2].V().(int32) return types.NewIntegerValue(arg1 + arg2 + arg3), nil }, @@ -30,21 +30,17 @@ func TestScalarFunctionDef(t *testing.T) { require.Equal(t, "foo", def.Name()) }) - t.Run("Arity()", func(t *testing.T) { - require.Equal(t, 3, def.Arity()) - }) - t.Run("String()", func(t *testing.T) { require.Equal(t, "foo(arg1, arg2, arg3)", def.String()) }) t.Run("Function()", func(t *testing.T) { - fb := object.NewFieldBuffer() + fb := row.NewColumnBuffer() fb = fb.Add("a", types.NewIntegerValue(2)) r := database.NewBasicRow(fb) env := environment.New(r) expr1 := expr.Add(expr.LiteralValue{Value: types.NewIntegerValue(1)}, expr.LiteralValue{Value: types.NewIntegerValue(0)}) - expr2 := expr.Path(object.NewPath("a")) + expr2 := expr.Column("a") expr3 := expr.Div(expr.LiteralValue{Value: types.NewIntegerValue(6)}, expr.LiteralValue{Value: types.NewIntegerValue(2)}) t.Run("OK", func(t *testing.T) { diff --git a/internal/expr/functions/strings.go b/internal/expr/functions/strings.go index 29440f6b8..45dbdb62a 100644 --- a/internal/expr/functions/strings.go +++ b/internal/expr/functions/strings.go @@ -9,48 +9,6 @@ import ( "github.com/chaisql/chai/internal/types" ) -var stringsFunctions = Definitions{ - "lower": &definition{ - name: "lower", - arity: 1, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &Lower{Expr: args[0]}, nil - }, - }, - "upper": &definition{ - name: "upper", - arity: 1, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &Upper{Expr: args[0]}, nil - }, - }, - "trim": &definition{ - name: "trim", - arity: variadicArity, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &Trim{Expr: args, TrimFunc: strings.Trim, Name: "TRIM"}, nil - }, - }, - "ltrim": &definition{ - name: "ltrim", - arity: variadicArity, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &Trim{Expr: args, TrimFunc: strings.TrimLeft, Name: "LTRIM"}, nil - }, - }, - "rtrim": &definition{ - name: "rtrim", - arity: variadicArity, - constructorFn: func(args ...expr.Expr) (expr.Function, error) { - return &Trim{Expr: args, TrimFunc: strings.TrimRight, Name: "RTRIM"}, nil - }, - }, -} - -func StringsDefinitions() Definitions { - return stringsFunctions -} - // Lower is the LOWER function // It returns the lower-case version of a string type Lower struct { diff --git a/internal/expr/functions/testdata/builtin_functions.sql b/internal/expr/functions/testdata/builtin_functions.sql index 089784ccc..721263593 100644 --- a/internal/expr/functions/testdata/builtin_functions.sql +++ b/internal/expr/functions/testdata/builtin_functions.sql @@ -2,7 +2,7 @@ ! typeof() ! typeof(a) -'field not found' +'no table specified' > typeof(1) 'integer' @@ -25,12 +25,6 @@ > typeof('\xAA') 'blob' -> typeof([]) -'array' - -> typeof({}) -'object' - > typeof(NULL) 'null' diff --git a/internal/expr/functions/testdata/math_functions.sql b/internal/expr/functions/testdata/math_functions.sql index 49985abfa..3bf1f8a10 100644 --- a/internal/expr/functions/testdata/math_functions.sql +++ b/internal/expr/functions/testdata/math_functions.sql @@ -1,137 +1,137 @@ --- test: math.floor -> math.floor(2.3) +-- test: floor +> floor(2.3) 2.0 -> math.floor(2) +> floor(2) 2 -! math.floor('a') +! floor('a') 'floor(arg1) expects arg1 to be a number' --- test: math.abs -> math.abs(NULL) +-- test: abs +> abs(NULL) NULL -> math.abs(-2) +> abs(-2) 2 -> math.abs(-2.0) +> abs(-2.0) 2.0 -> math.abs('-2.0') +> abs('-2.0') 2.0 -! math.abs('foo') +! abs('foo') 'cannot cast "foo" as double' -! math.abs(-9223372036854775808) +! abs(-9223372036854775808) 'integer out of range' --- test: math.acos -> math.acos(NULL) +-- test: acos +> acos(NULL) NULL -> math.acos(1) +> acos(1) 0.0 -> math.acos(0.5) +> acos(0.5) 1.0471975511965976 -> math.acos('0.5') +> acos('0.5') 1.0471975511965976 -! math.acos(2) +! acos(2) 'out of range' -! math.acos(-2) +! acos(-2) 'out of range' -! math.acos(2.2) +! acos(2.2) 'out of range' -! math.acos(-2.2) +! acos(-2.2) 'out of range' -! math.acos('foo') +! acos('foo') 'cannot cast "foo" as double' --- test: math.acosh -> math.acosh(NULL) +-- test: acosh +> acosh(NULL) NULL -> math.acosh(1) +> acosh(1) 0.0 -> math.acosh(2) +> acosh(2) 1.3169578969248166 -> math.acosh('2') +> acosh('2') 1.3169578969248166 -> math.acosh(2.5) +> acosh(2.5) 1.566799236972411 -! math.acosh(0) +! acosh(0) 'out of range' -! math.acosh(0.99999999) +! acosh(0.99999999) 'out of range' -! math.acosh('foo') +! acosh('foo') 'cannot cast "foo" as double' --- test: math.asin -> math.asin(NULL) +-- test: asin +> asin(NULL) NULL -> math.asin(0) +> asin(0) 0.0 -> math.asin(0.5) +> asin(0.5) 0.5235987755982989 -! math.asin(2) +! asin(2) 'out of range' -! math.asin(-2) +! asin(-2) 'out of range' -! math.asin(2.2) +! asin(2.2) 'out of range' -! math.asin(-2.2) +! asin(-2.2) 'out of range' -! math.asin('foo') +! asin('foo') 'cannot cast "foo" as double' --- test: math.asinh -> math.asinh(NULL) +-- test: asinh +> asinh(NULL) NULL -> math.asinh(0) +> asinh(0) 0.0 -> math.asinh(0.5) +> asinh(0.5) 0.48121182505960347 -> math.asinh(1) +> asinh(1) 0.881373587019543 -> math.asinh(-1) +> asinh(-1) -0.881373587019543 -! math.asinh('foo') +! asinh('foo') 'cannot cast "foo" as double' --- test: math.atan -> math.atan(NULL) +-- test: atan +> atan(NULL) NULL -> math.atan(0) +> atan(0) 0.0 -> math.atan(0.5) +> atan(0.5) 0.4636476090008061 -> math.atan(1) +> atan(1) 0.7853981633974483 -> math.atan(-1) +> atan(-1) -0.7853981633974483 -! math.atan('foo') +! atan('foo') 'cannot cast "foo" as double' --- test: math.atan2 -> math.atan2(NULL, NULL) +-- test: atan2 +> atan2(NULL, NULL) NULL -> math.atan2(1, NULL) +> atan2(1, NULL) NULL -> math.atan2(NULL, 1) +> atan2(NULL, 1) NULL -> math.atan2(0, 0) +> atan2(0, 0) 0.0 -> math.atan2(1, 1) +> atan2(1, 1) 0.7853981633974483 -> math.atan2(1.1, 1.1) +> atan2(1.1, 1.1) 0.7853981633974483 -> math.atan2(1.1, -1.1) +> atan2(1.1, -1.1) 2.356194490192345 -! math.atan2('foo', 1) +! atan2('foo', 1) 'cannot cast "foo" as double' --- test: math.sqrt -> math.sqrt(NULL) +-- test: sqrt +> sqrt(NULL) NULL -> math.sqrt(4) +> sqrt(4) 2.0 -> math.sqrt(81) +> sqrt(81) 9.0 -> math.sqrt(15) +> sqrt(15) 3.872983346207417 -> math.sqrt(1.1) +> sqrt(1.1) 1.0488088481701516 -> math.sqrt('foo') +> sqrt('foo') NULL \ No newline at end of file diff --git a/internal/expr/literal.go b/internal/expr/literal.go index c19d932a2..044bfab5a 100644 --- a/internal/expr/literal.go +++ b/internal/expr/literal.go @@ -1,12 +1,10 @@ package expr import ( - "fmt" "strings" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/stringutil" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/types" ) @@ -41,11 +39,7 @@ type LiteralExprList []Expr // IsEqual compares this expression with the other expression and returns // true if they are equal. -func (l LiteralExprList) IsEqual(other Expr) bool { - o, ok := other.(LiteralExprList) - if !ok { - return false - } +func (l LiteralExprList) IsEqual(o LiteralExprList) bool { if len(l) != len(o) { return false } @@ -63,115 +57,59 @@ func (l LiteralExprList) IsEqual(other Expr) bool { func (l LiteralExprList) String() string { var b strings.Builder - b.WriteRune('[') + b.WriteRune('(') for i, e := range l { if i > 0 { b.WriteString(", ") } b.WriteString(e.String()) } - b.WriteRune(']') + b.WriteRune(')') return b.String() } -// Eval evaluates all the expressions and returns a literalValueList. It implements the Expr interface. func (l LiteralExprList) Eval(env *environment.Environment) (types.Value, error) { + panic("not implemented") +} + +// Eval evaluates all the expressions and returns a literalValueList. It implements the Expr interface. +func (l LiteralExprList) EvalAll(env *environment.Environment) ([]types.Value, error) { var err error if len(l) == 0 { - return types.NewArrayValue(object.NewValueBuffer()), nil + return nil, nil } values := make([]types.Value, len(l)) for i, e := range l { values[i], err = e.Eval(env) if err != nil { - return NullLiteral, err + return nil, err } } - return types.NewArrayValue(object.NewValueBuffer(values...)), nil + return values, nil } -// KVPair associates an identifier with an expression. -type KVPair struct { - K string - V Expr +type Row struct { + Columns []string + Exprs []Expr } -// String implements the fmt.Stringer interface. -func (p KVPair) String() string { - if stringutil.NeedsQuotes(p.K) { - return fmt.Sprintf("%q: %v", p.K, p.V) - } - return fmt.Sprintf("%s: %v", p.K, p.V) -} +func (r *Row) Eval(env *environment.Environment) (row.Row, error) { + var cb row.ColumnBuffer -// KVPairs is a list of KVPair. -type KVPairs struct { - Pairs []KVPair - SelfReferenced bool -} - -// IsEqual compares this expression with the other expression and returns -// true if they are equal. -func (kvp *KVPairs) IsEqual(other Expr) bool { - o, ok := other.(*KVPairs) - if !ok { - return false - } - if kvp.SelfReferenced != o.SelfReferenced { - return false - } - - if len(kvp.Pairs) != len(o.Pairs) { - return false - } - - for i := range kvp.Pairs { - if kvp.Pairs[i].K != o.Pairs[i].K { - return false - } - if !Equal(kvp.Pairs[i].V, o.Pairs[i].V) { - return false - } - } - - return true -} - -// Eval turns a list of KVPairs into an object. -func (kvp *KVPairs) Eval(env *environment.Environment) (types.Value, error) { - var fb object.FieldBuffer - if kvp.SelfReferenced { - if _, ok := env.GetRow(); !ok { - env.SetRowFromObject(&fb) - } - } - - for _, kv := range kvp.Pairs { - v, err := kv.V.Eval(env) + for i, e := range r.Exprs { + v, err := e.Eval(env) if err != nil { return nil, err } - fb.Add(kv.K, v) + cb.Set(r.Columns[i], v) } - return types.NewObjectValue(&fb), nil + return &cb, nil } -// String implements the fmt.Stringer interface. -func (kvp *KVPairs) String() string { - var b strings.Builder - - b.WriteRune('{') - for i, p := range kvp.Pairs { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(p.String()) - } - b.WriteRune('}') - - return b.String() +func (r *Row) String() string { + return LiteralExprList(r.Exprs).String() } diff --git a/internal/expr/operator.go b/internal/expr/operator.go index 3d5eea379..a8aadddaf 100644 --- a/internal/expr/operator.go +++ b/internal/expr/operator.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -117,14 +116,14 @@ type Cast struct { CastAs types.Type } -// Eval returns the primary key of the current object. +// Eval returns the primary key of the current row. func (c Cast) Eval(env *environment.Environment) (types.Value, error) { v, err := c.Expr.Eval(env) if err != nil { return v, err } - return object.CastAs(v, c.CastAs) + return v.CastAs(c.CastAs) } // IsEqual compares this expression with the other expression and returns diff --git a/internal/expr/operator_test.go b/internal/expr/operator_test.go index 13a7abf26..abd37c08e 100644 --- a/internal/expr/operator_test.go +++ b/internal/expr/operator_test.go @@ -15,13 +15,13 @@ func TestConcatExpr(t *testing.T) { }{ {"'a' || 'b'", types.NewTextValue("ab"), false}, {"'a' || NULL", nullLiteral, false}, - {"'a' || notFound", nullLiteral, false}, + {"'a' || notFound", nullLiteral, true}, {"'a' || 1", nullLiteral, false}, } for _, test := range tests { t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, envWithDoc, test.res, test.fails) + testutil.TestExpr(t, test.expr, envWithRow, test.res, test.fails) }) } } diff --git a/internal/expr/path.go b/internal/expr/path.go deleted file mode 100644 index 93c4847b0..000000000 --- a/internal/expr/path.go +++ /dev/null @@ -1,82 +0,0 @@ -package expr - -import ( - "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" -) - -// A Path is an expression that extracts a value from a object at a given path. -type Path object.Path - -// Eval extracts the current value from the environment and returns the value stored at p. -// It implements the Expr interface. -func (p Path) Eval(env *environment.Environment) (types.Value, error) { - if len(p) == 0 { - return NullLiteral, nil - } - - r, ok := env.GetRow() - if !ok { - return NullLiteral, types.ErrFieldNotFound - } - dp := object.Path(p) - - v, ok := env.Get(dp) - if ok { - return v, nil - } - - v, err := dp.GetValueFromObject(r.Object()) - if errors.Is(err, types.ErrFieldNotFound) { - return NullLiteral, nil - } - - return v, err -} - -// IsEqual compares this expression with the other expression and returns -// true if they are equal. -func (p Path) IsEqual(other Expr) bool { - if other == nil { - return false - } - - o, ok := other.(Path) - if !ok { - return false - } - - return object.Path(p).IsEqual(object.Path(o)) -} - -func (p Path) String() string { - return object.Path(p).String() -} - -// A Wildcard is an expression that iterates over all the fields of a object. -type Wildcard struct{} - -func (w Wildcard) String() string { - return "*" -} - -func (w Wildcard) Eval(env *environment.Environment) (types.Value, error) { - r, ok := env.GetRow() - if !ok { - return nil, errors.New("no table specified") - } - - return types.NewObjectValue(r.Object()), nil -} - -// Iterate call the object iterate method. -func (w Wildcard) Iterate(env environment.Environment, fn func(field string, value types.Value) error) error { - r, ok := env.GetRow() - if !ok { - return errors.New("no table specified") - } - - return r.Iterate(fn) -} diff --git a/internal/expr/path_test.go b/internal/expr/path_test.go deleted file mode 100644 index 98add3a1f..000000000 --- a/internal/expr/path_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package expr_test - -import ( - "encoding/json" - "fmt" - "testing" - - "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/sql/parser" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func TestPathExpr(t *testing.T) { - tests := []struct { - expr string - res types.Value - fails bool - }{ - {"a", types.NewIntegerValue(1), false}, - {"b", func() types.Value { - fb := object.NewFieldBuffer() - err := json.Unmarshal([]byte(`{"foo bar": [1, 2]}`), fb) - assert.NoError(t, err) - return types.NewObjectValue(fb) - }(), - false}, - {"b.`foo bar`[0]", types.NewIntegerValue(1), false}, - {"_v.b.`foo bar`[0]", types.NewNullValue(), false}, - {"b.`foo bar`[1]", types.NewIntegerValue(2), false}, - {"b.`foo bar`[2]", nullLiteral, false}, - {"b[0]", nullLiteral, false}, - {"c[0]", types.NewIntegerValue(1), false}, - {"c[1].foo", types.NewTextValue("bar"), false}, - {"c.foo", nullLiteral, false}, - {"d", nullLiteral, false}, - } - - r := database.NewBasicRow(object.NewFromJSON([]byte(`{ - "a": 1, - "b": {"foo bar": [1, 2]}, - "c": [1, {"foo": "bar"}, [1, 2]] - }`))) - - for _, test := range tests { - t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, environment.New(r), test.res, test.fails) - }) - } - - t.Run("empty env", func(t *testing.T) { - testutil.TestExpr(t, "a", &environment.Environment{}, nullLiteral, true) - }) -} - -func TestPathIsEqual(t *testing.T) { - tests := []struct { - a, b string - isEqual bool - }{ - {`a`, `a`, true}, - {`a[0].b`, `a[0].b`, true}, - {`a[0].b`, `a[1].b`, false}, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s = %s", test.a, test.b), func(t *testing.T) { - pa, err := parser.ParsePath(test.a) - assert.NoError(t, err) - ea := expr.Path(pa) - - pb, err := parser.ParsePath(test.b) - assert.NoError(t, err) - eb := expr.Path(pb) - - require.Equal(t, test.isEqual, ea.IsEqual(eb)) - }) - } -} - -func TestEnvPathExpr(t *testing.T) { - tests := []struct { - expr string - res types.Value - fails bool - }{ - {"a", types.NewIntegerValue(1), false}, - {"b", func() types.Value { - fb := object.NewFieldBuffer() - err := json.Unmarshal([]byte(`{"foo bar": [1, 2]}`), fb) - assert.NoError(t, err) - return types.NewObjectValue(fb) - }(), - false}, - {"b.`foo bar`[0]", types.NewIntegerValue(1), false}, - {"_v.b.`foo bar`[0]", types.NewNullValue(), false}, - {"b.`foo bar`[1]", types.NewIntegerValue(2), false}, - {"b.`foo bar`[2]", nullLiteral, false}, - {"b[0]", nullLiteral, false}, - {"c[0]", types.NewIntegerValue(1), false}, - {"c[1].foo", types.NewTextValue("bar"), false}, - {"c.foo", nullLiteral, false}, - {"d", nullLiteral, false}, - } - - r := database.NewBasicRow(object.NewFromJSON([]byte(`{ - "a": 1, - "b": {"foo bar": [1, 2]}, - "c": [1, {"foo": "bar"}, [1, 2]] - }`))) - - for _, test := range tests { - t.Run(test.expr, func(t *testing.T) { - testutil.TestExpr(t, test.expr, environment.New(r), test.res, test.fails) - }) - } - - t.Run("empty env", func(t *testing.T) { - testutil.TestExpr(t, "a", &environment.Environment{}, nullLiteral, true) - }) -} diff --git a/internal/expr/wildcard.go b/internal/expr/wildcard.go new file mode 100644 index 000000000..f2557946b --- /dev/null +++ b/internal/expr/wildcard.go @@ -0,0 +1,28 @@ +package expr + +import ( + "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/types" + "github.com/cockroachdb/errors" +) + +// A Wildcard is an expression that iterates over all the columns of a row. +type Wildcard struct{} + +func (w Wildcard) String() string { + return "*" +} + +func (w Wildcard) Eval(env *environment.Environment) (types.Value, error) { + panic("not implemented") +} + +// Iterate call the object iterate method. +func (w Wildcard) Iterate(env environment.Environment, fn func(field string, value types.Value) error) error { + r, ok := env.GetRow() + if !ok { + return errors.New("no table specified") + } + + return r.Iterate(fn) +} diff --git a/internal/kv/session_test.go b/internal/kv/session_test.go index 60ba3be95..a6d9ed267 100644 --- a/internal/kv/session_test.go +++ b/internal/kv/session_test.go @@ -251,7 +251,7 @@ func TestQueries(t *testing.T) { assert.NoError(t, err) r, err := db.QueryRow(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4); SELECT COUNT(*) FROM test; `) @@ -286,7 +286,7 @@ func TestQueries(t *testing.T) { assert.NoError(t, err) err = db.Exec(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4); `) assert.NoError(t, err) @@ -299,7 +299,7 @@ func TestQueries(t *testing.T) { assert.NoError(t, err) st, err := db.Query(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4); UPDATE test SET a = 5; SELECT * FROM test; @@ -318,7 +318,7 @@ func TestQueries(t *testing.T) { db, err := chai.Open(filepath.Join(dir, "pebble")) assert.NoError(t, err) - err = db.Exec("CREATE TABLE test") + err = db.Exec("CREATE TABLE test(a INT)") assert.NoError(t, err) err = db.Update(func(tx *chai.Tx) error { @@ -352,7 +352,7 @@ func TestQueriesSameTransaction(t *testing.T) { err = db.Update(func(tx *chai.Tx) error { r, err := tx.QueryRow(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4); SELECT COUNT(*) FROM test; `) @@ -374,7 +374,7 @@ func TestQueriesSameTransaction(t *testing.T) { err = db.Update(func(tx *chai.Tx) error { err = tx.Exec(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4); `) assert.NoError(t, err) @@ -391,7 +391,7 @@ func TestQueriesSameTransaction(t *testing.T) { err = db.Update(func(tx *chai.Tx) error { st, err := tx.Query(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4); UPDATE test SET a = 5; SELECT * FROM test; @@ -415,7 +415,7 @@ func TestQueriesSameTransaction(t *testing.T) { err = db.Update(func(tx *chai.Tx) error { r, err := tx.QueryRow(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10); DELETE FROM test WHERE a > 2; SELECT COUNT(*) FROM test; diff --git a/internal/object/array.go b/internal/object/array.go deleted file mode 100644 index 47578ccc9..000000000 --- a/internal/object/array.go +++ /dev/null @@ -1,210 +0,0 @@ -package object - -import ( - "github.com/buger/jsonparser" - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" -) - -// ArrayLength returns the length of an array. -func ArrayLength(a types.Array) (int, error) { - if vb, ok := a.(*ValueBuffer); ok { - return len(vb.Values), nil - } - - var len int - err := a.Iterate(func(_ int, _ types.Value) error { - len++ - return nil - }) - return len, err -} - -var errStop = errors.New("stop") - -// ArrayContains iterates over a and returns whether v is equal to one of its values. -func ArrayContains(a types.Array, v types.Value) (bool, error) { - var found bool - - err := a.Iterate(func(i int, vv types.Value) error { - ok, err := vv.EQ(v) - if err != nil { - return err - } - if ok { - found = true - return errStop - } - - return nil - }) - - if err != nil && !errors.Is(err, errStop) { - return false, err - } - - return found, nil -} - -// ValueBuffer is an array that holds values in memory. -type ValueBuffer struct { - Values []types.Value -} - -// NewValueBuffer creates a buffer of values. -func NewValueBuffer(values ...types.Value) *ValueBuffer { - return &ValueBuffer{Values: values} -} - -// Iterate over all the values of the buffer. It implements the Array interface. -func (vb *ValueBuffer) Iterate(fn func(i int, value types.Value) error) error { - for i, v := range vb.Values { - err := fn(i, v) - if err != nil { - return err - } - } - - return nil -} - -// GetByIndex returns a value set at the given index. If the index is out of range it returns an error. -func (vb *ValueBuffer) GetByIndex(i int) (types.Value, error) { - if i >= len(vb.Values) { - return nil, types.ErrFieldNotFound - } - - return vb.Values[i], nil -} - -// Len returns the length the of array -func (vb *ValueBuffer) Len() int { - if vb == nil { - return 0 - } - - return len(vb.Values) -} - -// Append a value to the buffer and return a new buffer. -func (vb *ValueBuffer) Append(v types.Value) *ValueBuffer { - vb.Values = append(vb.Values, v) - return vb -} - -// ScanArray copies all the values of a to the buffer. -func (vb *ValueBuffer) ScanArray(a types.Array) error { - return a.Iterate(func(i int, v types.Value) error { - vb = vb.Append(v) - return nil - }) -} - -// Copy deep copies all the values from the given array. -// If a value is an object or an array, it will be stored as a *FieldBuffer or *ValueBuffer respectively. -func (vb *ValueBuffer) Copy(a types.Array) error { - return a.Iterate(func(i int, value types.Value) error { - v, err := CloneValue(value) - if err != nil { - return err - } - vb.Append(v) - return nil - }) -} - -// Reset the buffer. -func (vb *ValueBuffer) Reset() { - vb.Values = vb.Values[:0] -} - -// Apply a function to all the values of the buffer. -func (vb *ValueBuffer) Apply(fn func(p Path, v types.Value) (types.Value, error)) error { - path := Path{PathFragment{}} - - for i, v := range vb.Values { - path[0].ArrayIndex = i - - switch v.Type() { - case types.TypeObject: - buf, ok := types.Is[*FieldBuffer](v) - if !ok { - buf = NewFieldBuffer() - err := buf.Copy(types.AsObject(v)) - if err != nil { - return err - } - } - - err := buf.Apply(func(p Path, v types.Value) (types.Value, error) { - return fn(append(path, p...), v) - }) - if err != nil { - return err - } - vb.Values[i] = types.NewObjectValue(buf) - case types.TypeArray: - buf, ok := types.Is[*ValueBuffer](v) - if !ok { - buf = NewValueBuffer() - err := buf.Copy(types.AsArray(v)) - if err != nil { - return err - } - } - - err := buf.Apply(func(p Path, v types.Value) (types.Value, error) { - return fn(append(path, p...), v) - }) - if err != nil { - return err - } - vb.Values[i] = types.NewArrayValue(buf) - default: - var err error - v, err = fn(path, v) - if err != nil { - return err - } - vb.Values[i] = v - } - } - - return nil -} - -// Replace the value of the index by v. -func (vb *ValueBuffer) Replace(index int, v types.Value) error { - if len(vb.Values) <= index { - return types.ErrFieldNotFound - } - - vb.Values[index] = v - return nil -} - -// MarshalJSON implements the json.Marshaler interface. -func (vb ValueBuffer) MarshalJSON() ([]byte, error) { - return MarshalJSONArray(&vb) -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (vb *ValueBuffer) UnmarshalJSON(data []byte) error { - var err error - _, perr := jsonparser.ArrayEach(data, func(value []byte, dataType jsonparser.ValueType, offset int, _ error) { - v, err := parseJSONValue(dataType, value) - if err != nil { - return - } - - vb.Values = append(vb.Values, v) - }) - if err != nil { - return err - } - if perr != nil { - return perr - } - - return nil -} diff --git a/internal/object/array_test.go b/internal/object/array_test.go deleted file mode 100644 index 58a567f06..000000000 --- a/internal/object/array_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package object_test - -import ( - "encoding/json" - "testing" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func TestArrayContains(t *testing.T) { - arr := object.NewValueBuffer( - types.NewIntegerValue(1), - types.NewTextValue("foo"), - types.NewBlobValue([]byte{1, 2, 3}), - ) - - ok, err := object.ArrayContains(arr, types.NewDoubleValue(1)) - assert.NoError(t, err) - require.True(t, ok) - - ok, err = object.ArrayContains(arr, types.NewTextValue("foo")) - assert.NoError(t, err) - require.True(t, ok) - - ok, err = object.ArrayContains(arr, types.NewTextValue("bar")) - assert.NoError(t, err) - require.False(t, ok) -} - -func TestValueBufferCopy(t *testing.T) { - tests := []struct { - name string - want string - }{ - {"empty array", `[]`}, - {"flat", `[1.4,-5,"hello",true]`}, - {"nested", `[["foo","bar",1],{"a":1},[1,2]]`}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var from, to object.ValueBuffer - assert.NoError(t, from.UnmarshalJSON([]byte(test.want))) - err := to.Copy(&from) - assert.NoError(t, err) - got, err := json.Marshal(to) - assert.NoError(t, err) - require.Equal(t, test.want, string(got)) - }) - } -} - -func TestValueBufferApply(t *testing.T) { - var buf object.ValueBuffer - err := buf.UnmarshalJSON([]byte(`[1, [1, 3], {"4": 5}]`)) - assert.NoError(t, err) - - err = buf.Apply(func(p object.Path, v types.Value) (types.Value, error) { - return types.NewIntegerValue(6), nil - }) - assert.NoError(t, err) - - got, err := json.Marshal(buf) - assert.NoError(t, err) - require.JSONEq(t, `[6, [6, 6], {"4": 6}]`, string(got)) -} diff --git a/internal/object/cast.go b/internal/object/cast.go deleted file mode 100644 index e082b51ee..000000000 --- a/internal/object/cast.go +++ /dev/null @@ -1,267 +0,0 @@ -package object - -import ( - "encoding/base64" - "fmt" - "math" - "strconv" - "time" - - "github.com/chaisql/chai/internal/types" -) - -// CastAs casts v as the selected type when possible. -func CastAs(v types.Value, t types.Type) (types.Value, error) { - if v.Type() == t { - return v, nil - } - - switch t { - case types.TypeBoolean: - return CastAsBool(v) - case types.TypeInteger: - return CastAsInteger(v) - case types.TypeDouble: - return CastAsDouble(v) - case types.TypeTimestamp: - return CastAsTimestamp(v) - case types.TypeBlob: - return CastAsBlob(v) - case types.TypeText: - return CastAsText(v) - case types.TypeArray: - return CastAsArray(v) - case types.TypeObject: - return CastAsObject(v) - } - - return nil, fmt.Errorf("cannot cast %s as %q", v.Type(), t) -} - -// CastAsBool casts according to the following rules: -// Integer: true if truthy, otherwise false. -// Text: uses strconv.Parsebool to determine the boolean value, -// it fails if the text doesn't contain a valid boolean. -// Any other type is considered an invalid cast. -func CastAsBool(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - switch v.Type() { - case types.TypeBoolean: - return v, nil - case types.TypeInteger: - return types.NewBooleanValue(types.AsInt64(v) != 0), nil - case types.TypeText: - b, err := strconv.ParseBool(types.AsString(v)) - if err != nil { - return nil, fmt.Errorf(`cannot cast %q as bool: %w`, v.V(), err) - } - return types.NewBooleanValue(b), nil - } - - return nil, fmt.Errorf("cannot cast %s as bool", v.Type()) -} - -// CastAsInteger casts according to the following rules: -// Bool: returns 1 if true, 0 if false. -// Double: cuts off the decimal and remaining numbers. -// Text: uses strconv.ParseInt to determine the integer value, -// then casts it to an integer. If it fails uses strconv.ParseFloat -// to determine the double value, then casts it to an integer -// It fails if the text doesn't contain a valid float value. -// Any other type is considered an invalid cast. -func CastAsInteger(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - switch v.Type() { - case types.TypeInteger: - return v, nil - case types.TypeBoolean: - if types.AsBool(v) { - return types.NewIntegerValue(1), nil - } - return types.NewIntegerValue(0), nil - case types.TypeDouble: - f := types.AsFloat64(v) - if f > 0 && (int64(f) < 0 || f >= math.MaxInt64) { - return nil, fmt.Errorf("integer out of range") - } - return types.NewIntegerValue(int64(f)), nil - case types.TypeText: - i, err := strconv.ParseInt(types.AsString(v), 10, 64) - if err != nil { - intErr := err - f, err := strconv.ParseFloat(types.AsString(v), 64) - if err != nil { - return nil, fmt.Errorf(`cannot cast %q as integer: %w`, v.V(), intErr) - } - i = int64(f) - } - return types.NewIntegerValue(i), nil - } - - return nil, fmt.Errorf("cannot cast %s as integer", v.Type()) -} - -// CastAsDouble casts according to the following rules: -// Integer: returns a double version of the integer. -// Text: uses strconv.ParseFloat to determine the double value, -// it fails if the text doesn't contain a valid float value. -// Any other type is considered an invalid cast. -func CastAsDouble(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - switch v.Type() { - case types.TypeDouble: - return v, nil - case types.TypeInteger: - return types.NewDoubleValue(float64(types.AsInt64(v))), nil - case types.TypeText: - f, err := strconv.ParseFloat(types.AsString(v), 64) - if err != nil { - return nil, fmt.Errorf(`cannot cast %q as double: %w`, v.V(), err) - } - return types.NewDoubleValue(f), nil - } - - return nil, fmt.Errorf("cannot cast %s as double", v.Type()) -} - -// CastAsTimestamp casts according to the following rules: -// Text: uses carbon.Parse to determine the timestamp value -// it fails if the text doesn't contain a valid timestamp. -// Any other type is considered an invalid cast. -func CastAsTimestamp(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - switch v.Type() { - case types.TypeTimestamp: - return v, nil - case types.TypeText: - t, err := types.ParseTimestamp(types.AsString(v)) - if err != nil { - return nil, fmt.Errorf(`cannot cast %q as timestamp: %w`, v.V(), err) - } - return types.NewTimestampValue(t), nil - } - - return nil, fmt.Errorf("cannot cast %s as timestamp", v.Type()) -} - -// CastAsText returns a JSON representation of v. -// If the representation is a string, it gets unquoted. -func CastAsText(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - switch v.Type() { - case types.TypeText: - return v, nil - case types.TypeBlob: - return types.NewTextValue(base64.StdEncoding.EncodeToString(types.AsByteSlice(v))), nil - case types.TypeTimestamp: - return types.NewTextValue(types.AsTime(v).Format(time.RFC3339Nano)), nil - } - - d, err := v.MarshalJSON() - if err != nil { - return nil, err - } - - s := string(d) - - return types.NewTextValue(s), nil -} - -// CastAsBlob casts according to the following rules: -// Text: decodes a base64 string, otherwise fails. -// Any other type is considered an invalid cast. -func CastAsBlob(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - if v.Type() == types.TypeBlob { - return v, nil - } - - if v.Type() == types.TypeText { - // if the string starts with \x, read it as hex - s := types.AsString(v) - b, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, err - } - - return types.NewBlobValue(b), nil - } - - return nil, fmt.Errorf("cannot cast %s as blob", v.Type()) -} - -// CastAsArray casts according to the following rules: -// Text: decodes a JSON array, otherwise fails. -// Any other type is considered an invalid cast. -func CastAsArray(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - if v.Type() == types.TypeArray { - return v, nil - } - - if v.Type() == types.TypeText { - var vb ValueBuffer - err := vb.UnmarshalJSON([]byte(types.AsString(v))) - if err != nil { - return nil, fmt.Errorf(`cannot cast %q as array: %w`, v.V(), err) - } - - return types.NewArrayValue(&vb), nil - } - - return nil, fmt.Errorf("cannot cast %s as array", v.Type()) -} - -// CastAsObject casts according to the following rules: -// Text: decodes a JSON object, otherwise fails. -// Any other type is considered an invalid cast. -func CastAsObject(v types.Value) (types.Value, error) { - // Null values always remain null. - if v.Type() == types.TypeNull { - return v, nil - } - - if v.Type() == types.TypeObject { - return v, nil - } - - if v.Type() == types.TypeText { - var fb FieldBuffer - err := fb.UnmarshalJSON([]byte(types.AsString(v))) - if err != nil { - return nil, fmt.Errorf(`cannot cast %q as object: %w`, v.V(), err) - } - - return types.NewObjectValue(&fb), nil - } - - return nil, fmt.Errorf("cannot cast %s as object", v.Type()) -} diff --git a/internal/object/create.go b/internal/object/create.go deleted file mode 100644 index 30c072708..000000000 --- a/internal/object/create.go +++ /dev/null @@ -1,355 +0,0 @@ -package object - -import ( - "fmt" - "math" - "reflect" - "strings" - "time" - - "github.com/buger/jsonparser" - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" -) - -// NewFromJSON creates an object from raw JSON data. -// The returned object will lazily decode the data. -// If data is not a valid json object, calls to Iterate or GetByField will -// return an error. -func NewFromJSON(data []byte) types.Object { - return &jsonEncodedObject{data} -} - -type jsonEncodedObject struct { - data []byte -} - -func (j jsonEncodedObject) Iterate(fn func(field string, value types.Value) error) error { - return jsonparser.ObjectEach(j.data, func(key, value []byte, dataType jsonparser.ValueType, offset int) error { - v, err := parseJSONValue(dataType, value) - if err != nil { - return err - } - - return fn(string(key), v) - }) -} - -func (j jsonEncodedObject) GetByField(field string) (types.Value, error) { - v, dt, _, err := jsonparser.Get(j.data, field) - if dt == jsonparser.NotExist { - return nil, types.ErrFieldNotFound - } - if err != nil { - return nil, err - } - - return parseJSONValue(dt, v) -} - -func (j jsonEncodedObject) MarshalJSON() ([]byte, error) { - return j.data, nil -} - -// NewFromMap creates an object from a map. -// Due to the way maps are designed, iteration order is not guaranteed. -func NewFromMap[T any](m map[string]T) types.Object { - return mapObject[T](m) -} - -type mapObject[T any] map[string]T - -var _ types.Object = (*mapObject[any])(nil) - -func (m mapObject[T]) Iterate(fn func(field string, value types.Value) error) error { - for k, v := range m { - v, err := NewValue(v) - if err != nil { - return err - } - - err = fn(k, v) - if err != nil { - return err - } - } - return nil -} - -func (m mapObject[T]) GetByField(field string) (types.Value, error) { - v, ok := m[field] - if !ok { - return nil, types.ErrFieldNotFound - } - - return NewValue(v) -} - -// MarshalJSON implements the json.Marshaler interface. -func (m mapObject[T]) MarshalJSON() ([]byte, error) { - return MarshalJSON(m) -} - -type reflectMapObject reflect.Value - -var _ types.Object = (*reflectMapObject)(nil) - -func (m reflectMapObject) Iterate(fn func(field string, value types.Value) error) error { - M := reflect.Value(m) - it := M.MapRange() - - for it.Next() { - v, err := NewValue(it.Value().Interface()) - if err != nil { - return err - } - - err = fn(it.Key().String(), v) - if err != nil { - return err - } - } - return nil -} - -func (m reflectMapObject) GetByField(field string) (types.Value, error) { - M := reflect.Value(m) - v := M.MapIndex(reflect.ValueOf(field)) - if v == (reflect.Value{}) { - return nil, types.ErrFieldNotFound - } - return NewValue(v.Interface()) -} - -// MarshalJSON implements the json.Marshaler interface. -func (m reflectMapObject) MarshalJSON() ([]byte, error) { - return MarshalJSON(m) -} - -// NewFromStruct creates an object from a struct using reflection. -func NewFromStruct(s interface{}) (types.Object, error) { - ref := reflect.Indirect(reflect.ValueOf(s)) - - if !ref.IsValid() || ref.Kind() != reflect.Struct { - return nil, errors.New("expected struct or pointer to struct") - } - - return newFromStruct(ref) -} - -func newFromStruct(ref reflect.Value) (types.Object, error) { - var fb FieldBuffer - l := ref.NumField() - tp := ref.Type() - - for i := 0; i < l; i++ { - f := ref.Field(i) - if !f.IsValid() { - continue - } - - if f.Kind() == reflect.Ptr { - if f.IsNil() { - continue - } - - f = f.Elem() - } - - sf := tp.Field(i) - - isUnexported := sf.PkgPath != "" - - if sf.Anonymous { - if isUnexported && f.Kind() != reflect.Struct { - continue - } - d, err := newFromStruct(f) - if err != nil { - return nil, err - } - err = d.Iterate(func(field string, value types.Value) error { - fb.Add(field, value) - return nil - }) - if err != nil { - return nil, err - } - continue - } else if isUnexported { - continue - } - - v, err := NewValue(f.Interface()) - if err != nil { - return nil, err - } - - field := strings.ToLower(sf.Name) - if gtag, ok := sf.Tag.Lookup("chai"); ok { - if gtag == "-" { - continue - } - field = gtag - } - - fb.Add(field, v) - } - - return &fb, nil -} - -// NewValue creates a value whose type is infered from x. -func NewValue(x any) (types.Value, error) { - // Attempt exact matches first: - switch v := x.(type) { - case time.Duration: - return types.NewIntegerValue(v.Nanoseconds()), nil - case time.Time: - return types.NewTimestampValue(v), nil - case nil: - return types.NewNullValue(), nil - case types.Object: - return types.NewObjectValue(v), nil - case types.Array: - return types.NewArrayValue(v), nil - } - - // Compare by kind to detect type definitions over built-in types. - v := reflect.ValueOf(x) - switch v.Kind() { - case reflect.Ptr: - if v.IsNil() { - return types.NewNullValue(), nil - } - return NewValue(reflect.Indirect(v).Interface()) - case reflect.Bool: - return types.NewBooleanValue(v.Bool()), nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return types.NewIntegerValue(v.Int()), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x := v.Uint() - if x > math.MaxInt64 { - return nil, fmt.Errorf("cannot convert unsigned integer struct field to int64: %d out of range", x) - } - return types.NewIntegerValue(int64(x)), nil - case reflect.Float32, reflect.Float64: - return types.NewDoubleValue(v.Float()), nil - case reflect.String: - return types.NewTextValue(v.String()), nil - case reflect.Interface: - if v.IsNil() { - return types.NewNullValue(), nil - } - return NewValue(v.Elem().Interface()) - case reflect.Struct: - doc, err := NewFromStruct(x) - if err != nil { - return nil, err - } - return types.NewObjectValue(doc), nil - case reflect.Array: - return types.NewArrayValue(&sliceArray{v}), nil - case reflect.Slice: - if reflect.TypeOf(v.Interface()).Elem().Kind() == reflect.Uint8 { - return types.NewBlobValue(v.Bytes()), nil - } - if v.IsNil() { - return types.NewNullValue(), nil - } - return types.NewArrayValue(&sliceArray{ref: v}), nil - case reflect.Map: - if v.Type().Key().Kind() != reflect.String { - return nil, &ErrUnsupportedType{x, "parameter must be a map with a string key"} - } - - // use fast generic map if possible - switch v.Type().Elem().Kind() { - case reflect.Bool: - return types.NewObjectValue(NewFromMap(x.(map[string]bool))), nil - case reflect.Int: - return types.NewObjectValue(NewFromMap(x.(map[string]int))), nil - case reflect.Int8: - return types.NewObjectValue(NewFromMap(x.(map[string]int8))), nil - case reflect.Int16: - return types.NewObjectValue(NewFromMap(x.(map[string]int16))), nil - case reflect.Int32: - return types.NewObjectValue(NewFromMap(x.(map[string]int32))), nil - case reflect.Int64: - return types.NewObjectValue(NewFromMap(x.(map[string]int64))), nil - case reflect.Float32: - return types.NewObjectValue(NewFromMap(x.(map[string]float32))), nil - case reflect.Float64: - return types.NewObjectValue(NewFromMap(x.(map[string]float64))), nil - case reflect.String: - return types.NewObjectValue(NewFromMap(x.(map[string]string))), nil - case reflect.Interface: - return types.NewObjectValue(NewFromMap(x.(map[string]any))), nil - } - - // use reflect based map for other types - return types.NewObjectValue(reflectMapObject(v)), nil - } - - return nil, &ErrUnsupportedType{x, ""} -} - -type sliceArray struct { - ref reflect.Value -} - -var _ types.Array = (*sliceArray)(nil) - -func (s sliceArray) Iterate(fn func(i int, v types.Value) error) error { - l := s.ref.Len() - - for i := 0; i < l; i++ { - f := s.ref.Index(i) - - v, err := NewValue(f.Interface()) - if err != nil { - if err.(*ErrUnsupportedType) != nil { - continue - } - return err - } - - err = fn(i, v) - if err != nil { - return err - } - } - - return nil -} - -func (s sliceArray) GetByIndex(i int) (types.Value, error) { - if i >= s.ref.Len() { - return nil, types.ErrFieldNotFound - } - - v := s.ref.Index(i) - if !v.IsValid() { - return nil, types.ErrFieldNotFound - } - - return NewValue(v.Interface()) -} - -func (s sliceArray) MarshalJSON() ([]byte, error) { - return MarshalJSONArray(s) -} - -// NewFromCSV takes a list of headers and columns and returns an object. -// Each header will be assigned as the key and each corresponding column as a text value. -// The length of headers and columns must be the same. -func NewFromCSV(headers, columns []string) types.Object { - fb := NewFieldBuffer() - fb.ScanCSV(headers, columns) - - return fb -} - -func NewArrayFromSlice[T any](l []T) types.Array { - return &sliceArray{ref: reflect.ValueOf(l)} -} diff --git a/internal/object/create_test.go b/internal/object/create_test.go deleted file mode 100644 index 6785055fa..000000000 --- a/internal/object/create_test.go +++ /dev/null @@ -1,216 +0,0 @@ -package object_test - -import ( - "testing" - "time" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func TestNewValue(t *testing.T) { - type myBytes []byte - type myString string - type myUint uint - type myUint16 uint16 - type myUint32 uint32 - type myUint64 uint64 - type myInt int - type myInt8 int8 - type myInt16 int16 - type myInt64 int64 - type myFloat64 float64 - - mapAny := map[string]any{ - "a": 1, - "b": true, - } - - mapInt := map[string]int{ - "a": 1, - "b": 2, - } - - now := time.Now() - - tests := []struct { - name string - value, expected interface{} - }{ - {"bytes", []byte("bar"), []byte("bar")}, - {"string", "bar", "bar"}, - {"bool", true, true}, - {"uint", uint(10), int64(10)}, - {"uint8", uint8(10), int64(10)}, - {"uint16", uint16(10), int64(10)}, - {"uint32", uint32(10), int64(10)}, - {"uint64", uint64(10), int64(10)}, - {"int", int(10), int64(10)}, - {"int8", int8(10), int64(10)}, - {"int16", int16(10), int64(10)}, - {"int32", int32(10), int64(10)}, - {"int64", int64(10), int64(10)}, - {"float64", 10.1, float64(10.1)}, - {"null", nil, nil}, - {"object", object.NewFieldBuffer().Add("a", types.NewIntegerValue(10)), object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))}, - {"array", object.NewValueBuffer(types.NewIntegerValue(10)), object.NewValueBuffer(types.NewIntegerValue(10))}, - {"time", now, now.UTC()}, - {"bytes", myBytes("bar"), []byte("bar")}, - {"string", myString("bar"), "bar"}, - {"myUint", myUint(10), int64(10)}, - {"myUint16", myUint16(500), int64(500)}, - {"myUint32", myUint32(90000), int64(90000)}, - {"myUint64", myUint64(100), int64(100)}, - {"myInt", myInt(7), int64(7)}, - {"myInt8", myInt8(3), int64(3)}, - {"myInt16", myInt16(500), int64(500)}, - {"myInt64", myInt64(10), int64(10)}, - {"myFloat64", myFloat64(10.1), float64(10.1)}, - {"map[string]any", mapAny, object.NewFromMap(mapAny)}, - {"map[string]int", mapInt, object.NewFromMap(mapInt)}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - v, err := object.NewValue(test.value) - assert.NoError(t, err) - require.Equal(t, test.expected, v.V()) - }) - } -} - -func TestNewFromJSON(t *testing.T) { - tests := []struct { - name string - data string - expected *object.FieldBuffer - fails bool - }{ - {"empty object", "{}", object.NewFieldBuffer(), false}, - {"empty object, missing closing bracket", "{", nil, true}, - {"classic object", `{"a": 1, "b": true, "c": "hello", "d": [1, 2, 3], "e": {"f": "g"}}`, - object.NewFieldBuffer(). - Add("a", types.NewIntegerValue(1)). - Add("b", types.NewBooleanValue(true)). - Add("c", types.NewTextValue("hello")). - Add("d", types.NewArrayValue(object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)). - Append(types.NewIntegerValue(3)))). - Add("e", types.NewObjectValue(object.NewFieldBuffer().Add("f", types.NewTextValue("g")))), - false}, - {"string values", `{"a": "hello ciao"}`, object.NewFieldBuffer().Add("a", types.NewTextValue("hello ciao")), false}, - {"+integer values", `{"a": 1000}`, object.NewFieldBuffer().Add("a", types.NewIntegerValue(1000)), false}, - {"-integer values", `{"a": -1000}`, object.NewFieldBuffer().Add("a", types.NewIntegerValue(-1000)), false}, - {"+float values", `{"a": 10000000000.0}`, object.NewFieldBuffer().Add("a", types.NewDoubleValue(10000000000)), false}, - {"-float values", `{"a": -10000000000.0}`, object.NewFieldBuffer().Add("a", types.NewDoubleValue(-10000000000)), false}, - {"bool values", `{"a": true, "b": false}`, object.NewFieldBuffer().Add("a", types.NewBooleanValue(true)).Add("b", types.NewBooleanValue(false)), false}, - {"empty arrays", `{"a": []}`, object.NewFieldBuffer().Add("a", types.NewArrayValue(object.NewValueBuffer())), false}, - {"nested arrays", `{"a": [[1, 2]]}`, object.NewFieldBuffer(). - Add("a", types.NewArrayValue( - object.NewValueBuffer(). - Append(types.NewArrayValue( - object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)))))), false}, - {"missing comma", `{"a": 1 "b": 2}`, nil, true}, - {"missing closing brackets", `{"a": 1, "b": 2`, nil, true}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - d := object.NewFromJSON([]byte(test.data)) - - fb := object.NewFieldBuffer() - err := fb.Copy(d) - - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - require.Equal(t, *test.expected, *fb) - } - }) - } - - t.Run("GetByField", func(t *testing.T) { - d := object.NewFromJSON([]byte(`{"a": 1000}`)) - - v, err := d.GetByField("a") - assert.NoError(t, err) - require.Equal(t, types.NewIntegerValue(1000), v) - - _, err = d.GetByField("b") - assert.ErrorIs(t, err, types.ErrFieldNotFound) - }) -} - -func TestNewFromMap(t *testing.T) { - m := map[string]interface{}{ - "name": "foo", - "age": 10, - "nilField": nil, - } - - doc := object.NewFromMap(m) - - t.Run("Iterate", func(t *testing.T) { - counter := make(map[string]int) - - err := doc.Iterate(func(f string, v types.Value) error { - counter[f]++ - switch f { - case "name": - require.Equal(t, m[f], types.AsString(v)) - default: - require.EqualValues(t, m[f], v.V()) - } - return nil - }) - assert.NoError(t, err) - require.Len(t, counter, 3) - require.Equal(t, counter["name"], 1) - require.Equal(t, counter["age"], 1) - require.Equal(t, counter["nilField"], 1) - }) - - t.Run("GetByField", func(t *testing.T) { - v, err := doc.GetByField("name") - assert.NoError(t, err) - require.Equal(t, types.NewTextValue("foo"), v) - - v, err = doc.GetByField("age") - assert.NoError(t, err) - require.Equal(t, types.NewIntegerValue(10), v) - - v, err = doc.GetByField("nilField") - assert.NoError(t, err) - require.Equal(t, types.NewNullValue(), v) - - _, err = doc.GetByField("bar") - require.Equal(t, types.ErrFieldNotFound, err) - }) -} - -func BenchmarkJSONToObject(b *testing.B) { - data := []byte(`{"_id":"5f8aefb8e443c6c13afdb305","index":0,"guid":"42c2719e-3371-4b2f-b855-d302a8b7eab0","isActive":true,"balance":"$1,064.79","picture":"http://placehold.it/32x32","age":40,"eyeColor":"blue","name":"Adele Webb","gender":"female","company":"EXTRAGEN","email":"adelewebb@extragen.com","phone":"+1 (964) 409-2397","address":"970 Charles Place, Watrous, Texas, 2522","about":"Amet non do ullamco duis velit sunt esse et cillum nisi mollit ea magna. Tempor ut occaecat proident laborum velit nisi et excepteur exercitation non est labore. Laboris pariatur enim proident et. Qui minim enim et incididunt incididunt adipisicing tempor. Occaecat adipisicing sint ex ut exercitation exercitation voluptate. Laboris adipisicing ut cillum eu cillum est sunt amet Lorem quis pariatur.\r\n","registered":"2016-05-25T10:36:44 -04:00","latitude":64.57112,"longitude":176.136138,"tags":["velit","minim","eiusmod","est","eu","voluptate","deserunt"],"friends":[{"id":0,"name":"Mathis Robertson"},{"id":1,"name":"Cecilia Donaldson"},{"id":2,"name":"Joann Goodwin"}],"greeting":"Hello, Adele Webb! You have 2 unread messages.","favoriteFruit":"apple"}`) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - d := object.NewFromJSON(data) - d.Iterate(func(string, types.Value) error { - return nil - }) - } -} - -func TestNewFromCSV(t *testing.T) { - headers := []string{"a", "b", "c"} - columns := []string{"A", "B", "C"} - - d := object.NewFromCSV(headers, columns) - testutil.RequireJSONEq(t, d, `{"a": "A", "b": "B", "c": "C"}`) -} diff --git a/internal/object/diff.go b/internal/object/diff.go deleted file mode 100644 index 613aec563..000000000 --- a/internal/object/diff.go +++ /dev/null @@ -1,188 +0,0 @@ -package object - -import ( - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" -) - -// Diff returns the operations needed to transform the first object into the second. -func Diff(d1, d2 types.Object) ([]Op, error) { - return diff(nil, d1, d2) -} - -func diff(path Path, d1, d2 types.Object) ([]Op, error) { - var ops []Op - f1, err := types.Fields(d1) - if err != nil { - return nil, err - } - - f2, err := types.Fields(d2) - if err != nil { - return nil, err - } - - var i, j int - for { - for i < len(f1) && (j >= len(f2) || f1[i] < f2[j]) { - v, err := d1.GetByField(f1[j]) - if err != nil { - return nil, err - } - ops = append(ops, NewDeleteOp(path.ExtendField(f1[i]), v)) - i++ - } - - for j < len(f2) && (i >= len(f1) || f1[i] > f2[j]) { - v, err := d2.GetByField(f2[j]) - if err != nil { - return nil, err - } - ops = append(ops, NewSetOp(path.ExtendField(f2[j]), v)) - j++ - } - - if i == len(f1) && j == len(f2) { - break - } - - v1, err := d1.GetByField(f1[i]) - if err != nil { - return nil, err - } - - v2, err := d2.GetByField(f2[j]) - if err != nil { - return nil, err - } - - if v1.Type() != v2.Type() { - v, err := d2.GetByField(f2[j]) - if err != nil { - return nil, err - } - ops = append(ops, NewSetOp(path.ExtendField(f2[j]), v)) - } else { - switch v1.Type() { - case types.TypeObject: - subOps, err := diff(append(path, PathFragment{FieldName: f1[i]}), types.AsObject(v1), types.AsObject(v2)) - if err != nil { - return nil, err - } - ops = append(ops, subOps...) - case types.TypeArray: - subOps, err := arrayDiff(append(path, PathFragment{FieldName: f1[i]}), types.AsArray(v1), types.AsArray(v2)) - if err != nil { - return nil, err - } - ops = append(ops, subOps...) - default: - ok, err := v1.EQ(v2) - if err != nil { - return nil, err - } - if !ok { - ops = append(ops, NewSetOp(path.ExtendField(f2[j]), v2)) - } - } - } - i++ - j++ - } - - return ops, nil -} - -func arrayDiff(path Path, a1, a2 types.Array) ([]Op, error) { - var ops []Op - - var i int - for { - v1, err := a1.GetByIndex(i) - nov1 := errors.Is(err, types.ErrFieldNotFound) - if !nov1 && err != nil { - return nil, err - } - - v2, err := a2.GetByIndex(i) - nov2 := errors.Is(err, types.ErrFieldNotFound) - if !nov2 && err != nil { - return nil, err - } - - if nov1 && nov2 { - break - } - - if nov1 && !nov2 { - ops = append(ops, NewSetOp(path.ExtendIndex(i), v2)) - i++ - continue - } - if !nov1 && nov2 { - ops = append(ops, NewDeleteOp(path.ExtendIndex(i), v1)) - i++ - continue - } - - if v1.Type() != v2.Type() { - ops = append(ops, NewSetOp(path.ExtendIndex(i), v2)) - i++ - continue - } - - switch v1.Type() { - case types.TypeObject: - subOps, err := diff(append(path, PathFragment{ArrayIndex: i}), types.AsObject(v1), types.AsObject(v2)) - if err != nil { - return nil, err - } - ops = append(ops, subOps...) - case types.TypeArray: - subOps, err := arrayDiff(append(path, PathFragment{ArrayIndex: i}), types.AsArray(v1), types.AsArray(v2)) - if err != nil { - return nil, err - } - ops = append(ops, subOps...) - default: - ok, err := v1.EQ(v2) - if err != nil { - return nil, err - } - if !ok { - ops = append(ops, NewSetOp(path.ExtendIndex(i), v2)) - } - } - i++ - } - - return ops, nil -} - -// Op represents a single operation on an object. -// It is returned by the Diff function. -type Op struct { - Type string - Path Path - Value types.Value -} - -func NewSetOp(path Path, v types.Value) Op { - return newOp("set", path, v) -} - -func NewDeleteOp(path Path, v types.Value) Op { - return newOp("delete", path, v) -} - -func newOp(op string, path Path, v types.Value) Op { - return Op{ - Type: op, - Path: path, - Value: v, - } -} - -func (o *Op) MarshalBinary() ([]byte, error) { - panic("not implemented") // TODO: Implement -} diff --git a/internal/object/diff_test.go b/internal/object/diff_test.go deleted file mode 100644 index 1f054ac28..000000000 --- a/internal/object/diff_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package object_test - -import ( - "testing" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func TestDiff(t *testing.T) { - tests := []struct { - name string - d1, d2 string - want []object.Op - }{ - { - name: "empty", - d1: `{}`, - d2: `{}`, - want: nil, - }, - { - name: "add field", - d1: `{}`, - d2: `{"a": 1}`, - want: []object.Op{ - {"set", object.NewPath("a"), types.NewIntegerValue(1)}, - }, - }, - { - name: "remove field", - d1: `{"a": 1}`, - d2: `{}`, - want: []object.Op{ - {"delete", object.NewPath("a"), types.NewIntegerValue(1)}, - }, - }, - { - name: "same", - d1: `{"a": 1}`, - d2: `{"a": 1}`, - want: nil, - }, - { - name: "replace field", - d1: `{"a": 1}`, - d2: `{"a": 2}`, - want: []object.Op{ - {"set", object.NewPath("a"), types.NewIntegerValue(2)}, - }, - }, - { - name: "replace field: different type", - d1: `{"a": 1}`, - d2: `{"a": "hello"}`, - want: []object.Op{ - {"set", object.NewPath("a"), types.NewTextValue("hello")}, - }, - }, - { - name: "nested object: replace field", - d1: `{"a": {"b": 1}}`, - d2: `{"a": {"b": 2}}`, - want: []object.Op{ - {"set", object.NewPath("a", "b"), types.NewIntegerValue(2)}, - }, - }, - { - name: "nested object: add field", - d1: `{"a": {"b": 1}}`, - d2: `{"a": {"b": 1, "c": 2}}`, - want: []object.Op{ - {"set", object.NewPath("a", "c"), types.NewIntegerValue(2)}, - }, - }, - { - name: "nested object: remove field", - d1: `{"a": {"b": 1, "c": 2}}`, - d2: `{"a": {"b": 1}}`, - want: []object.Op{ - {"delete", object.NewPath("a", "c"), types.NewIntegerValue(2)}, - }, - }, - { - name: "nested array: replace index", - d1: `{"a": [1, 2, 3]}`, - d2: `{"a": [1, 2, 4]}`, - want: []object.Op{ - {"set", object.NewPath("a", "2"), types.NewIntegerValue(4)}, - }, - }, - { - name: "nested array: replace index with different type", - d1: `{"a": [1, 2, 3]}`, - d2: `{"a": [1, 2, 4.5]}`, - want: []object.Op{ - {"set", object.NewPath("a", "2"), types.NewDoubleValue(4.5)}, - }, - }, - { - name: "nested array: add index", - d1: `{"a": [1, 2, 3]}`, - d2: `{"a": [1, 2, 3, 4]}`, - want: []object.Op{ - {"set", object.NewPath("a", "3"), types.NewIntegerValue(4)}, - }, - }, - { - name: "nested array: remove index", - d1: `{"a": [1, 2, 3, 4]}`, - d2: `{"a": [1, 2, 3]}`, - want: []object.Op{ - {"delete", object.NewPath("a", "3"), types.NewIntegerValue(4)}, - }, - }, - { - name: "nested array: add in the middle", - d1: `{"a": [1, 2, 3]}`, - d2: `{"a": [1, 2, 2.5, 3]}`, - want: []object.Op{ - {"set", object.NewPath("a", "2"), types.NewDoubleValue(2.5)}, - {"set", object.NewPath("a", "3"), types.NewIntegerValue(3)}, - }, - }, - { - name: "nested array: with nested array", - d1: `{"a": [1, 2, []]}`, - d2: `{"a": [1, 2, [1], 3]}`, - want: []object.Op{ - {"set", object.NewPath("a", "2", "0"), types.NewIntegerValue(1)}, - {"set", object.NewPath("a", "3"), types.NewIntegerValue(3)}, - }, - }, - { - name: "nested array: with nested object", - d1: `{"a": [1, 2, {"b": [1]}]}`, - d2: `{"a": [1, 2, {"b": [2]}, 3]}`, - want: []object.Op{ - {"set", object.NewPath("a", "2", "b", "0"), types.NewIntegerValue(2)}, - {"set", object.NewPath("a", "3"), types.NewIntegerValue(3)}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - d1 := testutil.MakeObject(t, test.d1) - d2 := testutil.MakeObject(t, test.d2) - - got, err := object.Diff(d1, d2) - require.NoError(t, err) - require.Equal(t, test.want, got) - }) - } -} diff --git a/internal/object/object.go b/internal/object/object.go deleted file mode 100644 index e645a2124..000000000 --- a/internal/object/object.go +++ /dev/null @@ -1,530 +0,0 @@ -// Package object defines types to manipulate and compare objects and values. -package object - -import ( - "fmt" - "sort" - "strings" - - "github.com/buger/jsonparser" - "github.com/cockroachdb/errors" - - "github.com/chaisql/chai/internal/stringutil" - "github.com/chaisql/chai/internal/types" -) - -// ErrUnsupportedType is used to skip struct or array fields that are not supported. -type ErrUnsupportedType struct { - Value interface{} - Msg string -} - -func (e *ErrUnsupportedType) Error() string { - return fmt.Sprintf("unsupported type %T. %s", e.Value, e.Msg) -} - -// An Iterator can iterate over object keys. -type Iterator interface { - // Iterate goes through all the objects and calls the given function by passing each one of them. - // If the given function returns an error, the iteration stops. - Iterate(fn func(d types.Object) error) error -} - -// MarshalJSON encodes an object to json. -func MarshalJSON(d types.Object) ([]byte, error) { - return types.NewObjectValue(d).MarshalJSON() -} - -// MarshalJSONArray encodes an array to json. -func MarshalJSONArray(a types.Array) ([]byte, error) { - return types.NewArrayValue(a).MarshalJSON() -} - -// Length returns the length of an object. -func Length(d types.Object) (int, error) { - if fb, ok := d.(*FieldBuffer); ok { - return fb.Len(), nil - } - - var len int - err := d.Iterate(func(_ string, _ types.Value) error { - len++ - return nil - }) - return len, err -} - -// FieldBuffer stores a group of fields in memory. It implements the object interface. -type FieldBuffer struct { - fields []fieldValue -} - -// NewFieldBuffer creates a FieldBuffer. -func NewFieldBuffer() *FieldBuffer { - return new(FieldBuffer) -} - -// MarshalJSON implements the json.Marshaler interface. -func (fb *FieldBuffer) MarshalJSON() ([]byte, error) { - return MarshalJSON(fb) -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (fb *FieldBuffer) UnmarshalJSON(data []byte) error { - return jsonparser.ObjectEach(data, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { - v, err := parseJSONValue(dataType, value) - if err != nil { - return err - } - - fb.Add(string(key), v) - return nil - }) -} - -func (fb *FieldBuffer) String() string { - s, _ := fb.MarshalJSON() - return string(s) -} - -type fieldValue struct { - Field string - Value types.Value -} - -// Add a field to the buffer. -func (fb *FieldBuffer) Add(field string, v types.Value) *FieldBuffer { - fb.fields = append(fb.fields, fieldValue{field, v}) - return fb -} - -// ScanObject copies all the fields of d to the buffer. -func (fb *FieldBuffer) ScanObject(d types.Object) error { - return d.Iterate(func(f string, v types.Value) error { - fb.Add(f, v) - return nil - }) -} - -// GetByField returns a value by field. Returns an error if the field doesn't exists. -func (fb FieldBuffer) GetByField(field string) (types.Value, error) { - for _, fv := range fb.fields { - if fv.Field == field { - return fv.Value, nil - } - } - - return nil, types.ErrFieldNotFound -} - -// setFieldValue replaces a field if it already exists or creates one if not. -func (fb *FieldBuffer) setFieldValue(field string, reqValue types.Value) error { - _, err := fb.GetByField(field) - switch err { - case types.ErrFieldNotFound: - fb.Add(field, reqValue) - return nil - case nil: - _ = fb.Replace(field, reqValue) - return nil - } - - return err -} - -// setValueAtPath deep replaces or creates a field at the given path -func setValueAtPath(v types.Value, p Path, newValue types.Value) (types.Value, error) { - switch v.Type() { - case types.TypeObject: - var buf FieldBuffer - err := buf.ScanObject(types.AsObject(v)) - if err != nil { - return v, err - } - - if len(p) == 1 { - err = buf.setFieldValue(p[0].FieldName, newValue) - return types.NewObjectValue(&buf), err - } - - // the field is an object but the path expects an array, - // return an error - if p[0].FieldName == "" { - return nil, types.ErrFieldNotFound - } - - va, err := buf.GetByField(p[0].FieldName) - if err != nil { - return v, err - } - - va, err = setValueAtPath(va, p[1:], newValue) - if err != nil { - return v, err - } - - err = buf.setFieldValue(p[0].FieldName, va) - return types.NewObjectValue(&buf), err - case types.TypeArray: - var vb ValueBuffer - err := vb.ScanArray(types.AsArray(v)) - if err != nil { - return v, err - } - - va, err := vb.GetByIndex(p[0].ArrayIndex) - if err != nil { - return v, err - } - - if len(p) == 1 { - err = vb.Replace(p[0].ArrayIndex, newValue) - return types.NewArrayValue(&vb), err - } - - va, err = setValueAtPath(va, p[1:], newValue) - if err != nil { - return v, err - } - err = vb.Replace(p[0].ArrayIndex, va) - return types.NewArrayValue(&vb), err - } - - return nil, types.ErrFieldNotFound -} - -// Set replaces a field if it already exists or creates one if not. -// TODO(asdine): Set should always fail with types.ErrFieldNotFound if the path -// doesn't resolve to an existing field. -func (fb *FieldBuffer) Set(path Path, v types.Value) error { - if len(path) == 0 || path[0].FieldName == "" { - return types.ErrFieldNotFound - } - - if len(path) == 1 { - return fb.setFieldValue(path[0].FieldName, v) - } - - container, err := fb.GetByField(path[0].FieldName) - if err != nil { - return err - } - - va, err := setValueAtPath(container, path[1:], v) - if err != nil { - return err - } - - err = fb.setFieldValue(path[0].FieldName, va) - if err != nil { - return err - } - - return nil -} - -// Iterate goes through all the fields of the object and calls the given function by passing each one of them. -// If the given function returns an error, the iteration stops. -func (fb FieldBuffer) Iterate(fn func(field string, value types.Value) error) error { - for _, fv := range fb.fields { - err := fn(fv.Field, fv.Value) - if err != nil { - return err - } - } - - return nil -} - -// Delete a field from the buffer. -func (fb *FieldBuffer) Delete(path Path) error { - if len(path) == 1 { - for i := range fb.fields { - if fb.fields[i].Field == path[0].FieldName { - fb.fields = append(fb.fields[0:i], fb.fields[i+1:]...) - return nil - } - } - } - - parentPath := path[:len(path)-1] - lastFragment := path[len(path)-1] - - // get parent doc or array - v, err := parentPath.GetValueFromObject(fb) - if err != nil { - return err - } - switch v.Type() { - case types.TypeObject: - subBuf, ok := types.Is[*FieldBuffer](v) - if !ok { - return errors.New("delete doesn't support non buffered object") - } - - for i := range subBuf.fields { - if subBuf.fields[i].Field == lastFragment.FieldName { - subBuf.fields = append(subBuf.fields[0:i], subBuf.fields[i+1:]...) - return nil - } - } - - return types.ErrFieldNotFound - case types.TypeArray: - subBuf, ok := types.Is[*ValueBuffer](v) - if !ok { - return errors.New("delete doesn't support non buffered array") - } - - idx := path[len(path)-1].ArrayIndex - if idx >= len(subBuf.Values) { - return types.ErrFieldNotFound - } - subBuf.Values = append(subBuf.Values[0:idx], subBuf.Values[idx+1:]...) - default: - return types.ErrFieldNotFound - } - - return nil -} - -// Replace the value of the field by v. -func (fb *FieldBuffer) Replace(field string, v types.Value) error { - for i := range fb.fields { - if fb.fields[i].Field == field { - fb.fields[i].Value = v - return nil - } - } - - return types.ErrFieldNotFound -} - -// Copy deep copies every value of the object to the buffer. -// If a value is an object or an array, it will be stored as a FieldBuffer or ValueBuffer respectively. -func (fb *FieldBuffer) Copy(d types.Object) error { - return d.Iterate(func(field string, value types.Value) error { - v, err := CloneValue(value) - if err != nil { - return err - } - fb.Add(strings.Clone(field), v) - return nil - }) -} - -func CloneValue(v types.Value) (types.Value, error) { - switch v.Type() { - case types.TypeNull: - return types.NewNullValue(), nil - case types.TypeBoolean: - return types.NewBooleanValue(types.AsBool(v)), nil - case types.TypeInteger: - return types.NewIntegerValue(types.AsInt64(v)), nil - case types.TypeDouble: - return types.NewDoubleValue(types.AsFloat64(v)), nil - case types.TypeTimestamp: - return types.NewTimestampValue(types.AsTime(v)), nil - case types.TypeText: - return types.NewTextValue(strings.Clone(types.AsString(v))), nil - case types.TypeBlob: - return types.NewBlobValue(append([]byte{}, types.AsByteSlice(v)...)), nil - case types.TypeArray: - vb := NewValueBuffer() - err := vb.Copy(types.AsArray(v)) - if err != nil { - return nil, err - } - return types.NewArrayValue(vb), nil - case types.TypeObject: - fb := NewFieldBuffer() - err := fb.Copy(types.AsObject(v)) - if err != nil { - return nil, err - } - return types.NewObjectValue(fb), nil - } - - panic(fmt.Sprintf("Unsupported value type: %s", v.Type())) -} - -// Apply a function to all the values of the buffer. -func (fb *FieldBuffer) Apply(fn func(p Path, v types.Value) (types.Value, error)) error { - path := Path{PathFragment{}} - var err error - - for i, f := range fb.fields { - path[0].FieldName = f.Field - - f.Value, err = fn(path, f.Value) - if err != nil { - return err - } - fb.fields[i].Value = f.Value - - switch f.Value.Type() { - case types.TypeObject: - buf, ok := types.Is[*FieldBuffer](f.Value) - if !ok { - buf = NewFieldBuffer() - err := buf.Copy(types.AsObject(f.Value)) - if err != nil { - return err - } - } - - err := buf.Apply(func(p Path, v types.Value) (types.Value, error) { - return fn(append(path, p...), v) - }) - if err != nil { - return err - } - fb.fields[i].Value = types.NewObjectValue(buf) - case types.TypeArray: - buf, ok := types.Is[*ValueBuffer](f.Value) - if !ok { - buf = NewValueBuffer() - err := buf.Copy(types.AsArray(f.Value)) - if err != nil { - return err - } - } - - err := buf.Apply(func(p Path, v types.Value) (types.Value, error) { - return fn(append(path, p...), v) - }) - if err != nil { - return err - } - fb.fields[i].Value = types.NewArrayValue(buf) - } - } - - return nil -} - -// Len of the buffer. -func (fb FieldBuffer) Len() int { - return len(fb.fields) -} - -// Reset the buffer. -func (fb *FieldBuffer) Reset() { - fb.fields = fb.fields[:0] -} - -func (fb *FieldBuffer) ScanCSV(headers, columns []string) { - for i, h := range headers { - if i >= len(columns) { - break - } - - fb.Add(h, types.NewTextValue(columns[i])) - } -} - -// MaskFields returns a new object that masks the given fields. -func MaskFields(d types.Object, fields ...string) types.Object { - return &maskObject{d, fields} -} - -type maskObject struct { - d types.Object - mask []string -} - -func (m *maskObject) Iterate(fn func(field string, value types.Value) error) error { - return m.d.Iterate(func(field string, value types.Value) error { - if !stringutil.Contains(m.mask, field) { - return fn(field, value) - } - - return nil - }) -} - -func (m *maskObject) GetByField(field string) (types.Value, error) { - if !stringutil.Contains(m.mask, field) { - return m.d.GetByField(field) - } - - return nil, types.ErrFieldNotFound -} - -func (m *maskObject) MarshalJSON() ([]byte, error) { - return MarshalJSON(m) -} - -// OnlyFields returns a new object that only contains the given fields. -func OnlyFields(d types.Object, fields ...string) types.Object { - return &onlyObject{d, fields} -} - -type onlyObject struct { - d types.Object - fields []string -} - -func (o *onlyObject) Iterate(fn func(field string, value types.Value) error) error { - for _, f := range o.fields { - v, err := o.d.GetByField(f) - if err != nil { - continue - } - - if err := fn(f, v); err != nil { - return err - } - } - - return nil -} - -func (o *onlyObject) GetByField(field string) (types.Value, error) { - if stringutil.Contains(o.fields, field) { - return o.d.GetByField(field) - } - - return nil, types.ErrFieldNotFound -} - -func (o *onlyObject) MarshalJSON() ([]byte, error) { - return MarshalJSON(o) -} - -func WithSortedFields(d types.Object) types.Object { - return &sortedObject{d} -} - -type sortedObject struct { - types.Object -} - -func (s *sortedObject) Iterate(fn func(field string, value types.Value) error) error { - // iterate first to get the list of fields - var fields []string - err := s.Object.Iterate(func(field string, value types.Value) error { - fields = append(fields, field) - return nil - }) - if err != nil { - return err - } - - // sort the fields - sort.Strings(fields) - - // iterate again - for _, f := range fields { - v, err := s.Object.GetByField(f) - if err != nil { - continue - } - - if err := fn(f, v); err != nil { - return err - } - } - - return nil -} diff --git a/internal/object/object_test.go b/internal/object/object_test.go deleted file mode 100644 index cf206c245..000000000 --- a/internal/object/object_test.go +++ /dev/null @@ -1,652 +0,0 @@ -package object_test - -import ( - "encoding/json" - "testing" - "time" - - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/require" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/sql/parser" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" -) - -var _ types.Object = new(object.FieldBuffer) - -func TestFieldBuffer(t *testing.T) { - var buf object.FieldBuffer - buf.Add("a", types.NewIntegerValue(10)) - buf.Add("b", types.NewTextValue("hello")) - - t.Run("Iterate", func(t *testing.T) { - var i int - err := buf.Iterate(func(f string, v types.Value) error { - switch i { - case 0: - require.Equal(t, "a", f) - require.Equal(t, types.NewIntegerValue(10), v) - case 1: - require.Equal(t, "b", f) - require.Equal(t, types.NewTextValue("hello"), v) - } - i++ - return nil - }) - assert.NoError(t, err) - require.Equal(t, 2, i) - }) - - t.Run("Add", func(t *testing.T) { - var buf object.FieldBuffer - buf.Add("a", types.NewIntegerValue(10)) - buf.Add("b", types.NewTextValue("hello")) - - c := types.NewBooleanValue(true) - buf.Add("c", c) - require.Equal(t, 3, buf.Len()) - }) - - t.Run("ScanObject", func(t *testing.T) { - var buf1, buf2 object.FieldBuffer - - buf1.Add("a", types.NewIntegerValue(10)) - buf1.Add("b", types.NewTextValue("hello")) - - buf2.Add("a", types.NewIntegerValue(20)) - buf2.Add("b", types.NewTextValue("bye")) - buf2.Add("c", types.NewBooleanValue(true)) - - err := buf1.ScanObject(&buf2) - assert.NoError(t, err) - - var buf object.FieldBuffer - buf.Add("a", types.NewIntegerValue(10)) - buf.Add("b", types.NewTextValue("hello")) - buf.Add("a", types.NewIntegerValue(20)) - buf.Add("b", types.NewTextValue("bye")) - buf.Add("c", types.NewBooleanValue(true)) - require.Equal(t, buf, buf1) - }) - - t.Run("GetByField", func(t *testing.T) { - v, err := buf.GetByField("a") - assert.NoError(t, err) - require.Equal(t, types.NewIntegerValue(10), v) - - v, err = buf.GetByField("not existing") - assert.ErrorIs(t, err, types.ErrFieldNotFound) - require.Zero(t, v) - }) - - t.Run("Set", func(t *testing.T) { - tests := []struct { - name string - data string - path string - value types.Value - want string - fails bool - }{ - {"root", `{}`, `a`, types.NewIntegerValue(1), `{"a": 1}`, false}, - {"add field", `{"a": {"b": [1, 2, 3]}}`, `c`, types.NewTextValue("foo"), `{"a": {"b": [1, 2, 3]}, "c": "foo"}`, false}, - {"non existing doc", `{}`, `a.b.c`, types.NewTextValue("foo"), ``, true}, - {"wrong type", `{"a": 1}`, `a.b.c`, types.NewTextValue("foo"), ``, true}, - {"nested doc", `{"a": "foo"}`, `a`, types.NewObjectValue(object.NewFieldBuffer(). - Add("b", types.NewArrayValue(object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)). - Append(types.NewIntegerValue(3))))), `{"a": {"b": [1, 2, 3]}}`, false}, - {"nested doc", `{"a": {"b": [1, 2, 3]}}`, `a.b`, types.NewArrayValue(object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)). - Append(types.NewIntegerValue(3))), `{"a": {"b": [1, 2, 3]}}`, false}, - {"nested array", `{"a": {"b": [1, 2, 3]}}`, `a.b[1]`, types.NewIntegerValue(1), `{"a": {"b": [1, 1, 3]}}`, false}, - {"nested array multiple indexes", `{"a": {"b": [1, 2, [1, 2, {"c": "foo"}]]}}`, `a.b[2][2].c`, types.NewTextValue("bar"), `{"a": {"b": [1, 2, [1, 2, {"c": "bar"}]]}}`, false}, - {"number field", `{"a": {"0": [1, 2, 3]}}`, "a.`0`[0]", types.NewIntegerValue(6), `{"a": {"0": [6, 2, 3]}}`, false}, - {"object in array", `{"a": [{"b":"foo"}, 2, 3]}`, `a[0].b`, types.NewTextValue("bar"), `{"a": [{"b": "bar"}, 2, 3]}`, false}, - // with errors or request ignored doc unchanged - {"field not found", `{"a": {"b": [1, 2, 3]}}`, `a.b.c`, types.NewIntegerValue(1), `{"a": {"b": [1, 2, 3]}}`, false}, - {"unknown path", `{"a": {"b": [1, 2, 3]}}`, `a.e.f`, types.NewIntegerValue(1), ``, true}, - {"index out of range", `{"a": {"b": [1, 2, 3]}}`, `a.b[1000]`, types.NewIntegerValue(1), ``, true}, - {"object not array", `{"a": {"b": "foo"}}`, `a[0].b`, types.NewTextValue("bar"), ``, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var fb object.FieldBuffer - - d := object.NewFromJSON([]byte(tt.data)) - err := fb.Copy(d) - assert.NoError(t, err) - p, err := parser.ParsePath(tt.path) - assert.NoError(t, err) - err = fb.Set(p, tt.value) - if tt.fails { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - data, err := object.MarshalJSON(&fb) - assert.NoError(t, err) - require.Equal(t, tt.want, string(data)) - }) - } - }) - - t.Run("Delete", func(t *testing.T) { - tests := []struct { - object string - deletePath string - expected string - fails bool - }{ - {`{"a": 10, "b": "hello"}`, "a", `{"b": "hello"}`, false}, - {`{"a": 10, "b": "hello"}`, "c", ``, true}, - {`{"a": [1], "b": "hello"}`, "a[0]", `{"a": [], "b": "hello"}`, false}, - {`{"a": [1, 2], "b": "hello"}`, "a[0]", `{"a": [2], "b": "hello"}`, false}, - {`{"a": [1, 2], "b": "hello"}`, "a[5]", ``, true}, - {`{"a": [1, {"c": [1]}], "b": "hello"}`, "a[1].c", `{"a": [1, {}], "b": "hello"}`, false}, - {`{"a": [1, {"c": [1]}], "b": "hello"}`, "a[1].d", ``, true}, - } - - for _, test := range tests { - t.Run(test.object, func(t *testing.T) { - var buf object.FieldBuffer - err := json.Unmarshal([]byte(test.object), &buf) - assert.NoError(t, err) - - path := testutil.ParseObjectPath(t, test.deletePath) - - err = buf.Delete(path) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - got, err := json.Marshal(&buf) - assert.NoError(t, err) - require.JSONEq(t, test.expected, string(got)) - } - }) - } - }) - - t.Run("Replace", func(t *testing.T) { - var buf object.FieldBuffer - buf.Add("a", types.NewIntegerValue(10)) - buf.Add("b", types.NewTextValue("hello")) - - err := buf.Replace("a", types.NewBooleanValue(true)) - assert.NoError(t, err) - v, err := buf.GetByField("a") - assert.NoError(t, err) - require.Equal(t, types.NewBooleanValue(true), v) - err = buf.Replace("d", types.NewIntegerValue(11)) - assert.Error(t, err) - }) - - t.Run("Apply", func(t *testing.T) { - d := object.NewFromJSON([]byte(`{ - "a": "b", - "c": ["d", "e"], - "f": {"g": "h"} - }`)) - - buf := object.NewFieldBuffer() - err := buf.Copy(d) - assert.NoError(t, err) - - err = buf.Apply(func(p object.Path, v types.Value) (types.Value, error) { - if v.Type() == types.TypeArray || v.Type() == types.TypeObject { - return v, nil - } - - return types.NewIntegerValue(1), nil - }) - assert.NoError(t, err) - - got, err := json.Marshal(buf) - assert.NoError(t, err) - require.JSONEq(t, `{"a": 1, "c": [1, 1], "f": {"g": 1}}`, string(got)) - }) - - t.Run("CloneValue", func(t *testing.T) { - d := testutil.MakeObject(t, `{ - "a": "b", - "c": ["d", "e"], - "f": {"g": "h"} - }`) - - got, err := object.CloneValue(types.NewObjectValue(d)) - require.NoError(t, err) - testutil.RequireObjEqual(t, d, types.AsObject(got)) - }) - - t.Run("UnmarshalJSON", func(t *testing.T) { - tests := []struct { - name string - data string - expected *object.FieldBuffer - fails bool - }{ - {"empty object", "{}", object.NewFieldBuffer(), false}, - {"empty object, missing closing bracket", "{", nil, true}, - {"classic object", `{"a": 1, "b": true, "c": "hello", "d": [1, 2, 3], "e": {"f": "g"}}`, - object.NewFieldBuffer(). - Add("a", types.NewIntegerValue(1)). - Add("b", types.NewBooleanValue(true)). - Add("c", types.NewTextValue("hello")). - Add("d", types.NewArrayValue(object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)). - Append(types.NewIntegerValue(3)))). - Add("e", types.NewObjectValue(object.NewFieldBuffer().Add("f", types.NewTextValue("g")))), - false}, - {"string values", `{"a": "hello ciao"}`, object.NewFieldBuffer().Add("a", types.NewTextValue("hello ciao")), false}, - {"+integer values", `{"a": 1000}`, object.NewFieldBuffer().Add("a", types.NewIntegerValue(1000)), false}, - {"-integer values", `{"a": -1000}`, object.NewFieldBuffer().Add("a", types.NewIntegerValue(-1000)), false}, - {"+float values", `{"a": 10000000000.0}`, object.NewFieldBuffer().Add("a", types.NewDoubleValue(10000000000)), false}, - {"-float values", `{"a": -10000000000.0}`, object.NewFieldBuffer().Add("a", types.NewDoubleValue(-10000000000)), false}, - {"bool values", `{"a": true, "b": false}`, object.NewFieldBuffer().Add("a", types.NewBooleanValue(true)).Add("b", types.NewBooleanValue(false)), false}, - {"empty arrays", `{"a": []}`, object.NewFieldBuffer().Add("a", types.NewArrayValue(object.NewValueBuffer())), false}, - {"nested arrays", `{"a": [[1, 2]]}`, object.NewFieldBuffer(). - Add("a", types.NewArrayValue( - object.NewValueBuffer(). - Append(types.NewArrayValue( - object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)))))), false}, - {"missing comma", `{"a": 1 "b": 2}`, nil, true}, - {"missing closing brackets", `{"a": 1, "b": 2`, nil, true}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var buf object.FieldBuffer - - err := json.Unmarshal([]byte(test.data), &buf) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - require.Equal(t, *test.expected, buf) - } - }) - } - }) -} - -func TestNewFromStruct(t *testing.T) { - type group struct { - Ig int - } - - type user struct { - A []byte - B string - C bool - D uint `chai:"la-reponse-d"` - E uint8 - F uint16 - G uint32 - H uint64 - I int - J int8 - K int16 - L int32 - M int64 - N float64 - // structs must be considered as objects - O group - - // nil pointers must be skipped - // otherwise they must be dereferenced - P *int - Q *int - - // struct pointers should be considered as objects - // if there are nil though, they must be skipped - R *group - S *group - - T []int - U []int - V []*int - W []user - X []interface{} - Y [3]int - Z interface{} - ZZ interface{} - - AA int `chai:"-"` // ignored - - *group - - BB time.Time // some have special encoding as object - - // unexported fields should be ignored - t int - } - - u := user{ - A: []byte("foo"), - B: "bar", - C: true, - D: 1, - E: 2, - F: 3, - G: 4, - H: 5, - I: 6, - J: 7, - K: 8, - L: 9, - M: 10, - N: 11.12, - Z: 26, - AA: 27, - group: &group{ - Ig: 100, - }, - BB: time.Date(2020, 11, 15, 16, 37, 10, 20, time.UTC), - t: 99, - } - - q := 5 - u.Q = &q - u.R = new(group) - u.T = []int{1, 2, 3} - u.V = []*int{&q} - u.W = []user{u} - u.X = []interface{}{1, "foo"} - - t.Run("Iterate", func(t *testing.T) { - doc, err := object.NewFromStruct(u) - assert.NoError(t, err) - - var counter int - - err = doc.Iterate(func(f string, v types.Value) error { - switch counter { - case 0: - require.Equal(t, u.A, types.AsByteSlice(v)) - case 1: - require.Equal(t, u.B, types.AsString(v)) - case 2: - require.Equal(t, u.C, types.AsBool(v)) - case 3: - require.Equal(t, "la-reponse-d", f) - require.EqualValues(t, u.D, types.AsInt64(v)) - case 4: - require.EqualValues(t, u.E, types.AsInt64(v)) - case 5: - require.EqualValues(t, u.F, types.AsInt64(v)) - case 6: - require.EqualValues(t, u.G, types.AsInt64(v)) - case 7: - require.EqualValues(t, u.H, types.AsInt64(v)) - case 8: - require.EqualValues(t, u.I, types.AsInt64(v)) - case 9: - require.EqualValues(t, u.J, types.AsInt64(v)) - case 10: - require.EqualValues(t, u.K, types.AsInt64(v)) - case 11: - require.EqualValues(t, u.L, types.AsInt64(v)) - case 12: - require.EqualValues(t, u.M, types.AsInt64(v)) - case 13: - require.Equal(t, u.N, types.AsFloat64(v)) - case 14: - require.EqualValues(t, types.TypeObject, v.Type()) - case 15: - require.EqualValues(t, *u.Q, types.AsInt64(v)) - case 16: - require.EqualValues(t, types.TypeObject, v.Type()) - case 17: - require.EqualValues(t, types.TypeArray, v.Type()) - case 18: - require.EqualValues(t, types.TypeNull, v.Type()) - case 19: - require.EqualValues(t, types.TypeArray, v.Type()) - case 20: - require.EqualValues(t, types.TypeArray, v.Type()) - case 21: - require.EqualValues(t, types.TypeArray, v.Type()) - case 22: - require.EqualValues(t, types.TypeArray, v.Type()) - case 23: - require.EqualValues(t, u.Z, types.AsInt64(v)) - case 24: - require.EqualValues(t, types.TypeNull, v.Type()) - case 25: - require.EqualValues(t, types.TypeInteger, v.Type()) - case 26: - require.EqualValues(t, types.TypeTimestamp, v.Type()) - default: - require.FailNowf(t, "", "unknown field %q", f) - } - - counter++ - - return nil - }) - assert.NoError(t, err) - require.Equal(t, 27, counter) - }) - - t.Run("GetByField", func(t *testing.T) { - doc, err := object.NewFromStruct(u) - assert.NoError(t, err) - - v, err := doc.GetByField("a") - assert.NoError(t, err) - require.Equal(t, u.A, types.AsByteSlice(v)) - v, err = doc.GetByField("b") - assert.NoError(t, err) - require.Equal(t, u.B, types.AsString(v)) - v, err = doc.GetByField("c") - assert.NoError(t, err) - require.Equal(t, u.C, types.AsBool(v)) - v, err = doc.GetByField("la-reponse-d") - assert.NoError(t, err) - require.EqualValues(t, u.D, types.AsInt64(v)) - v, err = doc.GetByField("e") - assert.NoError(t, err) - require.EqualValues(t, u.E, types.AsInt64(v)) - v, err = doc.GetByField("f") - assert.NoError(t, err) - require.EqualValues(t, u.F, types.AsInt64(v)) - v, err = doc.GetByField("g") - assert.NoError(t, err) - require.EqualValues(t, u.G, types.AsInt64(v)) - v, err = doc.GetByField("h") - assert.NoError(t, err) - require.EqualValues(t, u.H, types.AsInt64(v)) - v, err = doc.GetByField("i") - assert.NoError(t, err) - require.EqualValues(t, u.I, types.AsInt64(v)) - v, err = doc.GetByField("j") - assert.NoError(t, err) - require.EqualValues(t, u.J, types.AsInt64(v)) - v, err = doc.GetByField("k") - assert.NoError(t, err) - require.EqualValues(t, u.K, types.AsInt64(v)) - v, err = doc.GetByField("l") - assert.NoError(t, err) - require.EqualValues(t, u.L, types.AsInt64(v)) - v, err = doc.GetByField("m") - assert.NoError(t, err) - require.EqualValues(t, u.M, types.AsInt64(v)) - v, err = doc.GetByField("n") - assert.NoError(t, err) - require.Equal(t, u.N, types.AsFloat64(v)) - - v, err = doc.GetByField("o") - assert.NoError(t, err) - d, ok := types.Is[types.Object](v) - require.True(t, ok) - v, err = d.GetByField("ig") - assert.NoError(t, err) - require.EqualValues(t, 0, types.AsInt64(v)) - - v, err = doc.GetByField("ig") - assert.NoError(t, err) - require.EqualValues(t, 100, types.AsInt64(v)) - - v, err = doc.GetByField("t") - assert.NoError(t, err) - a, ok := types.Is[types.Array](v) - require.True(t, ok) - var count int - err = a.Iterate(func(i int, v types.Value) error { - count++ - require.EqualValues(t, i+1, types.AsInt64(v)) - return nil - }) - assert.NoError(t, err) - require.Equal(t, 3, count) - _, err = a.GetByIndex(10) - assert.ErrorIs(t, err, types.ErrFieldNotFound) - v, err = a.GetByIndex(1) - assert.NoError(t, err) - require.EqualValues(t, 2, types.AsInt64(v)) - - v, err = doc.GetByField("bb") - assert.NoError(t, err) - var tm time.Time - assert.NoError(t, object.ScanValue(v, &tm)) - require.Equal(t, u.BB, tm) - }) - - t.Run("pointers", func(t *testing.T) { - type s struct { - A *int - } - - d, err := object.NewFromStruct(new(s)) - assert.NoError(t, err) - _, err = d.GetByField("a") - assert.ErrorIs(t, err, types.ErrFieldNotFound) - - a := 10 - ss := s{A: &a} - d, err = object.NewFromStruct(&ss) - assert.NoError(t, err) - v, err := d.GetByField("a") - assert.NoError(t, err) - require.Equal(t, types.NewIntegerValue(10), v) - }) -} - -type foo struct { - A string - B int64 - C bool - D float64 -} - -func (f *foo) Iterate(fn func(field string, value types.Value) error) error { - var err error - - err = fn("a", types.NewTextValue(f.A)) - if err != nil { - return err - } - - err = fn("b", types.NewIntegerValue(f.B)) - if err != nil { - return err - } - - err = fn("c", types.NewBooleanValue(f.C)) - if err != nil { - return err - } - - err = fn("d", types.NewDoubleValue(f.D)) - if err != nil { - return err - } - - return nil -} - -func (f *foo) GetByField(field string) (types.Value, error) { - switch field { - case "a": - return types.NewTextValue(f.A), nil - case "b": - return types.NewIntegerValue(f.B), nil - case "c": - return types.NewBooleanValue(f.C), nil - case "d": - return types.NewDoubleValue(f.D), nil - } - - return nil, errors.New("unknown field") -} - -func TestJSONObject(t *testing.T) { - tests := []struct { - name string - o types.Object - expected string - }{ - { - "Flat", - object.NewFieldBuffer(). - Add("name", types.NewTextValue("John")). - Add("age", types.NewIntegerValue(10)). - Add(`"something with" quotes`, types.NewIntegerValue(10)), - `{"name":"John","age":10,"\"something with\" quotes":10}`, - }, - { - "Nested", - object.NewFieldBuffer(). - Add("name", types.NewTextValue("John")). - Add("age", types.NewIntegerValue(10)). - Add("address", types.NewObjectValue(object.NewFieldBuffer(). - Add("city", types.NewTextValue("Ajaccio")). - Add("country", types.NewTextValue("France")), - )). - Add("friends", types.NewArrayValue( - object.NewValueBuffer(). - Append(types.NewTextValue("fred")). - Append(types.NewTextValue("jamie")), - )), - `{"name":"John","age":10,"address":{"city":"Ajaccio","country":"France"},"friends":["fred","jamie"]}`, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - data, err := json.Marshal(test.o) - assert.NoError(t, err) - require.Equal(t, test.expected, string(data)) - assert.NoError(t, err) - }) - } -} - -func BenchmarkObjectIterate(b *testing.B) { - f := foo{ - A: "a", - B: 1000, - C: true, - D: 1e10, - } - - b.Run("Implementation", func(b *testing.B) { - for i := 0; i < b.N; i++ { - f.Iterate(func(string, types.Value) error { - return nil - }) - } - }) - -} diff --git a/internal/object/path.go b/internal/object/path.go deleted file mode 100644 index 1b59d3557..000000000 --- a/internal/object/path.go +++ /dev/null @@ -1,182 +0,0 @@ -package object - -import ( - "strconv" - "strings" - - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" -) - -// A Path represents the path to a particular value within an object. -type Path []PathFragment - -// NewPath creates a path from a list of strings representing either a field name -// or an array index in string form. -func NewPath(fragments ...string) Path { - var path Path - - for _, frag := range fragments { - idx, err := strconv.Atoi(frag) - if err != nil { - path = append(path, PathFragment{FieldName: frag}) - } else { - path = append(path, PathFragment{ArrayIndex: idx}) - } - } - - return path -} - -// PathFragment is a fragment of a path representing either a field name or -// the index of an array. -type PathFragment struct { - FieldName string - ArrayIndex int -} - -// String representation of all the fragments of the path. -// It implements the Stringer interface. -func (p Path) String() string { - var b strings.Builder - - for i := range p { - if p[i].FieldName != "" { - if i != 0 { - b.WriteRune('.') - } - b.WriteString(p[i].FieldName) - } else { - b.WriteString("[" + strconv.Itoa(p[i].ArrayIndex) + "]") - } - } - return b.String() -} - -// IsEqual returns whether other is equal to p. -func (p Path) IsEqual(other Path) bool { - if len(other) != len(p) { - return false - } - - for i := range p { - if other[i] != p[i] { - return false - } - } - - return true -} - -// GetValueFromObject returns the value at path p from d. -func (p Path) GetValueFromObject(d types.Object) (types.Value, error) { - if len(p) == 0 { - return nil, errors.WithStack(types.ErrFieldNotFound) - } - if p[0].FieldName == "" { - return nil, errors.WithStack(types.ErrFieldNotFound) - } - - v, err := d.GetByField(p[0].FieldName) - if err != nil { - return nil, err - } - - if len(p) == 1 { - return v, nil - } - - return p[1:].getValueFromValue(v) -} - -// GetValueFromArray returns the value at path p from a. -func (p Path) GetValueFromArray(a types.Array) (types.Value, error) { - if len(p) == 0 { - return nil, errors.WithStack(types.ErrFieldNotFound) - } - if p[0].FieldName != "" { - return nil, errors.WithStack(types.ErrFieldNotFound) - } - - v, err := a.GetByIndex(p[0].ArrayIndex) - if err != nil { - if errors.Is(err, types.ErrValueNotFound) { - return nil, errors.WithStack(types.ErrFieldNotFound) - } - - return nil, err - } - - if len(p) == 1 { - return v, nil - } - - return p[1:].getValueFromValue(v) -} - -func (p Path) Clone() Path { - c := make(Path, len(p)) - copy(c, p) - return c -} - -// Extend clones the path and appends the fragment to it. -func (p Path) Extend(f ...PathFragment) Path { - c := make(Path, len(p)+len(f)) - copy(c, p) - for i := range f { - c[len(p)+i] = f[i] - } - return c -} - -// Extend clones the path and appends the field to it. -func (p Path) ExtendField(field string) Path { - return p.Extend(PathFragment{FieldName: field}) -} - -// Extend clones the path and appends the array index to it. -func (p Path) ExtendIndex(index int) Path { - return p.Extend(PathFragment{ArrayIndex: index}) -} - -func (p Path) getValueFromValue(v types.Value) (types.Value, error) { - switch v.Type() { - case types.TypeObject: - return p.GetValueFromObject(types.AsObject(v)) - case types.TypeArray: - return p.GetValueFromArray(types.AsArray(v)) - } - - return nil, types.ErrFieldNotFound -} - -type Paths []Path - -func (p Paths) String() string { - var sb strings.Builder - - for i, pt := range p { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(pt.String()) - } - - return sb.String() -} - -// IsEqual returns whether other is equal to p. -func (p Paths) IsEqual(other Paths) bool { - if len(other) != len(p) { - return false - } - - for i := range p { - if !other[i].IsEqual(p[i]) { - return false - } - } - - return true -} diff --git a/internal/object/path_test.go b/internal/object/path_test.go deleted file mode 100644 index 9afb3d754..000000000 --- a/internal/object/path_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package object_test - -import ( - "encoding/json" - "testing" - - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/sql/parser" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/stretchr/testify/require" -) - -func TestPath(t *testing.T) { - tests := []struct { - name string - data string - path string - result string - fails bool - }{ - {"root", `{"a": {"b": [1, 2, 3]}}`, `a`, `{"b": [1, 2, 3]}`, false}, - {"nested doc", `{"a": {"b": [1, 2, 3]}}`, `a.b`, `[1, 2, 3]`, false}, - {"nested array", `{"a": {"b": [1, 2, 3]}}`, `a.b[1]`, `2`, false}, - {"index out of range", `{"a": {"b": [1, 2, 3]}}`, `a.b[1000]`, ``, true}, - {"number field", `{"a": {"0": [1, 2, 3]}}`, "a.`0`", `[1, 2, 3]`, false}, - {"letter index", `{"a": {"b": [1, 2, 3]}}`, `a.b.c`, ``, true}, - {"unknown path", `{"a": {"b": [1, 2, 3]}}`, `a.e.f`, ``, true}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var buf object.FieldBuffer - - err := json.Unmarshal([]byte(test.data), &buf) - assert.NoError(t, err) - p, err := parser.ParsePath(test.path) - assert.NoError(t, err) - v, err := p.GetValueFromObject(&buf) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - res, err := json.Marshal(v) - assert.NoError(t, err) - require.JSONEq(t, test.result, string(res)) - } - }) - } -} diff --git a/internal/object/scan_test.go b/internal/object/scan_test.go deleted file mode 100644 index 27d0a2ed3..000000000 --- a/internal/object/scan_test.go +++ /dev/null @@ -1,313 +0,0 @@ -package object_test - -import ( - "testing" - "time" - - "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func strPtr(s string) *string { - return &s -} - -func boolPtr(b bool) *bool { - return &b -} - -func TestScan(t *testing.T) { - now := time.Now() - - simpleDoc := object.NewFieldBuffer(). - Add("foo", types.NewTextValue("foo")). - Add("bar", types.NewTextValue("bar")). - Add("baz", types.NewArrayValue(object.NewValueBuffer( - types.NewIntegerValue(10), - types.NewDoubleValue(20.5), - ))) - - nestedDoc := object.NewFieldBuffer(). - Add("foo", types.NewObjectValue(simpleDoc)) - - var buf []byte - buf, err := encoding.EncodeValue(buf, types.NewObjectValue(nestedDoc), false) - assert.NoError(t, err) - - dec, _ := encoding.DecodeValue(buf, false) - assert.NoError(t, err) - - doc := object.NewFieldBuffer(). - Add("a", types.NewBlobValue([]byte("foo"))). - Add("b", types.NewTextValue("bar")). - Add("c", types.NewBooleanValue(true)). - Add("d", types.NewIntegerValue(10)). - Add("e", types.NewIntegerValue(10)). - Add("f", types.NewIntegerValue(10)). - Add("g", types.NewIntegerValue(10)). - Add("h", types.NewIntegerValue(10)). - Add("i", types.NewDoubleValue(10.5)). - Add("j", types.NewArrayValue( - object.NewValueBuffer(). - Append(types.NewBooleanValue(true)), - )). - Add("k", types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("foo")). - Add("bar", types.NewTextValue("bar")), - )). - Add("l", types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("foo")). - Add("bar", types.NewTextValue("bar")), - )). - Add("m", types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("foo")). - Add("bar", types.NewTextValue("bar")). - Add("baz", types.NewTextValue("baz")). - Add("-", types.NewTextValue("bat")), - )). - Add("n", types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("foo")). - Add("bar", types.NewTextValue("bar")), - )). - Add("o", types.NewNullValue()). - Add("p", types.NewTextValue(now.Format(time.RFC3339Nano))). - Add("r", dec). - Add("s", types.NewArrayValue(object.NewValueBuffer(types.NewBooleanValue(true), types.NewBooleanValue(false)))). - Add("u", types.NewArrayValue(object.NewValueBuffer( - types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("a")). - Add("bar", types.NewTextValue("b")), - ), - types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("c")). - Add("bar", types.NewTextValue("d")), - ), - ))). - Add("v", types.NewArrayValue(object.NewValueBuffer( - types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("a")). - Add("bar", types.NewTextValue("b")), - ), - types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("c")). - Add("bar", types.NewTextValue("d")), - ), - ))). - Add("w", types.NewArrayValue(object.NewValueBuffer( - types.NewIntegerValue(1), - types.NewIntegerValue(2), - types.NewIntegerValue(3), - types.NewIntegerValue(4), - ))). - Add("x", types.NewBlobValue([]byte{1, 2, 3, 4})). - Add("y", types.NewObjectValue( - object.NewFieldBuffer(). - Add("foo", types.NewTextValue("foo")). - Add("bar", types.NewTextValue("bar")). - Add("baz", types.NewTextValue("baz")). - Add("bat", types.NewTextValue("bat")). - Add("-", types.NewTextValue("bat")), - )). - Add("z", types.NewTimestampValue(now)) - - type foo struct { - Foo string - Pub *string `chai:"bar"` - Baz *string `chai:"-"` - } - - var a []byte - var b string - var c bool - var d int - var e int8 - var f int16 - var g int32 - var h int64 - var i float64 - var j []bool - var k foo - var l *foo = new(foo) - var m *foo - var n map[string]string - var o []int = []int{1, 2, 3} - var p time.Time - var r map[string]interface{} - var s []*bool - var u []foo - var v []*foo - var w [4]int - var x [4]uint8 - var y struct { - foo - Pub string `chai:"bar"` - Bat string - } - var z time.Time - - err = object.Scan(doc, &a, &b, &c, &d, &e, &f, &g, &h, &i, &j, &k, &l, &m, &n, &o, &p, &r, &s, &u, &v, &w, &x, &y, &z) - assert.NoError(t, err) - require.Equal(t, a, []byte("foo")) - require.Equal(t, b, "bar") - require.Equal(t, c, true) - require.Equal(t, d, int(10)) - require.Equal(t, e, int8(10)) - require.Equal(t, f, int16(10)) - require.Equal(t, g, int32(10)) - require.Equal(t, h, int64(10)) - require.Equal(t, i, float64(10.5)) - require.Equal(t, j, []bool{true}) - require.Equal(t, foo{Foo: "foo", Pub: strPtr("bar")}, k) - require.Equal(t, &foo{Foo: "foo", Pub: strPtr("bar")}, l) - require.Equal(t, &foo{Foo: "foo", Pub: strPtr("bar")}, m) - require.Equal(t, map[string]string{"foo": "foo", "bar": "bar"}, n) - require.Equal(t, []int(nil), o) - require.Equal(t, now.Format(time.RFC3339Nano), p.Format(time.RFC3339Nano)) - require.Equal(t, map[string]interface{}{ - "foo": map[string]interface{}{ - "foo": "foo", - "bar": "bar", - "baz": []interface{}{ - int64(10), float64(20.5), - }, - }, - }, r) - require.Equal(t, []*bool{boolPtr(true), boolPtr(false)}, s) - require.Equal(t, foo{Foo: "foo", Pub: strPtr("bar")}, k) - require.Equal(t, []foo{{Foo: "a", Pub: strPtr("b")}, {Foo: "c", Pub: strPtr("d")}}, u) - require.Equal(t, []*foo{{Foo: "a", Pub: strPtr("b")}, {Foo: "c", Pub: strPtr("d")}}, v) - require.Equal(t, [4]int{1, 2, 3, 4}, w) - require.Equal(t, [4]uint8{1, 2, 3, 4}, x) - require.Equal(t, now.UTC(), z) - - t.Run("objectcanner", func(t *testing.T) { - var ds objectScanner - ds.fn = func(d types.Object) error { - require.Equal(t, doc, d) - return nil - } - err := object.StructScan(doc, &ds) - assert.NoError(t, err) - }) - - t.Run("Map", func(t *testing.T) { - m := make(map[string]interface{}) - err := object.MapScan(doc, m) - assert.NoError(t, err) - require.Len(t, m, 24) - }) - - t.Run("MapPtr", func(t *testing.T) { - var m map[string]interface{} - err := object.MapScan(doc, &m) - assert.NoError(t, err) - require.Len(t, m, 24) - }) - - t.Run("Small Slice", func(t *testing.T) { - s := make([]int, 1) - arr := object.NewValueBuffer().Append(types.NewIntegerValue(1)).Append(types.NewIntegerValue(2)) - err := object.SliceScan(arr, &s) - assert.NoError(t, err) - require.Len(t, s, 2) - require.Equal(t, []int{1, 2}, s) - }) - - t.Run("Slice overwrite", func(t *testing.T) { - s := make([]int, 1) - arr := object.NewValueBuffer().Append(types.NewIntegerValue(1)).Append(types.NewIntegerValue(2)) - err := object.SliceScan(arr, &s) - assert.NoError(t, err) - err = object.SliceScan(arr, &s) - assert.NoError(t, err) - require.Len(t, s, 2) - require.Equal(t, []int{1, 2}, s) - }) - - t.Run("pointers", func(t *testing.T) { - type bar struct { - A *int - } - - b := bar{} - - d := object.NewFieldBuffer().Add("a", types.NewIntegerValue(10)) - err := object.StructScan(d, &b) - assert.NoError(t, err) - - a := 10 - require.Equal(t, bar{A: &a}, b) - }) - - t.Run("NULL with pointers", func(t *testing.T) { - type bar struct { - A *int - B *string - C *int - } - - c := 10 - b := bar{ - C: &c, - } - - d := object.NewFieldBuffer().Add("a", types.NewNullValue()) - err := object.StructScan(d, &b) - assert.NoError(t, err) - require.Equal(t, bar{}, b) - }) - - t.Run("Incompatible type", func(t *testing.T) { - var a struct { - A int - } - - d := object.NewFieldBuffer().Add("a", types.NewObjectValue(doc)) - err := object.StructScan(d, &a) - assert.Error(t, err) - }) - - t.Run("Interface member", func(t *testing.T) { - type foo struct { - A interface{} - } - - type bar struct { - B int - } - - var f foo - f.A = &bar{} - - d := object.NewFieldBuffer().Add("a", types.NewObjectValue(object.NewFieldBuffer().Add("b", types.NewIntegerValue(10)))) - err := object.StructScan(d, &f) - assert.NoError(t, err) - require.Equal(t, &foo{A: &bar{B: 10}}, &f) - }) - - t.Run("Pointer not to struct", func(t *testing.T) { - var b int - d := object.NewFieldBuffer().Add("a", types.NewIntegerValue(10)) - err := object.StructScan(d, &b) - assert.Error(t, err) - }) -} - -type objectScanner struct { - fn func(d types.Object) error -} - -func (ds objectScanner) ScanObject(d types.Object) error { - return ds.fn(d) -} diff --git a/internal/planner/index_selection.go b/internal/planner/index_selection.go index 061796736..77699959e 100644 --- a/internal/planner/index_selection.go +++ b/internal/planner/index_selection.go @@ -1,14 +1,15 @@ package planner import ( + "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/index" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/stream/table" "github.com/chaisql/chai/internal/tree" + "github.com/chaisql/chai/internal/types" ) // SelectIndex attempts to replace a sequential scan by an index scan or a pk scan by @@ -92,7 +93,7 @@ func SelectIndex(sctx *StreamContext) error { } // ensure the table exists - _, err := sctx.Catalog.GetTableInfo(seq.TableName) + info, err := sctx.Catalog.GetTableInfo(seq.TableName) if err != nil { return err } @@ -105,6 +106,7 @@ func SelectIndex(sctx *StreamContext) error { is := indexSelector{ tableScan: seq, sctx: sctx, + info: info, } return is.selectIndex() @@ -116,6 +118,7 @@ func SelectIndex(sctx *StreamContext) error { type indexSelector struct { tableScan *table.ScanOperator sctx *StreamContext + info *database.TableInfo } func (i *indexSelector) selectIndex() error { @@ -126,7 +129,11 @@ func (i *indexSelector) selectIndex() error { // get all contiguous filter nodes that can be indexed for _, f := range i.sctx.Filters { - filter := i.isFilterIndexable(f) + filter, err := i.isFilterIndexable(f) + if err != nil { + return err + } + if filter == nil { continue } @@ -158,7 +165,7 @@ func (i *indexSelector) selectIndex() error { } pk := tb.PrimaryKey if pk != nil { - selected = i.associateIndexWithNodes(tb.TableName, false, false, pk.Paths, pk.SortOrder, nodes) + selected = i.associateIndexWithNodes(tb.TableName, false, false, pk.Columns, pk.SortOrder, nodes) if selected != nil { cost = selected.Cost() } @@ -172,7 +179,7 @@ func (i *indexSelector) selectIndex() error { return err } - candidate := i.associateIndexWithNodes(idxInfo.IndexName, true, idxInfo.Unique, idxInfo.Paths, idxInfo.KeySortOrder, nodes) + candidate := i.associateIndexWithNodes(idxInfo.IndexName, true, idxInfo.Unique, idxInfo.Columns, idxInfo.KeySortOrder, nodes) if candidate == nil { continue @@ -224,44 +231,44 @@ func (i *indexSelector) selectIndex() error { return nil } -func (i *indexSelector) isFilterIndexable(f *rows.FilterOperator) *indexableNode { +func (i *indexSelector) isFilterIndexable(f *rows.FilterOperator) (*indexableNode, error) { // only operators can associate this node to an index op, ok := f.Expr.(expr.Operator) if !ok { - return nil + return nil, nil } // ensure the operator is compatible if !operatorIsIndexCompatible(op) { - return nil + return nil, nil } // determine if the operator could benefit from an index - ok, path, e := operatorCanUseIndex(op) - if !ok { - return nil + ok, path, e, err := i.operatorCanUseIndex(op) + if !ok || err != nil { + return nil, err } node := indexableNode{ node: f, - path: path, + col: path, operator: op.Token(), operand: e, } - return &node + return &node, nil } func (i *indexSelector) isTempTreeSortIndexable(n *rows.TempTreeSortOperator) *indexableNode { - // only paths can be associated with an index - path, ok := n.Expr.(expr.Path) + // only columns can be associated with an index + col, ok := n.Expr.(expr.Column) if !ok { return nil } return &indexableNode{ node: n, - path: object.Path(path), + col: string(col), desc: n.Desc, operator: scanner.ORDER, } @@ -282,14 +289,14 @@ func (i *indexSelector) isTempTreeSortIndexable(n *rows.TempTreeSortOperator) *i // -> range = {min: [3], exact: true} // rows.Filter(a IN (1, 2)) // -> ranges = [1], [2] -func (i *indexSelector) associateIndexWithNodes(treeName string, isIndex bool, isUnique bool, paths []object.Path, sortOrder tree.SortOrder, nodes indexableNodes) *candidate { - found := make([]*indexableNode, 0, len(paths)) +func (i *indexSelector) associateIndexWithNodes(treeName string, isIndex bool, isUnique bool, columns []string, sortOrder tree.SortOrder, nodes indexableNodes) *candidate { + found := make([]*indexableNode, 0, len(columns)) var desc bool var hasIn bool var sorter *indexableNode - for _, p := range paths { - ns := nodes.getByPath(p) + for _, p := range columns { + ns := nodes.getByColumn(p) if len(ns) == 0 { break } @@ -394,7 +401,7 @@ func (i *indexSelector) associateIndexWithNodes(treeName string, isIndex bool, i if !hasIn { ranges = stream.Ranges{i.buildRangeFromFilterNodes(found...)} } else { - ranges = i.buildRangesFromFilterNodes(paths, found) + ranges = i.buildRangesFromFilterNodes(columns, found) } c := candidate{ @@ -436,7 +443,7 @@ func (i *indexSelector) associateIndexWithNodes(treeName string, isIndex bool, i return &c } -func (i *indexSelector) buildRangesFromFilterNodes(paths []object.Path, filters []*indexableNode) stream.Ranges { +func (i *indexSelector) buildRangesFromFilterNodes(columns []string, filters []*indexableNode) stream.Ranges { // build a 2 dimentional list of all expressions // so that: rows.Filter(a IN (10, 11)) | rows.Filter(b = 20) | rows.Filter(c IN (30, 31)) // becomes: @@ -467,7 +474,7 @@ func (i *indexSelector) buildRangesFromFilterNodes(paths []object.Path, filters var ranges stream.Ranges i.walkExpr(l, func(row []expr.Expr) { - ranges = append(ranges, i.buildRangeFromOperator(scanner.EQ, paths[:len(row)], row...)) + ranges = append(ranges, i.buildRangeFromOperator(scanner.EQ, columns[:len(row)], row...)) }) return ranges @@ -496,23 +503,23 @@ func (i *indexSelector) walkExpr(l [][]expr.Expr, fn func(row []expr.Expr)) { } func (i *indexSelector) buildRangeFromFilterNodes(filters ...*indexableNode) stream.Range { - // first, generate a list of paths and a list of expressions - paths := make([]object.Path, 0, len(filters)) + // first, generate a list of colums and a list of expressions + colums := make([]string, 0, len(filters)) el := make(expr.LiteralExprList, 0, len(filters)) for i := range filters { - paths = append(paths, filters[i].path) + colums = append(colums, filters[i].col) el = append(el, filters[i].operand) } // use last filter node to determine the direction of the range filter := filters[len(filters)-1] - return i.buildRangeFromOperator(filter.operator, paths, el...) + return i.buildRangeFromOperator(filter.operator, colums, el...) } -func (i *indexSelector) buildRangeFromOperator(lastOp scanner.Token, paths []object.Path, operands ...expr.Expr) stream.Range { +func (i *indexSelector) buildRangeFromOperator(lastOp scanner.Token, columns []string, operands ...expr.Expr) stream.Range { rng := stream.Range{ - Paths: paths, + Columns: columns, } el := expr.LiteralExprList(operands) @@ -569,21 +576,21 @@ type indexableNode struct { // For filter nodes // the expression of the node // has been broken into - // + // // Ex: WHERE a.b[0] > 5 + 5 // Gives: - // - path: a.b[0] + // - col: a.b[0] // - operator: scanner.GT // - operand: 5 + 5 // For TempTreeSort nodes // the expression of the node // has been broken into - // + // // Ex: ORDER BY a.b[0] ASC // Gives: - // - path: a.b[0] + // - col: a.b[0] // - desc: false - path object.Path + col string operator scanner.Token operand expr.Expr desc bool @@ -595,13 +602,13 @@ type indexableNode struct { type indexableNodes []*indexableNode -// getByPath returns all indexable nodes for the given path. +// getByColumn returns all indexable nodes for the given path. // TODO(asdine): add a rule that merges nodes that point to the // same path. -func (n indexableNodes) getByPath(p object.Path) []*indexableNode { +func (n indexableNodes) getByColumn(c string) []*indexableNode { var nodes []*indexableNode for _, fn := range n { - if fn.path.IsEqual(p) { + if fn.col == c { nodes = append(nodes, fn) } } @@ -654,62 +661,144 @@ func operatorIsIndexCompatible(op expr.Operator) bool { return false } -func operatorCanUseIndex(op expr.Operator) (bool, object.Path, expr.Expr) { - lf, leftIsPath := op.LeftHand().(expr.Path) - rf, rightIsPath := op.RightHand().(expr.Path) - - // Special case for IN operator: only left operand is valid for index usage - // valid: a IN [1, 2, 3] - // invalid: 1 IN a - // invalid: a IN (b + 1, 2) - if op.Token() == scanner.IN { - if leftIsPath && !rightIsPath && !exprContainsPath(op.RightHand()) { - rh := op.RightHand() - // The IN operator can use indexes only if the right hand side is an expression list. - if _, ok := rh.(expr.LiteralExprList); !ok { - return false, nil, nil - } - return true, object.Path(lf), rh +func (i *indexSelector) operatorCanUseIndex(op expr.Operator) (bool, string, expr.Expr, error) { + switch op.Token() { + case scanner.IN: + return i.inOperatorCanUseIndex(op) + case scanner.BETWEEN: + return i.betweenOperatorCanUseIndex(op) + } + + lh := op.LeftHand() + rh := op.RightHand() + lc, leftIsCol := lh.(expr.Column) + rc, rightIsCol := rh.(expr.Column) + + var cc *database.ColumnConstraint + if leftIsCol { + cc = i.info.ColumnConstraints.GetColumnConstraint(string(lc)) + } else if rightIsCol { + cc = i.info.ColumnConstraints.GetColumnConstraint(string(rc)) + } + if cc == nil { + return false, "", nil, nil + } + + // column OP literal + if leftIsCol { + ok, v, err := exprIsCompatibleLiteral(rh, cc.Type) + if !ok || err != nil { + return false, "", nil, err } - return false, nil, nil + return true, string(lc), v, nil } - // Special case for BETWEEN operator: Given this expression (x BETWEEN a AND b), - // we can only use the index if the "x" is a path and "a" and "b" don't contain path expressions. - if op.Token() == scanner.BETWEEN { - bt := op.(*expr.BetweenOperator) - x, xIsPath := bt.X.(expr.Path) - if !xIsPath || exprContainsPath(bt.LeftHand()) || exprContainsPath(bt.RightHand()) { - return false, nil, nil + // literal OP column + if rightIsCol { + ok, v, err := exprIsCompatibleLiteral(lh, cc.Type) + if !ok || err != nil { + return false, "", nil, err } - return true, object.Path(x), expr.LiteralExprList{bt.LeftHand(), bt.RightHand()} + return true, string(rc), v, nil } - // path OP expr - if leftIsPath && !rightIsPath && !exprContainsPath(op.RightHand()) { - return true, object.Path(lf), op.RightHand() + return false, "", nil, nil +} + +// Special case for IN operator: only left operand is valid for index usage +// valid: a IN (1, 2, 3) +// invalid: 1 IN a +// invalid: a IN (b + 1, 2) +func (i *indexSelector) inOperatorCanUseIndex(op expr.Operator) (bool, string, expr.Expr, error) { + rh := op.RightHand() + _, rightIsCol := rh.(expr.Column) + if rightIsCol { + return false, "", nil, nil } - // expr OP path - if rightIsPath && !leftIsPath && !exprContainsPath(op.LeftHand()) { - return true, object.Path(rf), op.LeftHand() + lh := op.LeftHand() + lc, leftIsCol := lh.(expr.Column) + + if !leftIsCol { + return false, "", nil, nil } - return false, nil, nil -} + // The IN operator can use indexes only if: + // - the right hand side is an expression list + // - each element of the list is a literal value + // - each value has the same type as the column + rlist, ok := rh.(expr.LiteralExprList) + if !ok { + return false, "", nil, nil + } -func exprContainsPath(e expr.Expr) bool { - var hasPath bool + cc := i.info.ColumnConstraints.GetColumnConstraint(string(lc)) + if cc == nil { + return false, "", nil, nil + } - expr.Walk(e, func(e expr.Expr) bool { - if _, ok := e.(expr.Path); ok { - hasPath = true - return false + // Ensure that each element of the list is a literal value + // and that each value has the same type as the column + for i, e := range rlist { + ok, v, err := exprIsCompatibleLiteral(e, cc.Type) + if !ok || err != nil { + return false, "", nil, err } - return true - }) - return hasPath + rlist[i] = v + } + + return true, string(lc), rlist, nil +} + +// Special case for BETWEEN operator: Given this expression (x BETWEEN a AND b), +// we can only use the index if the "x" is a column and "a" and "b" are literal values. +func (i *indexSelector) betweenOperatorCanUseIndex(op expr.Operator) (bool, string, expr.Expr, error) { + lh := op.LeftHand() + rh := op.RightHand() + + bt := op.(*expr.BetweenOperator) + x, xIsCol := bt.X.(expr.Column) + if !xIsCol { + return false, "", nil, nil + } + + cc := i.info.ColumnConstraints.GetColumnConstraint(string(x)) + if cc == nil { + return false, "", nil, nil + } + + lok, lv, err := exprIsCompatibleLiteral(lh, cc.Type) + if err != nil { + return false, "", nil, err + } + rok, rv, err := exprIsCompatibleLiteral(rh, cc.Type) + if err != nil { + return false, "", nil, err + } + if !xIsCol || !lok || !rok { + return false, "", nil, nil + } + + return true, string(x), expr.LiteralExprList{lv, rv}, nil +} + +func exprIsCompatibleLiteral(e expr.Expr, tp types.Type) (bool, expr.LiteralValue, error) { + l, ok := e.(expr.LiteralValue) + if !ok { + return false, expr.LiteralValue{}, nil + } + + if !l.Value.Type().Def().IsIndexComparableWith(tp) { + return false, expr.LiteralValue{}, nil + } + + v, err := l.Value.CastAs(tp) + if err != nil { + return false, expr.LiteralValue{}, err + } + + return true, expr.LiteralValue{Value: v}, nil } diff --git a/internal/planner/optimizer.go b/internal/planner/optimizer.go index 7bca88c8a..4a0f6c677 100644 --- a/internal/planner/optimizer.go +++ b/internal/planner/optimizer.go @@ -4,12 +4,13 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/path" "github.com/chaisql/chai/internal/stream/rows" + "github.com/chaisql/chai/internal/stream/table" "github.com/chaisql/chai/internal/types" + "github.com/cockroachdb/errors" ) var optimizerRules = []func(sctx *StreamContext) error{ @@ -25,11 +26,11 @@ var optimizerRules = []func(sctx *StreamContext) error{ // and returns an optimized tree. // Depending on the rule, the tree may be modified in place or // replaced by a new one. -func Optimize(s *stream.Stream, catalog *database.Catalog) (*stream.Stream, error) { +func Optimize(s *stream.Stream, catalog *database.Catalog, params []environment.Param) (*stream.Stream, error) { if firstNode, ok := s.First().(*stream.ConcatOperator); ok { // If the first operation is a concat, optimize all streams individually. for i, st := range firstNode.Streams { - ss, err := Optimize(st, catalog) + ss, err := Optimize(st, catalog, params) if err != nil { return nil, err } @@ -42,7 +43,7 @@ func Optimize(s *stream.Stream, catalog *database.Catalog) (*stream.Stream, erro if firstNode, ok := s.First().(*stream.UnionOperator); ok { // If the first operation is a union, optimize all streams individually. for i, st := range firstNode.Streams { - ss, err := Optimize(st, catalog) + ss, err := Optimize(st, catalog, params) if err != nil { return nil, err } @@ -52,20 +53,23 @@ func Optimize(s *stream.Stream, catalog *database.Catalog) (*stream.Stream, erro return s, nil } - return optimize(s, catalog) + return optimize(s, catalog, params) } type StreamContext struct { Catalog *database.Catalog + TableInfo *database.TableInfo + Params []environment.Param Stream *stream.Stream Filters []*rows.FilterOperator Projections []*rows.ProjectOperator TempTreeSorts []*rows.TempTreeSortOperator } -func NewStreamContext(s *stream.Stream) *StreamContext { +func NewStreamContext(s *stream.Stream, catalog *database.Catalog) *StreamContext { sctx := StreamContext{ - Stream: s, + Stream: s, + Catalog: catalog, } n := s.First() @@ -74,6 +78,14 @@ func NewStreamContext(s *stream.Stream) *StreamContext { for n != nil { switch t := n.(type) { + case *table.ScanOperator: + if catalog != nil { + ti, err := sctx.Catalog.GetTableInfo(t.TableName) + if err != nil { + panic(err) + } + sctx.TableInfo = ti + } case *rows.FilterOperator: if prevIsFilter || len(sctx.Filters) == 0 { sctx.Filters = append(sctx.Filters, t) @@ -130,9 +142,9 @@ func (sctx *StreamContext) removeProjectionNode(index int) { sctx.Projections = append(sctx.Projections[:index], sctx.Projections[index+1:]...) } -func optimize(s *stream.Stream, catalog *database.Catalog) (*stream.Stream, error) { - sctx := NewStreamContext(s) - sctx.Catalog = catalog +func optimize(s *stream.Stream, catalog *database.Catalog, params []environment.Param) (*stream.Stream, error) { + sctx := NewStreamContext(s, catalog) + sctx.Params = params for _, rule := range optimizerRules { err := rule(sctx) @@ -221,24 +233,25 @@ func PrecalculateExprRule(sctx *StreamContext) error { for n != nil { switch t := n.(type) { case *rows.FilterOperator: - t.Expr, err = precalculateExpr(t.Expr) + t.Expr, err = precalculateExpr(sctx, t.Expr) case *rows.ProjectOperator: for i := range t.Exprs { - t.Exprs[i], err = precalculateExpr(t.Exprs[i]) + t.Exprs[i], err = precalculateExpr(sctx, t.Exprs[i]) if err != nil { return err } } case *rows.TempTreeSortOperator: - t.Expr, err = precalculateExpr(t.Expr) + t.Expr, err = precalculateExpr(sctx, t.Expr) case *path.SetOperator: - t.Expr, err = precalculateExpr(t.Expr) + t.Expr, err = precalculateExpr(sctx, t.Expr) case *rows.EmitOperator: - for i := range t.Exprs { - t.Exprs[i], err = precalculateExpr(t.Exprs[i]) + for i := range t.Rows { + e, err := precalculateExpr(sctx, expr.LiteralExprList(t.Rows[i].Exprs)) if err != nil { return err } + t.Rows[i].Exprs = e.(expr.LiteralExprList) } } @@ -256,62 +269,24 @@ func PrecalculateExprRule(sctx *StreamContext) error { // expression nodes when possible. // it returns a new expression with simplified nodes. // if no simplification is possible it returns the same expression. -func precalculateExpr(e expr.Expr) (expr.Expr, error) { +func precalculateExpr(sctx *StreamContext, e expr.Expr) (expr.Expr, error) { switch t := e.(type) { case expr.LiteralExprList: // we assume that the list of expressions contains only literals // until proven wrong. - literalsOnly := true for i, te := range t { - newExpr, err := precalculateExpr(te) + newExpr, err := precalculateExpr(sctx, te) if err != nil { return nil, err } - if _, ok := newExpr.(expr.LiteralValue); !ok { - literalsOnly = false - } t[i] = newExpr } - - // if literalsOnly is still true, it means we have a list or expressions - // that only contain constant values (ex: [1, true]). - // We can transform that into a types.Array. - if literalsOnly { - var vb object.ValueBuffer - for i := range t { - vb.Append(t[i].(expr.LiteralValue).Value) - } - - return expr.LiteralValue{Value: types.NewArrayValue(&vb)}, nil - } - case *expr.KVPairs: - // we assume that the list of kvpairs contains only literals - // until proven wrong. - literalsOnly := true - - var err error - for i, kv := range t.Pairs { - kv.V, err = precalculateExpr(kv.V) - if err != nil { - return nil, err - } - if _, ok := kv.V.(expr.LiteralValue); !ok { - literalsOnly = false - } - t.Pairs[i] = kv - } - - // if literalsOnly is still true, it means we have a list of kvpairs - // that only contain constant values (ex: {"a": 1, "b": true}. - // We can transform that into a types.Object. - if literalsOnly { - var fb object.FieldBuffer - for i := range t.Pairs { - fb.Add(t.Pairs[i].K, types.Value(t.Pairs[i].V.(expr.LiteralValue).Value)) - } - - return expr.LiteralValue{Value: types.NewObjectValue(&fb)}, nil + case expr.PositionalParam, expr.NamedParam: + v, err := t.Eval(&environment.Environment{Params: sctx.Params}) + if err != nil { + return nil, err } + return expr.LiteralValue{Value: v}, nil case expr.Operator: // since expr.Operator is an interface, // this optimization must only be applied to @@ -324,11 +299,11 @@ func precalculateExpr(e expr.Expr) (expr.Expr, error) { return e, nil } - lh, err := precalculateExpr(t.LeftHand()) + lh, err := precalculateExpr(sctx, t.LeftHand()) if err != nil { return nil, err } - rh, err := precalculateExpr(t.RightHand()) + rh, err := precalculateExpr(sctx, t.RightHand()) if err != nil { return nil, err } @@ -336,7 +311,7 @@ func precalculateExpr(e expr.Expr) (expr.Expr, error) { t.SetRightHandExpr(rh) if b, ok := t.(*expr.BetweenOperator); ok { - b.X, err = precalculateExpr(b.X) + b.X, err = precalculateExpr(sctx, b.X) if err != nil { return nil, err } @@ -346,23 +321,139 @@ func precalculateExpr(e expr.Expr) (expr.Expr, error) { } } - _, leftIsLit := lh.(expr.LiteralValue) - _, rightIsLit := rh.(expr.LiteralValue) + lv, leftIsLit := lh.(expr.LiteralValue) + rv, rightIsLit := rh.(expr.LiteralValue) // if both operands are literals, we can precalculate them now if leftIsLit && rightIsLit { v, err := t.Eval(&environment.Environment{}) - // any error encountered here is unexpected if err != nil { - panic(err) + return nil, err } // we replace this expression with the result of its evaluation return expr.LiteralValue{Value: v}, nil } + + // if one operand is a column and the other is a literal + // we can check if the types are compatible + lc, leftIsCol := lh.(expr.Column) + rc, rightIsCol := rh.(expr.Column) + + if leftIsCol && rightIsLit { + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(lc)).Type + if !tp.Def().IsComparableWith(rv.Value.Type()) { + return nil, errors.Errorf("invalid input syntax for type %s: %s", tp, rh) + } + + if tp.Def().IsIndexComparableWith(rv.Value.Type()) { + v, err := rv.Value.CastAs(tp) + if err != nil { + return nil, errors.Errorf("invalid input syntax for type %s: %s", tp, rh) + } + t.SetRightHandExpr(expr.LiteralValue{Value: v}) + } + } + + if leftIsLit && rightIsCol { + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(rc)).Type + if !tp.Def().IsComparableWith(lv.Value.Type()) { + return nil, errors.Errorf("invalid input syntax for type %s: %s", tp, lh) + } + + if tp.Def().IsIndexComparableWith(lv.Value.Type()) { + v, err := lv.Value.CastAs(tp) + if err != nil { + return nil, errors.Errorf("invalid input syntax for type %s: %s", tp, lh) + } + t.SetLeftHandExpr(expr.LiteralValue{Value: v}) + } + } + + return t, nil } return e, nil } +func CheckExprTypeRule(sctx *StreamContext) error { + n := sctx.Stream.Op + var err error + + for n != nil { + switch t := n.(type) { + case *rows.FilterOperator: + err = checkExprType(sctx, t.Expr) + case *rows.ProjectOperator: + for i := range t.Exprs { + err = checkExprType(sctx, t.Exprs[i]) + if err != nil { + return err + } + } + case *rows.TempTreeSortOperator: + err = checkExprType(sctx, t.Expr) + case *path.SetOperator: + err = checkExprType(sctx, t.Expr) + case *rows.EmitOperator: + for i := range t.Rows { + err := checkExprType(sctx, expr.LiteralExprList(t.Rows[i].Exprs)) + if err != nil { + return err + } + } + } + + if err != nil { + return err + } + + n = n.GetPrev() + } + + return err +} + +func checkExprType(sctx *StreamContext, e expr.Expr) (err error) { + op, ok := e.(expr.Operator) + if !ok { + return nil + } + + lh := op.LeftHand() + rh := op.RightHand() + + lc, leftIsCol := lh.(expr.Column) + rc, rightIsCol := rh.(expr.Column) + + lv, leftIsLit := lh.(expr.LiteralValue) + rv, rightIsLit := rh.(expr.LiteralValue) + + if leftIsCol && rightIsCol { + return nil + } + + if leftIsCol && rightIsLit { + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(lc)).Type + _, err := rv.Value.CastAs(tp) + if err != nil { + return errors.Errorf("invalid input syntax for type %s: %s", tp, rh) + } + + return nil + } + + if leftIsLit && rightIsCol { + tp := sctx.TableInfo.ColumnConstraints.GetColumnConstraint(string(rc)).Type + _, err := lv.Value.CastAs(tp) + if err != nil { + return errors.Errorf("invalid input syntax for type %s: %s", tp, lh) + } + + return nil + } + + return nil +} + // RemoveUnnecessaryFilterNodesRule removes any filter node whose // condition is a constant expression that evaluates to a truthy value. // if it evaluates to a falsy value, it considers that the tree @@ -386,21 +477,6 @@ func RemoveUnnecessaryFilterNodesRule(sctx *StreamContext) error { // if the expr is truthy, we remove the node from the stream sctx.removeFilterNodeByIndex(i) - case *expr.InOperator: - // IN operator with empty array - // ex: WHERE a IN [] - lv, ok := t.RightHand().(expr.LiteralValue) - if ok && lv.Value.Type() == types.TypeArray { - l, err := object.ArrayLength(types.AsArray(lv.Value)) - if err != nil { - return err - } - // if the array is empty, we return an empty stream - if l == 0 { - sctx.Stream = new(stream.Stream) - return nil - } - } } } @@ -438,17 +514,17 @@ func RemoveUnnecessaryTempSortNodesRule(sctx *StreamContext) error { return nil } - lpath, ok := sctx.TempTreeSorts[0].Expr.(expr.Path) + lcol, ok := sctx.TempTreeSorts[0].Expr.(expr.Column) if !ok { return nil } - rpath, ok := sctx.TempTreeSorts[1].Expr.(expr.Path) + rcol, ok := sctx.TempTreeSorts[1].Expr.(expr.Column) if !ok { return nil } - if !lpath.IsEqual(rpath) { + if lcol != rcol { return nil } diff --git a/internal/planner/optimizer_test.go b/internal/planner/optimizer_test.go index ecb77505a..ef89b3bab 100644 --- a/internal/planner/optimizer_test.go +++ b/internal/planner/optimizer_test.go @@ -3,8 +3,8 @@ package planner_test import ( "testing" + "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/planner" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" @@ -13,7 +13,6 @@ import ( "github.com/chaisql/chai/internal/stream/table" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) @@ -81,7 +80,7 @@ func TestSplitANDConditionRule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - sctx := planner.NewStreamContext(test.in) + sctx := planner.NewStreamContext(test.in, nil) err := planner.SplitANDConditionRule(sctx) assert.NoError(t, err) require.Equal(t, test.expected.String(), sctx.Stream.String()) @@ -111,66 +110,35 @@ func TestPrecalculateExprRule(t *testing.T) { }, { "constant sub-expr: a > 1 - 40 -> a > -39", - expr.Gt(expr.Path{object.PathFragment{FieldName: "a"}}, expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40))), - expr.Gt(expr.Path{object.PathFragment{FieldName: "a"}}, testutil.DoubleValue(-39)), + expr.Gt(expr.Column("a"), expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40))), + expr.Gt(expr.Column("a"), testutil.DoubleValue(-39)), }, { - "constant sub-expr: a IN [1, 2] -> a IN array([1, 2])", - expr.In(expr.Path{object.PathFragment{FieldName: "a"}}, expr.LiteralExprList{testutil.IntegerValue(1), testutil.IntegerValue(2)}), - expr.In(expr.Path{object.PathFragment{FieldName: "a"}}, expr.LiteralValue{Value: types.NewArrayValue(object.NewValueBuffer(). - Append(types.NewIntegerValue(1)). - Append(types.NewIntegerValue(2)))}), - }, - { - "non-constant expr list: [a, 1 - 40] -> [a, -39]", + "non-constant expr list: (a, 1 - 40) -> (a, -39)", expr.LiteralExprList{ - expr.Path{object.PathFragment{FieldName: "a"}}, + expr.Column("a"), expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40)), }, expr.LiteralExprList{ - expr.Path{object.PathFragment{FieldName: "a"}}, + expr.Column("a"), testutil.DoubleValue(-39), }, }, - { - "constant expr list: [3, 1 - 40] -> array([3, -39])", - expr.LiteralExprList{ - testutil.IntegerValue(3), - expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40)), - }, - expr.LiteralValue{Value: types.NewArrayValue(object.NewValueBuffer(). - Append(types.NewIntegerValue(3)). - Append(types.NewDoubleValue(-39)))}, - }, - { - `non-constant kvpair: {"a": d, "b": 1 - 40} -> {"a": 3, "b": -39}`, - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: expr.Path{object.PathFragment{FieldName: "d"}}}, - {K: "b", V: expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40))}, - }}, - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: expr.Path{object.PathFragment{FieldName: "d"}}}, - {K: "b", V: testutil.DoubleValue(-39)}, - }}, - }, - { - `constant kvpair: {"a": 3, "b": 1 - 40} -> object({"a": 3, "b": -39})`, - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.IntegerValue(3)}, - {K: "b", V: expr.Sub(testutil.IntegerValue(1), testutil.DoubleValue(40))}, - }}, - expr.LiteralValue{Value: types.NewObjectValue(object.NewFieldBuffer(). - Add("a", types.NewIntegerValue(3)). - Add("b", types.NewDoubleValue(-39)), - )}, - }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + db, tx, cleanup := testutil.NewTestTx(t) + defer cleanup() + + testutil.MustExec(t, db, tx, ` + CREATE TABLE foo (k INT PRIMARY KEY, a INT); + `) + s := stream.New(table.Scan("foo")). Pipe(rows.Filter(test.e)) - sctx := planner.NewStreamContext(s) + + sctx := planner.NewStreamContext(s, tx.Catalog) err := planner.PrecalculateExprRule(sctx) assert.NoError(t, err) require.Equal(t, stream.New(table.Scan("foo")).Pipe(rows.Filter(test.expected)).String(), sctx.Stream.String()) @@ -193,14 +161,6 @@ func TestRemoveUnnecessarySelectionNodesRule(t *testing.T) { stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("10"))), stream.New(table.Scan("foo")), }, - { - "truthy constant expr with IN", - stream.New(table.Scan("foo")).Pipe(rows.Filter(expr.In( - expr.Path(object.NewPath("a")), - testutil.ArrayValue(object.NewValueBuffer()), - ))), - &stream.Stream{}, - }, { "falsy constant expr", stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("0"))), @@ -210,7 +170,7 @@ func TestRemoveUnnecessarySelectionNodesRule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - sctx := planner.NewStreamContext(test.root) + sctx := planner.NewStreamContext(test.root, nil) err := planner.RemoveUnnecessaryFilterNodesRule(sctx) assert.NoError(t, err) require.Equal(t, test.expected.String(), sctx.Stream.String()) @@ -272,31 +232,21 @@ func TestSelectIndex_Simple(t *testing.T) { Pipe(rows.Project(parser.MustParseExpr("a"))), }, { - "SELECT a FROM foo WHERE c = 'hello' AND b = 2", - stream.New(table.Scan("foo")). - Pipe(rows.Filter(parser.MustParseExpr("c = 'hello'"))). - Pipe(rows.Filter(parser.MustParseExpr("b = 2"))). - Pipe(rows.Project(parser.MustParseExpr("a"))), - stream.New(index.Scan("idx_foo_c", stream.Range{Min: exprList(testutil.TextValue("hello")), Exact: true})). - Pipe(rows.Filter(parser.MustParseExpr("b = 2"))). - Pipe(rows.Project(parser.MustParseExpr("a"))), - }, - { - "SELECT a FROM foo WHERE c = 'hello' AND d = 2", + "SELECT a FROM foo WHERE c = 3 AND d = 2", stream.New(table.Scan("foo")). - Pipe(rows.Filter(parser.MustParseExpr("c = 'hello'"))). + Pipe(rows.Filter(parser.MustParseExpr("c = 3"))). Pipe(rows.Filter(parser.MustParseExpr("d = 2"))). Pipe(rows.Project(parser.MustParseExpr("a"))), - stream.New(index.Scan("idx_foo_c", stream.Range{Min: exprList(testutil.TextValue("hello")), Exact: true})). + stream.New(index.Scan("idx_foo_c", stream.Range{Min: exprList(testutil.IntegerValue(3)), Exact: true})). Pipe(rows.Filter(parser.MustParseExpr("d = 2"))). Pipe(rows.Project(parser.MustParseExpr("a"))), }, { - "FROM foo WHERE a IN [1, 2]", + "FROM foo WHERE a IN (1, 2)", stream.New(table.Scan("foo")).Pipe(rows.Filter( expr.In( parser.MustParseExpr("a"), - testutil.ExprList(t, `[1, 2]`), + testutil.ExprList(t, `(1, 2)`), ), )), stream.New(index.Scan("idx_foo_a", stream.Range{Min: exprList(testutil.IntegerValue(1)), Exact: true}, stream.Range{Min: exprList(testutil.IntegerValue(2)), Exact: true})), @@ -340,18 +290,10 @@ func TestSelectIndex_Simple(t *testing.T) { stream.New(index.Scan("idx_foo_a", stream.Range{Min: exprList(testutil.IntegerValue(1)), Exact: true})). Pipe(rows.Filter(parser.MustParseExpr("k < 2"))), }, - { - "FROM foo WHERE a = 1 AND k = 'hello'", - stream.New(table.Scan("foo")). - Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). - Pipe(rows.Filter(parser.MustParseExpr("k = 'hello'"))), - stream.New(table.Scan("foo", stream.Range{Min: exprList(testutil.TextValue("hello")), Exact: true})). - Pipe(rows.Filter(parser.MustParseExpr("a = 1"))), - }, { // c is an INT, 1.1 cannot be converted to int without precision loss, don't use the index "FROM foo WHERE c < 1.1", stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("c < 1.1"))), - stream.New(index.Scan("idx_foo_c", stream.Range{Max: exprList(testutil.DoubleValue(1.1)), Exclusive: true})), + stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("c < 1.1"))), }, // { // "FROM foo WHERE a = 1 OR b = 2", @@ -390,7 +332,7 @@ func TestSelectIndex_Simple(t *testing.T) { defer cleanup() testutil.MustExec(t, db, tx, ` - CREATE TABLE foo (k INT PRIMARY KEY, a INT, b INT, c INT, d ANY); + CREATE TABLE foo (k INT PRIMARY KEY, a INT, b INT, c INT, d INT); CREATE INDEX idx_foo_a ON foo(a); CREATE INDEX idx_foo_b ON foo(b); CREATE UNIQUE INDEX idx_foo_c ON foo(c); @@ -400,70 +342,14 @@ func TestSelectIndex_Simple(t *testing.T) { (3, 3, 3, 3, 3) `) - sctx := planner.NewStreamContext(test.root) + sctx := planner.NewStreamContext(test.root, tx.Catalog) sctx.Catalog = tx.Catalog - err := planner.SelectIndex(sctx) + st, err := planner.Optimize(test.root, tx.Catalog, nil) + // err := planner.SelectIndex(sctx) assert.NoError(t, err) - require.Equal(t, test.expected.String(), sctx.Stream.String()) + require.Equal(t, test.expected.String(), st.String()) }) } - - t.Run("array indexes", func(t *testing.T) { - tests := []struct { - name string - root, expected *stream.Stream - }{ - { - "non-indexed path", - stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("b = [1, 1]"))), - stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("b = [1, 1]"))), - }, - { - "FROM foo WHERE k = [1, 1]", - stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("k = [1, 1]"))), - stream.New(table.Scan("foo", stream.Range{Min: exprList(testutil.ExprList(t, `[1, 1]`)), Exact: true})), - }, - { // constraint on k[0] INT should not modify the operand - "FROM foo WHERE k = [1.5, 1.5]", - stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("k = [1.5, 1.5]"))), - stream.New(table.Scan("foo", stream.Range{Min: exprList(testutil.ExprList(t, `[1.5, 1.5]`)), Exact: true})), - }, - { - "FROM foo WHERE a = [1, 1]", - stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("a = [1, 1]"))), - stream.New(index.Scan("idx_foo_a", stream.Range{Min: testutil.ExprList(t, `[[1, 1]]`), Exact: true})), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - db, tx, cleanup := testutil.NewTestTx(t) - defer cleanup() - - testutil.MustExec(t, db, tx, ` - CREATE TABLE foo ( - k ARRAY PRIMARY KEY, - a ARRAY, - b ARRAY - ); - CREATE INDEX idx_foo_a ON foo(a); - INSERT INTO foo (k, a, b) VALUES - ([1, 1], [1, 1], [1, 1]), - ([2, 2], [2, 2], [2, 2]), - ([3, 3], [3, 3], [3, 3]) - `) - - sctx := planner.NewStreamContext(test.root) - sctx.Catalog = tx.Catalog - err := planner.PrecalculateExprRule(sctx) - assert.NoError(t, err) - - err = planner.SelectIndex(sctx) - assert.NoError(t, err) - require.Equal(t, test.expected.String(), sctx.Stream.String()) - }) - } - }) } func TestSelectIndex_Composite(t *testing.T) { @@ -476,42 +362,42 @@ func TestSelectIndex_Composite(t *testing.T) { stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("d = 2"))), - stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})), + stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exact: true})), }, { "FROM foo WHERE a = 1 AND d > 2", stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("d > 2"))), - stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exclusive: true})), + stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exclusive: true})), }, { "FROM foo WHERE a = 1 AND d < 2", stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("d < 2"))), - stream.New(index.Scan("idx_foo_a_d", stream.Range{Max: testutil.ExprList(t, `[1, 2]`), Exclusive: true})), + stream.New(index.Scan("idx_foo_a_d", stream.Range{Max: testutil.ExprList(t, `(1, 2)`), Exclusive: true})), }, { "FROM foo WHERE a = 1 AND d <= 2", stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("d <= 2"))), - stream.New(index.Scan("idx_foo_a_d", stream.Range{Max: testutil.ExprList(t, `[1, 2]`)})), + stream.New(index.Scan("idx_foo_a_d", stream.Range{Max: testutil.ExprList(t, `(1, 2)`)})), }, { "FROM foo WHERE a = 1 AND d >= 2", stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("d >= 2"))), - stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `[1, 2]`)})), + stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `(1, 2)`)})), }, { "FROM foo WHERE a > 1 AND d > 2", stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a > 1"))). Pipe(rows.Filter(parser.MustParseExpr("d > 2"))), - stream.New(index.Scan("idx_foo_a", stream.Range{Min: testutil.ExprList(t, `[1]`), Exclusive: true})). + stream.New(index.Scan("idx_foo_a", stream.Range{Min: testutil.ExprList(t, `(1)`), Exclusive: true})). Pipe(rows.Filter(parser.MustParseExpr("d > 2"))), }, { @@ -519,8 +405,8 @@ func TestSelectIndex_Composite(t *testing.T) { stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a > ?"))). Pipe(rows.Filter(parser.MustParseExpr("d > ?"))), - stream.New(index.Scan("idx_foo_a", stream.Range{Min: testutil.ExprList(t, `[?]`), Exclusive: true})). - Pipe(rows.Filter(parser.MustParseExpr("d > ?"))), + stream.New(index.Scan("idx_foo_a", stream.Range{Min: testutil.ExprList(t, `(1)`), Exclusive: true})). + Pipe(rows.Filter(parser.MustParseExpr("d > 1"))), }, { "FROM foo WHERE a = 1 AND b = 2 AND c = 3", @@ -528,28 +414,28 @@ func TestSelectIndex_Composite(t *testing.T) { Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("b = 2"))). Pipe(rows.Filter(parser.MustParseExpr("c = 3"))), - stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `[1, 2, 3]`), Exact: true})), + stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `(1, 2, 3)`), Exact: true})), }, { "FROM foo WHERE a = 1 AND b = 2", // c is omitted, but it can still use idx_foo_a_b_c stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("b = 2"))), - stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})), + stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exact: true})), }, { "FROM foo WHERE a = 1 AND b > 2", // c is omitted, but it can still use idx_foo_a_b_c, with > b stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("b > 2"))), - stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exclusive: true})), + stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exclusive: true})), }, { "FROM foo WHERE a = 1 AND b < 2", // c is omitted, but it can still use idx_foo_a_b_c, with > b stream.New(table.Scan("foo")). Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("b < 2"))), - stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Max: testutil.ExprList(t, `[1, 2]`), Exclusive: true})), + stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Max: testutil.ExprList(t, `(1, 2)`), Exclusive: true})), }, { "FROM foo WHERE a = 1 AND b = 2 and k = 3", // c is omitted, but it can still use idx_foo_a_b_c @@ -557,7 +443,7 @@ func TestSelectIndex_Composite(t *testing.T) { Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("b = 2"))). Pipe(rows.Filter(parser.MustParseExpr("k = 3"))), - stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})). + stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exact: true})). Pipe(rows.Filter(parser.MustParseExpr("k = 3"))), }, // If a path is missing from the query, we can still the index, with paths after the missing one are @@ -589,98 +475,89 @@ func TestSelectIndex_Composite(t *testing.T) { Pipe(rows.Filter(parser.MustParseExpr("b = 1"))), }, { - "FROM foo WHERE a = 1 AND b = 2 AND c = 'a'", - stream.New(table.Scan("foo")). - Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). - Pipe(rows.Filter(parser.MustParseExpr("b = 2"))). - Pipe(rows.Filter(parser.MustParseExpr("c = 'a'"))), - stream.New(index.Scan("idx_foo_a_b_c", stream.Range{Min: exprList(testutil.IntegerValue(1), testutil.IntegerValue(2), testutil.TextValue("a")), Exact: true})), - }, - - { - "FROM foo WHERE a IN [1, 2] AND d = 4", + "FROM foo WHERE a IN (1, 2) AND d = 4", stream.New(table.Scan("foo")). Pipe(rows.Filter( expr.In( parser.MustParseExpr("a"), - testutil.ExprList(t, `[1, 2]`), + testutil.ExprList(t, `(1, 2)`), ), )). Pipe(rows.Filter(parser.MustParseExpr("d = 4"))), stream.New(index.Scan("idx_foo_a_d", - stream.Range{Min: testutil.ExprList(t, `[1, 4]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[2, 4]`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(1, 4)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(2, 4)`), Exact: true}, )), }, { - "FROM foo WHERE a IN [1, 2] AND b = 3 AND c = 4", + "FROM foo WHERE a IN (1, 2) AND b = 3 AND c = 4", stream.New(table.Scan("foo")). Pipe(rows.Filter( expr.In( parser.MustParseExpr("a"), - testutil.ExprList(t, `[1, 2]`), + testutil.ExprList(t, `(1, 2)`), ), )). Pipe(rows.Filter(parser.MustParseExpr("b = 3"))). Pipe(rows.Filter(parser.MustParseExpr("c = 4"))), stream.New(index.Scan("idx_foo_a_b_c", - stream.Range{Min: testutil.ExprList(t, `[1, 3, 4]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[2, 3, 4]`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(1, 3, 4)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(2, 3, 4)`), Exact: true}, )), }, { - "FROM foo WHERE a IN [1, 2] AND b = 3 AND c > 4", + "FROM foo WHERE a IN (1, 2) AND b = 3 AND c > 4", stream.New(table.Scan("foo")). Pipe(rows.Filter( expr.In( parser.MustParseExpr("a"), - testutil.ExprList(t, `[1, 2]`), + testutil.ExprList(t, `(1, 2)`), ), )). Pipe(rows.Filter(parser.MustParseExpr("b = 3"))). Pipe(rows.Filter(parser.MustParseExpr("c > 4"))), stream.New(index.Scan("idx_foo_a_b_c", - stream.Range{Min: testutil.ExprList(t, `[1, 3]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[2, 3]`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(1, 3)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(2, 3)`), Exact: true}, )).Pipe(rows.Filter(parser.MustParseExpr("c > 4"))), }, { - "FROM foo WHERE a IN [1, 2] AND b = 3 AND c < 4", + "FROM foo WHERE a IN (1, 2) AND b = 3 AND c < 4", stream.New(table.Scan("foo")). Pipe(rows.Filter( expr.In( parser.MustParseExpr("a"), - testutil.ExprList(t, `[1, 2]`), + testutil.ExprList(t, `(1, 2)`), ), )). Pipe(rows.Filter(parser.MustParseExpr("b = 3"))). Pipe(rows.Filter(parser.MustParseExpr("c < 4"))), stream.New(index.Scan("idx_foo_a_b_c", - stream.Range{Min: testutil.ExprList(t, `[1, 3]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[2, 3]`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(1, 3)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(2, 3)`), Exact: true}, )).Pipe(rows.Filter(parser.MustParseExpr("c < 4"))), }, { - "FROM foo WHERE a IN [1, 2] AND b IN [3, 4] AND c > 5", + "FROM foo WHERE a IN (1, 2) AND b IN (3, 4) AND c > 5", stream.New(table.Scan("foo")). Pipe(rows.Filter( expr.In( parser.MustParseExpr("a"), - testutil.ExprList(t, `[1, 2]`), + testutil.ExprList(t, `(1, 2)`), ), )). Pipe(rows.Filter( expr.In( parser.MustParseExpr("b"), - testutil.ExprList(t, `[3, 4]`), + testutil.ExprList(t, `(3, 4)`), ), )). Pipe(rows.Filter(parser.MustParseExpr("c > 5"))), stream.New(index.Scan("idx_foo_a_b_c", - stream.Range{Min: testutil.ExprList(t, `[1, 3]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[1, 4]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[2, 3]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[2, 4]`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(1, 3)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(1, 4)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(2, 3)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(2, 4)`), Exact: true}, )).Pipe(rows.Filter(parser.MustParseExpr("c > 5"))), }, { @@ -700,7 +577,7 @@ func TestSelectIndex_Composite(t *testing.T) { defer cleanup() testutil.MustExec(t, db, tx, ` - CREATE TABLE foo (k INT PRIMARY KEY, a ANY, b ANY, c INT, d ANY, x ANY, y ANY, z ANY); + CREATE TABLE foo (k INT PRIMARY KEY, a INT, b INT, c INT, d INT, x INT, y INT, z INT); CREATE INDEX idx_foo_a ON foo(a); CREATE INDEX idx_foo_b ON foo(b); CREATE UNIQUE INDEX idx_foo_c ON foo(c); @@ -713,68 +590,16 @@ func TestSelectIndex_Composite(t *testing.T) { (3, 3, 3, 3, 3) `) - sctx := planner.NewStreamContext(test.root) + sctx := planner.NewStreamContext(test.root, tx.Catalog) sctx.Catalog = tx.Catalog - err := planner.SelectIndex(sctx) + st, err := planner.Optimize(test.root, tx.Catalog, []environment.Param{ + {Value: 1}, + {Value: 2}, + }) assert.NoError(t, err) - require.Equal(t, test.expected.String(), sctx.Stream.String()) + require.Equal(t, test.expected.String(), st.String()) }) } - - t.Run("array indexes", func(t *testing.T) { - tests := []struct { - name string - root, expected *stream.Stream - }{ - { - "FROM foo WHERE a = [1, 1] AND b = [2, 2]", - stream.New(table.Scan("foo")). - Pipe(rows.Filter(parser.MustParseExpr("a = [1, 1]"))). - Pipe(rows.Filter(parser.MustParseExpr("b = [2, 2]"))), - stream.New(index.Scan("idx_foo_a_b", stream.Range{ - Min: testutil.ExprList(t, `[[1, 1], [2, 2]]`), - Exact: true})), - }, - { - "FROM foo WHERE a = [1, 1] AND b > [2, 2]", - stream.New(table.Scan("foo")). - Pipe(rows.Filter(parser.MustParseExpr("a = [1, 1]"))). - Pipe(rows.Filter(parser.MustParseExpr("b > [2, 2]"))), - stream.New(index.Scan("idx_foo_a_b", stream.Range{ - Min: testutil.ExprList(t, `[[1, 1], [2, 2]]`), - Exclusive: true})), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - db, tx, cleanup := testutil.NewTestTx(t) - defer cleanup() - - testutil.MustExec(t, db, tx, ` - CREATE TABLE foo ( - k ARRAY PRIMARY KEY, - a ARRAY, - b ANY - ); - CREATE INDEX idx_foo_a_b ON foo(a, b); - INSERT INTO foo (k, a, b) VALUES - ([1, 1], [1, 1], [1, 1]), - ([2, 2], [2, 2], [2, 2]), - ([3, 3], [3, 3], [3, 3]) - `) - - sctx := planner.NewStreamContext(test.root) - sctx.Catalog = tx.Catalog - err := planner.PrecalculateExprRule(sctx) - assert.NoError(t, err) - - err = planner.SelectIndex(sctx) - assert.NoError(t, err) - require.Equal(t, test.expected.String(), sctx.Stream.String()) - }) - } - }) } func TestOptimize(t *testing.T) { @@ -783,20 +608,24 @@ func TestOptimize(t *testing.T) { db, tx, cleanup := testutil.NewTestTx(t) defer cleanup() testutil.MustExec(t, db, tx, ` - CREATE TABLE foo; - CREATE TABLE bar; + CREATE TABLE foo(a INT, b INT, c INT, d INT); + CREATE TABLE bar(a INT, b INT, c INT, d INT); `) got, err := planner.Optimize( stream.New(stream.Union( stream.New(stream.Concat( stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("a = 1 + 2"))), - stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("b = 1 + 2"))), + stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("b = 1 + $1"))), )), stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("c = 1 + 2"))), - stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("d = 1 + 2"))), + stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("d = 1 + $2"))), )), - tx.Catalog) + tx.Catalog, []environment.Param{ + {Name: "1", Value: 2}, + {Name: "2", Value: 3}, + }) + assert.NoError(t, err) want := stream.New(stream.Union( stream.New(stream.Concat( @@ -804,10 +633,9 @@ func TestOptimize(t *testing.T) { stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("b = 3"))), )), stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("c = 3"))), - stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("d = 3"))), + stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("d = 4"))), )) - assert.NoError(t, err) require.Equal(t, want.String(), got.String()) }) @@ -815,8 +643,8 @@ func TestOptimize(t *testing.T) { db, tx, cleanup := testutil.NewTestTx(t) defer cleanup() testutil.MustExec(t, db, tx, ` - CREATE TABLE foo; - CREATE TABLE bar; + CREATE TABLE foo(a INT, b INT, c INT, d INT); + CREATE TABLE bar(a INT, b INT, c INT, d INT); `) got, err := planner.Optimize( @@ -828,7 +656,7 @@ func TestOptimize(t *testing.T) { stream.New(table.Scan("foo")).Pipe(rows.Filter(parser.MustParseExpr("12"))), stream.New(table.Scan("bar")).Pipe(rows.Filter(parser.MustParseExpr("13"))), )), - tx.Catalog) + tx.Catalog, nil) want := stream.New(stream.Union( stream.New(stream.Concat( @@ -848,8 +676,8 @@ func TestOptimize(t *testing.T) { db, tx, cleanup := testutil.NewTestTx(t) defer cleanup() testutil.MustExec(t, db, tx, ` - CREATE TABLE foo(a any, d any); - CREATE TABLE bar(a, d); + CREATE TABLE foo(a INT, d INT); + CREATE TABLE bar(a INT, d INT); CREATE INDEX idx_foo_a_d ON foo(a, d); CREATE INDEX idx_bar_a_d ON bar(a, d); `) @@ -863,11 +691,11 @@ func TestOptimize(t *testing.T) { Pipe(rows.Filter(parser.MustParseExpr("a = 1"))). Pipe(rows.Filter(parser.MustParseExpr("d = 2"))), )), - tx.Catalog) + tx.Catalog, nil) want := stream.New(stream.Concat( - stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})), - stream.New(index.Scan("idx_bar_a_d", stream.Range{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})), + stream.New(index.Scan("idx_foo_a_d", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exact: true})), + stream.New(index.Scan("idx_bar_a_d", stream.Range{Min: testutil.ExprList(t, `(1, 2)`), Exact: true})), )) assert.NoError(t, err) diff --git a/internal/query/statement/alter.go b/internal/query/statement/alter.go index 52e9173fe..68879fd77 100644 --- a/internal/query/statement/alter.go +++ b/internal/query/statement/alter.go @@ -43,7 +43,7 @@ func (stmt AlterTableRenameStmt) Run(ctx *Context) (Result, error) { type AlterTableAddColumnStmt struct { TableName string - FieldConstraint *database.FieldConstraint + ColumnConstraint *database.ColumnConstraint TableConstraints database.TableConstraints } @@ -58,7 +58,7 @@ func (stmt *AlterTableAddColumnStmt) IsReadOnly() bool { func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) { var err error - // get the table before adding the field constraint + // get the table before adding the column constraint // and assign the table to the table.Scan operator // so that it can decode the records properly scan := table.Scan(stmt.TableName) @@ -70,11 +70,11 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) { // get the current list of indexes indexNames := ctx.Tx.Catalog.ListIndexes(stmt.TableName) - // add the field constraint to the table - err = ctx.Tx.CatalogWriter().AddFieldConstraint( + // add the column constraint to the table + err = ctx.Tx.CatalogWriter().AddColumnConstraint( ctx.Tx, stmt.TableName, - stmt.FieldConstraint, + stmt.ColumnConstraint, stmt.TableConstraints) if err != nil { return Result{}, err @@ -86,11 +86,11 @@ func (stmt *AlterTableAddColumnStmt) Run(ctx *Context) (Result, error) { for _, tc := range stmt.TableConstraints { if tc.Unique { idx, err := ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{ - Paths: tc.Paths, - Unique: true, + Columns: tc.Columns, + Unique: true, Owner: database.Owner{ TableName: stmt.TableName, - Paths: tc.Paths, + Columns: tc.Columns, }, }) if err != nil { diff --git a/internal/query/statement/alter_test.go b/internal/query/statement/alter_test.go index 59b33260f..4be918724 100644 --- a/internal/query/statement/alter_test.go +++ b/internal/query/statement/alter_test.go @@ -14,11 +14,11 @@ func TestAlterTable(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE foo") + err = db.Exec("CREATE TABLE foo(name TEXT, age INT)") assert.NoError(t, err) // Insert some data into foo - err = db.Exec(`INSERT INTO foo VALUES {name: "John Doe", age: 99}`) + err = db.Exec(`INSERT INTO foo VALUES ('John Doe', 99)`) assert.NoError(t, err) // Renaming the table to the same name should fail. @@ -31,7 +31,7 @@ func TestAlterTable(t *testing.T) { // Selecting from the old name should fail. err = db.Exec("SELECT * FROM foo") if !errs.IsNotFoundError(err) { - assert.ErrorIs(t, err, errs.NotFoundError{Name: "foo"}) + assert.ErrorIs(t, err, errs.NewNotFoundError("foo")) } r, err := db.QueryRow("SELECT * FROM bar") diff --git a/internal/query/statement/create.go b/internal/query/statement/create.go index a8357ae5d..9cd661a89 100644 --- a/internal/query/statement/create.go +++ b/internal/query/statement/create.go @@ -56,11 +56,11 @@ func (stmt *CreateTableStmt) Run(ctx *Context) (Result, error) { for _, tc := range stmt.Info.TableConstraints { if tc.Unique { _, err = ctx.Tx.CatalogWriter().CreateIndex(ctx.Tx, &database.IndexInfo{ - Paths: tc.Paths, - Unique: true, + Columns: tc.Columns, + Unique: true, Owner: database.Owner{ TableName: stmt.Info.TableName, - Paths: tc.Paths, + Columns: tc.Columns, }, KeySortOrder: tc.SortOrder, }) diff --git a/internal/query/statement/create_test.go b/internal/query/statement/create_test.go index 4e9d52a81..9a474ffb4 100644 --- a/internal/query/statement/create_test.go +++ b/internal/query/statement/create_test.go @@ -3,27 +3,10 @@ package statement_test import ( "testing" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" ) -func ParseObjectPath(t testing.TB, str string) object.Path { - vp, err := parser.ParsePath(str) - assert.NoError(t, err) - return vp -} - -func ParseObjectPaths(t testing.TB, str ...string) []object.Path { - var paths []object.Path - for _, s := range str { - paths = append(paths, ParseObjectPath(t, s)) - } - - return paths -} - func TestCreateIndex(t *testing.T) { tests := []struct { name string @@ -31,8 +14,8 @@ func TestCreateIndex(t *testing.T) { fails bool }{ {"Basic", "CREATE INDEX idx ON test (foo)", false}, - {"If not exists", "CREATE INDEX IF NOT EXISTS idx ON test (foo.bar)", false}, - {"Duplicate", "CREATE INDEX idx ON test (foo.bar);CREATE INDEX idx ON test (foo.bar)", true}, + {"If not exists", "CREATE INDEX IF NOT EXISTS idx ON test (foo)", false}, + {"Duplicate", "CREATE INDEX idx ON test (foo);CREATE INDEX idx ON test (foo)", true}, {"Unique", "CREATE UNIQUE INDEX IF NOT EXISTS idx ON test (foo)", false}, {"No name", "CREATE UNIQUE INDEX ON test (foo)", false}, {"No name if not exists", "CREATE UNIQUE INDEX IF NOT EXISTS ON test (foo)", true}, @@ -46,7 +29,7 @@ func TestCreateIndex(t *testing.T) { db, tx, cleanup := testutil.NewTestTx(t) defer cleanup() - testutil.MustExec(t, db, tx, "CREATE TABLE test(foo (bar TEXT), baz any, baf any)") + testutil.MustExec(t, db, tx, "CREATE TABLE test(foo TEXT, baz INT, baf BOOL)") err := testutil.Exec(db, tx, test.query) if test.fails { diff --git a/internal/query/statement/delete.go b/internal/query/statement/delete.go index c836e9e27..ab6702bf2 100644 --- a/internal/query/statement/delete.go +++ b/internal/query/statement/delete.go @@ -16,7 +16,7 @@ type DeleteStmt struct { TableName string WhereExpr expr.Expr OffsetExpr expr.Expr - OrderBy expr.Path + OrderBy expr.Column LimitExpr expr.Expr OrderByDirection scanner.Token } @@ -36,10 +36,15 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) { s := stream.New(table.Scan(stmt.TableName)) if stmt.WhereExpr != nil { + err := ensureExprColumnsExist(c, stmt.TableName, stmt.WhereExpr) + if err != nil { + return nil, err + } + s = s.Pipe(rows.Filter(stmt.WhereExpr)) } - if stmt.OrderBy != nil { + if stmt.OrderBy != "" { if stmt.OrderByDirection == scanner.DESC { s = s.Pipe(rows.TempTreeSortReverse(stmt.OrderBy)) } else { @@ -48,10 +53,18 @@ func (stmt *DeleteStmt) Prepare(c *Context) (Statement, error) { } if stmt.OffsetExpr != nil { + err := ensureExprColumnsExist(c, stmt.TableName, stmt.OffsetExpr) + if err != nil { + return nil, err + } s = s.Pipe(rows.Skip(stmt.OffsetExpr)) } if stmt.LimitExpr != nil { + err := ensureExprColumnsExist(c, stmt.TableName, stmt.LimitExpr) + if err != nil { + return nil, err + } s = s.Pipe(rows.Take(stmt.LimitExpr)) } diff --git a/internal/query/statement/delete_test.go b/internal/query/statement/delete_test.go index 5ed3e5c22..d755aa550 100644 --- a/internal/query/statement/delete_test.go +++ b/internal/query/statement/delete_test.go @@ -18,12 +18,12 @@ func TestDeleteStmt(t *testing.T) { params []interface{} }{ {"No cond", `DELETE FROM test`, false, "[]", nil}, - {"With cond", "DELETE FROM test WHERE b = 'bar1'", false, `[{"d": "foo3", "b": "bar2", "e": "bar3", "n": 1}]`, nil}, - {"With offset", "DELETE FROM test OFFSET 1", false, `[{"a":"foo1", "b":"bar1", "c":"baz1", "n": 3}]`, nil}, - {"With order by then offset", "DELETE FROM test ORDER BY n OFFSET 1", false, `[{"d":"foo3", "b":"bar2", "e":"bar3", "n": 1}]`, nil}, - {"With order by DESC then offset", "DELETE FROM test ORDER BY n DESC OFFSET 1", false, `[{"a": "foo1", "b": "bar1", "c": "baz1", "n": 3}]`, nil}, - {"With limit", "DELETE FROM test ORDER BY n LIMIT 2", false, `[{"a":"foo1", "b":"bar1", "c":"baz1", "n": 3}]`, nil}, - {"With order by then limit then offset", "DELETE FROM test ORDER BY n LIMIT 1 OFFSET 1", false, `[{"a": "foo1", "b": "bar1", "c": "baz1", "n": 3}, {"d": "foo3", "b": "bar2", "e": "bar3", "n": 1}]`, nil}, + {"With cond", "DELETE FROM test WHERE b = 'bar1'", false, `[{"id": 3}]`, nil}, + {"With offset", "DELETE FROM test OFFSET 1", false, `[{"id":1}]`, nil}, + {"With order by then offset", "DELETE FROM test ORDER BY n OFFSET 1", false, `[{"id": 3}]`, nil}, + {"With order by DESC then offset", "DELETE FROM test ORDER BY n DESC OFFSET 1", false, `[{"id":1}]`, nil}, + {"With limit", "DELETE FROM test ORDER BY n LIMIT 2", false, `[{"id":1}]`, nil}, + {"With order by then limit then offset", "DELETE FROM test ORDER BY n LIMIT 1 OFFSET 1", false, `[{"id":1}, {"id": 3}]`, nil}, {"Table not found", "DELETE FROM foo WHERE b = 'bar1'", true, "[]", nil}, {"Read-only table", "DELETE FROM __chai_catalog", true, "[]", nil}, } @@ -34,13 +34,13 @@ func TestDeleteStmt(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test") + err = db.Exec("CREATE TABLE test(id INT PRIMARY KEY, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, n INT)") assert.NoError(t, err) - err = db.Exec("INSERT INTO test (a, b, c, n) VALUES ('foo1', 'bar1', 'baz1', 3)") + err = db.Exec("INSERT INTO test (id, a, b, c, n) VALUES (1, 'foo1', 'bar1', 'baz1', 3)") assert.NoError(t, err) - err = db.Exec("INSERT INTO test (a, b, n) VALUES ('foo2', 'bar1', 2)") + err = db.Exec("INSERT INTO test (id, a, b, n) VALUES (2, 'foo2', 'bar1', 2)") assert.NoError(t, err) - err = db.Exec("INSERT INTO test (d, b, e, n) VALUES ('foo3', 'bar2', 'bar3', 1)") + err = db.Exec("INSERT INTO test (id, d, b, e, n) VALUES (3, 'foo3', 'bar2', 'bar3', 1)") assert.NoError(t, err) err = db.Exec(test.query, test.params...) @@ -50,7 +50,7 @@ func TestDeleteStmt(t *testing.T) { } assert.NoError(t, err) - st, err := db.Query("SELECT * FROM test") + st, err := db.Query("SELECT id FROM test") assert.NoError(t, err) defer st.Close() diff --git a/internal/query/statement/drop_test.go b/internal/query/statement/drop_test.go index a8bdb50b1..55e359c25 100644 --- a/internal/query/statement/drop_test.go +++ b/internal/query/statement/drop_test.go @@ -16,7 +16,7 @@ func TestDropTable(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test1(a INT UNIQUE); CREATE TABLE test2; CREATE TABLE test3") + err = db.Exec("CREATE TABLE test1(a INT UNIQUE); CREATE TABLE test2(a INT); CREATE TABLE test3(a INT)") assert.NoError(t, err) err = db.Exec("DROP TABLE test1") diff --git a/internal/query/statement/explain.go b/internal/query/statement/explain.go index 78c86ba03..ac4bf4679 100644 --- a/internal/query/statement/explain.go +++ b/internal/query/statement/explain.go @@ -2,6 +2,7 @@ package statement import ( "github.com/chaisql/chai/internal/expr" + "github.com/chaisql/chai/internal/planner" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/types" @@ -30,6 +31,12 @@ func (stmt *ExplainStmt) Run(ctx *Context) (Result, error) { return Result{}, errors.New("EXPLAIN only works on INSERT, SELECT, UPDATE AND DELETE statements") } + // Optimize the stream. + s.Stream, err = planner.Optimize(s.Stream, ctx.Tx.Catalog, ctx.Params) + if err != nil { + return Result{}, err + } + var plan string if s.Stream != nil { plan = s.Stream.String() diff --git a/internal/query/statement/explain_test.go b/internal/query/statement/explain_test.go index 7f3fcfc34..dd5063a74 100644 --- a/internal/query/statement/explain_test.go +++ b/internal/query/statement/explain_test.go @@ -23,10 +23,10 @@ func TestExplainStmt(t *testing.T) { {"EXPLAIN SELECT a + 1 FROM test WHERE c > 10", false, `"table.Scan(\"test\") | rows.Filter(c > 10) | rows.Project(a + 1)"`}, {"EXPLAIN SELECT a + 1 FROM test WHERE c > 10 AND d > 20", false, `"table.Scan(\"test\") | rows.Filter(c > 10) | rows.Filter(d > 20) | rows.Project(a + 1)"`}, {"EXPLAIN SELECT a + 1 FROM test WHERE c > 10 OR d > 20", false, `"table.Scan(\"test\") | rows.Filter(c > 10 OR d > 20) | rows.Project(a + 1)"`}, - {"EXPLAIN SELECT a + 1 FROM test WHERE c IN [1 + 1, 2 + 2]", false, `"table.Scan(\"test\") | rows.Filter(c IN [2, 4]) | rows.Project(a + 1)"`}, - {"EXPLAIN SELECT a + 1 FROM test WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": [10], \"exclusive\": true}]) | rows.Project(a + 1)"`}, - {"EXPLAIN SELECT a + 1 FROM test WHERE x = 10 AND y > 5", false, `"index.Scan(\"idx_x_y\", [{\"min\": [10, 5], \"exclusive\": true}]) | rows.Project(a + 1)"`}, - {"EXPLAIN SELECT a + 1 FROM test WHERE a > 10 AND b > 20 AND c > 30", false, `"index.Scan(\"idx_b\", [{\"min\": [20], \"exclusive\": true}]) | rows.Filter(a > 10) | rows.Filter(c > 30) | rows.Project(a + 1)"`}, + {"EXPLAIN SELECT a + 1 FROM test WHERE c IN (1 + 1, 2 + 2)", false, `"table.Scan(\"test\") | rows.Filter(c IN (2, 4)) | rows.Project(a + 1)"`}, + {"EXPLAIN SELECT a + 1 FROM test WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": (10), \"exclusive\": true}]) | rows.Project(a + 1)"`}, + {"EXPLAIN SELECT a + 1 FROM test WHERE x = 10 AND y > 5", false, `"index.Scan(\"idx_x_y\", [{\"min\": (10, 5), \"exclusive\": true}]) | rows.Project(a + 1)"`}, + {"EXPLAIN SELECT a + 1 FROM test WHERE a > 10 AND b > 20 AND c > 30", false, `"index.Scan(\"idx_b\", [{\"min\": (20), \"exclusive\": true}]) | rows.Filter(a > 10) | rows.Filter(c > 30) | rows.Project(a + 1)"`}, {"EXPLAIN SELECT a + 1 FROM test WHERE c > 30 ORDER BY d LIMIT 10 OFFSET 20", false, `"table.Scan(\"test\") | rows.Filter(c > 30) | rows.Project(a + 1) | rows.TempTreeSort(d) | rows.Skip(20) | rows.Take(10)"`}, {"EXPLAIN SELECT a + 1 FROM test WHERE c > 30 ORDER BY d DESC LIMIT 10 OFFSET 20", false, `"table.Scan(\"test\") | rows.Filter(c > 30) | rows.Project(a + 1) | rows.TempTreeSortReverse(d) | rows.Skip(20) | rows.Take(10)"`}, {"EXPLAIN SELECT a + 1 FROM test WHERE c > 30 ORDER BY a DESC LIMIT 10 OFFSET 20", false, `"index.ScanReverse(\"idx_a\") | rows.Filter(c > 30) | rows.Project(a + 1) | rows.Skip(20) | rows.Take(10)"`}, @@ -34,10 +34,10 @@ func TestExplainStmt(t *testing.T) { {"EXPLAIN SELECT a + 1 FROM test WHERE c > 30 GROUP BY a + 1 ORDER BY a DESC LIMIT 10 OFFSET 20", false, `"table.Scan(\"test\") | rows.Filter(c > 30) | rows.TempTreeSort(a + 1) | rows.GroupAggregate(a + 1) | rows.Project(a + 1) | rows.TempTreeSortReverse(a) | rows.Skip(20) | rows.Take(10)"`}, {"EXPLAIN UPDATE test SET a = 10", false, `"table.Scan(\"test\") | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, {"EXPLAIN UPDATE test SET a = 10 WHERE c > 10", false, `"table.Scan(\"test\") | rows.Filter(c > 10) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, - {"EXPLAIN UPDATE test SET a = 10 WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": [10], \"exclusive\": true}]) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, + {"EXPLAIN UPDATE test SET a = 10 WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": (10), \"exclusive\": true}]) | paths.Set(a, 10) | table.Validate(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Replace(\"test\") | index.Insert(\"idx_a\") | index.Validate(\"idx_b\") | index.Insert(\"idx_b\") | index.Insert(\"idx_x_y\") | discard()"`}, {"EXPLAIN DELETE FROM test", false, `"table.Scan(\"test\") | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, {"EXPLAIN DELETE FROM test WHERE c > 10", false, `"table.Scan(\"test\") | rows.Filter(c > 10) | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, - {"EXPLAIN DELETE FROM test WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": [10], \"exclusive\": true}]) | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, + {"EXPLAIN DELETE FROM test WHERE a > 10", false, `"index.Scan(\"idx_a\", [{\"min\": (10), \"exclusive\": true}]) | index.Delete(\"idx_a\") | index.Delete(\"idx_b\") | index.Delete(\"idx_x_y\") | table.Delete('test') | discard()"`}, } for _, test := range tests { @@ -46,7 +46,7 @@ func TestExplainStmt(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test (k INTEGER PRIMARY KEY, a any, b any, x any, y any)") + err = db.Exec("CREATE TABLE test (k INTEGER PRIMARY KEY, a INT, b INT, c INT, d INT, x INT, y INT)") assert.NoError(t, err) err = db.Exec(` CREATE INDEX idx_a ON test (a); diff --git a/internal/query/statement/insert.go b/internal/query/statement/insert.go index c46b5250b..515237df2 100644 --- a/internal/query/statement/insert.go +++ b/internal/query/statement/insert.go @@ -17,7 +17,7 @@ type InsertStmt struct { TableName string Values []expr.Expr - Fields []string + Columns []string SelectStmt Preparer Returning []expr.Expr OnConflict database.OnConflictAction @@ -43,35 +43,48 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) { return nil, err } - // if no fields have been specified, we need to inject the fields from the defined table info - if len(stmt.Fields) == 0 { + var rowList []expr.Row + // if no columns have been specified, we need to inject the columns from the defined table info + if len(stmt.Columns) == 0 { + rowList = make([]expr.Row, 0, len(stmt.Values)) for i := range stmt.Values { - kvs, ok := stmt.Values[i].(*expr.KVPairs) + var r expr.Row + var ok bool + + r.Exprs, ok = stmt.Values[i].(expr.LiteralExprList) if !ok { continue } - for i := range kvs.Pairs { - if kvs.Pairs[i].K == "" { - if i >= len(ti.FieldConstraints.Ordered) { - return nil, errors.Errorf("too many values for %s", stmt.TableName) - } - - kvs.Pairs[i].K = ti.FieldConstraints.Ordered[i].Field - } + for i := range r.Exprs { + r.Columns = append(r.Columns, ti.ColumnConstraints.Ordered[i].Column) } + + rowList = append(rowList, r) } } else { - if !ti.FieldConstraints.AllowExtraFields { - for i := range stmt.Fields { - _, ok := ti.FieldConstraints.ByField[stmt.Fields[i]] - if !ok { - return nil, errors.Errorf("table has no field %s", stmt.Fields[i]) - } + rowList = make([]expr.Row, 0, len(stmt.Values)) + for i := range stmt.Columns { + _, ok := ti.ColumnConstraints.ByColumn[stmt.Columns[i]] + if !ok { + return nil, errors.Errorf("table has no column %s", stmt.Columns[i]) + } + } + + for i := range stmt.Values { + var r expr.Row + var ok bool + + r.Exprs, ok = stmt.Values[i].(expr.LiteralExprList) + if !ok { + continue } + + r.Columns = stmt.Columns + rowList = append(rowList, r) } } - s = stream.New(rows.Emit(stmt.Values...)) + s = stream.New(rows.Emit(rowList...)) } else { selectStream, err := stmt.SelectStmt.Prepare(c) if err != nil { @@ -86,8 +99,8 @@ func (stmt *InsertStmt) Prepare(c *Context) (Statement, error) { return nil, errors.New("cannot read and write to the same table") } - if len(stmt.Fields) > 0 { - s = s.Pipe(path.PathsRename(stmt.Fields...)) + if len(stmt.Columns) > 0 { + s = s.Pipe(path.PathsRename(stmt.Columns...)) } } diff --git a/internal/query/statement/insert_test.go b/internal/query/statement/insert_test.go index 9a3185ebf..e2f5208bc 100644 --- a/internal/query/statement/insert_test.go +++ b/internal/query/statement/insert_test.go @@ -19,11 +19,9 @@ func TestInsertStmt(t *testing.T) { expected string params []interface{} }{ - {"Values / Positional Params", "INSERT INTO test (a, b, c) VALUES (?, 'e', ?)", false, `[{"pk()":[1],"a":"d","b":"e","c":"f"}]`, []interface{}{"d", "f"}}, - {"Values / Named Params", "INSERT INTO test (a, b, c) VALUES ($d, 'e', $f)", false, `[{"pk()":[1],"a":"d","b":"e","c":"f"}]`, []interface{}{sql.Named("f", "f"), sql.Named("d", "d")}}, + {"Values / Positional Params", "INSERT INTO test (a, b, c) VALUES (?, 'e', ?)", false, `[{"a":"d","b":"e","c":"f"}]`, []interface{}{"d", "f"}}, + {"Values / Named Params", "INSERT INTO test (a, b, c) VALUES ($d, 'e', $f)", false, `[{"a":"d","b":"e","c":"f"}]`, []interface{}{sql.Named("f", "f"), sql.Named("d", "d")}}, {"Values / Invalid params", "INSERT INTO test (a, b, c) VALUES ('d', ?)", true, "", []interface{}{'e'}}, - {"Objects / Named Params", "INSERT INTO test VALUES {a: $a, b: 2.3, c: $c}", false, `[{"pk()":[1],"a":1,"b":2.3,"c":true}]`, []interface{}{sql.Named("c", true), sql.Named("a", 1)}}, - {"Objects / List ", "INSERT INTO test VALUES {a: [1, 2, 3]}", false, `[{"pk()":[1],"a":[1,2,3]}]`, nil}, {"Select / same table", "INSERT INTO test SELECT * FROM test", true, ``, nil}, } @@ -34,7 +32,7 @@ func TestInsertStmt(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test(a any, b any, c any)") + err = db.Exec("CREATE TABLE test(a TEXT, b TEXT, c TEXT)") assert.NoError(t, err) if withIndexes { err = db.Exec(` @@ -52,7 +50,7 @@ func TestInsertStmt(t *testing.T) { } assert.NoError(t, err) - st, err := db.Query("SELECT pk(), * FROM test") + st, err := db.Query("SELECT * FROM test") assert.NoError(t, err) defer st.Close() @@ -67,41 +65,17 @@ func TestInsertStmt(t *testing.T) { t.Run("With Index/"+test.name, testFn(true)) } - t.Run("with struct param", func(t *testing.T) { - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec("CREATE TABLE test") - assert.NoError(t, err) - - type foo struct { - A string - B string `chai:"b-b"` - } - - err = db.Exec("INSERT INTO test VALUES ?", &foo{A: "a", B: "b"}) - assert.NoError(t, err) - res, err := db.Query("SELECT * FROM test") - defer res.Close() - - assert.NoError(t, err) - buf, err := res.MarshalJSON() - assert.NoError(t, err) - require.JSONEq(t, `[{"a": "a", "b-b": "b"}]`, string(buf)) - }) - t.Run("with RETURNING", func(t *testing.T) { db, err := chai.Open(":memory:") assert.NoError(t, err) defer db.Close() - err = db.Exec(`CREATE TABLE test`) + err = db.Exec(`CREATE TABLE test(a INT)`) assert.NoError(t, err) - d, err := db.QueryRow(`insert into test (a) VALUES (1) RETURNING *, pk(), a AS A`) + d, err := db.QueryRow(`insert into test (a) VALUES (1) RETURNING *, a AS A`) assert.NoError(t, err) - testutil.RequireJSONEq(t, d, `{"a": 1, "pk()": [1], "A": 1}`) + testutil.RequireJSONEq(t, d, `{"a": 1, "A": 1}`) }) t.Run("ensure rollback", func(t *testing.T) { @@ -119,7 +93,7 @@ func TestInsertStmt(t *testing.T) { assert.NoError(t, err) defer res.Close() - testutil.RequireStreamEq(t, ``, res, false) + testutil.RequireStreamEq(t, ``, res) }) t.Run("with NEXT VALUE FOR", func(t *testing.T) { @@ -159,14 +133,14 @@ func TestInsertSelect(t *testing.T) { params []interface{} }{ {"Same table", `INSERT INTO foo SELECT * FROM foo`, true, ``, nil}, - {"No fields / No projection", `INSERT INTO foo SELECT * FROM bar`, false, `[{"pk()":[1], "a":1, "b":10}]`, nil}, - {"No fields / Projection", `INSERT INTO foo SELECT a FROM bar`, false, `[{"pk()":[1], "a":1}]`, nil}, - {"With fields / No Projection", `INSERT INTO foo (a, b) SELECT * FROM bar`, false, `[{"pk()":[1], "a":1, "b":10}]`, nil}, - {"With fields / Projection", `INSERT INTO foo (c, d) SELECT a, b FROM bar`, false, `[{"pk()":[1], "c":1, "d":10}]`, nil}, - {"Too many fields / No Projection", `INSERT INTO foo (c) SELECT * FROM bar`, true, ``, nil}, - {"Too many fields / Projection", `INSERT INTO foo (c, d) SELECT a, b, c FROM bar`, true, ``, nil}, - {"Too few fields / No Projection", `INSERT INTO foo (c, d, e) SELECT * FROM bar`, true, ``, nil}, - {"Too few fields / Projection", `INSERT INTO foo (c, d) SELECT a FROM bar`, true, ``, nil}, + {"No columns / No projection", `INSERT INTO foo SELECT * FROM bar`, false, `[{"a":1, "b":10, "c":null, "d":null, "e":null}]`, nil}, + {"No columns / Projection", `INSERT INTO foo SELECT a FROM bar`, false, `[{"a":1, "b":null, "c":null, "d":null, "e":null}]`, nil}, + {"With columns / No Projection", `INSERT INTO foo (a, b) SELECT * FROM bar`, true, ``, nil}, + {"With columns / Projection", `INSERT INTO foo (c, d) SELECT a, b FROM bar`, false, `[{"a":null, "b":null, "c":1, "d":10, "e":null}]`, nil}, + {"Too many columns / No Projection", `INSERT INTO foo (c) SELECT * FROM bar`, true, ``, nil}, + {"Too many columns / Projection", `INSERT INTO foo (c, d) SELECT a, b, c FROM bar`, true, ``, nil}, + {"Too few columns / No Projection", `INSERT INTO foo (c, d, e) SELECT * FROM bar`, true, ``, nil}, + {"Too few columns / Projection", `INSERT INTO foo (c, d) SELECT a FROM bar`, true, ``, nil}, } for _, test := range tests { @@ -176,8 +150,8 @@ func TestInsertSelect(t *testing.T) { defer db.Close() err = db.Exec(` - CREATE TABLE foo; - CREATE TABLE bar; + CREATE TABLE foo(a INT, b INT, c INT, d INT, e INT); + CREATE TABLE bar(a INT, b INT, c INT, d INT, e INT); INSERT INTO bar (a, b) VALUES (1, 10) `) assert.NoError(t, err) @@ -189,7 +163,7 @@ func TestInsertSelect(t *testing.T) { } assert.NoError(t, err) - st, err := db.Query("SELECT pk(), * FROM foo") + st, err := db.Query("SELECT * FROM foo") assert.NoError(t, err) defer st.Close() diff --git a/internal/query/statement/reindex_test.go b/internal/query/statement/reindex_test.go index 1d7eaa5fc..3a27738fe 100644 --- a/internal/query/statement/reindex_test.go +++ b/internal/query/statement/reindex_test.go @@ -29,8 +29,8 @@ func TestReIndex(t *testing.T) { defer cleanup() testutil.MustExec(t, db, tx, ` - CREATE TABLE test1(a any, b any); - CREATE TABLE test2(a any, b any); + CREATE TABLE test1(a TEXT, b TEXT); + CREATE TABLE test2(a TEXT, b TEXT); CREATE INDEX idx_test1_a ON test1(a); CREATE INDEX idx_test1_b ON test1(b); diff --git a/internal/query/statement/select.go b/internal/query/statement/select.go index ed3ed5e96..4f3315853 100644 --- a/internal/query/statement/select.go +++ b/internal/query/statement/select.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" @@ -20,21 +19,35 @@ type SelectCoreStmt struct { ProjectionExprs []expr.Expr } -func (stmt *SelectCoreStmt) Prepare(*Context) (*StreamStmt, error) { +func (stmt *SelectCoreStmt) Prepare(ctx *Context) (*StreamStmt, error) { isReadOnly := true var s *stream.Stream if stmt.TableName != "" { + _, err := ctx.Tx.Catalog.GetTableInfo(stmt.TableName) + if err != nil { + return nil, err + } + s = s.Pipe(table.Scan(stmt.TableName)) } if stmt.WhereExpr != nil { + err := ensureExprColumnsExist(ctx, stmt.TableName, stmt.WhereExpr) + if err != nil { + return nil, err + } + s = s.Pipe(rows.Filter(stmt.WhereExpr)) } // when using GROUP BY, only aggregation functions or GroupByExpr can be selected if stmt.GroupByExpr != nil { + err := ensureExprColumnsExist(ctx, stmt.TableName, stmt.GroupByExpr) + if err != nil { + return nil, err + } var invalidProjectedField expr.Expr var aggregators []expr.AggregatorBuilder @@ -54,10 +67,10 @@ func (stmt *SelectCoreStmt) Prepare(*Context) (*StreamStmt, error) { // check if this is the same expression as the one used in the GROUP BY clause if expr.Equal(e, stmt.GroupByExpr) { - // if so, replace the expression with a path expression + // if so, replace the expression with a column expression stmt.ProjectionExprs[i] = &expr.NamedExpr{ ExprName: ne.ExprName, - Expr: expr.Path(object.NewPath(e.String())), + Expr: expr.Column(e.String()), } continue } @@ -79,16 +92,23 @@ func (stmt *SelectCoreStmt) Prepare(*Context) (*StreamStmt, error) { var aggregators []expr.AggregatorBuilder for _, pe := range stmt.ProjectionExprs { - ne, ok := pe.(*expr.NamedExpr) - if !ok { - continue - } - e := ne.Expr + expr.Walk(pe, func(e expr.Expr) bool { + // check if the projected expression contains an aggregation function + if agg, ok := e.(expr.AggregatorBuilder); ok { + aggregators = append(aggregators, agg) + return true + } - // check if the projected expression is an aggregation function - if agg, ok := e.(expr.AggregatorBuilder); ok { - aggregators = append(aggregators, agg) - } + if c, ok := e.(expr.Column); ok { + // check if the projected expression is a column + err := ensureExprColumnsExist(ctx, stmt.TableName, c) + if err != nil { + return false + } + } + + return true + }) } // add Aggregation node @@ -104,7 +124,7 @@ func (stmt *SelectCoreStmt) Prepare(*Context) (*StreamStmt, error) { for _, e := range stmt.ProjectionExprs { expr.Walk(e, func(e expr.Expr) bool { switch e.(type) { - case expr.Path, expr.Wildcard: + case expr.Column, expr.Wildcard: err = errors.New("no tables specified") return false default: @@ -148,7 +168,7 @@ type SelectStmt struct { CompoundSelect []*SelectCoreStmt CompoundOperators []scanner.Token - OrderBy expr.Path + OrderBy expr.Column OrderByDirection scanner.Token OffsetExpr expr.Expr LimitExpr expr.Expr @@ -211,7 +231,7 @@ func (stmt *SelectStmt) Prepare(ctx *Context) (Statement, error) { prev = tok } - if stmt.OrderBy != nil { + if stmt.OrderBy != "" { if stmt.OrderByDirection == scanner.DESC { s = s.Pipe(rows.TempTreeSortReverse(stmt.OrderBy)) } else { diff --git a/internal/query/statement/select_test.go b/internal/query/statement/select_test.go index 26f9d5b32..57ec3ea4c 100644 --- a/internal/query/statement/select_test.go +++ b/internal/query/statement/select_test.go @@ -20,69 +20,67 @@ func TestSelectStmt(t *testing.T) { expected string params []interface{} }{ - // {"No table, Add", "SELECT 1 + 1", false, `[{"1 + 1":2}]`, nil}, - // {"No table, Mult", "SELECT 2 * 3", false, `[{"2 * 3":6}]`, nil}, - // {"No table, Div", "SELECT 10 / 6", false, `[{"10 / 6":1}]`, nil}, - // {"No table, Mod", "SELECT 10 % 6", false, `[{"10 % 6":4}]`, nil}, - // {"No table, BitwiseAnd", "SELECT 10 & 6", false, `[{"10 & 6":2}]`, nil}, - // {"No table, BitwiseOr", "SELECT 10 | 6", false, `[{"10 | 6":14}]`, nil}, - // {"No table, BitwiseXor", "SELECT 10 ^ 6", false, `[{"10 ^ 6":12}]`, nil}, - // {"No table, function pk()", "SELECT pk()", false, `[{"pk()":[n]ull}]`, nil}, - // {"No table, field", "SELECT a", true, ``, nil}, - // {"No table, wildcard", "SELECT *", true, ``, nil}, - // {"No table, object", "SELECT {a: 1, b: 2 + 1}", false, `[{"{a: 1, b: 2 + 1}":{"a":1,"b":3}}]`, nil}, - // {"No cond", "SELECT * FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil}, - // {"No cond Multiple wildcards", "SELECT *, *, color FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square","k":1,"color":"red","size":10,"shape":"square","color":"red"},{"k":2,"color":"blue","size":10,"weight":100,"k":2,"color":"blue","size":10,"weight":100,"color":"blue"},{"k":3,"height":100,"weight":200,"k":3,"height":100,"weight":200,"color":null}]`, nil}, - // {"With fields", "SELECT color, shape FROM test", false, `[{"color":"red","shape":"square"},{"color":"blue","shape":null},{"color":null,"shape":null}]`, nil}, - // {"No cond, wildcard and other field", "SELECT *, color FROM test", false, `[{"color": "red", "k": 1, "color": "red", "size": 10, "shape": "square"}, {"color": "blue", "k": 2, "color": "blue", "size": 10, "weight": 100}, {"color": null, "k": 3, "height": 100, "weight": 200}]`, nil}, - // {"With DISTINCT", "SELECT DISTINCT * FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil}, - // {"With DISTINCT and expr", "SELECT DISTINCT 'a' FROM test", false, `[{"\"a\"":"a"}]`, nil}, - // {"With expr fields", "SELECT color, color != 'red' AS notred FROM test", false, `[{"color":"red","notred":false},{"color":"blue","notred":true},{"color":null,"notred":null}]`, nil}, - // {"With eq op", "SELECT * FROM test WHERE size = 10", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - // {"With neq op", "SELECT * FROM test WHERE color != 'red'", false, `[{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - // {"With gt op", "SELECT * FROM test WHERE size > 10", false, `[]`, nil}, - // {"With gt bis", "SELECT * FROM test WHERE size > 9", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - // {"With lt op", "SELECT * FROM test WHERE size < 15", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - // {"With lte op", "SELECT * FROM test WHERE color <= 'salmon' ORDER BY k ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - // {"With add op", "SELECT size + 10 AS s FROM test ORDER BY k", false, `[{"s":20},{"s":20},{"s":null}]`, nil}, - // {"With sub op", "SELECT size - 10 AS s FROM test ORDER BY k", false, `[{"s":0},{"s":0},{"s":null}]`, nil}, - // {"With mul op", "SELECT size * 10 AS s FROM test ORDER BY k", false, `[{"s":100},{"s":100},{"s":null}]`, nil}, - // {"With div op", "SELECT size / 10 AS s FROM test ORDER BY k", false, `[{"s":1},{"s":1},{"s":null}]`, nil}, - // {"With IN op", "SELECT color FROM test WHERE color IN ['red', 'purple'] ORDER BY k", false, `[{"color":"red"}]`, nil}, - // {"With IN op on PK", "SELECT color FROM test WHERE k IN [1.1, 1.0] ORDER BY k", false, `[{"color":"red"}]`, nil}, - // {"With NOT IN op", "SELECT color FROM test WHERE color NOT IN ['red', 'purple'] ORDER BY k", false, `[{"color":"blue"}]`, nil}, - {"With field comparison", "SELECT * FROM test WHERE color < shape", false, `[{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, + {"No table, Add", "SELECT 1 + 1", false, `[{"1 + 1":2}]`, nil}, + {"No table, Mult", "SELECT 2 * 3", false, `[{"2 * 3":6}]`, nil}, + {"No table, Div", "SELECT 10 / 6", false, `[{"10 / 6":1}]`, nil}, + {"No table, Mod", "SELECT 10 % 6", false, `[{"10 % 6":4}]`, nil}, + {"No table, BitwiseAnd", "SELECT 10 & 6", false, `[{"10 & 6":2}]`, nil}, + {"No table, BitwiseOr", "SELECT 10 | 6", false, `[{"10 | 6":14}]`, nil}, + {"No table, BitwiseXor", "SELECT 10 ^ 6", false, `[{"10 ^ 6":12}]`, nil}, + {"No table, column", "SELECT a", true, ``, nil}, + {"No table, wildcard", "SELECT *", true, ``, nil}, + {"No cond", "SELECT * FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"No cond Multiple wildcards", "SELECT *, *, color FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null,"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100,"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200,"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With columns", "SELECT color, shape FROM test", false, `[{"color":"red","shape":"square"},{"color":"blue","shape":null},{"color":null,"shape":null}]`, nil}, + {"No cond, wildcard and other column", "SELECT *, color FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null,"color":"red"}, {"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100,"color":"blue"}, {"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200,"color":null}]`, nil}, + {"With DISTINCT", "SELECT DISTINCT * FROM test", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With DISTINCT and expr", "SELECT DISTINCT 'a' FROM test", false, `[{"\"a\"":"a"}]`, nil}, + {"With expr columns", "SELECT color, color != 'red' AS notred FROM test", false, `[{"color":"red","notred":false},{"color":"blue","notred":true},{"color":null,"notred":null}]`, nil}, + {"With eq op", "SELECT * FROM test WHERE size = 10", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With neq op", "SELECT * FROM test WHERE color != 'red'", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With gt op", "SELECT * FROM test WHERE size > 10", false, `[]`, nil}, + {"With gt bis", "SELECT * FROM test WHERE size > 9", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With lt op", "SELECT * FROM test WHERE size < 15", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With lte op", "SELECT * FROM test WHERE color <= 'salmon' ORDER BY k ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With add op", "SELECT size + 10 AS s FROM test ORDER BY k", false, `[{"s":20},{"s":20},{"s":null}]`, nil}, + {"With sub op", "SELECT size - 10 AS s FROM test ORDER BY k", false, `[{"s":0},{"s":0},{"s":null}]`, nil}, + {"With mul op", "SELECT size * 10 AS s FROM test ORDER BY k", false, `[{"s":100},{"s":100},{"s":null}]`, nil}, + {"With div op", "SELECT size / 10 AS s FROM test ORDER BY k", false, `[{"s":1},{"s":1},{"s":null}]`, nil}, + {"With IN op", "SELECT color FROM test WHERE color IN ('red', 'purple') ORDER BY k", false, `[{"color":"red"}]`, nil}, + {"With IN op on PK", "SELECT color FROM test WHERE k IN (1.1, 1.0) ORDER BY k", false, `[{"color":"red"}]`, nil}, + {"With NOT IN op", "SELECT color FROM test WHERE color NOT IN ('red', 'purple') ORDER BY k", false, `[{"color":"blue"}]`, nil}, + {"With column comparison", "SELECT * FROM test WHERE color < shape", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, {"With group by", "SELECT color FROM test GROUP BY color", false, `[{"color":null},{"color":"blue"},{"color":"red"}]`, nil}, {"With group by expr", "SELECT weight / 2 as half FROM test GROUP BY weight / 2", false, `[{"half":null},{"half":50},{"half":100}]`, nil}, {"With group by and count", "SELECT COUNT(k) FROM test GROUP BY size", false, `[{"COUNT(k)":1},{"COUNT(k)":2}]`, nil}, {"With group by and count wildcard", "SELECT COUNT(* ) FROM test GROUP BY size", false, `[{"COUNT(*)":1},{"COUNT(*)":2}]`, nil}, - {"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, + {"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, {"With invalid group by / wildcard", "SELECT * FROM test WHERE age = 10 GROUP BY a.b.c", true, ``, nil}, {"With invalid group by / a.b", "SELECT a.b FROM test WHERE age = 10 GROUP BY a.b.c", true, ``, nil}, - {"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With order by asc", "SELECT * FROM test ORDER BY color ASC", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With order by asc numeric", "SELECT * FROM test ORDER BY weight ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil}, - {"With order by asc with limit 2", "SELECT * FROM test ORDER BY color LIMIT 2", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - {"With order by asc with limit 1", "SELECT * FROM test ORDER BY color LIMIT 1", false, `[{"k":3,"height":100,"weight":200}]`, nil}, - {"With order by asc with offset", "SELECT * FROM test ORDER BY color OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With order by asc with limit offset", "SELECT * FROM test ORDER BY color LIMIT 1 OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - {"With order by desc", "SELECT * FROM test ORDER BY color DESC", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil}, - {"With order by desc numeric", "SELECT * FROM test ORDER BY weight DESC", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With order by desc with limit", "SELECT * FROM test ORDER BY color DESC LIMIT 2", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - {"With order by desc with offset", "SELECT * FROM test ORDER BY color DESC OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil}, - {"With order by desc with limit offset", "SELECT * FROM test ORDER BY color DESC LIMIT 1 OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"weight":100}]`, nil}, - {"With order by pk asc", "SELECT * FROM test ORDER BY k ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":2,"color":"blue","size":10,"weight":100},{"k":3,"height":100,"weight":200}]`, nil}, - {"With order by pk desc", "SELECT * FROM test ORDER BY k DESC", false, `[{"k":3,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"weight":100},{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With order by and where", "SELECT * FROM test WHERE color != 'blue' ORDER BY color DESC LIMIT 1", false, `[{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With limit", "SELECT * FROM test WHERE size = 10 LIMIT 1", false, `[{"k":1,"color":"red","size":10,"shape":"square"}]`, nil}, - {"With offset", "SELECT *, pk() FROM test WHERE size = 10 OFFSET 1", false, `[{"pk()":[2],"color":"blue","size":10,"weight":100,"k":2}]`, nil}, - {"With limit then offset", "SELECT * FROM test WHERE size = 10 LIMIT 1 OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"weight":100,"k":2}]`, nil}, + {"With order by", "SELECT * FROM test ORDER BY color", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With order by asc", "SELECT * FROM test ORDER BY color ASC", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With order by asc numeric", "SELECT * FROM test ORDER BY weight ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With order by asc with limit 2", "SELECT * FROM test ORDER BY color LIMIT 2", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With order by asc with limit 1", "SELECT * FROM test ORDER BY color LIMIT 1", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With order by asc with offset", "SELECT * FROM test ORDER BY color OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With order by asc with limit offset", "SELECT * FROM test ORDER BY color LIMIT 1 OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With order by desc", "SELECT * FROM test ORDER BY color DESC", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With order by desc numeric", "SELECT * FROM test ORDER BY weight DESC", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With order by desc with limit", "SELECT * FROM test ORDER BY color DESC LIMIT 2", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With order by desc with offset", "SELECT * FROM test ORDER BY color DESC OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With order by desc with limit offset", "SELECT * FROM test ORDER BY color DESC LIMIT 1 OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With order by pk asc", "SELECT * FROM test ORDER BY k ASC", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, nil}, + {"With order by pk desc", "SELECT * FROM test ORDER BY k DESC", false, `[{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200},{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100},{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With order by and where", "SELECT * FROM test WHERE color != 'blue' ORDER BY color DESC LIMIT 1", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With limit", "SELECT * FROM test WHERE size = 10 LIMIT 1", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null}]`, nil}, + {"With offset", "SELECT * FROM test WHERE size = 10 OFFSET 1", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With limit then offset", "SELECT * FROM test WHERE size = 10 LIMIT 1 OFFSET 1", false, `[{"k":2,"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, {"With offset then limit", "SELECT * FROM test WHERE size = 10 OFFSET 1 LIMIT 1", true, "", nil}, - {"With positional params", "SELECT * FROM test WHERE color = ? OR height = ?", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":3,"height":100,"weight":200}]`, []interface{}{"red", 100}}, - {"With named params", "SELECT * FROM test WHERE color = $a OR height = $d", false, `[{"k":1,"color":"red","size":10,"shape":"square"},{"k":3,"height":100,"weight":200}]`, []interface{}{sql.Named("a", "red"), sql.Named("d", 100)}}, - {"With pk()", "SELECT pk(), color FROM test", false, `[{"pk()":[1],"color":"red"},{"pk()":[2],"color":"blue"},{"pk()":[3],"color":null}]`, []interface{}{sql.Named("a", "red"), sql.Named("d", 100)}}, - {"With pk in cond, gt", "SELECT * FROM test WHERE k > 0 AND weight = 100", false, `[{"k":2,"color":"blue","size":10,"weight":100,"k":2}]`, nil}, - {"With pk in cond, =", "SELECT * FROM test WHERE k = 2.0 AND weight = 100", false, `[{"k":2,"color":"blue","size":10,"weight":100,"k":2}]`, nil}, + {"With positional params", "SELECT * FROM test WHERE color = ? OR height = ?", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, []interface{}{"red", 100}}, + {"With named params", "SELECT * FROM test WHERE color = $a OR height = $d", false, `[{"k":1,"color":"red","size":10,"shape":"square","height":null,"weight":null},{"k":3,"color":null,"size":null,"shape":null,"height":100,"weight":200}]`, []interface{}{sql.Named("a", "red"), sql.Named("d", 100)}}, + {"With pk()", "SELECT color FROM test", false, `[{"color":"red"},{"color":"blue"},{"color":null}]`, []interface{}{sql.Named("a", "red"), sql.Named("d", 100)}}, + {"With pk in cond, gt", "SELECT * FROM test WHERE k > 0 AND weight = 100", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, + {"With pk in cond, =", "SELECT * FROM test WHERE k = 2.0 AND weight = 100", false, `[{"k":2,"color":"blue","size":10,"shape":null,"height":null,"weight":100}]`, nil}, {"With count", "SELECT COUNT(k) FROM test", false, `[{"COUNT(k)": 3}]`, nil}, {"With count wildcard", "SELECT COUNT(*) FROM test", false, `[{"COUNT(*)": 3}]`, nil}, {"With multiple counts", "SELECT COUNT(k), COUNT(color) FROM test", false, `[{"COUNT(k)": 3, "COUNT(color)": 2}]`, nil}, @@ -92,11 +90,9 @@ func TestSelectStmt(t *testing.T) { {"With multiple maxs", "SELECT MAX(color), MAX(weight) FROM test", false, `[{"MAX(color)": "red", "MAX(weight)": 200}]`, nil}, {"With sum", "SELECT SUM(k) FROM test", false, `[{"SUM(k)": 6}]`, nil}, {"With multiple sums", "SELECT SUM(color), SUM(weight) FROM test", false, `[{"SUM(color)": null, "SUM(weight)": 300}]`, nil}, - {"With two non existing idents, =", "SELECT * FROM test WHERE z = y", false, `[]`, nil}, - {"With two non existing idents, >", "SELECT * FROM test WHERE z > y", false, `[]`, nil}, - {"With two non existing idents, !=", "SELECT * FROM test WHERE z != y", false, `[]`, nil}, - // See issue https://github.com/chaisql/chai/issues/283 - {"With empty WHERE and IN", "SELECT * FROM test WHERE [] IN [];", false, `[]`, nil}, + {"With two non existing idents, =", "SELECT * FROM test WHERE z = y", true, ``, nil}, + {"With two non existing idents, >", "SELECT * FROM test WHERE z > y", true, ``, nil}, + {"With two non existing idents, !=", "SELECT * FROM test WHERE z != y", true, ``, nil}, {"Invalid use of MIN() aggregator", "SELECT * FROM test LIMIT min(0)", true, ``, nil}, {"Invalid use of COUNT() aggregator", "SELECT * FROM test OFFSET count(*)", true, ``, nil}, {"Invalid use of MAX() aggregator", "SELECT * FROM test LIMIT max(0)", true, ``, nil}, @@ -186,39 +182,6 @@ func TestSelectStmt(t *testing.T) { require.JSONEq(t, `[{"foo": 2, "bar": "b"},{"foo": 3, "bar": "c"},{"foo": 4, "bar": "d"}]`, buf.String()) }) - t.Run("with objects", func(t *testing.T) { - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec("CREATE TABLE test") - assert.NoError(t, err) - - err = db.Exec(`INSERT INTO test VALUES {a: {b: 1}}, {a: 1}, {a: [1, 2, [8,9]]}`) - assert.NoError(t, err) - - call := func(q string, res ...string) { - st, err := db.Query(q) - assert.NoError(t, err) - defer st.Close() - - var i int - err = st.Iterate(func(r *chai.Row) error { - data, err := r.MarshalJSON() - assert.NoError(t, err) - require.JSONEq(t, res[i], string(data)) - i++ - return nil - }) - assert.NoError(t, err) - } - - call("SELECT *, a.b FROM test WHERE a = {b: 1}", `{"a": {"b":1}, "a.b": 1}`) - call("SELECT a.b FROM test", `{"a.b": 1}`, `{"a.b": null}`, `{"a.b": null}`) - call("SELECT a[1] FROM test", `{"a[1]": null}`, `{"a[1]": null}`, `{"a[1]": 2}`) - call("SELECT a[2][1] FROM test", `{"a[2][1]": null}`, `{"a[2][1]": null}`, `{"a[2][1]": 9}`) - }) - t.Run("table not found", func(t *testing.T) { db, err := chai.Open(":memory:") assert.NoError(t, err) @@ -233,10 +196,10 @@ func TestSelectStmt(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test(foo any); CREATE INDEX idx_foo ON test(foo);") + err = db.Exec("CREATE TABLE test(foo INT); CREATE INDEX idx_foo ON test(foo);") assert.NoError(t, err) - err = db.Exec(`INSERT INTO test (foo) VALUES (1), ('hello'), (2), (true)`) + err = db.Exec(`INSERT INTO test (foo) VALUES (4), (2), (1), (3)`) assert.NoError(t, err) st, err := db.Query("SELECT * FROM test ORDER BY foo") @@ -246,25 +209,7 @@ func TestSelectStmt(t *testing.T) { var buf bytes.Buffer err = st.MarshalJSONTo(&buf) assert.NoError(t, err) - require.JSONEq(t, `[{"foo": true},{"foo": 1}, {"foo": 2},{"foo": "hello"}]`, buf.String()) - }) - - // https://github.com/chaisql/chai/issues/208 - t.Run("group by with arrays", func(t *testing.T) { - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec("CREATE TABLE test; INSERT INTO test (a) VALUES ([1, 2, 3]);") - assert.NoError(t, err) - - d, err := db.QueryRow("SELECT MAX(a) from test GROUP BY a") - assert.NoError(t, err) - - enc, err := json.Marshal(d) - assert.NoError(t, err) - - require.JSONEq(t, `{"MAX(a)": [1, 2, 3]}`, string(enc)) + require.JSONEq(t, `[{"foo": 1},{"foo": 2}, {"foo": 3},{"foo": 4}]`, buf.String()) }) t.Run("empty table with aggregators", func(t *testing.T) { @@ -272,7 +217,7 @@ func TestSelectStmt(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test;") + err = db.Exec("CREATE TABLE test(a INTEGER, b INTEGER, id INTEGER PRIMARY KEY);") assert.NoError(t, err) d, err := db.QueryRow("SELECT MAX(a), MIN(b), COUNT(*), SUM(id) FROM test") @@ -284,44 +229,13 @@ func TestSelectStmt(t *testing.T) { require.JSONEq(t, `{"MAX(a)": null, "MIN(b)": null, "COUNT(*)": 0, "SUM(id)": null}`, string(enc)) }) - t.Run("array number comparison with no constraints", func(t *testing.T) { - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec(` - CREATE TABLE test(a any); - INSERT INTO test (a) VALUES ([1,2,3]), ([4, 5, 6]); - `) - assert.NoError(t, err) - - check := func() { - t.Helper() - - d, err := db.QueryRow("SELECT * FROM test WHERE a = [1,2,3];") - assert.NoError(t, err) - - enc, err := json.Marshal(d) - assert.NoError(t, err) - - require.JSONEq(t, `{"a": [1, 2, 3]}`, string(enc)) - } - - check() - - err = db.Exec("CREATE INDEX idx_test_a ON test(a);") - assert.NoError(t, err) - - check() - }) - t.Run("using sequences in SELECT must open read-write transaction instead of read-only", func(t *testing.T) { db, err := chai.Open(":memory:") assert.NoError(t, err) defer db.Close() err = db.Exec(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1); CREATE SEQUENCE seq; `) @@ -350,7 +264,7 @@ func TestSelectStmt(t *testing.T) { defer db.Close() err = db.Exec(` - CREATE TABLE test; + CREATE TABLE test(a INT); INSERT INTO test (a) VALUES (1), (2), (3); `) assert.NoError(t, err) @@ -378,9 +292,6 @@ func TestDistinct(t *testing.T) { {`text`, func(i, notUniqueCount int) (unique interface{}, notunique interface{}) { return strconv.Itoa(i), strconv.Itoa(i % notUniqueCount) }}, - {`array`, func(i, notUniqueCount int) (unique interface{}, notunique interface{}) { - return []interface{}{i}, []interface{}{i % notUniqueCount} - }}, } for _, typ := range tps { @@ -396,15 +307,15 @@ func TestDistinct(t *testing.T) { assert.NoError(t, err) defer tx.Rollback() - err = tx.Exec("CREATE TABLE test(a " + typ.name + " PRIMARY KEY, b " + typ.name + ", doc OBJECT, nullable " + typ.name + ");") + err = tx.Exec("CREATE TABLE test(a " + typ.name + " PRIMARY KEY, b " + typ.name + ", c TEXT, nullable " + typ.name + ");") assert.NoError(t, err) - err = tx.Exec("CREATE UNIQUE INDEX test_doc_index ON test(doc);") + err = tx.Exec("CREATE UNIQUE INDEX test_c_index ON test(c);") assert.NoError(t, err) for i := 0; i < total; i++ { unique, nonunique := typ.generateValue(i, notUnique) - err = tx.Exec(`INSERT INTO test VALUES {a: ?, b: ?, doc: {a: ?, b: ?}, nullable: null}`, unique, nonunique, unique, nonunique) + err = tx.Exec(`INSERT INTO test VALUES (?, ?, ?, null)`, unique, nonunique, unique) assert.NoError(t, err) } err = tx.Commit() @@ -417,11 +328,9 @@ func TestDistinct(t *testing.T) { }{ {`unique`, `SELECT DISTINCT a FROM test`, total}, {`non-unique`, `SELECT DISTINCT b FROM test`, notUnique}, - {`objects`, `SELECT DISTINCT doc FROM test`, total}, {`null`, `SELECT DISTINCT nullable FROM test`, 1}, {`wildcard`, `SELECT DISTINCT * FROM test`, total}, {`literal`, `SELECT DISTINCT 'a' FROM test`, 1}, - {`pk()`, `SELECT DISTINCT pk() FROM test`, total}, } for _, test := range tests { diff --git a/internal/query/statement/statement.go b/internal/query/statement/statement.go index 5304ef10d..60c6afb21 100644 --- a/internal/query/statement/statement.go +++ b/internal/query/statement/statement.go @@ -3,6 +3,7 @@ package statement import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/expr" "github.com/cockroachdb/errors" ) @@ -81,3 +82,24 @@ func (r *Result) Close() (err error) { return err } + +func ensureExprColumnsExist(ctx *Context, tableName string, e expr.Expr) (err error) { + info, err := ctx.Tx.Catalog.GetTableInfo(tableName) + if err != nil { + return err + } + expr.Walk(e, func(e expr.Expr) bool { + switch t := e.(type) { + case expr.Column: + cc := info.ColumnConstraints.GetColumnConstraint(string(t)) + if cc == nil { + err = errors.Newf("column %s does not exist", t) + return false + } + } + + return true + }) + + return err +} diff --git a/internal/query/statement/stream.go b/internal/query/statement/stream.go index e50a575f9..53b4c7ba4 100644 --- a/internal/query/statement/stream.go +++ b/internal/query/statement/stream.go @@ -16,13 +16,8 @@ type StreamStmt struct { // Prepare implements the Preparer interface. func (s *StreamStmt) Prepare(ctx *Context) (Statement, error) { - st, err := planner.Optimize(s.Stream, ctx.Tx.Catalog) - if err != nil { - return nil, err - } - return &PreparedStreamStmt{ - Stream: st, + Stream: s.Stream, ReadOnly: s.ReadOnly, }, nil } @@ -36,9 +31,14 @@ type PreparedStreamStmt struct { // Run returns a result containing the stream. The stream will be executed by calling the Iterate method of // the result. func (s *PreparedStreamStmt) Run(ctx *Context) (Result, error) { + st, err := planner.Optimize(s.Stream, ctx.Tx.Catalog, ctx.Params) + if err != nil { + return Result{}, err + } + return Result{ Iterator: &StreamStmtIterator{ - Stream: s.Stream, + Stream: st, Context: ctx, }, }, nil @@ -73,7 +73,7 @@ func (s *StreamStmtIterator) Iterate(fn func(r database.Row) error) error { return nil } - return fn(env.Row) + return fn(env.Row.(database.Row)) }) if errors.Is(err, stream.ErrStreamClosed) { err = nil diff --git a/internal/query/statement/update.go b/internal/query/statement/update.go index 7916a3922..f3a0bee1d 100644 --- a/internal/query/statement/update.go +++ b/internal/query/statement/update.go @@ -2,13 +2,11 @@ package statement import ( "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/index" "github.com/chaisql/chai/internal/stream/path" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/stream/table" - "github.com/cockroachdb/errors" ) // UpdateConfig holds UPDATE configuration. @@ -18,14 +16,10 @@ type UpdateStmt struct { TableName string // SetPairs is used along with the Set clause. It holds - // each path with its corresponding value that - // should be set in the object. + // each column with its corresponding value that + // should be set in the row. SetPairs []UpdateSetPair - // UnsetFields is used along with the Unset clause. It holds - // each path that should be unset from the object. - UnsetFields []string - WhereExpr expr.Expr } @@ -41,8 +35,8 @@ func NewUpdateStatement() *UpdateStmt { } type UpdateSetPair struct { - Path object.Path - E expr.Expr + Column expr.Column + E expr.Expr } // Prepare implements the Preparer interface. @@ -56,36 +50,33 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) { s := stream.New(table.Scan(stmt.TableName)) if stmt.WhereExpr != nil { + err := ensureExprColumnsExist(c, stmt.TableName, stmt.WhereExpr) + if err != nil { + return nil, err + } + s = s.Pipe(rows.Filter(stmt.WhereExpr)) } var pkModified bool if stmt.SetPairs != nil { for _, pair := range stmt.SetPairs { + err := ensureExprColumnsExist(c, stmt.TableName, pair.Column) + if err != nil { + return nil, err + } + // if we modify the primary key, // we must remove the old row and create an new one if pk != nil && !pkModified { - for _, p := range pk.Paths { - if p.IsEqual(pair.Path) { + for _, c := range pk.Columns { + if c == string(pair.Column) { pkModified = true break } } } - s = s.Pipe(path.Set(pair.Path, pair.E)) - } - } else if stmt.UnsetFields != nil { - for _, name := range stmt.UnsetFields { - // ensure we do not unset any path the is used in the primary key - if pk != nil { - path := object.NewPath(name) - for _, p := range pk.Paths { - if p.IsEqual(path) { - return nil, errors.New("cannot unset primary key path") - } - } - } - s = s.Pipe(path.Unset(name)) + s = s.Pipe(path.Set(string(pair.Column), pair.E)) } } diff --git a/internal/query/statement/update_test.go b/internal/query/statement/update_test.go index 458464163..699fc1703 100644 --- a/internal/query/statement/update_test.go +++ b/internal/query/statement/update_test.go @@ -21,27 +21,17 @@ func TestUpdateStmt(t *testing.T) { {"No clause", `UPDATE test`, true, "", nil}, {"Read-only table", `UPDATE __chai_catalog SET a = 1`, true, "", nil}, - // SET tests. - {"SET / No cond", `UPDATE test SET a = 'boo'`, false, `[{"a":"boo","b":"bar1","c":"baz1"},{"a":"boo","b":"bar2"},{"a":"boo","d":"bar3","e":"baz3"}]`, nil}, - {"SET / No cond / with ident string", "UPDATE test SET `a` = 'boo'", false, `[{"a":"boo","b":"bar1","c":"baz1"},{"a":"boo","b":"bar2"},{"a":"boo","d":"bar3","e":"baz3"}]`, nil}, + {"SET / No cond", `UPDATE test SET a = 'boo'`, false, `[{"a":"boo","b":"bar1","c":"baz1","d":null,"e":null},{"a":"boo","b":"bar2","c":null,"d":null,"e":null},{"a":"boo","d":"bar3","e":"baz3","c":null,"b":null}]`, nil}, + {"SET / No cond / with ident string", "UPDATE test SET `a` = 'boo'", false, `[{"a":"boo","b":"bar1","c":"baz1","d":null,"e":null},{"a":"boo","b":"bar2","c":null,"d":null,"e":null},{"a":"boo","d":"bar3","e":"baz3","c":null,"b":null}]`, nil}, {"SET / No cond / with multiple idents and constraint", `UPDATE test SET a = c`, true, ``, nil}, - {"SET / No cond / with multiple idents", `UPDATE test SET b = c`, false, `[{"a":"foo1","b":"baz1","c":"baz1"},{"a":"foo2","b":null},{"a":"foo3","b":null,"d":"bar3","e":"baz3"}]`, nil}, - {"SET / No cond / with missing field", "UPDATE test SET f = 'boo'", false, `[{"a":"foo1","b":"bar1","c":"baz1","f":"boo"},{"a":"foo2","b":"bar2","f":"boo"},{"a":"foo3","d":"bar3","e":"baz3","f":"boo"}]`, nil}, + {"SET / No cond / with multiple idents", `UPDATE test SET b = c`, false, `[{"a":"foo1","b":"baz1","c":"baz1","d":null,"e":null},{"a":"foo2","b":null,"c":null,"d":null,"e":null},{"a":"foo3","b":null,"c":null,"d":"bar3","e":"baz3"}]`, nil}, + {"SET / No cond / with missing column", "UPDATE test SET f = 'boo'", true, "", nil}, {"SET / No cond / with string", `UPDATE test SET 'a' = 'boo'`, true, "", nil}, - {"SET / With cond", "UPDATE test SET a = 'FOO2', b = 2 WHERE a = 'foo2'", false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"FOO2","b":2},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, - {"SET / With cond / with missing field", "UPDATE test SET f = 'boo' WHERE d = 'bar3'", false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3","f":"boo"}]`, nil}, - {"SET / Field not found", "UPDATE test SET a = 1, b = 2 WHERE a = f", false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, - {"SET / Positional params", "UPDATE test SET a = ?, b = ? WHERE a = ?", false, `[{"a":"a","b":"b","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, []interface{}{"a", "b", "foo1"}}, - {"SET / Named params", "UPDATE test SET a = $a, b = $b WHERE a = $c", false, `[{"a":"a","b":"b","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, []interface{}{sql.Named("b", "b"), sql.Named("a", "a"), sql.Named("c", "foo1")}}, - {"SET / Nested objects on a / Wrong type", "UPDATE test SET a.b = 2", false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, - {"SET / Nested objects on a / missing row", "UPDATE test SET g.h.i = 2", false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, - - // UNSET tests. - {"UNSET / No cond", `UPDATE test UNSET b`, false, `[{"a":"foo1","c":"baz1"},{"a":"foo2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, - {"UNSET / No cond / with ident string", "UPDATE test UNSET `a`", true, "", nil}, - {"UNSET / No cond / with missing field", "UPDATE test UNSET f", false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"foo2","b":"bar2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, - {"UNSET / No cond / with string", `UPDATE test UNSET 'a'`, true, "", nil}, - {"UNSET / With cond", `UPDATE test UNSET b WHERE a = 'foo2'`, false, `[{"a":"foo1","b":"bar1","c":"baz1"},{"a":"foo2"},{"a":"foo3","d":"bar3","e":"baz3"}]`, nil}, + {"SET / With cond", "UPDATE test SET a = 'FOO2', b = 2 WHERE a = 'foo2'", false, `[{"a":"foo1","b":"bar1","c":"baz1","d":null,"e":null},{"a":"FOO2","b":"2","c":null,"d":null,"e":null},{"a":"foo3","b":null,"c":null,"d":"bar3","e":"baz3"}]`, nil}, + {"SET / With cond / with missing column", "UPDATE test SET f = 'boo' WHERE d = 'bar3'", true, ``, nil}, + {"SET / Field not found", "UPDATE test SET a = 1, b = 2 WHERE a = f", true, ``, nil}, + {"SET / Positional params", "UPDATE test SET a = ?, b = ? WHERE a = ?", false, `[{"a":"a","b":"b","c":"baz1","d":null,"e":null},{"a":"foo2","b":"bar2","c":null,"d":null,"e":null},{"a":"foo3","b":null,"c":null,"d":"bar3","e":"baz3"}]`, []interface{}{"a", "b", "foo1"}}, + {"SET / Named params", "UPDATE test SET a = $a, b = $b WHERE a = $c", false, `[{"a":"a","b":"b","c":"baz1","d":null,"e":null},{"a":"foo2","b":"bar2","c":null,"d":null,"e":null},{"a":"foo3","b":null,"c":null,"d":"bar3","e":"baz3"}]`, []interface{}{sql.Named("b", "b"), sql.Named("a", "a"), sql.Named("c", "foo1")}}, } for _, test := range tests { @@ -51,7 +41,7 @@ func TestUpdateStmt(t *testing.T) { assert.NoError(t, err) defer db.Close() - err = db.Exec("CREATE TABLE test (a text not null, ...)") + err = db.Exec("CREATE TABLE test (a text not null, b text, c text, d text, e text)") assert.NoError(t, err) if indexed { @@ -88,54 +78,4 @@ func TestUpdateStmt(t *testing.T) { runTest(true) }) } - - t.Run("with arrays", func(t *testing.T) { - tests := []struct { - name string - query string - fails bool - expected string - params []interface{} - }{ - {"SET / No cond add field ", `UPDATE foo set b = 0`, false, `[{"a": [1, 0, 0], "b": 0}, {"a": [2, 0], "b": 0}]`, nil}, - {"SET / No cond / with path at existing index only", `UPDATE foo SET a[2] = 10`, false, `[{"a": [1, 0, 10]}, {"a": [2, 0]}]`, nil}, - {"SET / No cond / with index array", `UPDATE foo SET a[1] = 10`, false, `[{"a": [1, 10, 0]}, {"a": [2, 10]}]`, nil}, - {"SET / No cond / with path on non existing field", `UPDATE foo SET a.foo[1] = 10`, false, `[{"a": [1, 0, 0]}, {"a": [2, 0]}]`, nil}, - {"SET / With cond / index array", `UPDATE foo SET a[0] = 1 WHERE a[0] = 2`, false, `[{"a": [1, 0, 0]}, {"a": [1, 0]}]`, nil}, - {"SET / No cond / index out of range", `UPDATE foo SET a[10] = 1`, false, `[{"a": [1, 0, 0]}, {"a": [2, 0]}]`, nil}, - {"SET / No cond / Nested array", `UPDATE foo SET a[1] = [1, 0, 0]`, false, `[{"a": [1, [1, 0, 0], 0]}, {"a": [2, [1, 0, 0]]}]`, nil}, - {"SET / No cond / with multiple idents", `UPDATE foo SET a[1] = [1, 0, 0], a[1][2] = 9`, false, `[{"a": [1, [1, 0, 9], 0]}, {"a": [2, [1, 0, 9]]}]`, nil}, - {"SET / No cond / add doc / with multiple idents with multiple indexes", `UPDATE foo SET a[1] = [1, 0, 0], a[1][2] = {"b": "foo"}`, false, `[{"a": [1, [1, 0, {"b":"foo"}], 0]}, {"a": [2, [1, 0, {"b":"foo"}]]}]`, nil}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, err := chai.Open(":memory:") - assert.NoError(t, err) - defer db.Close() - - err = db.Exec(`CREATE TABLE foo;`) - assert.NoError(t, err) - err = db.Exec(`INSERT INTO foo (a) VALUES ([1, 0, 0]), ([2, 0]);`) - assert.NoError(t, err) - - err = db.Exec(tt.query, tt.params...) - if tt.fails { - assert.Error(t, err) - return - } - assert.NoError(t, err) - - st, err := db.Query("SELECT * FROM foo") - assert.NoError(t, err) - defer st.Close() - - var buf bytes.Buffer - - err = st.MarshalJSONTo(&buf) - assert.NoError(t, err) - require.JSONEq(t, tt.expected, buf.String()) - }) - } - }) } diff --git a/internal/row/diff.go b/internal/row/diff.go new file mode 100644 index 000000000..42d2f80da --- /dev/null +++ b/internal/row/diff.go @@ -0,0 +1,98 @@ +package row + +import ( + "github.com/chaisql/chai/internal/types" +) + +// Diff returns the operations needed to transform the first row into the second. +func Diff(r1, r2 Row) ([]Op, error) { + var ops []Op + f1, err := Columns(r1) + if err != nil { + return nil, err + } + + f2, err := Columns(r2) + if err != nil { + return nil, err + } + + var i, j int + for { + for i < len(f1) && (j >= len(f2) || f1[i] < f2[j]) { + v, err := r1.Get(f1[j]) + if err != nil { + return nil, err + } + ops = append(ops, NewDeleteOp(f1[i], v)) + i++ + } + + for j < len(f2) && (i >= len(f1) || f1[i] > f2[j]) { + v, err := r2.Get(f2[j]) + if err != nil { + return nil, err + } + ops = append(ops, NewSetOp(f2[j], v)) + j++ + } + + if i == len(f1) && j == len(f2) { + break + } + + v1, err := r1.Get(f1[i]) + if err != nil { + return nil, err + } + + v2, err := r2.Get(f2[j]) + if err != nil { + return nil, err + } + + if v1.Type() != v2.Type() { + v, err := r2.Get(f2[j]) + if err != nil { + return nil, err + } + ops = append(ops, NewSetOp(f2[j], v)) + } else { + ok, err := v1.EQ(v2) + if err != nil { + return nil, err + } + if !ok { + ops = append(ops, NewSetOp(f2[j], v2)) + } + } + i++ + j++ + } + + return ops, nil +} + +// Op represents a single operation on an row. +// It is returned by the Diff function. +type Op struct { + Type string + Column string + Value types.Value +} + +func NewSetOp(column string, v types.Value) Op { + return newOp("set", column, v) +} + +func NewDeleteOp(column string, v types.Value) Op { + return newOp("delete", column, v) +} + +func newOp(op string, column string, v types.Value) Op { + return Op{ + Type: op, + Column: column, + Value: v, + } +} diff --git a/internal/row/diff_test.go b/internal/row/diff_test.go new file mode 100644 index 000000000..958f6cce1 --- /dev/null +++ b/internal/row/diff_test.go @@ -0,0 +1,74 @@ +package row_test + +import ( + "testing" + + "github.com/chaisql/chai/internal/row" + "github.com/chaisql/chai/internal/testutil" + "github.com/chaisql/chai/internal/types" + "github.com/stretchr/testify/require" +) + +func TestDiff(t *testing.T) { + tests := []struct { + name string + d1, d2 string + want []row.Op + }{ + { + name: "empty", + d1: `{}`, + d2: `{}`, + want: nil, + }, + { + name: "add field", + d1: `{}`, + d2: `{"a": 1}`, + want: []row.Op{ + {"set", "a", types.NewIntegerValue(1)}, + }, + }, + { + name: "remove field", + d1: `{"a": 1}`, + d2: `{}`, + want: []row.Op{ + {"delete", "a", types.NewIntegerValue(1)}, + }, + }, + { + name: "same", + d1: `{"a": 1}`, + d2: `{"a": 1}`, + want: nil, + }, + { + name: "replace field", + d1: `{"a": 1}`, + d2: `{"a": 2}`, + want: []row.Op{ + {"set", "a", types.NewIntegerValue(2)}, + }, + }, + { + name: "replace field: different type", + d1: `{"a": 1}`, + d2: `{"a": "hello"}`, + want: []row.Op{ + {"set", "a", types.NewTextValue("hello")}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + d1 := testutil.MakeRow(t, test.d1) + d2 := testutil.MakeRow(t, test.d2) + + got, err := row.Diff(d1, d2) + require.NoError(t, err) + require.Equal(t, test.want, got) + }) + } +} diff --git a/internal/row/format.go b/internal/row/format.go new file mode 100644 index 000000000..286250165 --- /dev/null +++ b/internal/row/format.go @@ -0,0 +1,138 @@ +package row + +import ( + "bytes" + "encoding/hex" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/chaisql/chai/internal/stringutil" + "github.com/chaisql/chai/internal/types" +) + +// MarshalJSON encodes a row to json. +func MarshalJSON(r Row) ([]byte, error) { + var buf bytes.Buffer + + buf.WriteByte('{') + + var notFirst bool + err := SortColumns(r).Iterate(func(c string, v types.Value) error { + if notFirst { + buf.WriteString(", ") + } + notFirst = true + + buf.WriteString(strconv.Quote(c)) + buf.WriteString(": ") + + data, err := v.MarshalJSON() + if err != nil { + return err + } + _, err = buf.Write(data) + return err + }) + if err != nil { + return nil, err + } + + buf.WriteByte('}') + + return buf.Bytes(), nil +} + +func MarshalTextIndent(r Row, prefix, indent string) ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte('{') + var i int + err := r.Iterate(func(field string, value types.Value) error { + if i > 0 { + buf.WriteByte(',') + if prefix == "" { + buf.WriteByte(' ') + } + } + newline(&buf, prefix, indent, 1) + i++ + + var ident string + if strings.HasPrefix(field, "\"") { + ident = stringutil.NormalizeIdentifier(field, '`') + } else { + ident = stringutil.NormalizeIdentifier(field, '"') + } + buf.WriteString(ident) + buf.WriteString(": ") + + return marshalText(&buf, value, prefix, indent, 1) + }) + if err != nil { + return nil, err + } + newline(&buf, prefix, indent, 0) + buf.WriteRune('}') + return buf.Bytes(), nil +} + +func marshalText(dst *bytes.Buffer, v types.Value, prefix, indent string, depth int) error { + if v.V() == nil { + dst.WriteString("NULL") + return nil + } + + switch v.Type() { + case types.TypeNull: + dst.WriteString("NULL") + return nil + case types.TypeBoolean: + dst.WriteString(strconv.FormatBool(types.AsBool(v))) + return nil + case types.TypeInteger, types.TypeBigint: + dst.WriteString(strconv.FormatInt(types.AsInt64(v), 10)) + return nil + case types.TypeDouble: + f := types.AsFloat64(v) + abs := math.Abs(f) + fmt := byte('f') + if abs != 0 { + if abs < 1e-6 || abs >= 1e15 { + fmt = 'e' + } + } + + // By default the precision is -1 to use the smallest number of digits. + // See https://pkg.go.dev/strconv#FormatFloat + prec := -1 + // if the number is round, add .0 + if float64(int64(f)) == f { + prec = 1 + } + dst.WriteString(strconv.FormatFloat(types.AsFloat64(v), fmt, prec, 64)) + return nil + case types.TypeTimestamp: + dst.WriteString(strconv.Quote(types.AsTime(v).Format(time.RFC3339Nano))) + return nil + case types.TypeText: + dst.WriteString(strconv.Quote(types.AsString(v))) + return nil + case types.TypeBlob: + src := types.AsByteSlice(v) + dst.WriteString("\"\\x") + hex.NewEncoder(dst).Write(src) + dst.WriteByte('"') + return nil + default: + return fmt.Errorf("unexpected type: %d", v.Type()) + } +} + +func newline(dst *bytes.Buffer, prefix, indent string, depth int) { + dst.WriteString(prefix) + for i := 0; i < depth; i++ { + dst.WriteString(indent) + } +} diff --git a/internal/object/json.go b/internal/row/json.go similarity index 68% rename from internal/object/json.go rename to internal/row/json.go index ed0a204c6..bc3f3861b 100644 --- a/internal/object/json.go +++ b/internal/row/json.go @@ -1,8 +1,11 @@ -package object +package row import ( + "math" + "github.com/buger/jsonparser" "github.com/chaisql/chai/internal/types" + "github.com/cockroachdb/errors" ) func parseJSONValue(dataType jsonparser.ValueType, data []byte) (v types.Value, err error) { @@ -27,31 +30,18 @@ func parseJSONValue(dataType jsonparser.ValueType, data []byte) (v types.Value, return types.NewDoubleValue(f), nil } - return types.NewIntegerValue(i), nil + if i < math.MinInt32 || i > math.MaxInt32 { + return types.NewBigintValue(i), nil + } + + return types.NewIntegerValue(int32(i)), nil case jsonparser.String: s, err := jsonparser.ParseString(data) if err != nil { return nil, err } return types.NewTextValue(s), nil - case jsonparser.Array: - buf := NewValueBuffer() - err := buf.UnmarshalJSON(data) - if err != nil { - return nil, err - } - - return types.NewArrayValue(buf), nil - case jsonparser.Object: - buf := NewFieldBuffer() - err = buf.UnmarshalJSON(data) - if err != nil { - return nil, err - } - - return types.NewObjectValue(buf), nil default: + return nil, errors.Errorf("unsupported JSON type: %v", dataType) } - - return nil, nil } diff --git a/internal/row/object_test.go b/internal/row/object_test.go new file mode 100644 index 000000000..d51ca97d7 --- /dev/null +++ b/internal/row/object_test.go @@ -0,0 +1,474 @@ +package row_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" + + "github.com/chaisql/chai/internal/row" + "github.com/chaisql/chai/internal/testutil" + "github.com/chaisql/chai/internal/testutil/assert" + "github.com/chaisql/chai/internal/types" +) + +var _ row.Row = new(row.ColumnBuffer) + +func TestColumnBuffer(t *testing.T) { + var buf row.ColumnBuffer + buf.Add("a", types.NewIntegerValue(10)) + buf.Add("b", types.NewTextValue("hello")) + + t.Run("Iterate", func(t *testing.T) { + var i int + err := buf.Iterate(func(f string, v types.Value) error { + switch i { + case 0: + require.Equal(t, "a", f) + require.Equal(t, types.NewIntegerValue(10), v) + case 1: + require.Equal(t, "b", f) + require.Equal(t, types.NewTextValue("hello"), v) + } + i++ + return nil + }) + assert.NoError(t, err) + require.Equal(t, 2, i) + }) + + t.Run("Add", func(t *testing.T) { + var buf row.ColumnBuffer + buf.Add("a", types.NewIntegerValue(10)) + buf.Add("b", types.NewTextValue("hello")) + + c := types.NewBooleanValue(true) + buf.Add("c", c) + require.Equal(t, 3, buf.Len()) + }) + + t.Run("ScanRow", func(t *testing.T) { + var buf1, buf2 row.ColumnBuffer + + buf1.Add("a", types.NewIntegerValue(10)) + buf1.Add("b", types.NewTextValue("hello")) + + buf2.Add("a", types.NewIntegerValue(20)) + buf2.Add("b", types.NewTextValue("bye")) + buf2.Add("c", types.NewBooleanValue(true)) + + err := buf1.ScanRow(&buf2) + assert.NoError(t, err) + + var buf row.ColumnBuffer + buf.Add("a", types.NewIntegerValue(10)) + buf.Add("b", types.NewTextValue("hello")) + buf.Add("a", types.NewIntegerValue(20)) + buf.Add("b", types.NewTextValue("bye")) + buf.Add("c", types.NewBooleanValue(true)) + require.Equal(t, buf, buf1) + }) + + t.Run("Get", func(t *testing.T) { + v, err := buf.Get("a") + assert.NoError(t, err) + require.Equal(t, types.NewIntegerValue(10), v) + + v, err = buf.Get("not existing") + assert.ErrorIs(t, err, types.ErrColumnNotFound) + require.Zero(t, v) + }) + + t.Run("Set", func(t *testing.T) { + tests := []struct { + name string + data string + column string + value types.Value + want string + fails bool + }{ + {"root", `{}`, `a`, types.NewIntegerValue(1), `{"a": 1}`, false}, + {"add column", `{"a": 1}`, `c`, types.NewTextValue("foo"), `{"a": 1, "c": "foo"}`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var fb row.ColumnBuffer + + r := testutil.MakeRow(t, tt.data) + err := fb.Copy(r) + assert.NoError(t, err) + err = fb.Set(tt.column, tt.value) + if tt.fails { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + data, err := row.MarshalJSON(&fb) + assert.NoError(t, err) + require.Equal(t, tt.want, string(data)) + }) + } + }) + + t.Run("Delete", func(t *testing.T) { + tests := []struct { + object string + column string + expected string + fails bool + }{ + {`{"a": 10, "b": "hello"}`, "a", `{"b": "hello"}`, false}, + {`{"a": 10, "b": "hello"}`, "c", ``, true}, + } + + for _, test := range tests { + t.Run(test.object, func(t *testing.T) { + var buf row.ColumnBuffer + err := buf.Copy(testutil.MakeRow(t, test.object)) + assert.NoError(t, err) + + err = buf.Delete(test.column) + if test.fails { + assert.Error(t, err) + } else { + assert.NoError(t, err) + got, err := json.Marshal(&buf) + assert.NoError(t, err) + require.JSONEq(t, test.expected, string(got)) + } + }) + } + }) + + t.Run("Replace", func(t *testing.T) { + var buf row.ColumnBuffer + buf.Add("a", types.NewIntegerValue(10)) + buf.Add("b", types.NewTextValue("hello")) + + err := buf.Replace("a", types.NewBooleanValue(true)) + assert.NoError(t, err) + v, err := buf.Get("a") + assert.NoError(t, err) + require.Equal(t, types.NewBooleanValue(true), v) + err = buf.Replace("d", types.NewIntegerValue(11)) + assert.Error(t, err) + }) + + t.Run("Apply", func(t *testing.T) { + d := testutil.MakeRow(t, `{ + "a": "b", + "c": "d", + "e": "f" + }`) + + buf := row.NewColumnBuffer() + err := buf.Copy(d) + assert.NoError(t, err) + + err = buf.Apply(func(c string, v types.Value) (types.Value, error) { + return types.NewIntegerValue(1), nil + }) + assert.NoError(t, err) + + got, err := json.Marshal(buf) + assert.NoError(t, err) + require.JSONEq(t, `{"a":1, "c":1, "e":1}`, string(got)) + }) +} + +func TestNewFromStruct(t *testing.T) { + type group struct { + Ig int + } + + type user struct { + A []byte + B string + C bool + D uint `chai:"la-reponse-d"` + E uint8 + F uint16 + G uint32 + H uint64 + I int + J int8 + K int16 + L int32 + M int64 + N float64 + + // nil pointers must be skipped + // otherwise they must be dereferenced + P *int + Q *int + + Z interface{} + ZZ interface{} + + AA int `chai:"-"` // ignored + + *group + + BB time.Time // some have special encoding as object + + // unexported fields should be ignored + t int + } + + u := user{ + A: []byte("foo"), + B: "bar", + C: true, + D: 1, + E: 2, + F: 3, + G: 4, + H: 5, + I: 6, + J: 7, + K: 8, + L: 9, + M: 10, + N: 11.12, + Z: 26, + AA: 27, + group: &group{ + Ig: 100, + }, + BB: time.Date(2020, 11, 15, 16, 37, 10, 20, time.UTC), + t: 99, + } + + q := 5 + u.Q = &q + + t.Run("Iterate", func(t *testing.T) { + doc, err := row.NewFromStruct(u) + assert.NoError(t, err) + + var counter int + + err = doc.Iterate(func(f string, v types.Value) error { + switch counter { + case 0: + require.Equal(t, u.A, types.AsByteSlice(v)) + case 1: + require.Equal(t, u.B, types.AsString(v)) + case 2: + require.Equal(t, u.C, types.AsBool(v)) + case 3: + require.Equal(t, "la-reponse-d", f) + require.EqualValues(t, u.D, types.AsInt64(v)) + case 4: + require.EqualValues(t, u.E, types.AsInt64(v)) + case 5: + require.EqualValues(t, u.F, types.AsInt64(v)) + case 6: + require.EqualValues(t, u.G, types.AsInt64(v)) + case 7: + require.EqualValues(t, u.H, types.AsInt64(v)) + case 8: + require.EqualValues(t, u.I, types.AsInt64(v)) + case 9: + require.EqualValues(t, u.J, types.AsInt64(v)) + case 10: + require.EqualValues(t, u.K, types.AsInt64(v)) + case 11: + require.EqualValues(t, u.L, types.AsInt64(v)) + case 12: + require.EqualValues(t, u.M, types.AsInt64(v)) + case 13: + require.Equal(t, u.N, types.AsFloat64(v)) + case 14: + require.EqualValues(t, *u.Q, types.AsInt64(v)) + case 15: + require.EqualValues(t, u.Z, types.AsInt64(v)) + case 16: + require.EqualValues(t, types.TypeNull, v.Type()) + case 17: + require.EqualValues(t, types.TypeBigint, v.Type()) + case 18: + require.EqualValues(t, types.TypeTimestamp, v.Type()) + case 19: + default: + require.FailNowf(t, "", "unknown field %q", f) + } + + counter++ + + return nil + }) + assert.NoError(t, err) + require.Equal(t, 19, counter) + }) + + t.Run("Get", func(t *testing.T) { + doc, err := row.NewFromStruct(u) + assert.NoError(t, err) + + v, err := doc.Get("a") + assert.NoError(t, err) + require.Equal(t, u.A, types.AsByteSlice(v)) + v, err = doc.Get("b") + assert.NoError(t, err) + require.Equal(t, u.B, types.AsString(v)) + v, err = doc.Get("c") + assert.NoError(t, err) + require.Equal(t, u.C, types.AsBool(v)) + v, err = doc.Get("la-reponse-d") + assert.NoError(t, err) + require.EqualValues(t, u.D, types.AsInt64(v)) + v, err = doc.Get("e") + assert.NoError(t, err) + require.EqualValues(t, u.E, types.AsInt64(v)) + v, err = doc.Get("f") + assert.NoError(t, err) + require.EqualValues(t, u.F, types.AsInt64(v)) + v, err = doc.Get("g") + assert.NoError(t, err) + require.EqualValues(t, u.G, types.AsInt64(v)) + v, err = doc.Get("h") + assert.NoError(t, err) + require.EqualValues(t, u.H, types.AsInt64(v)) + v, err = doc.Get("i") + assert.NoError(t, err) + require.EqualValues(t, u.I, types.AsInt64(v)) + v, err = doc.Get("j") + assert.NoError(t, err) + require.EqualValues(t, u.J, types.AsInt64(v)) + v, err = doc.Get("k") + assert.NoError(t, err) + require.EqualValues(t, u.K, types.AsInt64(v)) + v, err = doc.Get("l") + assert.NoError(t, err) + require.EqualValues(t, u.L, types.AsInt64(v)) + v, err = doc.Get("m") + assert.NoError(t, err) + require.EqualValues(t, u.M, types.AsInt64(v)) + v, err = doc.Get("n") + assert.NoError(t, err) + require.Equal(t, u.N, types.AsFloat64(v)) + + v, err = doc.Get("bb") + assert.NoError(t, err) + var tm time.Time + assert.NoError(t, row.ScanValue(v, &tm)) + require.Equal(t, u.BB, tm) + }) + + t.Run("pointers", func(t *testing.T) { + type s struct { + A *int + } + + d, err := row.NewFromStruct(new(s)) + assert.NoError(t, err) + _, err = d.Get("a") + assert.ErrorIs(t, err, types.ErrColumnNotFound) + + a := 10 + ss := s{A: &a} + d, err = row.NewFromStruct(&ss) + assert.NoError(t, err) + v, err := d.Get("a") + assert.NoError(t, err) + require.Equal(t, types.NewBigintValue(10), v) + }) +} + +type foo struct { + A string + B int64 + C bool + D float64 +} + +func (f *foo) Iterate(fn func(field string, value types.Value) error) error { + var err error + + err = fn("a", types.NewTextValue(f.A)) + if err != nil { + return err + } + + err = fn("b", types.NewBigintValue(f.B)) + if err != nil { + return err + } + + err = fn("c", types.NewBooleanValue(f.C)) + if err != nil { + return err + } + + err = fn("d", types.NewDoubleValue(f.D)) + if err != nil { + return err + } + + return nil +} + +func (f *foo) Get(field string) (types.Value, error) { + switch field { + case "a": + return types.NewTextValue(f.A), nil + case "b": + return types.NewBigintValue(f.B), nil + case "c": + return types.NewBooleanValue(f.C), nil + case "d": + return types.NewDoubleValue(f.D), nil + } + + return nil, errors.New("unknown field") +} + +func TestJSONObject(t *testing.T) { + tests := []struct { + name string + o row.Row + expected string + }{ + { + "Flat", + row.NewColumnBuffer(). + Add("name", types.NewTextValue("John")). + Add("age", types.NewIntegerValue(10)). + Add(`"something with" quotes`, types.NewIntegerValue(10)), + `{"\"something with\" quotes":10,"age":10,"name":"John"}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + data, err := json.Marshal(test.o) + assert.NoError(t, err) + require.Equal(t, test.expected, string(data)) + assert.NoError(t, err) + }) + } +} + +func BenchmarkObjectIterate(b *testing.B) { + f := foo{ + A: "a", + B: 1000, + C: true, + D: 1e10, + } + + b.Run("Implementation", func(b *testing.B) { + for i := 0; i < b.N; i++ { + f.Iterate(func(string, types.Value) error { + return nil + }) + } + }) + +} diff --git a/internal/row/row.go b/internal/row/row.go new file mode 100644 index 000000000..940e45d92 --- /dev/null +++ b/internal/row/row.go @@ -0,0 +1,468 @@ +package row + +import ( + "fmt" + "math" + "reflect" + "sort" + "strings" + "time" + + "github.com/buger/jsonparser" + "github.com/chaisql/chai/internal/types" + "github.com/cockroachdb/errors" +) + +type Row interface { + // Iterate goes through all the columns of the row and calls the given function + // by passing the column name + Iterate(fn func(column string, value types.Value) error) error + + // Get returns the value of the given column. + // If the column does not exist, it returns ErrColumnNotFound. + Get(name string) (types.Value, error) + + // MarshalJSON encodes the row as JSON. + MarshalJSON() ([]byte, error) +} + +// Length returns the number of columns of a row. +func Length(r Row) (int, error) { + if cb, ok := r.(*ColumnBuffer); ok { + return cb.Len(), nil + } + + var len int + err := r.Iterate(func(_ string, _ types.Value) error { + len++ + return nil + }) + return len, err +} + +func Columns(r Row) ([]string, error) { + var columns []string + err := r.Iterate(func(c string, _ types.Value) error { + columns = append(columns, c) + return nil + }) + return columns, err +} + +// NewFromMap creates an object from a map. +// Due to the way maps are designed, iteration order is not guaranteed. +func NewFromMap[T any](m map[string]T) Row { + return mapRow[T](m) +} + +type mapRow[T any] map[string]T + +var _ Row = (*mapRow[any])(nil) + +func (m mapRow[T]) Iterate(fn func(column string, value types.Value) error) error { + for k, v := range m { + v, err := NewValue(v) + if err != nil { + return err + } + + err = fn(k, v) + if err != nil { + return err + } + } + return nil +} + +func (m mapRow[T]) Get(column string) (types.Value, error) { + v, ok := m[column] + if !ok { + return nil, errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) + } + + return NewValue(v) +} + +// MarshalJSON implements the json.Marshaler interface. +func (m mapRow[T]) MarshalJSON() ([]byte, error) { + return MarshalJSON(m) +} + +type reflectMapObject reflect.Value + +var _ Row = (*reflectMapObject)(nil) + +func (m reflectMapObject) Iterate(fn func(column string, value types.Value) error) error { + M := reflect.Value(m) + it := M.MapRange() + + for it.Next() { + v, err := NewValue(it.Value().Interface()) + if err != nil { + return err + } + + err = fn(it.Key().String(), v) + if err != nil { + return err + } + } + return nil +} + +func (m reflectMapObject) Get(column string) (types.Value, error) { + M := reflect.Value(m) + v := M.MapIndex(reflect.ValueOf(column)) + if v == (reflect.Value{}) { + return nil, errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) + } + return NewValue(v.Interface()) +} + +// MarshalJSON implements the json.Marshaler interface. +func (m reflectMapObject) MarshalJSON() ([]byte, error) { + return MarshalJSON(m) +} + +// NewFromStruct creates an object from a struct using reflection. +func NewFromStruct(s any) (Row, error) { + ref := reflect.Indirect(reflect.ValueOf(s)) + + if !ref.IsValid() || ref.Kind() != reflect.Struct { + return nil, errors.New("expected struct or pointer to struct") + } + + return newFromStruct(ref) +} + +func newFromStruct(ref reflect.Value) (Row, error) { + var cb ColumnBuffer + l := ref.NumField() + tp := ref.Type() + + for i := 0; i < l; i++ { + f := ref.Field(i) + if !f.IsValid() { + continue + } + + if f.Kind() == reflect.Ptr { + if f.IsNil() { + continue + } + + f = f.Elem() + } + + sf := tp.Field(i) + + isUnexported := sf.PkgPath != "" + + if sf.Anonymous { + if isUnexported && f.Kind() != reflect.Struct { + continue + } + d, err := newFromStruct(f) + if err != nil { + return nil, err + } + err = d.Iterate(func(column string, value types.Value) error { + cb.Add(column, value) + return nil + }) + if err != nil { + return nil, err + } + continue + } else if isUnexported { + continue + } + + v, err := NewValue(f.Interface()) + if err != nil { + return nil, err + } + + column := strings.ToLower(sf.Name) + if gtag, ok := sf.Tag.Lookup("chai"); ok { + if gtag == "-" { + continue + } + column = gtag + } + + cb.Add(column, v) + } + + return &cb, nil +} + +// NewValue creates a value whose type is infered from x. +func NewValue(x any) (types.Value, error) { + // Attempt exact matches first: + switch v := x.(type) { + case time.Duration: + return types.NewBigintValue(v.Nanoseconds()), nil + case time.Time: + return types.NewTimestampValue(v), nil + case nil: + return types.NewNullValue(), nil + } + + // Compare by kind to detect type definitions over built-in types. + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return types.NewNullValue(), nil + } + return NewValue(reflect.Indirect(v).Interface()) + case reflect.Bool: + return types.NewBooleanValue(v.Bool()), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return types.NewBigintValue(v.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x := v.Uint() + if x > math.MaxInt64 { + return nil, fmt.Errorf("cannot convert unsigned integer struct column to int64: %d out of range", x) + } + return types.NewBigintValue(int64(x)), nil + case reflect.Float32, reflect.Float64: + return types.NewDoubleValue(v.Float()), nil + case reflect.String: + return types.NewTextValue(v.String()), nil + case reflect.Slice: + if reflect.TypeOf(v.Interface()).Elem().Kind() == reflect.Uint8 { + return types.NewBlobValue(v.Bytes()), nil + } + return nil, errors.Errorf("unsupported slice type: %T", x) + case reflect.Interface: + if v.IsNil() { + return types.NewNullValue(), nil + } + return NewValue(v.Elem().Interface()) + } + + return nil, NewErrUnsupportedType(x, "") +} + +// NewFromCSV takes a list of headers and columns and returns an row. +// Each header will be assigned as the key and each corresponding column as a text value. +// The length of headers and columns must be the same. +func NewFromCSV(headers, columns []string) Row { + fb := NewColumnBuffer() + fb.ScanCSV(headers, columns) + + return fb +} + +// ColumnBuffer stores a group of columns in memory. It implements the Row interface. +type ColumnBuffer struct { + columns []Column +} + +// NewColumnBuffer creates a ColumnBuffer. +func NewColumnBuffer() *ColumnBuffer { + return new(ColumnBuffer) +} + +// MarshalJSON implements the json.Marshaler interface. +func (cb *ColumnBuffer) MarshalJSON() ([]byte, error) { + return MarshalJSON(cb) +} + +func (cb *ColumnBuffer) UnmarshalJSON(data []byte) error { + return jsonparser.ObjectEach(data, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { + v, err := parseJSONValue(dataType, value) + if err != nil { + return err + } + + cb.Add(string(key), v) + return nil + }) +} + +func (cb *ColumnBuffer) String() string { + s, _ := cb.MarshalJSON() + return string(s) +} + +type Column struct { + Name string + Value types.Value +} + +// Add a field to the buffer. +func (cb *ColumnBuffer) Add(column string, v types.Value) *ColumnBuffer { + cb.columns = append(cb.columns, Column{column, v}) + return cb +} + +// ScanRow copies all the columns of d to the buffer. +func (cb *ColumnBuffer) ScanRow(r Row) error { + return r.Iterate(func(f string, v types.Value) error { + cb.Add(f, v) + return nil + }) +} + +// Get returns a value by column. Returns an error if the column doesn't exists. +func (cb ColumnBuffer) Get(column string) (types.Value, error) { + for _, fv := range cb.columns { + if fv.Name == column { + return fv.Value, nil + } + } + + return nil, errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) +} + +// Set replaces a column if it already exists or creates one if not. +func (cb *ColumnBuffer) Set(column string, v types.Value) error { + _, err := cb.Get(column) + if errors.Is(err, types.ErrColumnNotFound) { + cb.Add(column, v) + return nil + } + if err != nil { + return err + } + + _ = cb.Replace(column, v) + return nil +} + +// Iterate goes through all the columns of the row and calls the given function by passing each one of them. +// If the given function returns an error, the iteration stops. +func (cb ColumnBuffer) Iterate(fn func(column string, value types.Value) error) error { + for _, cv := range cb.columns { + err := fn(cv.Name, cv.Value) + if err != nil { + return err + } + } + + return nil +} + +// Delete a column from the buffer. +func (cb *ColumnBuffer) Delete(column string) error { + for i := range cb.columns { + if cb.columns[i].Name == column { + cb.columns = append(cb.columns[0:i], cb.columns[i+1:]...) + return nil + } + } + + return errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) +} + +// Replace the value of the column by v. +func (cb *ColumnBuffer) Replace(column string, v types.Value) error { + for i := range cb.columns { + if cb.columns[i].Name == column { + cb.columns[i].Value = v + return nil + } + } + + return errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) +} + +// Copy every value of the row to the buffer. +func (cb *ColumnBuffer) Copy(r Row) error { + return r.Iterate(func(column string, value types.Value) error { + cb.Add(strings.Clone(column), value) + return nil + }) +} + +// Apply a function to all the values of the buffer. +func (cb *ColumnBuffer) Apply(fn func(column string, v types.Value) (types.Value, error)) error { + var err error + + for i, c := range cb.columns { + cb.columns[i].Value, err = fn(c.Name, c.Value) + if err != nil { + return err + } + } + + return nil +} + +// Len of the buffer. +func (cb ColumnBuffer) Len() int { + return len(cb.columns) +} + +// Reset the buffer. +func (cb *ColumnBuffer) Reset() { + cb.columns = cb.columns[:0] +} + +func (cb *ColumnBuffer) ScanCSV(headers, columns []string) { + for i, h := range headers { + if i >= len(columns) { + break + } + + cb.Add(h, types.NewTextValue(columns[i])) + } +} + +func SortColumns(r Row) Row { + return &sortedRow{r} +} + +type sortedRow struct { + Row +} + +func (s *sortedRow) Iterate(fn func(column string, value types.Value) error) error { + // iterate first to get the list of columns + var columns []string + err := s.Row.Iterate(func(column string, value types.Value) error { + columns = append(columns, column) + return nil + }) + if err != nil { + return err + } + + // sort the fields + sort.Strings(columns) + + // iterate again + for _, f := range columns { + v, err := s.Row.Get(f) + if err != nil { + continue + } + + if err := fn(f, v); err != nil { + return err + } + } + + return nil +} + +func Flatten(r Row) []types.Value { + var values []types.Value + r.Iterate(func(column string, v types.Value) error { + values = append(values, types.NewTextValue(column)) + values = append(values, v) + return nil + }) + return values +} + +func Unflatten(values []types.Value) Row { + cb := NewColumnBuffer() + for i := 0; i < len(values); i += 2 { + cb.Add(types.AsString(values[i]), values[i+1]) + } + return cb +} diff --git a/internal/row/row_test.go b/internal/row/row_test.go new file mode 100644 index 000000000..3c41203de --- /dev/null +++ b/internal/row/row_test.go @@ -0,0 +1,124 @@ +package row_test + +import ( + "testing" + "time" + + "github.com/chaisql/chai/internal/row" + "github.com/chaisql/chai/internal/testutil" + "github.com/chaisql/chai/internal/testutil/assert" + "github.com/chaisql/chai/internal/types" + "github.com/stretchr/testify/require" +) + +func TestNewValue(t *testing.T) { + type myBytes []byte + type myString string + type myUint uint + type myUint16 uint16 + type myUint32 uint32 + type myUint64 uint64 + type myInt int + type myInt8 int8 + type myInt16 int16 + type myInt64 int64 + type myFloat64 float64 + + now := time.Now() + + tests := []struct { + name string + value, expected interface{} + }{ + {"bytes", []byte("bar"), []byte("bar")}, + {"string", "bar", "bar"}, + {"bool", true, true}, + {"uint", uint(10), int64(10)}, + {"uint8", uint8(10), int64(10)}, + {"uint16", uint16(10), int64(10)}, + {"uint32", uint32(10), int64(10)}, + {"uint64", uint64(10), int64(10)}, + {"int", int(10), int64(10)}, + {"int8", int8(10), int64(10)}, + {"int16", int16(10), int64(10)}, + {"int32", int32(10), int64(10)}, + {"int64", int64(10), int64(10)}, + {"float64", 10.1, float64(10.1)}, + {"null", nil, nil}, + {"time", now, now.UTC()}, + {"bytes", myBytes("bar"), []byte("bar")}, + {"string", myString("bar"), "bar"}, + {"myUint", myUint(10), int64(10)}, + {"myUint16", myUint16(500), int64(500)}, + {"myUint32", myUint32(90000), int64(90000)}, + {"myUint64", myUint64(100), int64(100)}, + {"myInt", myInt(7), int64(7)}, + {"myInt8", myInt8(3), int64(3)}, + {"myInt16", myInt16(500), int64(500)}, + {"myInt64", myInt64(10), int64(10)}, + {"myFloat64", myFloat64(10.1), float64(10.1)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + v, err := row.NewValue(test.value) + assert.NoError(t, err) + require.Equal(t, test.expected, v.V()) + }) + } +} + +func TestNewFromMap(t *testing.T) { + m := map[string]interface{}{ + "name": "foo", + "age": 10, + "nilField": nil, + } + + r := row.NewFromMap(m) + + t.Run("Iterate", func(t *testing.T) { + counter := make(map[string]int) + + err := r.Iterate(func(f string, v types.Value) error { + counter[f]++ + switch f { + case "name": + require.Equal(t, m[f], types.AsString(v)) + default: + require.EqualValues(t, m[f], v.V()) + } + return nil + }) + assert.NoError(t, err) + require.Len(t, counter, 3) + require.Equal(t, counter["name"], 1) + require.Equal(t, counter["age"], 1) + require.Equal(t, counter["nilField"], 1) + }) + + t.Run("Get", func(t *testing.T) { + v, err := r.Get("name") + assert.NoError(t, err) + require.Equal(t, types.NewTextValue("foo"), v) + + v, err = r.Get("age") + assert.NoError(t, err) + require.Equal(t, types.NewBigintValue(10), v) + + v, err = r.Get("nilField") + assert.NoError(t, err) + require.Equal(t, types.NewNullValue(), v) + + _, err = r.Get("bar") + require.ErrorIs(t, err, types.ErrColumnNotFound) + }) +} + +func TestNewFromCSV(t *testing.T) { + headers := []string{"a", "b", "c"} + columns := []string{"A", "B", "C"} + + d := row.NewFromCSV(headers, columns) + testutil.RequireJSONEq(t, d, `{"a": "A", "b": "B", "c": "C"}`) +} diff --git a/internal/object/scan.go b/internal/row/scan.go similarity index 52% rename from internal/object/scan.go rename to internal/row/scan.go index 0b733b148..36d6ca5e2 100644 --- a/internal/object/scan.go +++ b/internal/row/scan.go @@ -1,4 +1,4 @@ -package object +package row import ( "bytes" @@ -11,16 +11,33 @@ import ( "github.com/cockroachdb/errors" ) -// A Scanner can iterate over an object and scan all the fields. -type Scanner interface { - ScanObject(types.Object) error +// ErrUnsupportedType is used to skip struct or array fields that are not supported. +type ErrUnsupportedType struct { + Value interface{} + Msg string +} + +func NewErrUnsupportedType(value any, msg string) error { + return errors.WithStack(&ErrUnsupportedType{ + Value: value, + Msg: msg, + }) +} + +func (e *ErrUnsupportedType) Error() string { + return fmt.Sprintf("unsupported type %T. %s", e.Value, e.Msg) +} + +// A RowScanner can iterate over a row and scan all the columns. +type RowScanner interface { + ScanRow(Row) error } // Scan each field of the object into the given variables. -func Scan(d types.Object, targets ...interface{}) error { +func Scan(r Row, targets ...any) error { var i int - return d.Iterate(func(f string, v types.Value) error { + return r.Iterate(func(c string, v types.Value) error { if i >= len(targets) { return errors.New("target list too small") } @@ -30,7 +47,7 @@ func Scan(d types.Object, targets ...interface{}) error { ref := reflect.ValueOf(target) if !ref.IsValid() { - return &ErrUnsupportedType{target, fmt.Sprintf("Parameter %d is not valid", i)} + return NewErrUnsupportedType(target, fmt.Sprintf("Parameter %d is not valid", i)) } return scanValue(v, ref) @@ -39,14 +56,18 @@ func Scan(d types.Object, targets ...interface{}) error { // StructScan scans d into t. t is expected to be a pointer to a struct. // -// By default, each struct field name is lowercased and the object's GetByField method +// By default, each struct field name is lowercased and the row's Get method // is called with that name. If there is a match, the value is converted to the struct // field type when possible, otherwise an error is returned. // The decoding of each struct field can be customized by the format string stored // under the "chai" key stored in the struct field's tag. // The content of the format string is used instead of the struct field name and passed -// to the GetByField method. -func StructScan(d types.Object, t interface{}) error { +// to the Get method. +func StructScan(r Row, t any) error { + if cb, ok := t.(*ColumnBuffer); ok { + return cb.Copy(r) + } + ref := reflect.ValueOf(t) if !ref.IsValid() || ref.Kind() != reflect.Ptr { @@ -61,12 +82,12 @@ func StructScan(d types.Object, t interface{}) error { ref.Set(reflect.New(ref.Type().Elem())) } - return structScan(d, ref) + return structScan(r, ref) } -func structScan(d types.Object, ref reflect.Value) error { - if ref.Type().Implements(reflect.TypeOf((*Scanner)(nil)).Elem()) { - return ref.Interface().(Scanner).ScanObject(d) +func structScan(r Row, ref reflect.Value) error { + if ref.Type().Implements(reflect.TypeOf((*RowScanner)(nil)).Elem()) { + return ref.Interface().(RowScanner).ScanRow(r) } sref := reflect.Indirect(ref) @@ -76,7 +97,7 @@ func structScan(d types.Object, ref reflect.Value) error { f := sref.Field(i) sf := stp.Field(i) if sf.Anonymous { - err := structScan(d, f) + err := structScan(r, f) if err != nil { return err } @@ -92,8 +113,8 @@ func structScan(d types.Object, ref reflect.Value) error { } else { name = strings.ToLower(sf.Name) } - v, err := d.GetByField(name) - if errors.Is(err, types.ErrFieldNotFound) { + v, err := r.Get(name) + if errors.Is(err, types.ErrColumnNotFound) { v = types.NewNullValue() } else if err != nil { return err @@ -107,87 +128,11 @@ func structScan(d types.Object, ref reflect.Value) error { return nil } -// SliceScan scans an array into a slice or fixed size array. t must be a pointer -// to a valid slice or array. -// -// It t is a slice pointer and its capacity is too low, a new slice will be allocated. -// Otherwise, its length is set to 0 so that its content is overwritten. -// -// If t is an array pointer, its capacity must be bigger than the length of a, otherwise an error is -// returned. -func SliceScan(a types.Array, t interface{}) error { - return sliceScan(a, reflect.ValueOf(t)) -} - -func sliceScan(a types.Array, ref reflect.Value) error { - if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.IsNil() { - return errors.New("target must be pointer to a slice or array") - } - - tp := ref.Type() - k := tp.Elem().Kind() - if k != reflect.Array && k != reflect.Slice { - return errors.New("target must be pointer to a slice or array") - } - - al, err := ArrayLength(a) - if err != nil { - return err - } - - sref := reflect.Indirect(ref) - - // if array, make sure it is big enough - if k == reflect.Array && sref.Len() < al { - return errors.New("array length too small") - } - - // if slice, reduce its length to 0 to overwrite the buffer - if k == reflect.Slice { - if sref.Cap() < al { - sref.Set(reflect.MakeSlice(tp.Elem(), 0, al)) - } else { - sref.SetLen(0) - } - } - - stp := sref.Type() - - err = a.Iterate(func(i int, v types.Value) error { - if k == reflect.Array { - err := scanValue(v, sref.Index(i).Addr()) - if err != nil { - return err - } - } else { - newV := reflect.New(stp.Elem()) - - err := scanValue(v, newV) - if err != nil { - return err - } - - sref = reflect.Append(sref, reflect.Indirect(newV)) - } - - return nil - }) - if err != nil { - return err - } - - if k == reflect.Slice { - ref.Elem().Set(sref) - } - - return nil -} - -// MapScan decodes the object into a map. -func MapScan(d types.Object, t any) error { +// MapScan decodes the row into a map. +func MapScan(r Row, t any) error { ref := reflect.ValueOf(t) if !ref.IsValid() { - return &ErrUnsupportedType{ref, "t must be a valid reference"} + return NewErrUnsupportedType(ref, "t must be a valid reference") } if ref.Kind() == reflect.Ptr { @@ -195,22 +140,22 @@ func MapScan(d types.Object, t any) error { } if ref.Kind() != reflect.Map { - return &ErrUnsupportedType{ref, "t is not a map"} + return NewErrUnsupportedType(ref, "t is not a map") } - return mapScan(d, ref) + return mapScan(r, ref) } -func mapScan(d types.Object, ref reflect.Value) error { +func mapScan(r Row, ref reflect.Value) error { if ref.Type().Key().Kind() != reflect.String { - return &ErrUnsupportedType{ref, "map key must be a string"} + return NewErrUnsupportedType(ref, "map key must be a string") } if ref.IsNil() { ref.Set(reflect.MakeMap(ref.Type())) } - return d.Iterate(func(f string, v types.Value) error { + return r.Iterate(func(f string, v types.Value) error { newV := reflect.New(ref.Type().Elem()) err := scanValue(v, newV) @@ -230,7 +175,7 @@ func ScanValue(v types.Value, t any) error { func scanValue(v types.Value, ref reflect.Value) error { if !ref.IsValid() { - return &ErrUnsupportedType{ref, "parameter is not a valid reference"} + return NewErrUnsupportedType(ref, "parameter is not a valid reference") } if v.Type() == types.TypeNull { @@ -276,21 +221,21 @@ func scanValue(v types.Value, ref reflect.Value) error { switch ref.Kind() { case reflect.String: - v, err := CastAsText(v) + v, err := v.CastAs(types.TypeText) if err != nil { return err } ref.SetString(types.AsString(v)) return nil case reflect.Bool: - v, err := CastAsBool(v) + v, err := v.CastAs(types.TypeBoolean) if err != nil { return err } ref.SetBool(types.AsBool(v)) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v, err := CastAsInteger(v) + v, err := v.CastAs(types.TypeBigint) if err != nil { return err } @@ -301,14 +246,14 @@ func scanValue(v types.Value, ref reflect.Value) error { ref.SetUint(uint64(x)) return nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v, err := CastAsInteger(v) + v, err := v.CastAs(types.TypeBigint) if err != nil { return err } ref.SetInt(types.AsInt64(v)) return nil case reflect.Float32, reflect.Float64: - v, err := CastAsDouble(v) + v, err := v.CastAs(types.TypeDouble) if err != nil { return err } @@ -319,20 +264,6 @@ func scanValue(v types.Value, ref reflect.Value) error { return scanValue(v, ref.Elem()) } switch v.Type() { - case types.TypeObject: - m := make(map[string]any) - vm := reflect.ValueOf(m) - ref.Set(vm) - return mapScan(types.AsObject(v), vm) - case types.TypeArray: - var s []interface{} - vs := reflect.ValueOf(&s) - err := sliceScan(types.AsArray(v), vs) - if err != nil { - return err - } - ref.Set(vs.Elem()) - return nil case types.TypeText: // copy the string to avoid // keeping a reference to the underlying buffer @@ -351,34 +282,6 @@ func scanValue(v types.Value, ref reflect.Value) error { ref.Set(reflect.ValueOf(v.V())) return nil - } - - // test with supported stdlib types - switch ref.Type().String() { - case "time.Time": - switch v.Type() { - case types.TypeText: - parsed, err := time.Parse(time.RFC3339Nano, types.AsString(v)) - if err != nil { - return err - } - - ref.Set(reflect.ValueOf(parsed)) - return nil - case types.TypeTimestamp: - ref.Set(reflect.ValueOf(types.AsTime(v))) - return nil - } - } - - switch ref.Kind() { - case reflect.Struct: - v, err := CastAsObject(v) - if err != nil { - return err - } - - return structScan(types.AsObject(v), ref) case reflect.Slice: if ref.Type().Elem().Kind() == reflect.Uint8 { if v.Type() != types.TypeText && v.Type() != types.TypeBlob { @@ -391,12 +294,7 @@ func scanValue(v types.Value, ref reflect.Value) error { } return nil } - v, err := CastAsArray(v) - if err != nil { - return err - } - - return sliceScan(types.AsArray(v), ref.Addr()) + return NewErrUnsupportedType(ref.Interface(), "Invalid type") case reflect.Array: if ref.Type().Elem().Kind() == reflect.Uint8 { if v.Type() != types.TypeText && v.Type() != types.TypeBlob { @@ -405,26 +303,32 @@ func scanValue(v types.Value, ref reflect.Value) error { reflect.Copy(ref, reflect.ValueOf(v.V())) return nil } - v, err := CastAsArray(v) - if err != nil { - return err - } + return NewErrUnsupportedType(ref.Interface(), "Invalid type") + } - return sliceScan(types.AsArray(v), ref.Addr()) - case reflect.Map: - v, err := CastAsObject(v) - if err != nil { - return err - } + // test with supported stdlib types + switch ref.Type().String() { + case "time.Time": + switch v.Type() { + case types.TypeText: + parsed, err := time.Parse(time.RFC3339Nano, types.AsString(v)) + if err != nil { + return err + } - return mapScan(types.AsObject(v), ref) + ref.Set(reflect.ValueOf(parsed)) + return nil + case types.TypeTimestamp: + ref.Set(reflect.ValueOf(types.AsTime(v))) + return nil + } } - return &ErrUnsupportedType{ref, "Invalid type"} + return NewErrUnsupportedType(ref.Interface(), "Invalid type") } // ScanRow scans a row into dest which must be either a struct pointer, a map or a map pointer. -func ScanRow(d types.Object, t interface{}) error { +func ScanRow(r Row, t any) error { ref := reflect.ValueOf(t) if !ref.IsValid() { @@ -433,67 +337,20 @@ func ScanRow(d types.Object, t interface{}) error { switch reflect.Indirect(ref).Kind() { case reflect.Map: - return mapScan(d, ref) + return mapScan(r, ref) case reflect.Struct: if ref.IsNil() { ref.Set(reflect.New(ref.Type().Elem())) } - return structScan(d, ref) + return structScan(r, ref) default: return errors.New("target must be a either a pointer to struct, a map or a map pointer") } } -// ScanIterator scans a row iterator into a slice or fixed size array. t must be a pointer -// to a valid slice or array. -// -// It t is a slice pointer and its capacity is too low, a new slice will be allocated. -// Otherwise, its length is set to 0 so that its content is overwritten. -// -// If t is an array pointer, its capacity must be bigger than the length of a, otherwise an error is -// returned. -func ScanIterator(it Iterator, t interface{}) error { - a := iteratorArray{it: it} - return SliceScan(&a, t) -} - -type iteratorArray struct { - it Iterator -} - -func (it *iteratorArray) Iterate(fn func(i int, value types.Value) error) error { - count := 0 - return it.it.Iterate(func(d types.Object) error { - err := fn(count, types.NewObjectValue(d)) - if err != nil { - return err - } - count++ - return nil - }) -} - -func (it *iteratorArray) GetByIndex(i int) (types.Value, error) { - panic("not implemented") -} - -func (it *iteratorArray) MarshalJSON() ([]byte, error) { - return MarshalJSONArray(it) -} - -// ScanField scans a single field into dest. -func ScanField(d types.Object, field string, dest interface{}) error { - v, err := d.GetByField(field) - if err != nil { - return err - } - - return ScanValue(v, dest) -} - -// ScanPath scans a single path into dest. -func ScanPath(d types.Object, path Path, dest interface{}) error { - v, err := path.GetValueFromObject(d) +// ScanColumn scans a single column into dest. +func ScanColumn(r Row, column string, dest any) error { + v, err := r.Get(column) if err != nil { return err } diff --git a/internal/row/scan_test.go b/internal/row/scan_test.go new file mode 100644 index 000000000..b086ee0cc --- /dev/null +++ b/internal/row/scan_test.go @@ -0,0 +1,124 @@ +package row_test + +import ( + "testing" + "time" + + "github.com/chaisql/chai/internal/row" + "github.com/chaisql/chai/internal/testutil/assert" + "github.com/chaisql/chai/internal/types" + "github.com/stretchr/testify/require" +) + +func TestScan(t *testing.T) { + now := time.Now() + + r := row.NewColumnBuffer(). + Add("a", types.NewBlobValue([]byte("foo"))). + Add("b", types.NewTextValue("bar")). + Add("c", types.NewBooleanValue(true)). + Add("d", types.NewIntegerValue(10)). + Add("e", types.NewIntegerValue(10)). + Add("f", types.NewIntegerValue(10)). + Add("g", types.NewIntegerValue(10)). + Add("h", types.NewIntegerValue(10)). + Add("i", types.NewDoubleValue(10.5)). + Add("j", types.NewNullValue()). + Add("k", types.NewTextValue(now.Format(time.RFC3339Nano))). + Add("l", types.NewBlobValue([]byte{1, 2, 3, 4})). + Add("m", types.NewTimestampValue(now)) + + var a []byte + var b string + var c bool + var d int + var e int8 + var f int16 + var g int32 + var h int64 + var i float64 + var j int = 1 + var k time.Time + var l [4]uint8 + var m time.Time + + err := row.Scan(r, &a, &b, &c, &d, &e, &f, &g, &h, &i, &j, &k, &l, &m) + assert.NoError(t, err) + require.Equal(t, a, []byte("foo")) + require.Equal(t, b, "bar") + require.Equal(t, c, true) + require.Equal(t, d, int(10)) + require.Equal(t, e, int8(10)) + require.Equal(t, f, int16(10)) + require.Equal(t, g, int32(10)) + require.Equal(t, h, int64(10)) + require.Equal(t, i, float64(10.5)) + require.Equal(t, 0, j) + require.Equal(t, now.Format(time.RFC3339Nano), k.Format(time.RFC3339Nano)) + require.Equal(t, [4]uint8{1, 2, 3, 4}, l) + require.Equal(t, now.UTC(), m) + + t.Run("Map", func(t *testing.T) { + m := make(map[string]interface{}) + err := row.MapScan(r, m) + assert.NoError(t, err) + require.Len(t, m, 13) + }) + + t.Run("MapPtr", func(t *testing.T) { + var m map[string]interface{} + err := row.MapScan(r, &m) + assert.NoError(t, err) + require.Len(t, m, 13) + }) + + t.Run("pointers", func(t *testing.T) { + type bar struct { + A *int + } + + b := bar{} + + d := row.NewColumnBuffer().Add("a", types.NewIntegerValue(10)) + err := row.StructScan(d, &b) + assert.NoError(t, err) + + a := 10 + require.Equal(t, bar{A: &a}, b) + }) + + t.Run("NULL with pointers", func(t *testing.T) { + type bar struct { + A *int + B *string + C *int + } + + c := 10 + b := bar{ + C: &c, + } + + d := row.NewColumnBuffer().Add("a", types.NewNullValue()) + err := row.StructScan(d, &b) + assert.NoError(t, err) + require.Equal(t, bar{}, b) + }) + + t.Run("Incompatible type", func(t *testing.T) { + var a struct { + A float64 + } + + d := row.NewColumnBuffer().Add("a", types.NewTimestampValue(time.Now())) + err := row.StructScan(d, &a) + assert.Error(t, err) + }) + + t.Run("Pointer not to struct", func(t *testing.T) { + var b int + d := row.NewColumnBuffer().Add("a", types.NewIntegerValue(10)) + err := row.StructScan(d, &b) + assert.Error(t, err) + }) +} diff --git a/internal/sql/parser/alter.go b/internal/sql/parser/alter.go index 5ee6b4e59..463ebbaea 100644 --- a/internal/sql/parser/alter.go +++ b/internal/sql/parser/alter.go @@ -12,7 +12,7 @@ func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ statement.A stmt.TableName = tableName // Parse "TO". - if err := p.parseTokens(scanner.TO); err != nil { + if err := p.ParseTokens(scanner.TO); err != nil { return stmt, err } @@ -25,35 +25,35 @@ func (p *Parser) parseAlterTableRenameStatement(tableName string) (_ statement.A return stmt, nil } -func (p *Parser) parseAlterTableAddFieldStatement(tableName string) (*statement.AlterTableAddColumnStmt, error) { +func (p *Parser) parseAlterTableAddColumnStatement(tableName string) (*statement.AlterTableAddColumnStmt, error) { var stmt statement.AlterTableAddColumnStmt stmt.TableName = tableName - // Parse "FIELD". - if err := p.parseTokens(scanner.COLUMN); err != nil { + // Parse "COLUMN". + if err := p.ParseTokens(scanner.COLUMN); err != nil { return nil, err } - // Parse new field definition. + // Parse new column definition. var err error - stmt.FieldConstraint, stmt.TableConstraints, err = p.parseFieldDefinition(nil) + stmt.ColumnConstraint, stmt.TableConstraints, err = p.parseColumnDefinition() if err != nil { return nil, err } - if stmt.FieldConstraint.IsEmpty() { - return nil, &ParseError{Message: "cannot add a field with no constraint"} + if stmt.ColumnConstraint.IsEmpty() { + return nil, &ParseError{Message: "cannot add a column with no constraint"} } return &stmt, nil } -// parseAlterStatement parses a Alter query string and returns a Statement AST object. +// parseAlterStatement parses a Alter query string and returns a Statement AST row. func (p *Parser) parseAlterStatement() (statement.Statement, error) { var err error // Parse "TABLE". - if err := p.parseTokens(scanner.ALTER, scanner.TABLE); err != nil { + if err := p.ParseTokens(scanner.ALTER, scanner.TABLE); err != nil { return nil, err } @@ -70,7 +70,7 @@ func (p *Parser) parseAlterStatement() (statement.Statement, error) { case scanner.RENAME: return p.parseAlterTableRenameStatement(tableName) case scanner.ADD_KEYWORD: - return p.parseAlterTableAddFieldStatement(tableName) + return p.parseAlterTableAddColumnStatement(tableName) } return nil, newParseError(scanner.Tokstr(tok, lit), []string{"ADD", "RENAME"}, pos) diff --git a/internal/sql/parser/alter_test.go b/internal/sql/parser/alter_test.go index 6fa646f37..b8244881b 100644 --- a/internal/sql/parser/alter_test.go +++ b/internal/sql/parser/alter_test.go @@ -5,7 +5,6 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/testutil/assert" @@ -40,58 +39,53 @@ func TestParserAlterTable(t *testing.T) { } } -func TestParserAlterTableAddField(t *testing.T) { +func TestParserAlterTableAddColumn(t *testing.T) { tests := []struct { name string s string expected statement.Statement errored bool }{ - {"Basic", "ALTER TABLE foo ADD COLUMN bar", &statement.AlterTableAddColumnStmt{ - TableName: "foo", - FieldConstraint: &database.FieldConstraint{ - Field: "bar", - Type: types.TypeAny, - }, - }, false}, + {"Without type", "ALTER TABLE foo ADD COLUMN bar", nil, true}, {"With type", "ALTER TABLE foo ADD COLUMN bar integer", &statement.AlterTableAddColumnStmt{ TableName: "foo", - FieldConstraint: &database.FieldConstraint{ - Field: "bar", - Type: types.TypeInteger, + ColumnConstraint: &database.ColumnConstraint{ + Column: "bar", + Type: types.TypeInteger, }, }, false}, - {"With not null", "ALTER TABLE foo ADD COLUMN bar NOT NULL", &statement.AlterTableAddColumnStmt{ + {"With not null", "ALTER TABLE foo ADD COLUMN bar TEXT NOT NULL", &statement.AlterTableAddColumnStmt{ TableName: "foo", - FieldConstraint: &database.FieldConstraint{ - Field: "bar", + ColumnConstraint: &database.ColumnConstraint{ + Column: "bar", + Type: types.TypeText, IsNotNull: true, }, }, false}, - {"With primary key", "ALTER TABLE foo ADD COLUMN bar PRIMARY KEY", &statement.AlterTableAddColumnStmt{ + {"With primary key", "ALTER TABLE foo ADD COLUMN bar TEXT PRIMARY KEY", &statement.AlterTableAddColumnStmt{ TableName: "foo", - FieldConstraint: &database.FieldConstraint{ - Field: "bar", - Type: types.TypeAny, + ColumnConstraint: &database.ColumnConstraint{ + Column: "bar", + Type: types.TypeText, }, TableConstraints: database.TableConstraints{ &database.TableConstraint{ - Paths: object.Paths{object.NewPath("bar")}, + Columns: []string{"bar"}, PrimaryKey: true, }, }, }, false}, {"With multiple constraints", "ALTER TABLE foo ADD COLUMN bar integer NOT NULL DEFAULT 0", &statement.AlterTableAddColumnStmt{ TableName: "foo", - FieldConstraint: &database.FieldConstraint{ - Field: "bar", + ColumnConstraint: &database.ColumnConstraint{ + Column: "bar", Type: types.TypeInteger, IsNotNull: true, DefaultValue: expr.Constraint(expr.LiteralValue{Value: types.NewIntegerValue(0)}), }, }, false}, - {"With error / missing FIELD keyword", "ALTER TABLE foo ADD bar", nil, true}, - {"With error / missing field name", "ALTER TABLE foo ADD COLUMN", nil, true}, + {"With error / missing COLUMN keyword", "ALTER TABLE foo ADD bar", nil, true}, + {"With error / missing column name", "ALTER TABLE foo ADD COLUMN", nil, true}, } for _, test := range tests { diff --git a/internal/sql/parser/create.go b/internal/sql/parser/create.go index 67604551f..8ed6351e6 100644 --- a/internal/sql/parser/create.go +++ b/internal/sql/parser/create.go @@ -6,17 +6,15 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/tree" - "github.com/chaisql/chai/internal/types" ) -// parseCreateStatement parses a create string and returns a Statement AST object. +// parseCreateStatement parses a create string and returns a Statement AST row. func (p *Parser) parseCreateStatement() (statement.Statement, error) { // Parse "CREATE". - if err := p.parseTokens(scanner.CREATE); err != nil { + if err := p.ParseTokens(scanner.CREATE); err != nil { return nil, err } @@ -39,7 +37,7 @@ func (p *Parser) parseCreateStatement() (statement.Statement, error) { return nil, newParseError(scanner.Tokstr(tok, lit), []string{"TABLE", "INDEX", "SEQUENCE"}, pos) } -// parseCreateTableStatement parses a create table string and returns a Statement AST object. +// parseCreateTableStatement parses a create table string and returns a Statement AST row. // This function assumes the CREATE TABLE tokens have already been consumed. func (p *Parser) parseCreateTableStatement() (*statement.CreateTableStmt, error) { var stmt statement.CreateTableStmt @@ -63,38 +61,27 @@ func (p *Parser) parseCreateTableStatement() (*statement.CreateTableStmt, error) return nil, err } - if len(stmt.Info.FieldConstraints.Ordered) == 0 { - stmt.Info.FieldConstraints.AllowExtraFields = true - } return &stmt, err } func (p *Parser) parseConstraints(stmt *statement.CreateTableStmt) error { // Parse ( token. - if ok, err := p.parseOptional(scanner.LPAREN); !ok || err != nil { - return err + tok, pos, lit := p.ScanIgnoreWhitespace() + if tok != scanner.LPAREN { + return newParseError(scanner.Tokstr(tok, lit), []string{"("}, pos) } // if set to true, the parser must no longer - // expect field definitions, but only table constraints. + // expect column definitions, but only table constraints. var parsingTableConstraints bool - stmt.Info.FieldConstraints, _ = database.NewFieldConstraints() + stmt.Info.ColumnConstraints, _ = database.NewColumnConstraints() var allTableConstraints []*database.TableConstraint // Parse constraints. for { - // start with the ellipsis token. - // if found, stop parsing constraints, as it should be the last one. - tok, _, _ := p.ScanIgnoreWhitespace() - if tok == scanner.ELLIPSIS { - stmt.Info.FieldConstraints.AllowExtraFields = true - break - } - p.Unscan() - - // then we check if it is a table constraint, + // check if it is a table constraint, // as it's easier to determine tc, err := p.parseTableConstraint(stmt) if err != nil { @@ -112,14 +99,14 @@ func (p *Parser) parseConstraints(stmt *statement.CreateTableStmt) error { allTableConstraints = append(allTableConstraints, tc) } - // if set to false, we are still parsing field definitions + // if set to false, we are still parsing column definitions if !parsingTableConstraints { - fc, tcs, err := p.parseFieldDefinition(object.Path{}) + cc, tcs, err := p.parseColumnDefinition() if err != nil { return err } - err = stmt.Info.AddFieldConstraint(fc) + err = stmt.Info.AddColumnConstraint(cc) if err != nil { return err } @@ -134,7 +121,7 @@ func (p *Parser) parseConstraints(stmt *statement.CreateTableStmt) error { } // Parse required ) token. - if err := p.parseTokens(scanner.RPAREN); err != nil { + if err := p.ParseTokens(scanner.RPAREN); err != nil { return err } @@ -149,59 +136,36 @@ func (p *Parser) parseConstraints(stmt *statement.CreateTableStmt) error { return nil } -func (p *Parser) parseFieldDefinition(parent object.Path) (*database.FieldConstraint, []*database.TableConstraint, error) { +func (p *Parser) parseColumnDefinition() (*database.ColumnConstraint, []*database.TableConstraint, error) { var err error - var fc database.FieldConstraint + var cc database.ColumnConstraint - fc.Field, err = p.parseIdent() + cc.Column, err = p.parseIdent() if err != nil { return nil, nil, err } - fc.Type, err = p.parseType() + cc.Type, err = p.parseType() if err != nil { - p.Unscan() + return nil, nil, err } - path := parent.ExtendField(fc.Field) - var tcs []*database.TableConstraint - if fc.Type.IsAny() || fc.Type == types.TypeObject { - anon, nestedTCs, err := p.parseObjectDefinition(path) - if err != nil { - return nil, nil, err - } - if anon != nil { - fc.Type = types.TypeObject - fc.AnonymousType = anon - } else if fc.Type == types.TypeObject { - // if the field constraint is an object but doesn't have any constraint, - // its AllowExtraFields is set to true - // i.e CREATE TABLE foo(a OBJECT) -> CREATE TABLE foo(a OBJECT (...)) - fc.AnonymousType = &database.AnonymousType{} - fc.AnonymousType.FieldConstraints.AllowExtraFields = true - } - - if len(nestedTCs) != 0 { - tcs = append(tcs, nestedTCs...) - } - } - LOOP: for { tok, pos, lit := p.ScanIgnoreWhitespace() switch tok { case scanner.PRIMARY: // Parse "KEY" - if err := p.parseTokens(scanner.KEY); err != nil { + if err := p.ParseTokens(scanner.KEY); err != nil { return nil, nil, err } tc := database.TableConstraint{ PrimaryKey: true, - Paths: object.Paths{path}, + Columns: []string{cc.Column}, } // if ASC is set, we ignore it, otherwise we check for DESC @@ -222,19 +186,19 @@ LOOP: tcs = append(tcs, &tc) case scanner.NOT: // Parse "NULL" - if err := p.parseTokens(scanner.NULL); err != nil { + if err := p.ParseTokens(scanner.NULL); err != nil { return nil, nil, err } // if it's already not null we return an error - if fc.IsNotNull { + if cc.IsNotNull { return nil, nil, newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos) } - fc.IsNotNull = true + cc.IsNotNull = true case scanner.DEFAULT: // if it has already a default value we return an error - if fc.DefaultValue != nil { + if cc.DefaultValue != nil { return nil, nil, newParseError(scanner.Tokstr(tok, lit), []string{"CONSTRAINT", ")"}, pos) } @@ -275,7 +239,7 @@ LOOP: return nil, nil, err } - fc.DefaultValue = expr.Constraint(e) + cc.DefaultValue = expr.Constraint(e) if withParentheses { _, err = p.parseOptional(scanner.RPAREN) @@ -285,18 +249,18 @@ LOOP: } case scanner.UNIQUE: tcs = append(tcs, &database.TableConstraint{ - Unique: true, - Paths: object.Paths{path}, + Unique: true, + Columns: []string{cc.Column}, }) case scanner.CHECK: - e, paths, err := p.parseCheckConstraint() + e, cols, err := p.parseCheckConstraint() if err != nil { return nil, nil, err } tcs = append(tcs, &database.TableConstraint{ - Check: expr.Constraint(e), - Paths: paths, + Check: expr.Constraint(e), + Columns: cols, }) default: p.Unscan() @@ -304,53 +268,7 @@ LOOP: } } - return &fc, tcs, nil -} - -func (p *Parser) parseObjectDefinition(parent object.Path) (*database.AnonymousType, []*database.TableConstraint, error) { - err := p.parseTokens(scanner.LPAREN) - if err != nil { - p.Unscan() - return nil, nil, nil - } - - var anon database.AnonymousType - var nestedTcs []*database.TableConstraint - - for { - // start with the ellipsis token. - // if found, stop parsing constraints, as it should be the last one. - tok, _, _ := p.ScanIgnoreWhitespace() - if tok == scanner.ELLIPSIS { - anon.FieldConstraints.AllowExtraFields = true - break - } - p.Unscan() - - fc, tcs, err := p.parseFieldDefinition(parent) - if err != nil { - return nil, nil, err - } - - err = anon.AddFieldConstraint(fc) - if err != nil { - return nil, nil, err - } - - nestedTcs = append(nestedTcs, tcs...) - - if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { - p.Unscan() - break - } - } - - err = p.parseTokens(scanner.RPAREN) - if err != nil { - return nil, nil, err - } - - return &anon, nestedTcs, nil + return &cc, tcs, nil } func (p *Parser) parseTableConstraint(stmt *statement.CreateTableStmt) (*database.TableConstraint, error) { @@ -376,41 +294,41 @@ func (p *Parser) parseTableConstraint(stmt *statement.CreateTableStmt) (*databas switch tok { case scanner.PRIMARY: // Parse "KEY (" - err = p.parseTokens(scanner.KEY) + err = p.ParseTokens(scanner.KEY) if err != nil { return nil, err } tc.PrimaryKey = true - tc.Paths, order, err = p.parsePathList() + tc.Columns, order, err = p.parseColumnList() if err != nil { return nil, err } - if len(tc.Paths) == 0 { + if len(tc.Columns) == 0 { tok, pos, lit := p.ScanIgnoreWhitespace() return nil, newParseError(scanner.Tokstr(tok, lit), []string{"PATHS"}, pos) } tc.SortOrder = order case scanner.UNIQUE: tc.Unique = true - tc.Paths, order, err = p.parsePathList() + tc.Columns, order, err = p.parseColumnList() if err != nil { return nil, err } - if len(tc.Paths) == 0 { + if len(tc.Columns) == 0 { tok, pos, lit := p.ScanIgnoreWhitespace() return nil, newParseError(scanner.Tokstr(tok, lit), []string{"PATHS"}, pos) } tc.SortOrder = order case scanner.CHECK: - e, paths, err := p.parseCheckConstraint() + e, columns, err := p.parseCheckConstraint() if err != nil { return nil, err } tc.Check = expr.Constraint(e) - tc.Paths = paths + tc.Columns = columns default: if requiresTc { return nil, newParseError(scanner.Tokstr(tok, lit), []string{"PRIMARY", "UNIQUE", "CHECK"}, pos) @@ -423,7 +341,7 @@ func (p *Parser) parseTableConstraint(stmt *statement.CreateTableStmt) (*databas return &tc, nil } -// parseCreateIndexStatement parses a create index string and returns a Statement AST object. +// parseCreateIndexStatement parses a create index string and returns a Statement AST row. // This function assumes the CREATE INDEX or CREATE UNIQUE INDEX tokens have already been consumed. func (p *Parser) parseCreateIndexStatement(unique bool) (*statement.CreateIndexStmt, error) { var err error @@ -448,7 +366,7 @@ func (p *Parser) parseCreateIndexStatement(unique bool) (*statement.CreateIndexS } // Parse "ON" - if err := p.parseTokens(scanner.ON); err != nil { + if err := p.ParseTokens(scanner.ON); err != nil { return nil, err } @@ -458,16 +376,16 @@ func (p *Parser) parseCreateIndexStatement(unique bool) (*statement.CreateIndexS return nil, err } - paths, order, err := p.parsePathList() + columns, order, err := p.parseColumnList() if err != nil { return nil, err } - if len(paths) == 0 { + if len(columns) == 0 { tok, pos, lit := p.ScanIgnoreWhitespace() return nil, newParseError(scanner.Tokstr(tok, lit), []string{"("}, pos) } - stmt.Info.Paths = paths + stmt.Info.Columns = columns stmt.Info.KeySortOrder = order return &stmt, nil @@ -703,9 +621,9 @@ func (p *Parser) parseCreateSequenceStatement() (*statement.CreateSequenceStmt, // parseCheckConstraint parses a check constraint. // it assumes the CHECK token has already been parsed. -func (p *Parser) parseCheckConstraint() (expr.Expr, []object.Path, error) { +func (p *Parser) parseCheckConstraint() (expr.Expr, []string, error) { // Parse "(" - err := p.parseTokens(scanner.LPAREN) + err := p.ParseTokens(scanner.LPAREN) if err != nil { return nil, nil, err } @@ -715,22 +633,22 @@ func (p *Parser) parseCheckConstraint() (expr.Expr, []object.Path, error) { return nil, nil, err } - var paths []object.Path + var columns []string // extract all the paths from the expression expr.Walk(e, func(e expr.Expr) bool { switch t := e.(type) { - case expr.Path: - pt := object.Path(t) + case expr.Column: + scol := string(t) // ensure that the path is not already in the list found := false - for _, p := range paths { - if p.IsEqual(pt) { + for _, c := range columns { + if c == scol { found = true break } } if !found { - paths = append(paths, object.Path(t)) + columns = append(columns, scol) } } @@ -738,10 +656,10 @@ func (p *Parser) parseCheckConstraint() (expr.Expr, []object.Path, error) { }) // Parse ")" - err = p.parseTokens(scanner.RPAREN) + err = p.ParseTokens(scanner.RPAREN) if err != nil { return nil, nil, err } - return e, paths, nil + return e, columns, nil } diff --git a/internal/sql/parser/create_test.go b/internal/sql/parser/create_test.go index 94931c6a7..88474f0fa 100644 --- a/internal/sql/parser/create_test.go +++ b/internal/sql/parser/create_test.go @@ -5,10 +5,8 @@ import ( "testing" "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" - "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/stretchr/testify/require" ) @@ -22,27 +20,27 @@ func TestParserCreateIndex(t *testing.T) { }{ {"Basic", "CREATE INDEX idx ON test (foo)", &statement.CreateIndexStmt{ Info: database.IndexInfo{ - IndexName: "idx", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{object.Path(testutil.ParseObjectPath(t, "foo"))}, + IndexName: "idx", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }}, false}, - {"If not exists", "CREATE INDEX IF NOT EXISTS idx ON test (foo.bar[1])", &statement.CreateIndexStmt{ + {"If not exists", "CREATE INDEX IF NOT EXISTS idx ON test (foo)", &statement.CreateIndexStmt{ Info: database.IndexInfo{ - IndexName: "idx", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{object.Path(testutil.ParseObjectPath(t, "foo.bar[1]"))}, + IndexName: "idx", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, }, IfNotExists: true}, false}, - {"Unique", "CREATE UNIQUE INDEX IF NOT EXISTS idx ON test (foo[3].baz)", &statement.CreateIndexStmt{ + {"Unique", "CREATE UNIQUE INDEX IF NOT EXISTS idx ON test (foo)", &statement.CreateIndexStmt{ Info: database.IndexInfo{ - IndexName: "idx", Owner: database.Owner{TableName: "test"}, Paths: []object.Path{object.Path(testutil.ParseObjectPath(t, "foo[3].baz"))}, Unique: true, + IndexName: "idx", Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, Unique: true, }, IfNotExists: true}, false}, - {"No name", "CREATE UNIQUE INDEX ON test (foo[3].baz)", &statement.CreateIndexStmt{ - Info: database.IndexInfo{Owner: database.Owner{TableName: "test"}, Paths: []object.Path{object.Path(testutil.ParseObjectPath(t, "foo[3].baz"))}, Unique: true}}, false}, - {"No name with IF NOT EXISTS", "CREATE UNIQUE INDEX IF NOT EXISTS ON test (foo[3].baz)", nil, true}, + {"No name", "CREATE UNIQUE INDEX ON test (foo)", &statement.CreateIndexStmt{ + Info: database.IndexInfo{Owner: database.Owner{TableName: "test"}, Columns: []string{"foo"}, Unique: true}}, false}, + {"No name with IF NOT EXISTS", "CREATE UNIQUE INDEX IF NOT EXISTS ON test (foo)", nil, true}, {"More than 1 path", "CREATE INDEX idx ON test (foo, bar)", &statement.CreateIndexStmt{ Info: database.IndexInfo{ IndexName: "idx", Owner: database.Owner{TableName: "test"}, - Paths: []object.Path{ - object.Path(testutil.ParseObjectPath(t, "foo")), - object.Path(testutil.ParseObjectPath(t, "bar")), + Columns: []string{ + "foo", + "bar", }, }, }, diff --git a/internal/sql/parser/delete.go b/internal/sql/parser/delete.go index c90867ab6..b52b9d648 100644 --- a/internal/sql/parser/delete.go +++ b/internal/sql/parser/delete.go @@ -7,13 +7,13 @@ import ( "github.com/chaisql/chai/internal/sql/scanner" ) -// parseDeleteStatement parses a delete string and returns a Statement AST object. +// parseDeleteStatement parses a delete string and returns a Statement AST row. func (p *Parser) parseDeleteStatement() (statement.Statement, error) { stmt := statement.NewDeleteStatement() var err error // Parse "DELETE FROM". - if err := p.parseTokens(scanner.DELETE, scanner.FROM); err != nil { + if err := p.ParseTokens(scanner.DELETE, scanner.FROM); err != nil { return nil, err } diff --git a/internal/sql/parser/delete_test.go b/internal/sql/parser/delete_test.go index abb5b9621..a5d4b7c10 100644 --- a/internal/sql/parser/delete_test.go +++ b/internal/sql/parser/delete_test.go @@ -65,7 +65,7 @@ func TestParserDelete(t *testing.T) { t.Run(test.name, func(t *testing.T) { db := testutil.NewTestDB(t) - testutil.MustExec(t, db, nil, "CREATE TABLE test") + testutil.MustExec(t, db, nil, "CREATE TABLE test(age int)") q, err := parser.ParseQuery(test.s) assert.NoError(t, err) diff --git a/internal/sql/parser/drop.go b/internal/sql/parser/drop.go index a65d2ef12..d063a4363 100644 --- a/internal/sql/parser/drop.go +++ b/internal/sql/parser/drop.go @@ -7,10 +7,10 @@ import ( "github.com/chaisql/chai/internal/sql/scanner" ) -// parseDropStatement parses a drop string and returns a Statement AST object. +// parseDropStatement parses a drop string and returns a Statement AST row. func (p *Parser) parseDropStatement() (statement.Statement, error) { // Parse "DROP". - if err := p.parseTokens(scanner.DROP); err != nil { + if err := p.ParseTokens(scanner.DROP); err != nil { return nil, err } @@ -27,7 +27,7 @@ func (p *Parser) parseDropStatement() (statement.Statement, error) { return nil, newParseError(scanner.Tokstr(tok, lit), []string{"TABLE", "INDEX", "SEQUENCE"}, pos) } -// parseDropTableStatement parses a drop table string and returns a Statement AST object. +// parseDropTableStatement parses a drop table string and returns a Statement AST row. // This function assumes the DROP TABLE tokens have already been consumed. func (p *Parser) parseDropTableStatement() (statement.DropTableStmt, error) { var stmt statement.DropTableStmt @@ -49,7 +49,7 @@ func (p *Parser) parseDropTableStatement() (statement.DropTableStmt, error) { return stmt, nil } -// parseDropIndexStatement parses a drop index string and returns a Statement AST object. +// parseDropIndexStatement parses a drop index string and returns a Statement AST row. // This function assumes the DROP INDEX tokens have already been consumed. func (p *Parser) parseDropIndexStatement() (statement.DropIndexStmt, error) { var stmt statement.DropIndexStmt @@ -71,7 +71,7 @@ func (p *Parser) parseDropIndexStatement() (statement.DropIndexStmt, error) { return stmt, nil } -// parseDropSequenceStatement parses a drop sequence string and returns a Statement AST object. +// parseDropSequenceStatement parses a drop sequence string and returns a Statement AST row. // This function assumes the DROP SEQUENCE tokens have already been consumed. func (p *Parser) parseDropSequenceStatement() (statement.DropSequenceStmt, error) { var stmt statement.DropSequenceStmt diff --git a/internal/sql/parser/explain.go b/internal/sql/parser/explain.go index 7da5acf9b..b7687d57a 100644 --- a/internal/sql/parser/explain.go +++ b/internal/sql/parser/explain.go @@ -5,11 +5,11 @@ import ( "github.com/chaisql/chai/internal/sql/scanner" ) -// parseExplainStatement parses any statement and returns an ExplainStmt object. +// parseExplainStatement parses any statement and returns an ExplainStmt row. // This function assumes the EXPLAIN token has already been consumed. func (p *Parser) parseExplainStatement() (statement.Statement, error) { // Parse "EXPLAIN". - if err := p.parseTokens(scanner.EXPLAIN); err != nil { + if err := p.ParseTokens(scanner.EXPLAIN); err != nil { return nil, err } diff --git a/internal/sql/parser/expr.go b/internal/sql/parser/expr.go index 0f1b17cca..3f0e70cb0 100644 --- a/internal/sql/parser/expr.go +++ b/internal/sql/parser/expr.go @@ -3,12 +3,13 @@ package parser import ( "encoding/hex" "fmt" + "math" "strconv" "strings" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/expr/functions" "github.com/chaisql/chai/internal/sql/scanner" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -167,7 +168,7 @@ func (p *Parser) parseOperator(minPrecedence int, allowed ...scanner.Token) (fun if err != nil { return nil, op, err } - err = p.parseTokens(scanner.AND) + err = p.ParseTokens(scanner.AND) if err != nil { return nil, op, err } @@ -195,7 +196,7 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) { return p.parseCastExpression() case scanner.IDENT: tok1, _, _ := p.ScanIgnoreWhitespace() - // if the next token is a left parenthesis, this is a global function + // if the next token is a left parenthesis, this is a function if tok1 == scanner.LPAREN { p.Unscan() if tk, _, _ := p.s.Curr(); tk == scanner.WS { @@ -203,22 +204,6 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) { } p.Unscan() return p.parseFunction() - } else if tok1 == scanner.DOT { - // it may be a package function instead. - if tok2, _, _ := p.Scan(); tok2 == scanner.IDENT { - if tok3, _, _ := p.Scan(); tok3 == scanner.LPAREN { - p.Unscan() - p.Unscan() - p.Unscan() - p.Unscan() - return p.parseFunction() - } else { - p.Unscan() - p.Unscan() - } - } else { - p.Unscan() - } } p.Unscan() if tk, _, _ := p.s.Curr(); tk == scanner.WS { @@ -227,12 +212,7 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) { p.Unscan() - field, err := p.parsePath() - if err != nil { - return nil, err - } - fs := expr.Path(field) - return fs, nil + return p.parseColumn() case scanner.NAMEDPARAM: if len(lit) == 1 { return nil, errors.WithStack(&ParseError{Message: "missing param name"}) @@ -286,20 +266,16 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) { } return nil, errors.WithStack(&ParseError{Message: "unable to parse integer", Pos: pos}) } - return expr.LiteralValue{Value: types.NewIntegerValue(v)}, nil + if v > math.MaxInt32 || v < math.MinInt32 { + return expr.LiteralValue{Value: types.NewBigintValue(v)}, nil + } + return expr.LiteralValue{Value: types.NewIntegerValue(int32(v))}, nil case scanner.TRUE, scanner.FALSE: return expr.LiteralValue{Value: types.NewBooleanValue(tok == scanner.TRUE)}, nil case scanner.NULL: return expr.LiteralValue{Value: types.NewNullValue()}, nil case scanner.MUL: return expr.Wildcard{}, nil - case scanner.LBRACKET: - p.Unscan() - e, err := p.ParseObject() - return e, err - case scanner.LSBRACKET: - p.Unscan() - return p.parseExprList(scanner.LSBRACKET, scanner.RSBRACKET) case scanner.LPAREN: e, err := p.ParseExpr() if err != nil { @@ -329,7 +305,7 @@ func (p *Parser) parseUnaryExpr(allowed ...scanner.Token) (expr.Expr, error) { } return expr.Not(e), nil case scanner.NEXT: - err := p.parseTokens(scanner.VALUE, scanner.FOR) + err := p.ParseTokens(scanner.VALUE, scanner.FOR) if err != nil { return nil, err } @@ -429,16 +405,10 @@ func (p *Parser) parseParam() (expr.Expr, error) { func (p *Parser) parseType() (types.Type, error) { tok, pos, lit := p.ScanIgnoreWhitespace() switch tok { - case scanner.TYPEANY: - return types.TypeAny, nil - case scanner.TYPEARRAY: - return types.TypeArray, nil case scanner.TYPEBLOB, scanner.TYPEBYTES: return types.TypeBlob, nil case scanner.TYPEBOOL, scanner.TYPEBOOLEAN: return types.TypeBoolean, nil - case scanner.TYPEOBJECT: - return types.TypeObject, nil case scanner.TYPEREAL: return types.TypeDouble, nil case scanner.TYPEDOUBLE: @@ -448,9 +418,11 @@ func (p *Parser) parseType() (types.Type, error) { } p.Unscan() return types.TypeDouble, nil - case scanner.TYPEINTEGER, scanner.TYPEINT, scanner.TYPEINT2, scanner.TYPEINT8, scanner.TYPETINYINT, - scanner.TYPEBIGINT, scanner.TYPEMEDIUMINT, scanner.TYPESMALLINT: + case scanner.TYPEINTEGER, scanner.TYPEINT, scanner.TYPEINT2, scanner.TYPETINYINT, + scanner.TYPEMEDIUMINT, scanner.TYPESMALLINT: return types.TypeInteger, nil + case scanner.TYPEINT8, scanner.TYPEBIGINT: + return types.TypeBigint, nil case scanner.TYPETEXT: return types.TypeText, nil case scanner.TYPETIMESTAMP: @@ -475,135 +447,15 @@ func (p *Parser) parseType() (types.Type, error) { return 0, newParseError(scanner.Tokstr(tok, lit), []string{"type"}, pos) } -// ParseObject parses an object -func (p *Parser) ParseObject() (*expr.KVPairs, error) { - // Parse { token. - if err := p.parseTokens(scanner.LBRACKET); err != nil { - return nil, err - } - - var pairs expr.KVPairs - pairs.SelfReferenced = true - var pair expr.KVPair - var err error - - // Parse kv pairs. - for { - if pair, err = p.parseKV(); err != nil { - p.Unscan() - break - } - - pairs.Pairs = append(pairs.Pairs, pair) - - if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { - p.Unscan() - break - } - } - - // Parse required } token. - if err := p.parseTokens(scanner.RBRACKET); err != nil { - return nil, err - } - - return &pairs, nil -} - -// parseKV parses a key-value pair in the form IDENT : Expr. -func (p *Parser) parseKV() (expr.KVPair, error) { - var k string - - tok, pos, lit := p.ScanIgnoreWhitespace() - if tok == scanner.IDENT || tok == scanner.STRING { - k = lit - } else { - return expr.KVPair{}, newParseError(scanner.Tokstr(tok, lit), []string{"ident", "string"}, pos) - } - - if err := p.parseTokens(scanner.COLON); err != nil { - p.Unscan() - return expr.KVPair{}, err - } - - e, err := p.ParseExpr() - if err != nil { - return expr.KVPair{}, err - } - - return expr.KVPair{ - K: k, - V: e, - }, nil -} - // parsePath parses a path to a specific value. -func (p *Parser) parsePath() (object.Path, error) { - var path object.Path +func (p *Parser) parseColumn() (expr.Column, error) { // parse first mandatory ident - chunk, err := p.parseIdent() + col, err := p.parseIdent() if err != nil { - return nil, err - } - path = append(path, object.PathFragment{ - FieldName: chunk, - }) - -LOOP: - for { - // scan the very next token. - // if can be either a '.' or a '[' - // Otherwise, unscan and return the path - tok, _, _ := p.Scan() - switch tok { - case scanner.DOT: - // scan the next token for an ident - tok, pos, lit := p.Scan() - if tok != scanner.IDENT { - return nil, newParseError(lit, []string{"identifier"}, pos) - } - path = append(path, object.PathFragment{ - FieldName: lit, - }) - case scanner.LSBRACKET: - // the next token can be either an integer or a quoted string - // if it's an integer, we have an array index - // if it's a quoted string, we have a field name - tok, pos, lit := p.Scan() - switch tok { - case scanner.INTEGER: - // is the number negative? - if lit[0] == '-' { - return nil, newParseError(lit, []string{"integer"}, pos) - } - // is the number too big? - if len(lit) > 10 { - return nil, newParseError(lit, []string{"integer"}, pos) - } - // parse the integer - i, err := strconv.ParseInt(lit, 10, 64) - if err != nil { - return nil, newParseError(lit, []string{"integer"}, pos) - } - path = append(path, object.PathFragment{ - ArrayIndex: int(i), - }) - case scanner.STRING: - path = append(path, object.PathFragment{ - FieldName: lit, - }) - } - // scan the next token for a closing left bracket - if err := p.parseTokens(scanner.RSBRACKET); err != nil { - return nil, err - } - default: - p.Unscan() - break LOOP - } + return "", err } - return path, nil + return expr.Column(col), nil } func (p *Parser) parseExprListUntil(rightToken scanner.Token) (expr.LiteralExprList, error) { @@ -627,7 +479,7 @@ func (p *Parser) parseExprListUntil(rightToken scanner.Token) (expr.LiteralExprL } // Parse required ) or ] token. - if err := p.parseTokens(rightToken); err != nil { + if err := p.ParseTokens(rightToken); err != nil { return nil, err } @@ -636,7 +488,7 @@ func (p *Parser) parseExprListUntil(rightToken scanner.Token) (expr.LiteralExprL func (p *Parser) parseExprList(leftToken, rightToken scanner.Token) (expr.LiteralExprList, error) { // Parse ( or [ token. - if err := p.parseTokens(leftToken); err != nil { + if err := p.ParseTokens(leftToken); err != nil { return nil, err } @@ -653,26 +505,14 @@ func (p *Parser) parseFunction() (expr.Expr, error) { return nil, err } - // Parse optional package name - var pkgName string - if tok, _, _ := p.Scan(); tok == scanner.DOT { - pkgName = funcName - funcName, err = p.parseIdent() - if err != nil { - return nil, err - } - } else { - p.Unscan() - } - // Parse required ( token. - if err := p.parseTokens(scanner.LPAREN); err != nil { + if err := p.ParseTokens(scanner.LPAREN); err != nil { return nil, err } // Check if the function is called without arguments. if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.RPAREN { - def, err := p.packagesTable.GetFunc(pkgName, funcName) + def, err := functions.GetFunc(funcName) if err != nil { return nil, err } @@ -698,11 +538,11 @@ func (p *Parser) parseFunction() (expr.Expr, error) { } // Parse required ) token. - if err := p.parseTokens(scanner.RPAREN); err != nil { + if err := p.ParseTokens(scanner.RPAREN); err != nil { return nil, err } - def, err := p.packagesTable.GetFunc(pkgName, funcName) + def, err := functions.GetFunc(funcName) if err != nil { return nil, err } @@ -712,7 +552,7 @@ func (p *Parser) parseFunction() (expr.Expr, error) { // parseCastExpression parses a string of the form CAST(expr AS type). func (p *Parser) parseCastExpression() (expr.Expr, error) { // Parse required CAST and ( tokens. - if err := p.parseTokens(scanner.CAST, scanner.LPAREN); err != nil { + if err := p.ParseTokens(scanner.CAST, scanner.LPAREN); err != nil { return nil, err } @@ -723,7 +563,7 @@ func (p *Parser) parseCastExpression() (expr.Expr, error) { } // Parse required AS token. - if err := p.parseTokens(scanner.AS); err != nil { + if err := p.ParseTokens(scanner.AS); err != nil { return nil, err } @@ -734,7 +574,7 @@ func (p *Parser) parseCastExpression() (expr.Expr, error) { } // Parse required ) token. - if err := p.parseTokens(scanner.RPAREN); err != nil { + if err := p.ParseTokens(scanner.RPAREN); err != nil { return nil, err } diff --git a/internal/sql/parser/expr_test.go b/internal/sql/parser/expr_test.go index 3cfd49134..5547efe44 100644 --- a/internal/sql/parser/expr_test.go +++ b/internal/sql/parser/expr_test.go @@ -6,7 +6,6 @@ import ( "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" @@ -41,41 +40,6 @@ func TestParserExpr(t *testing.T) { {"blob as hex string", `'\xff'`, testutil.BlobValue([]byte{255}), false}, {"invalid blob hex string", `'\xzz'`, nil, true}, - // objects - {"empty object", `{}`, &expr.KVPairs{SelfReferenced: true}, false}, - {"object values", `{a: 1, b: 1.0, c: true, d: 'string', e: "string", f: {foo: 'bar'}, g: h.i.j, k: [1, 2, 3]}`, - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "a", V: testutil.IntegerValue(1)}, - {K: "b", V: testutil.DoubleValue(1)}, - {K: "c", V: testutil.BoolValue(true)}, - {K: "d", V: testutil.TextValue("string")}, - {K: "e", V: testutil.TextValue("string")}, - {K: "f", V: &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "foo", V: testutil.TextValue("bar")}, - }}}, - {K: "g", V: testutil.ParsePath(t, "h.i.j")}, - {K: "k", V: expr.LiteralExprList{testutil.IntegerValue(1), testutil.IntegerValue(2), testutil.IntegerValue(3)}}, - }}, - false}, - {"object keys", `{a: 1, "foo bar __&&))": 1, 'ola ': 1}`, - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "a", V: testutil.IntegerValue(1)}, - {K: "foo bar __&&))", V: testutil.IntegerValue(1)}, - {K: "ola ", V: testutil.IntegerValue(1)}, - }}, - false}, - {"object keys: same key", `{a: 1, a: 2, "a": 3}`, - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "a", V: testutil.IntegerValue(1)}, - {K: "a", V: testutil.IntegerValue(2)}, - {K: "a", V: testutil.IntegerValue(3)}, - }}, false}, - {"bad object keys: param", `{?: 1}`, nil, true}, - {"bad object keys: dot", `{a.b: 1}`, nil, true}, - {"bad object keys: space", `{a b: 1}`, nil, true}, - {"bad object: missing right bracket", `{a: 1`, nil, true}, - {"bad object: missing colon", `{a: 1, 'b'}`, nil, true}, - // parentheses {"parentheses: empty", "()", nil, true}, {"parentheses: values", `(1)`, @@ -97,39 +61,28 @@ func TestParserExpr(t *testing.T) { ), ), }, false}, - {"list with brackets: empty", "[]", expr.LiteralExprList(nil), false}, - {"list with brackets: values", `[1, true, {a: 1}, a.b.c, (-1), [-1]]`, - expr.LiteralExprList{ - testutil.IntegerValue(1), - testutil.BoolValue(true), - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{{K: "a", V: testutil.IntegerValue(1)}}}, - testutil.ParsePath(t, "a.b.c"), - expr.Parentheses{E: testutil.IntegerValue(-1)}, - expr.LiteralExprList{testutil.IntegerValue(-1)}, - }, false}, - {"list with brackets: missing bracket", `[1, true, {a: 1}, a.b.c, (-1), [-1]`, nil, true}, // operators - {"=", "age = 10", expr.Eq(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"!=", "age != 10", expr.Neq(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {">", "age > 10", expr.Gt(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {">=", "age >= 10", expr.Gte(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"<", "age < 10", expr.Lt(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"<=", "age <= 10", expr.Lte(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, + {"=", "age = 10", expr.Eq(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"!=", "age != 10", expr.Neq(expr.Column("age"), testutil.IntegerValue(10)), false}, + {">", "age > 10", expr.Gt(expr.Column("age"), testutil.IntegerValue(10)), false}, + {">=", "age >= 10", expr.Gte(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"<", "age < 10", expr.Lt(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"<=", "age <= 10", expr.Lte(expr.Column("age"), testutil.IntegerValue(10)), false}, {"BETWEEN", "1 BETWEEN 10 AND 11", expr.Between(testutil.IntegerValue(10))(testutil.IntegerValue(1), testutil.IntegerValue(11)), false}, - {"+", "age + 10", expr.Add(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"-", "age - 10", expr.Sub(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"*", "age * 10", expr.Mul(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"/", "age / 10", expr.Div(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"%", "age % 10", expr.Mod(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"&", "age & 10", expr.BitwiseAnd(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), false}, - {"||", "name || 'foo'", expr.Concat(testutil.ParsePath(t, "name"), testutil.TextValue("foo")), false}, - {"IN", "age IN ages", expr.In(testutil.ParsePath(t, "age"), testutil.ParsePath(t, "ages")), false}, - {"NOT IN", "age NOT IN ages", expr.NotIn(testutil.ParsePath(t, "age"), testutil.ParsePath(t, "ages")), false}, - {"IS", "age IS NULL", expr.Is(testutil.ParsePath(t, "age"), testutil.NullValue()), false}, - {"IS NOT", "age IS NOT NULL", expr.IsNot(testutil.ParsePath(t, "age"), testutil.NullValue()), false}, - {"LIKE", "name LIKE 'foo'", expr.Like(testutil.ParsePath(t, "name"), testutil.TextValue("foo")), false}, - {"NOT LIKE", "name NOT LIKE 'foo'", expr.NotLike(testutil.ParsePath(t, "name"), testutil.TextValue("foo")), false}, + {"+", "age + 10", expr.Add(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"-", "age - 10", expr.Sub(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"*", "age * 10", expr.Mul(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"/", "age / 10", expr.Div(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"%", "age % 10", expr.Mod(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"&", "age & 10", expr.BitwiseAnd(expr.Column("age"), testutil.IntegerValue(10)), false}, + {"||", "name || 'foo'", expr.Concat(expr.Column("name"), testutil.TextValue("foo")), false}, + {"IN", "age IN ages", expr.In(expr.Column("age"), expr.Column("ages")), false}, + {"NOT IN", "age NOT IN ages", expr.NotIn(expr.Column("age"), expr.Column("ages")), false}, + {"IS", "age IS NULL", expr.Is(expr.Column("age"), testutil.NullValue()), false}, + {"IS NOT", "age IS NOT NULL", expr.IsNot(expr.Column("age"), testutil.NullValue()), false}, + {"LIKE", "name LIKE 'foo'", expr.Like(expr.Column("name"), testutil.TextValue("foo")), false}, + {"NOT LIKE", "name NOT LIKE 'foo'", expr.NotLike(expr.Column("name"), testutil.TextValue("foo")), false}, {"NOT =", "name NOT = 'foo'", nil, true}, {"precedence", "4 > 1 + 2", expr.Gt( testutil.IntegerValue(4), @@ -140,26 +93,26 @@ func TestParserExpr(t *testing.T) { ), false}, {"AND", "age = 10 AND age <= 11", expr.And( - expr.Eq(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), - expr.Lte(testutil.ParsePath(t, "age"), testutil.IntegerValue(11)), + expr.Eq(expr.Column("age"), testutil.IntegerValue(10)), + expr.Lte(expr.Column("age"), testutil.IntegerValue(11)), ), false}, {"OR", "age = 10 OR age = 11", expr.Or( - expr.Eq(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), - expr.Eq(testutil.ParsePath(t, "age"), testutil.IntegerValue(11)), + expr.Eq(expr.Column("age"), testutil.IntegerValue(10)), + expr.Eq(expr.Column("age"), testutil.IntegerValue(11)), ), false}, {"AND then OR", "age >= 10 AND age > $age OR age < 10.4", expr.Or( expr.And( - expr.Gte(testutil.ParsePath(t, "age"), testutil.IntegerValue(10)), - expr.Gt(testutil.ParsePath(t, "age"), expr.NamedParam("age")), + expr.Gte(expr.Column("age"), testutil.IntegerValue(10)), + expr.Gt(expr.Column("age"), expr.NamedParam("age")), ), - expr.Lt(testutil.ParsePath(t, "age"), testutil.DoubleValue(10.4)), + expr.Lt(expr.Column("age"), testutil.DoubleValue(10.4)), ), false}, - {"with NULL", "age > NULL", expr.Gt(testutil.ParsePath(t, "age"), testutil.NullValue()), false}, + {"with NULL", "age > NULL", expr.Gt(expr.Column("age"), testutil.NullValue()), false}, // unary operators - {"CAST", "CAST(a.b[1][0] AS TEXT)", expr.Cast{Expr: testutil.ParsePath(t, "a.b[1][0]"), CastAs: types.TypeText}, false}, + {"CAST", "CAST(a AS TEXT)", expr.Cast{Expr: expr.Column("a"), CastAs: types.TypeText}, false}, {"NOT", "NOT 10", expr.Not(testutil.IntegerValue(10)), false}, {"NOT", "NOT NOT", nil, true}, {"NOT", "NOT NOT 10", expr.Not(expr.Not(testutil.IntegerValue(10))), false}, @@ -168,11 +121,10 @@ func TestParserExpr(t *testing.T) { {"NEXT VALUE FOR", "NEXT VALUE FOR 10", nil, true}, // functions - {"pk() function", "pk()", &functions.PK{}, false}, - {"count(expr) function", "count(a)", &functions.Count{Expr: testutil.ParsePath(t, "a")}, false}, + {"count(expr) function", "count(a)", &functions.Count{Expr: expr.Column("a")}, false}, {"count(*) function", "count(*)", functions.NewCount(expr.Wildcard{}), false}, {"count (*) function with spaces", "count (*)", functions.NewCount(expr.Wildcard{}), false}, - {"packaged function", "math.floor(1.2)", testutil.FunctionExpr(t, "math.floor", testutil.DoubleValue(1.2)), false}, + {"packaged function", "floor(1.2)", testutil.FunctionExpr(t, "floor", testutil.DoubleValue(1.2)), false}, } for _, test := range tests { @@ -190,61 +142,6 @@ func TestParserExpr(t *testing.T) { } } -func TestParsePath(t *testing.T) { - tests := []struct { - name string - s string - expected object.Path - fails bool - }{ - {"one fragment", `a`, object.Path{ - object.PathFragment{FieldName: "a"}, - }, false}, - {"one fragment with quotes", "` \"a\"`", object.Path{ - object.PathFragment{FieldName: " \"a\""}, - }, false}, - {"multiple fragments", `a.b[100].c[1][2]`, object.Path{ - object.PathFragment{FieldName: "a"}, - object.PathFragment{FieldName: "b"}, - object.PathFragment{ArrayIndex: 100}, - object.PathFragment{FieldName: "c"}, - object.PathFragment{ArrayIndex: 1}, - object.PathFragment{ArrayIndex: 2}, - }, false}, - {"multiple fragments with brackets", `a["b"][100].c[1][2]`, object.Path{ - object.PathFragment{FieldName: "a"}, - object.PathFragment{FieldName: "b"}, - object.PathFragment{ArrayIndex: 100}, - object.PathFragment{FieldName: "c"}, - object.PathFragment{ArrayIndex: 1}, - object.PathFragment{ArrayIndex: 2}, - }, false}, - {"with quotes", "`some ident`.` with`[5].` \"quotes`", object.Path{ - object.PathFragment{FieldName: "some ident"}, - object.PathFragment{FieldName: " with"}, - object.PathFragment{ArrayIndex: 5}, - object.PathFragment{FieldName: " \"quotes"}, - }, false}, - - {"negative index", `a.b[-100].c`, nil, true}, - {"with spaces", `a. b[100]. c`, nil, true}, - {"starting with array", `[10].a`, nil, true}, - {"starting with brackets", `['a']`, nil, true}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - vp, err := parser.ParsePath(test.s) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - require.EqualValues(t, test.expected, vp) - } - }) - } -} - func TestParserParams(t *testing.T) { tests := []struct { name string @@ -252,17 +149,17 @@ func TestParserParams(t *testing.T) { expected expr.Expr errored bool }{ - {"one positional", "age = ?", expr.Eq(testutil.ParsePath(t, "age"), expr.PositionalParam(1)), false}, + {"one positional", "age = ?", expr.Eq(expr.Column("age"), expr.PositionalParam(1)), false}, {"multiple positional", "age = ? AND age <= ?", expr.And( - expr.Eq(testutil.ParsePath(t, "age"), expr.PositionalParam(1)), - expr.Lte(testutil.ParsePath(t, "age"), expr.PositionalParam(2)), + expr.Eq(expr.Column("age"), expr.PositionalParam(1)), + expr.Lte(expr.Column("age"), expr.PositionalParam(2)), ), false}, - {"one named", "age = $age", expr.Eq(testutil.ParsePath(t, "age"), expr.NamedParam("age")), false}, + {"one named", "age = $age", expr.Eq(expr.Column("age"), expr.NamedParam("age")), false}, {"multiple named", "age = $foo OR age = $bar", expr.Or( - expr.Eq(testutil.ParsePath(t, "age"), expr.NamedParam("foo")), - expr.Eq(testutil.ParsePath(t, "age"), expr.NamedParam("bar")), + expr.Eq(expr.Column("age"), expr.NamedParam("foo")), + expr.Eq(expr.Column("age"), expr.NamedParam("bar")), ), false}, {"mixed", "age >= ? AND age > $foo OR age < ?", nil, true}, } diff --git a/internal/sql/parser/insert.go b/internal/sql/parser/insert.go index a9ac3ca11..0760f19d6 100644 --- a/internal/sql/parser/insert.go +++ b/internal/sql/parser/insert.go @@ -10,13 +10,13 @@ import ( "github.com/cockroachdb/errors" ) -// parseInsertStatement parses an insert string and returns a Statement AST object. +// parseInsertStatement parses an insert string and returns a Statement AST row. func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) { stmt := statement.NewInsertStatement() var err error // Parse "INSERT INTO". - if err := p.parseTokens(scanner.INSERT, scanner.INTO); err != nil { + if err := p.ParseTokens(scanner.INSERT, scanner.INTO); err != nil { return nil, err } @@ -29,7 +29,7 @@ func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) { } // Parse path list: (a, b, c) - stmt.Fields, err = p.parseFieldList() + stmt.Columns, err = p.parseSimpleColumnList() if err != nil { return nil, err } @@ -39,7 +39,7 @@ func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) { switch tok { case scanner.VALUES: // Parse VALUES (v1, v2, v3) - stmt.Values, err = p.parseValues(stmt.Fields) + stmt.Values, err = p.parseValues(stmt.Columns) if err != nil { return nil, err } @@ -67,9 +67,9 @@ func (p *Parser) parseInsertStatement() (*statement.InsertStmt, error) { return stmt, nil } -// parseFieldList parses a list of fields in the form: (path, path, ...), if exists. +// parseColumnList parses a list of columns in the form: (column, column, ...), if exists. // If the list is empty, it returns an error. -func (p *Parser) parseFieldList() ([]string, error) { +func (p *Parser) parseSimpleColumnList() ([]string, error) { // Parse ( token. if ok, err := p.parseOptional(scanner.LPAREN); !ok || err != nil { p.Unscan() @@ -77,141 +77,61 @@ func (p *Parser) parseFieldList() ([]string, error) { } // Parse path list. - var fields []string + var columns []string var err error - if fields, err = p.parseIdentList(); err != nil { + if columns, err = p.parseIdentList(); err != nil { return nil, err } // Parse required ) token. - if err := p.parseTokens(scanner.RPAREN); err != nil { + if err := p.ParseTokens(scanner.RPAREN); err != nil { return nil, err } - return fields, nil + return columns, nil } // parseValues parses the "VALUES" clause of the query, if it exists. func (p *Parser) parseValues(fields []string) ([]expr.Expr, error) { - if len(fields) > 0 { - return p.parseObjectsWithFields(fields) - } - - tok, pos, lit := p.ScanIgnoreWhitespace() - p.Unscan() - switch tok { - case scanner.LPAREN: - return p.parseObjectsWithFields(fields) - case scanner.LBRACKET, scanner.NAMEDPARAM, scanner.POSITIONALPARAM: - return p.parseLiteralDocOrParamList() - } + var rows []expr.Expr - return nil, newParseError(scanner.Tokstr(tok, lit), []string{"(", "[", "?", "$"}, pos) -} - -// parseExprListValues parses the "VALUES" clause of the query, if it exists. -func (p *Parser) parseObjectsWithFields(fields []string) ([]expr.Expr, error) { - var docs []expr.Expr - - // Parse first (required) value list. - doc, err := p.parseExprListWithFields(fields) + // Parse first (required) row. + r, err := p.parseRowExprList(fields) if err != nil { return nil, err } - docs = append(docs, doc) + rows = append(rows, r) - // Parse remaining (optional) values. + // Parse remaining (optional) rows. for { if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { p.Unscan() break } - doc, err := p.parseExprListWithFields(fields) + doc, err := p.parseRowExprList(fields) if err != nil { return nil, err } - docs = append(docs, doc) + rows = append(rows, doc) } - return docs, nil + return rows, nil } -func (p *Parser) parseExprListWithFields(fields []string) (*expr.KVPairs, error) { +func (p *Parser) parseRowExprList(fields []string) (expr.LiteralExprList, error) { list, err := p.parseExprList(scanner.LPAREN, scanner.RPAREN) if err != nil { return nil, err } - var pairs expr.KVPairs - pairs.Pairs = make([]expr.KVPair, len(list)) - - if len(fields) > 0 { - if len(fields) != len(list) { - return nil, fmt.Errorf("%d values for %d fields", len(list), len(fields)) - } - - for i := range list { - pairs.Pairs[i].K = fields[i] - pairs.Pairs[i].V = list[i] - } - } else { - for i := range list { - pairs.Pairs[i].V = list[i] - } - } - - return &pairs, nil -} - -// parseExprListValues parses the "VALUES" clause of the query, if it exists. -func (p *Parser) parseLiteralDocOrParamList() ([]expr.Expr, error) { - var docs []expr.Expr - - // Parse first (required) value list. - doc, err := p.parseParamOrObject() - if err != nil { - return nil, err + if len(fields) > 0 && len(fields) != len(list) { + return nil, fmt.Errorf("%d values for %d fields", len(list), len(fields)) } - docs = append(docs, doc) - - // Parse remaining (optional) values. - for { - if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { - p.Unscan() - break - } - - doc, err := p.parseParamOrObject() - if err != nil { - return nil, err - } - - docs = append(docs, doc) - } - - return docs, nil -} - -// parseParamOrObject parses either a parameter or an object. -func (p *Parser) parseParamOrObject() (expr.Expr, error) { - // Parse a param first - prm, err := p.parseParam() - if err != nil { - return nil, err - } - if prm != nil { - return prm, nil - } - - // If not a param, start over - p.Unscan() - - // Expect an object - return p.ParseObject() + return list, nil } func (p *Parser) parseOnConflictClause() (database.OnConflictAction, error) { diff --git a/internal/sql/parser/insert_test.go b/internal/sql/parser/insert_test.go index ba0007a9f..d4c484263 100644 --- a/internal/sql/parser/insert_test.go +++ b/internal/sql/parser/insert_test.go @@ -24,58 +24,15 @@ func TestParserInsert(t *testing.T) { expected *stream.Stream fails bool }{ - {"Objects", `INSERT INTO test VALUES {a: 1, "b": "foo", c: 'bar', d: 1 = 1, e: {f: "baz"}}`, - stream.New(rows.Emit( - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "a", V: testutil.IntegerValue(1)}, - {K: "b", V: testutil.TextValue("foo")}, - {K: "c", V: testutil.TextValue("bar")}, - {K: "d", V: testutil.BoolValue(true)}, - {K: "e", V: &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "f", V: testutil.TextValue("baz")}, - }}}, - }}, - )). - Pipe(table.Validate("test")). - Pipe(table.Insert("test")). - Pipe(stream.Discard()), - false}, - {"Objects / Multiple", `INSERT INTO test VALUES {"a": 'a', b: -2.3}, {a: 1, d: true}`, - stream.New(rows.Emit( - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("a")}, - {K: "b", V: testutil.DoubleValue(-2.3)}, - }}, - &expr.KVPairs{SelfReferenced: true, Pairs: []expr.KVPair{{K: "a", V: testutil.IntegerValue(1)}, {K: "d", V: testutil.BoolValue(true)}}}, - )). - Pipe(table.Validate("test")). - Pipe(table.Insert("test")). - Pipe(stream.Discard()), - false}, - {"Objects / Positional Param", "INSERT INTO test VALUES ?, ?", - stream.New(rows.Emit( - expr.PositionalParam(1), - expr.PositionalParam(2), - )). - Pipe(table.Validate("test")). - Pipe(table.Insert("test")). - Pipe(stream.Discard()), - false}, - {"Objects / Named Param", "INSERT INTO test VALUES $foo, $bar", - stream.New(rows.Emit( - expr.NamedParam("foo"), - expr.NamedParam("bar"), - )). - Pipe(table.Validate("test")). - Pipe(table.Insert("test")). - Pipe(stream.Discard()), - false}, {"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd')", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, )). Pipe(table.Validate("test")). Pipe(table.Insert("test")). @@ -85,14 +42,20 @@ func TestParserInsert(t *testing.T) { nil, true}, {"Values / Multiple", "INSERT INTO test (a, b) VALUES ('c', 'd'), ('e', 'f')", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("e")}, - {K: "b", V: testutil.TextValue("f")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("e"), + testutil.TextValue("f"), + }, + }, )). Pipe(table.Validate("test")). Pipe(table.Insert("test")). @@ -100,10 +63,13 @@ func TestParserInsert(t *testing.T) { false}, {"Values / Returning", "INSERT INTO test (a, b) VALUES ('c', 'd') RETURNING *, a, b as B, c", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, )). Pipe(table.Validate("test")). Pipe(table.Insert("test")). @@ -115,46 +81,62 @@ func TestParserInsert(t *testing.T) { nil, true}, {"Values / ON CONFLICT DO NOTHING", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT DO NOTHING RETURNING *", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, )). Pipe(table.Validate("test")). Pipe(stream.OnConflict(nil)). - Pipe(table.Insert("test")), + Pipe(table.Insert("test")). + Pipe(rows.Project(expr.Wildcard{})), false}, {"Values / ON CONFLICT IGNORE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT IGNORE RETURNING *", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, )).Pipe(table.Validate("test")). Pipe(stream.OnConflict(nil)). - Pipe(table.Insert("test")), + Pipe(table.Insert("test")). + Pipe(rows.Project(expr.Wildcard{})), false}, {"Values / ON CONFLICT DO REPLACE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT DO REPLACE RETURNING *", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, )). Pipe(table.Validate("test")). Pipe(stream.OnConflict(stream.New(table.Replace("test")))). - Pipe(table.Insert("test")), + Pipe(table.Insert("test")). + Pipe(rows.Project(expr.Wildcard{})), false}, {"Values / ON CONFLICT REPLACE", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT REPLACE RETURNING *", stream.New(rows.Emit( - &expr.KVPairs{Pairs: []expr.KVPair{ - {K: "a", V: testutil.TextValue("c")}, - {K: "b", V: testutil.TextValue("d")}, - }}, + expr.Row{ + Columns: []string{"a", "b"}, + Exprs: []expr.Expr{ + testutil.TextValue("c"), + testutil.TextValue("d"), + }, + }, )). Pipe(table.Validate("test")). Pipe(stream.OnConflict(stream.New(table.Replace("test")))). - Pipe(table.Insert("test")), + Pipe(table.Insert("test")). + Pipe(rows.Project(expr.Wildcard{})), false}, {"Values / ON CONFLICT BLA", "INSERT INTO test (a, b) VALUES ('c', 'd') ON CONFLICT BLA RETURNING *", nil, true}, @@ -162,6 +144,7 @@ func TestParserInsert(t *testing.T) { nil, true}, {"Select / Without fields", "INSERT INTO test SELECT * FROM foo", stream.New(table.Scan("foo")). + Pipe(rows.Project(expr.Wildcard{})). Pipe(table.Validate("test")). Pipe(table.Insert("test")). Pipe(stream.Discard()), @@ -175,6 +158,7 @@ func TestParserInsert(t *testing.T) { false}, {"Select / With fields", "INSERT INTO test (a, b) SELECT * FROM foo", stream.New(table.Scan("foo")). + Pipe(rows.Project(expr.Wildcard{})). Pipe(path.PathsRename("a", "b")). Pipe(table.Validate("test")). Pipe(table.Insert("test")). @@ -219,7 +203,7 @@ func TestParserInsert(t *testing.T) { t.Run(test.name, func(t *testing.T) { db := testutil.NewTestDB(t) - testutil.MustExec(t, db, nil, "CREATE TABLE test; CREATE TABLE foo;") + testutil.MustExec(t, db, nil, "CREATE TABLE test(a TEXT, b TEXT); CREATE TABLE foo(c TEXT, d TEXT);") q, err := parser.ParseQuery(test.s) if test.fails { diff --git a/internal/sql/parser/options.go b/internal/sql/parser/options.go deleted file mode 100644 index 2879601ea..000000000 --- a/internal/sql/parser/options.go +++ /dev/null @@ -1,17 +0,0 @@ -package parser - -import ( - "github.com/chaisql/chai/internal/expr/functions" -) - -// Options of the SQL parser. -type Options struct { - // A table of function packages. - Packages functions.Packages -} - -func defaultOptions() *Options { - return &Options{ - Packages: functions.DefaultPackages(), - } -} diff --git a/internal/sql/parser/order_by.go b/internal/sql/parser/order_by.go index 987b8e6da..42c2e2866 100644 --- a/internal/sql/parser/order_by.go +++ b/internal/sql/parser/order_by.go @@ -7,26 +7,26 @@ import ( "github.com/chaisql/chai/internal/sql/scanner" ) -func (p *Parser) parseOrderBy() (expr.Path, scanner.Token, error) { +func (p *Parser) parseOrderBy() (expr.Column, scanner.Token, error) { // parse ORDER token ok, err := p.parseOptional(scanner.ORDER, scanner.BY) if err != nil || !ok { - return nil, 0, err + return "", 0, err } - // parse path - path, err := p.parsePath() + // parse col + col, err := p.parseColumn() if err != nil { - return nil, 0, err + return "", 0, err } // parse optional ASC or DESC if tok, _, _ := p.ScanIgnoreWhitespace(); tok == scanner.ASC || tok == scanner.DESC { - return expr.Path(path), tok, nil + return col, tok, nil } p.Unscan() - return expr.Path(path), 0, nil + return col, 0, nil } func (p *Parser) parseLimit() (expr.Expr, error) { diff --git a/internal/sql/parser/parser.go b/internal/sql/parser/parser.go index 13b8112a5..b2ebd0d55 100644 --- a/internal/sql/parser/parser.go +++ b/internal/sql/parser/parser.go @@ -6,8 +6,6 @@ import ( "strings" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/scanner" @@ -20,21 +18,11 @@ type Parser struct { s *scanner.Scanner orderedParams int namedParams int - packagesTable functions.Packages } // NewParser returns a new instance of Parser. func NewParser(r io.Reader) *Parser { - return NewParserWithOptions(r, nil) -} - -// NewParserWithOptions returns a new instance of Parser using given Options. -func NewParserWithOptions(r io.Reader, opts *Options) *Parser { - if opts == nil { - opts = defaultOptions() - } - - return &Parser{s: scanner.NewScanner(r), packagesTable: opts.Packages} + return &Parser{s: scanner.NewScanner(r)} } // ParseQuery parses a query string and returns its AST representation. @@ -42,11 +30,6 @@ func ParseQuery(s string) (query.Query, error) { return NewParser(strings.NewReader(s)).ParseQuery() } -// ParsePath parses a path to a value in an object. -func ParsePath(s string) (object.Path, error) { - return NewParser(strings.NewReader(s)).parsePath() -} - // ParseExpr parses an expression. func ParseExpr(s string) (expr.Expr, error) { e, err := NewParser(strings.NewReader(s)).ParseExpr() @@ -177,24 +160,24 @@ func (p *Parser) parseCondition() (expr.Expr, error) { return expr, nil } -// parsePathList parses a list of paths in the form: (path, path, ...), if exists -func (p *Parser) parsePathList() ([]object.Path, tree.SortOrder, error) { +// parseColumnList parses a list of columns in the form: (path, path, ...), if exists +func (p *Parser) parseColumnList() ([]string, tree.SortOrder, error) { // Parse ( token. if ok, err := p.parseOptional(scanner.LPAREN); !ok || err != nil { return nil, 0, err } - var paths []object.Path + var columns []string var err error - var path object.Path + var col string var order tree.SortOrder - // Parse first (required) path. - if path, err = p.parsePath(); err != nil { + // Parse first (required) column. + if col, err = p.parseIdent(); err != nil { return nil, 0, err } - paths = append(paths, path) + columns = append(columns, col) // Parse optional ASC/DESC token. ok, err := p.parseOptional(scanner.DESC) @@ -211,7 +194,7 @@ func (p *Parser) parsePathList() ([]object.Path, tree.SortOrder, error) { } } - // Parse remaining (optional) paths. + // Parse remaining (optional) columns. i := 0 for { if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { @@ -219,12 +202,12 @@ func (p *Parser) parsePathList() ([]object.Path, tree.SortOrder, error) { break } - vp, err := p.parsePath() + c, err := p.parseIdent() if err != nil { return nil, 0, err } - paths = append(paths, vp) + columns = append(columns, c) i++ @@ -245,11 +228,11 @@ func (p *Parser) parsePathList() ([]object.Path, tree.SortOrder, error) { } // Parse required ) token. - if err := p.parseTokens(scanner.RPAREN); err != nil { + if err := p.ParseTokens(scanner.RPAREN); err != nil { return nil, 0, err } - return paths, order, nil + return columns, order, nil } // Scan returns the next token from the underlying scanner. @@ -271,9 +254,9 @@ func (p *Parser) Unscan() { p.s.Unscan() } -// parseTokens parses all the given tokens one after the other. +// ParseTokens parses all the given tokens one after the other. // It returns an error if one of the token is missing. -func (p *Parser) parseTokens(tokens ...scanner.Token) error { +func (p *Parser) ParseTokens(tokens ...scanner.Token) error { for _, t := range tokens { if tok, pos, lit := p.ScanIgnoreWhitespace(); tok != t { return newParseError(scanner.Tokstr(tok, lit), []string{t.String()}, pos) @@ -297,7 +280,7 @@ func (p *Parser) parseOptional(tokens ...scanner.Token) (bool, error) { return true, nil } - err := p.parseTokens(tokens[1:]...) + err := p.ParseTokens(tokens[1:]...) return err == nil, err } diff --git a/internal/sql/parser/reindex.go b/internal/sql/parser/reindex.go index 907240af0..3fc538c2c 100644 --- a/internal/sql/parser/reindex.go +++ b/internal/sql/parser/reindex.go @@ -10,7 +10,7 @@ func (p *Parser) parseReIndexStatement() (statement.Statement, error) { stmt := statement.NewReIndexStatement() // Parse "REINDEX". - if err := p.parseTokens(scanner.REINDEX); err != nil { + if err := p.ParseTokens(scanner.REINDEX); err != nil { return nil, err } diff --git a/internal/sql/parser/select.go b/internal/sql/parser/select.go index 398264a15..5977d0cc2 100644 --- a/internal/sql/parser/select.go +++ b/internal/sql/parser/select.go @@ -7,7 +7,7 @@ import ( "github.com/cockroachdb/errors" ) -// parseSelectStatement parses a select string and returns a Statement AST object. +// parseSelectStatement parses a select string and returns a Statement AST row. // This function assumes the SELECT token has already been consumed. func (p *Parser) parseSelectStatement() (*statement.SelectStmt, error) { stmt := statement.NewSelectStatement() @@ -76,7 +76,7 @@ func (p *Parser) parseSelectCore() (*statement.SelectCoreStmt, error) { var err error // Parse "SELECT". - if err := p.parseTokens(scanner.SELECT); err != nil { + if err := p.ParseTokens(scanner.SELECT); err != nil { return nil, err } diff --git a/internal/sql/parser/select_test.go b/internal/sql/parser/select_test.go index 172c2e74d..f795d1719 100644 --- a/internal/sql/parser/select_test.go +++ b/internal/sql/parser/select_test.go @@ -6,7 +6,6 @@ import ( "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" @@ -30,23 +29,16 @@ func TestParserSelect(t *testing.T) { stream.New(rows.Project(testutil.ParseNamedExpr(t, "1"))), true, false, }, - {"NoTableWithTuple", "SELECT (1, 2)", - stream.New(rows.Project(testutil.ParseNamedExpr(t, "[1, 2]"))), - true, false, - }, - {"NoTableWithBrackets", "SELECT [1, 2]", - stream.New(rows.Project(testutil.ParseNamedExpr(t, "[1, 2]"))), - true, false, - }, {"NoTableWithINOperator", "SELECT 1 in (1, 2), 3", stream.New(rows.Project( - testutil.ParseNamedExpr(t, "1 IN [1, 2]"), + testutil.ParseNamedExpr(t, "1 IN (1, 2)"), testutil.ParseNamedExpr(t, "3"), )), true, false, }, {"NoCond", "SELECT * FROM test", - stream.New(table.Scan("test")), + stream.New(table.Scan("test")).Pipe(rows.Project(expr.Wildcard{})), + true, false, }, {"Multiple Wildcards", "SELECT *, * FROM test", @@ -57,10 +49,6 @@ func TestParserSelect(t *testing.T) { stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "a"), testutil.ParseNamedExpr(t, "b"))), true, false, }, - {"WithFieldsWithQuotes", "SELECT `long \"path\"` FROM test", - stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "`long \"path\"`", "long \"path\""))), - true, false, - }, {"WithAlias", "SELECT a AS A, b FROM test", stream.New(table.Scan("test")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "a", "A"), testutil.ParseNamedExpr(t, "b"))), true, false, @@ -75,50 +63,57 @@ func TestParserSelect(t *testing.T) { }, {"WithCond", "SELECT * FROM test WHERE age = 10", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))), + Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Project(expr.Wildcard{})), true, false, }, - {"WithGroupBy", "SELECT a.b.c FROM test WHERE age = 10 GROUP BY a.b.c", + {"WithGroupBy", "SELECT a FROM test WHERE age = 10 GROUP BY a", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSort(parser.MustParseExpr("a.b.c"))). - Pipe(rows.GroupAggregate(parser.MustParseExpr("a.b.c"))). - Pipe(rows.Project(&expr.NamedExpr{ExprName: "a.b.c", Expr: expr.Path(object.NewPath("a.b.c"))})), + Pipe(rows.TempTreeSort(parser.MustParseExpr("a"))). + Pipe(rows.GroupAggregate(parser.MustParseExpr("a"))). + Pipe(rows.Project(&expr.NamedExpr{ExprName: "a", Expr: expr.Column("a")})), true, false, }, - {"WithOrderBy", "SELECT * FROM test WHERE age = 10 ORDER BY a.b.c", + {"WithOrderBy", "SELECT * FROM test WHERE age = 10 ORDER BY a", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSort(testutil.ParsePath(t, "a.b.c"))), + Pipe(rows.Project(expr.Wildcard{})). + Pipe(rows.TempTreeSort(expr.Column("a"))), true, false, }, - {"WithOrderBy ASC", "SELECT * FROM test WHERE age = 10 ORDER BY a.b.c ASC", + {"WithOrderBy ASC", "SELECT * FROM test WHERE age = 10 ORDER BY a ASC", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSort(testutil.ParsePath(t, "a.b.c"))), + Pipe(rows.Project(expr.Wildcard{})). + Pipe(rows.TempTreeSort(expr.Column("a"))), true, false, }, - {"WithOrderBy DESC", "SELECT * FROM test WHERE age = 10 ORDER BY a.b.c DESC", + {"WithOrderBy DESC", "SELECT * FROM test WHERE age = 10 ORDER BY a DESC", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(rows.TempTreeSortReverse(testutil.ParsePath(t, "a.b.c"))), + Pipe(rows.Project(expr.Wildcard{})). + Pipe(rows.TempTreeSortReverse(expr.Column("a"))), true, false, }, {"WithLimit", "SELECT * FROM test WHERE age = 10 LIMIT 20", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Project(expr.Wildcard{})). Pipe(rows.Take(parser.MustParseExpr("20"))), true, false, }, {"WithOffset", "SELECT * FROM test WHERE age = 10 OFFSET 20", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Project(expr.Wildcard{})). Pipe(rows.Skip(parser.MustParseExpr("20"))), true, false, }, {"WithLimitThenOffset", "SELECT * FROM test WHERE age = 10 LIMIT 10 OFFSET 20", stream.New(table.Scan("test")). Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Project(expr.Wildcard{})). Pipe(rows.Skip(parser.MustParseExpr("20"))). Pipe(rows.Take(parser.MustParseExpr("10"))), true, false, @@ -135,16 +130,18 @@ func TestParserSelect(t *testing.T) { false, false}, {"WithUnionAll", "SELECT * FROM test1 UNION ALL SELECT * FROM test2", stream.New(stream.Concat( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), + stream.New(table.Scan("test1")).Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")).Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, {"CondWithUnionAll", "SELECT * FROM test1 WHERE age = 10 UNION ALL SELECT * FROM test2", stream.New(stream.Concat( stream.New(table.Scan("test1")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))), - stream.New(table.Scan("test2")), + Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, @@ -162,45 +159,57 @@ func TestParserSelect(t *testing.T) { }, {"WithUnionAllAndOrderBy", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 ORDER BY a", stream.New(stream.Concat( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), - )).Pipe(rows.TempTreeSort(testutil.ParsePath(t, "a"))), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), + )).Pipe(rows.TempTreeSort(expr.Column("a"))), true, false, }, {"WithUnionAllAndLimit", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 LIMIT 10", stream.New(stream.Concat( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )).Pipe(rows.Take(parser.MustParseExpr("10"))), true, false, }, {"WithUnionAllAndOffset", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 OFFSET 20", stream.New(stream.Concat( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )).Pipe(rows.Skip(parser.MustParseExpr("20"))), true, false, }, {"WithUnionAllAndOrderByAndLimitAndOffset", "SELECT * FROM test1 UNION ALL SELECT * FROM test2 ORDER BY a LIMIT 10 OFFSET 20", stream.New(stream.Concat( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), - )).Pipe(rows.TempTreeSort(testutil.ParsePath(t, "a"))).Pipe(rows.Skip(parser.MustParseExpr("20"))).Pipe(rows.Take(parser.MustParseExpr("10"))), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), + )).Pipe(rows.TempTreeSort(expr.Column("a"))).Pipe(rows.Skip(parser.MustParseExpr("20"))).Pipe(rows.Take(parser.MustParseExpr("10"))), true, false, }, {"WithUnion", "SELECT * FROM test1 UNION SELECT * FROM test2", stream.New(stream.Union( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, {"CondWithUnion", "SELECT * FROM test1 WHERE age = 10 UNION SELECT * FROM test2", stream.New(stream.Union( stream.New(table.Scan("test1")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))), - stream.New(table.Scan("test2")), + Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, @@ -218,69 +227,91 @@ func TestParserSelect(t *testing.T) { }, {"WithUnionAndOrderBy", "SELECT * FROM test1 UNION SELECT * FROM test2 ORDER BY a", stream.New(stream.Union( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), - )).Pipe(rows.TempTreeSort(testutil.ParsePath(t, "a"))), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), + )).Pipe(rows.TempTreeSort(expr.Column("a"))), true, false, }, {"WithUnionAndLimit", "SELECT * FROM test1 UNION SELECT * FROM test2 LIMIT 10", stream.New(stream.Union( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )).Pipe(rows.Take(parser.MustParseExpr("10"))), true, false, }, {"WithUnionAndOffset", "SELECT * FROM test1 UNION SELECT * FROM test2 OFFSET 20", stream.New(stream.Union( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), )).Pipe(rows.Skip(parser.MustParseExpr("20"))), true, false, }, {"WithUnionAndOrderByAndLimitAndOffset", "SELECT * FROM test1 UNION SELECT * FROM test2 ORDER BY a LIMIT 10 OFFSET 20", stream.New(stream.Union( - stream.New(table.Scan("test1")), - stream.New(table.Scan("test2")), - )).Pipe(rows.TempTreeSort(testutil.ParsePath(t, "a"))).Pipe(rows.Skip(parser.MustParseExpr("20"))).Pipe(rows.Take(parser.MustParseExpr("10"))), + stream.New(table.Scan("test1")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("test2")). + Pipe(rows.Project(expr.Wildcard{})), + )).Pipe(rows.TempTreeSort(expr.Column("a"))).Pipe(rows.Skip(parser.MustParseExpr("20"))).Pipe(rows.Take(parser.MustParseExpr("10"))), true, false, }, {"WithMultipleCompoundOps/1", "SELECT * FROM a UNION ALL SELECT * FROM b UNION ALL SELECT * FROM c", stream.New(stream.Concat( - stream.New(table.Scan("a")), - stream.New(table.Scan("b")), - stream.New(table.Scan("c")), + stream.New(table.Scan("a")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("b")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("c")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, {"WithMultipleCompoundOps/2", "SELECT * FROM a UNION ALL SELECT * FROM b UNION SELECT * FROM c", stream.New(stream.Union( stream.New(stream.Concat( - stream.New(table.Scan("a")), - stream.New(table.Scan("b")), + stream.New(table.Scan("a")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("b")). + Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("c")), + stream.New(table.Scan("c")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, {"WithMultipleCompoundOps/2", "SELECT * FROM a UNION ALL SELECT * FROM b UNION ALL SELECT * FROM c UNION SELECT * FROM d", stream.New(stream.Union( stream.New(stream.Concat( - stream.New(table.Scan("a")), - stream.New(table.Scan("b")), - stream.New(table.Scan("c")), + stream.New(table.Scan("a")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("b")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("c")). + Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("d")), + stream.New(table.Scan("d")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, {"WithMultipleCompoundOps/3", "SELECT * FROM a UNION ALL SELECT * FROM b UNION SELECT * FROM c UNION SELECT * FROM d", stream.New(stream.Union( stream.New(stream.Concat( - stream.New(table.Scan("a")), - stream.New(table.Scan("b")), + stream.New(table.Scan("a")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("b")). + Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("c")), - stream.New(table.Scan("d")), + stream.New(table.Scan("c")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("d")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, @@ -288,12 +319,16 @@ func TestParserSelect(t *testing.T) { stream.New(stream.Concat( stream.New(stream.Union( stream.New(stream.Concat( - stream.New(table.Scan("a")), - stream.New(table.Scan("b")), + stream.New(table.Scan("a")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("b")). + Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("c")), + stream.New(table.Scan("c")). + Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("d")), + stream.New(table.Scan("d")). + Pipe(rows.Project(expr.Wildcard{})), )), true, false, }, @@ -301,10 +336,13 @@ func TestParserSelect(t *testing.T) { stream.New(stream.Concat( stream.New(stream.Union( stream.New(stream.Concat( - stream.New(table.Scan("a")), - stream.New(table.Scan("b")), + stream.New(table.Scan("a")). + Pipe(rows.Project(expr.Wildcard{})), + stream.New(table.Scan("b")). + Pipe(rows.Project(expr.Wildcard{})), )), - stream.New(table.Scan("c")), + stream.New(table.Scan("c")). + Pipe(rows.Project(expr.Wildcard{})), )), stream.New(table.Scan("d")).Pipe(rows.Project(testutil.ParseNamedExpr(t, "NEXT VALUE FOR foo"))), )), @@ -319,13 +357,13 @@ func TestParserSelect(t *testing.T) { db := testutil.NewTestDB(t) testutil.MustExec(t, db, nil, ` - CREATE TABLE test; - CREATE TABLE test1; - CREATE TABLE test2; - CREATE TABLE a; - CREATE TABLE b; - CREATE TABLE c; - CREATE TABLE d; + CREATE TABLE test(a TEXT, b TEXT, age int); + CREATE TABLE test1(age INT, a INT); + CREATE TABLE test2(age INT, a INT); + CREATE TABLE a(age INT, a INT); + CREATE TABLE b(age INT, a INT); + CREATE TABLE c(age INT, a INT); + CREATE TABLE d(age INT, a INT); `, ) @@ -346,6 +384,6 @@ func TestParserSelect(t *testing.T) { func BenchmarkSelect(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = parser.ParseQuery("SELECT a, b.c[100].d AS `foo` FROM `some table` WHERE d.e[100] >= 12 AND c.d IN ([1, true], [2, false]) GROUP BY d.e[0] LIMIT 10 + 10 OFFSET 20 - 20 ORDER BY d DESC") + _, _ = parser.ParseQuery("SELECT a, b AS `foo` FROM `some table` WHERE d.e[100] >= 12 AND c.d IN ([1, true], [2, false]) GROUP BY d.e[0] LIMIT 10 + 10 OFFSET 20 - 20 ORDER BY d DESC") } } diff --git a/internal/sql/parser/transaction.go b/internal/sql/parser/transaction.go index b90b5d43a..460fab2ed 100644 --- a/internal/sql/parser/transaction.go +++ b/internal/sql/parser/transaction.go @@ -9,7 +9,7 @@ import ( // parseBeginStatement parses a BEGIN statement. func (p *Parser) parseBeginStatement() (statement.Statement, error) { // Parse "BEGIN". - if err := p.parseTokens(scanner.BEGIN); err != nil { + if err := p.ParseTokens(scanner.BEGIN); err != nil { return nil, err } @@ -40,7 +40,7 @@ func (p *Parser) parseBeginStatement() (statement.Statement, error) { // parseRollbackStatement parses a ROLLBACK statement. func (p *Parser) parseRollbackStatement() (statement.Statement, error) { // Parse "ROLLBACK". - if err := p.parseTokens(scanner.ROLLBACK); err != nil { + if err := p.ParseTokens(scanner.ROLLBACK); err != nil { return nil, err } @@ -53,7 +53,7 @@ func (p *Parser) parseRollbackStatement() (statement.Statement, error) { // parseCommitStatement parses a COMMIT statement. func (p *Parser) parseCommitStatement() (statement.Statement, error) { // Parse "COMMIT". - if err := p.parseTokens(scanner.COMMIT); err != nil { + if err := p.ParseTokens(scanner.COMMIT); err != nil { return nil, err } // parse optional TRANSACTION token diff --git a/internal/sql/parser/update.go b/internal/sql/parser/update.go index 40eab0de8..f8ec3b3c1 100644 --- a/internal/sql/parser/update.go +++ b/internal/sql/parser/update.go @@ -6,13 +6,13 @@ import ( "github.com/cockroachdb/errors" ) -// parseUpdateStatement parses a update string and returns a Statement AST object. +// parseUpdateStatement parses a update string and returns a Statement AST row. func (p *Parser) parseUpdateStatement() (*statement.UpdateStmt, error) { stmt := statement.NewUpdateStatement() var err error // Parse "UPDATE". - if err := p.parseTokens(scanner.UPDATE); err != nil { + if err := p.ParseTokens(scanner.UPDATE); err != nil { return nil, err } @@ -24,15 +24,13 @@ func (p *Parser) parseUpdateStatement() (*statement.UpdateStmt, error) { return nil, pErr } - // Parse clause: SET or UNSET. + // Parse clause: SET. tok, pos, lit := p.ScanIgnoreWhitespace() switch tok { case scanner.SET: stmt.SetPairs, err = p.parseSetClause() - case scanner.UNSET: - stmt.UnsetFields, err = p.parseUnsetClause() default: - err = newParseError(scanner.Tokstr(tok, lit), []string{"SET", "UNSET"}, pos) + err = newParseError(scanner.Tokstr(tok, lit), []string{"SET"}, pos) } if err != nil { return nil, err @@ -62,8 +60,8 @@ func (p *Parser) parseSetClause() ([]statement.UpdateSetPair, error) { } } - // Scan the identifier for the path name. - path, err := p.parsePath() + // Scan the identifier for the col name. + col, err := p.parseColumn() if err != nil { pErr := errors.Unwrap(err).(*ParseError) pErr.Expected = []string{"path"} @@ -71,7 +69,7 @@ func (p *Parser) parseSetClause() ([]statement.UpdateSetPair, error) { } // Scan the eq sign - if err := p.parseTokens(scanner.EQ); err != nil { + if err := p.ParseTokens(scanner.EQ); err != nil { return nil, err } @@ -80,36 +78,10 @@ func (p *Parser) parseSetClause() ([]statement.UpdateSetPair, error) { if err != nil { return nil, err } - pairs = append(pairs, statement.UpdateSetPair{Path: path, E: expr}) + pairs = append(pairs, statement.UpdateSetPair{Column: col, E: expr}) firstPair = false } return pairs, nil } - -func (p *Parser) parseUnsetClause() ([]string, error) { - var fields []string - - firstField := true - for { - if !firstField { - // Scan for a comma. - tok, _, _ := p.ScanIgnoreWhitespace() - if tok != scanner.COMMA { - p.Unscan() - break - } - } - - // Scan the identifier for the path to unset. - lit, err := p.parseIdent() - if err != nil { - return nil, err - } - fields = append(fields, lit) - - firstField = false - } - return fields, nil -} diff --git a/internal/sql/parser/update_test.go b/internal/sql/parser/update_test.go index 20e2b77bc..d88a63602 100644 --- a/internal/sql/parser/update_test.go +++ b/internal/sql/parser/update_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/query" "github.com/chaisql/chai/internal/query/statement" "github.com/chaisql/chai/internal/sql/parser" @@ -26,68 +25,34 @@ func TestParserUpdate(t *testing.T) { }{ {"SET/No cond", "UPDATE test SET a = 1", stream.New(table.Scan("test")). - Pipe(path.Set(object.Path(testutil.ParsePath(t, "a")), testutil.IntegerValue(1))). + Pipe(path.Set("a", testutil.IntegerValue(1))). Pipe(table.Validate("test")). Pipe(table.Replace("test")). Pipe(stream.Discard()), false, }, - {"SET/With cond", "UPDATE test SET a = 1, b = 2 WHERE age = 10", + {"SET/With cond", "UPDATE test SET a = 1, b = 2 WHERE a = 10", stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(path.Set(object.Path(testutil.ParsePath(t, "a")), testutil.IntegerValue(1))). - Pipe(path.Set(object.Path(testutil.ParsePath(t, "b")), parser.MustParseExpr("2"))). + Pipe(rows.Filter(parser.MustParseExpr("a = 10"))). + Pipe(path.Set("a", testutil.IntegerValue(1))). + Pipe(path.Set("b", parser.MustParseExpr("2"))). Pipe(table.Validate("test")). Pipe(table.Replace("test")). Pipe(stream.Discard()), false, }, - {"SET/No cond path with backquotes", "UPDATE test SET ` some \"path\" ` = 1", - stream.New(table.Scan("test")). - Pipe(path.Set(object.Path(testutil.ParsePath(t, "` some \"path\" `")), testutil.IntegerValue(1))). - Pipe(table.Validate("test")). - Pipe(table.Replace("test")). - Pipe(stream.Discard()), - false, - }, - {"SET/No cond nested path", "UPDATE test SET a.b = 1", - stream.New(table.Scan("test")). - Pipe(path.Set(object.Path(testutil.ParsePath(t, "a.b")), testutil.IntegerValue(1))). - Pipe(table.Validate("test")). - Pipe(table.Replace("test")). - Pipe(stream.Discard()), - false, - }, - {"UNSET/No cond", "UPDATE test UNSET a", - stream.New(table.Scan("test")). - Pipe(path.Unset("a")). - Pipe(table.Validate("test")). - Pipe(table.Replace("test")). - Pipe(stream.Discard()), - false, - }, - {"UNSET/With cond", "UPDATE test UNSET a, b WHERE age = 10", - stream.New(table.Scan("test")). - Pipe(rows.Filter(parser.MustParseExpr("age = 10"))). - Pipe(path.Unset("a")). - Pipe(path.Unset("b")). - Pipe(table.Validate("test")). - Pipe(table.Replace("test")). - Pipe(stream.Discard()), - false, - }, - {"Trailing comma", "UPDATE test SET a = 1, WHERE age = 10", nil, true}, - {"No SET", "UPDATE test WHERE age = 10", nil, true}, - {"No pair", "UPDATE test SET WHERE age = 10", nil, true}, - {"query.Field only", "UPDATE test SET a WHERE age = 10", nil, true}, - {"No value", "UPDATE test SET a = WHERE age = 10", nil, true}, + {"Trailing comma", "UPDATE test SET a = 1, WHERE a = 10", nil, true}, + {"No SET", "UPDATE test WHERE a = 10", nil, true}, + {"No pair", "UPDATE test SET WHERE a = 10", nil, true}, + {"query.Field only", "UPDATE test SET a WHERE a = 10", nil, true}, + {"No value", "UPDATE test SET a = WHERE a = 10", nil, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { db := testutil.NewTestDB(t) - testutil.MustExec(t, db, nil, "CREATE TABLE test") + testutil.MustExec(t, db, nil, "CREATE TABLE test(a INT, b TEXT)") q, err := parser.ParseQuery(test.s) if test.errored { diff --git a/internal/sql/scanner/scanner.go b/internal/sql/scanner/scanner.go index 872459dc9..6134ce582 100644 --- a/internal/sql/scanner/scanner.go +++ b/internal/sql/scanner/scanner.go @@ -69,14 +69,6 @@ func (s *scanner) Scan() (tok Token, pos Pos, lit string) { s.r.unread() return s.scanNumber() } - if ch1 == '.' { - ch2, _ := s.r.read() - if ch2 == '.' { - return ELLIPSIS, pos, "..." - } - - return ILLEGAL, pos, "" - } s.r.unread() return DOT, pos, "" case '$': diff --git a/internal/sql/scanner/scanner_test.go b/internal/sql/scanner/scanner_test.go index 5b3b5ab6d..309a20e55 100644 --- a/internal/sql/scanner/scanner_test.go +++ b/internal/sql/scanner/scanner_test.go @@ -26,7 +26,6 @@ func TestScanner_Scan(t *testing.T) { {s: "\n\r", tok: WS, lit: "\n\n"}, {s: " \n\t \r\n\t", tok: WS, lit: " \n\t \n\t"}, {s: " foo", tok: WS, lit: " "}, - {s: "...", tok: ELLIPSIS, lit: "..."}, // Numeric operators {s: `+`, tok: ADD}, @@ -177,7 +176,6 @@ func TestScanner_Scan(t *testing.T) { {s: `TRANSACTION`, tok: TRANSACTION}, {s: `UPDATE`, tok: UPDATE}, {s: `UNION`, tok: UNION}, - {s: `UNSET`, tok: UNSET}, {s: `VALUE`, tok: VALUE}, {s: `VALUES`, tok: VALUES}, {s: `WITH`, tok: WITH}, @@ -186,7 +184,6 @@ func TestScanner_Scan(t *testing.T) { {s: `seLECT`, tok: SELECT}, // case insensitive // types - {s: "ANY", tok: TYPEANY}, {s: "BYTES", tok: TYPEBYTES}, {s: "BOOL", tok: TYPEBOOL}, {s: "BOOLEAN", tok: TYPEBOOLEAN}, @@ -194,7 +191,6 @@ func TestScanner_Scan(t *testing.T) { {s: "INTEGER", tok: TYPEINTEGER}, {s: "TEXT", tok: TYPETEXT}, {s: "TIMESTAMP", tok: TYPETIMESTAMP}, - {s: "OBJECT", tok: TYPEOBJECT}, } for i, tt := range tests { diff --git a/internal/sql/scanner/token.go b/internal/sql/scanner/token.go index b420b7345..4f10f576e 100644 --- a/internal/sql/scanner/token.go +++ b/internal/sql/scanner/token.go @@ -29,7 +29,6 @@ const ( NULL // NULL REGEX // Regular expressions BADREGEX // `.* - ELLIPSIS // ... literalEnd operatorBeg @@ -140,7 +139,6 @@ const ( TRANSACTION UNION UNIQUE - UNSET UPDATE VALUE VALUES @@ -149,8 +147,6 @@ const ( WRITE // Types - TYPEANY - TYPEARRAY TYPEBIGINT TYPEBLOB TYPEBOOL @@ -163,7 +159,6 @@ const ( TYPEINT8 TYPEINTEGER TYPEMEDIUMINT - TYPEOBJECT TYPEREAL TYPESMALLINT TYPETEXT @@ -289,7 +284,6 @@ var tokens = [...]string{ TRANSACTION: "TRANSACTION", UNION: "UNION", UNIQUE: "UNIQUE", - UNSET: "UNSET", UPDATE: "UPDATE", VALUE: "VALUE", VALUES: "VALUES", @@ -297,8 +291,6 @@ var tokens = [...]string{ WHERE: "WHERE", WRITE: "WRITE", - TYPEANY: "ANY", - TYPEARRAY: "ARRAY", TYPEBIGINT: "BIGINT", TYPEBLOB: "BLOB", TYPEBOOL: "BOOL", @@ -311,7 +303,6 @@ var tokens = [...]string{ TYPEINT8: "INT8", TYPEINTEGER: "INTEGER", TYPEMEDIUMINT: "MEDIUMINT", - TYPEOBJECT: "OBJECT", TYPEREAL: "REAL", TYPESMALLINT: "SMALLINT", TYPETEXT: "TEXT", diff --git a/internal/sqltests/ALTER_TABLE/add_field.sql b/internal/sqltests/ALTER_TABLE/add_column.sql similarity index 85% rename from internal/sqltests/ALTER_TABLE/add_field.sql rename to internal/sqltests/ALTER_TABLE/add_column.sql index 157089874..dba21422f 100644 --- a/internal/sqltests/ALTER_TABLE/add_field.sql +++ b/internal/sqltests/ALTER_TABLE/add_column.sql @@ -1,7 +1,7 @@ -- setup: CREATE TABLE test(a int); --- test: field constraints are updated +-- test: column constraints are updated INSERT INTO test VALUES (1), (2); ALTER TABLE test ADD COLUMN b int DEFAULT 0; SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; @@ -39,7 +39,7 @@ SELECT * FROM test; /* result: { "a": 1, - "b": 10 + "b": 10, } { "a": 2, @@ -53,10 +53,12 @@ ALTER TABLE test ADD COLUMN b int UNIQUE; SELECT * FROM test; /* result: { - "a": 1 + "a": 1, + "b": null } { - "a": 2 + "a": 2, + "b": null } */ @@ -78,13 +80,15 @@ ALTER TABLE test ADD COLUMN b int PRIMARY KEY; -- test: primary key: without data ALTER TABLE test ADD COLUMN b int PRIMARY KEY; INSERT INTO test VALUES (1, 10), (2, 20); -SELECT pk() FROM test; +SELECT a, b FROM test; /* result: { - "pk()": [10] + "a": 1, + "b": 10 } { - "pk()": [20] + "a": 2, + "b": 20 } */ @@ -113,32 +117,15 @@ SELECT * FROM test; } */ --- test: no type +-- test: bad syntax: no type INSERT INTO test VALUES (1), (2); ALTER TABLE test ADD COLUMN b; -INSERT INTO test VALUES (3, 30), (4, 'hello'); -SELECT * FROM test; -/* result: -{ - "a": 1 -} -{ - "a": 2 -} -{ - "a": 3, - "b": 30.0 -} -{ - "a": 4, - "b": "hello" -} -*/ +-- error: --- test: bad syntax: no field name +-- test: bad syntax: no column name ALTER TABLE test ADD COLUMN; -- error: --- test: bad syntax: missing FIELD keyword +-- test: bad syntax: missing column keyword ALTER TABLE test ADD a int; -- error: \ No newline at end of file diff --git a/internal/sqltests/CREATE_INDEX/base.sql b/internal/sqltests/CREATE_INDEX/base.sql index bafeb5544..dcff339d2 100644 --- a/internal/sqltests/CREATE_INDEX/base.sql +++ b/internal/sqltests/CREATE_INDEX/base.sql @@ -3,7 +3,7 @@ CREATE TABLE test (a int); -- test: named index CREATE INDEX test_a_idx ON test(a); -SELECT name, owner.table_name AS table_name, sql FROM __chai_catalog WHERE type = "index"; +SELECT name, owner_table_name AS table_name, sql FROM __chai_catalog WHERE type = "index"; /* result: { "name": "test_a_idx", @@ -14,7 +14,7 @@ SELECT name, owner.table_name AS table_name, sql FROM __chai_catalog WHERE type -- test: named unique index CREATE UNIQUE INDEX test_a_idx ON test(a); -SELECT name, owner.table_name AS table_name, sql FROM __chai_catalog WHERE type = "index"; +SELECT name, owner_table_name AS table_name, sql FROM __chai_catalog WHERE type = "index"; /* result: { "name": "test_a_idx", @@ -36,7 +36,7 @@ CREATE UNIQUE INDEX test_a_idx ON test(a); -- test: IF NOT EXISTS CREATE INDEX test_a_idx ON test(a); CREATE INDEX IF NOT EXISTS test_a_idx ON test(a); -SELECT name, owner.table_name AS table_name, sql FROM __chai_catalog WHERE type = "index"; +SELECT name, owner_table_name AS table_name, sql FROM __chai_catalog WHERE type = "index"; /* result: { "name": "test_a_idx", @@ -50,7 +50,7 @@ CREATE INDEX ON test(a); CREATE INDEX ON test(a); CREATE INDEX test_a_idx2 ON test(a); CREATE INDEX ON test(a); -SELECT name, owner.table_name AS table_name, sql FROM __chai_catalog WHERE type = "index" ORDER BY name; +SELECT name, owner_table_name AS table_name, sql FROM __chai_catalog WHERE type = "index" ORDER BY name; /* result: { "name": "test_a_idx", diff --git a/internal/sqltests/CREATE_INDEX/undeclared.sql b/internal/sqltests/CREATE_INDEX/undeclared.sql index 9f3469b7b..4cda4b79f 100644 --- a/internal/sqltests/CREATE_INDEX/undeclared.sql +++ b/internal/sqltests/CREATE_INDEX/undeclared.sql @@ -1,19 +1,14 @@ --- test: undeclared field +-- test: undeclared column CREATE TABLE test; CREATE INDEX test_a_idx ON test(a); -- error: --- test: undeclared field: IF NOT EXISTS +-- test: undeclared column: IF NOT EXISTS CREATE TABLE test; CREATE INDEX IF NOT EXISTS test_a_idx ON test(a); -- error: --- test: undeclared field: other fields +-- test: undeclared column: other columns CREATE TABLE test(b int); CREATE INDEX test_a_idx ON test(a); -- error: - --- test: undeclared field: variadic -CREATE TABLE test(b int, ...); -CREATE INDEX test_a_idx ON test(a); --- error: diff --git a/internal/sqltests/CREATE_SEQUENCE/base.sql b/internal/sqltests/CREATE_SEQUENCE/base.sql index a456c6fb2..760e22a21 100644 --- a/internal/sqltests/CREATE_SEQUENCE/base.sql +++ b/internal/sqltests/CREATE_SEQUENCE/base.sql @@ -1,6 +1,6 @@ -- test: no config CREATE SEQUENCE seq; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -11,7 +11,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: __chai_sequence table CREATE SEQUENCE seq; -SELECT * FROM __chai_sequence WHERE name = "seq"; +SELECT name FROM __chai_sequence WHERE name = "seq"; /* result: { "name": "seq" @@ -20,7 +20,7 @@ SELECT * FROM __chai_sequence WHERE name = "seq"; -- test: IF NOT EXISTS CREATE SEQUENCE IF NOT EXISTS seq; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -31,7 +31,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: AS TINYINT CREATE SEQUENCE seq AS TINYINT; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -46,7 +46,7 @@ CREATE SEQUENCE seq AS DOUBLE; -- test: INCREMENT 10 CREATE SEQUENCE seq INCREMENT 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -57,7 +57,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: INCREMENT BY 10 CREATE SEQUENCE seq INCREMENT BY 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -72,7 +72,7 @@ CREATE SEQUENCE seq INCREMENT BY 0; -- test: INCREMENT BY -10 CREATE SEQUENCE seq INCREMENT BY -10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -84,7 +84,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: NO MINVALUE CREATE SEQUENCE seq NO MINVALUE; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -95,7 +95,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: NO MAXVALUE CREATE SEQUENCE seq NO MAXVALUE; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -106,7 +106,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: NO CYCLE CREATE SEQUENCE seq NO CYCLE; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -121,7 +121,7 @@ CREATE SEQUENCE seq NO SUGAR; -- test: MINVALUE 10 CREATE SEQUENCE seq MINVALUE 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -136,7 +136,7 @@ CREATE SEQUENCE seq MINVALUE 'hello'; -- test: MAXVALUE 10 CREATE SEQUENCE seq MAXVALUE 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -151,7 +151,7 @@ CREATE SEQUENCE seq MAXVALUE 'hello'; -- test: START WITH 10 CREATE SEQUENCE seq START WITH 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -166,7 +166,7 @@ CREATE SEQUENCE seq START WITH 'hello'; -- test: START 10 CREATE SEQUENCE seq START 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -177,7 +177,7 @@ SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; -- test: CACHE 10 CREATE SEQUENCE seq CACHE 10; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", @@ -196,7 +196,7 @@ CREATE SEQUENCE seq CACHE -10; -- test: CACHE 10 CREATE SEQUENCE seq CYCLE; -SELECT * FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; +SELECT name, type, sql FROM __chai_catalog WHERE type = "sequence" AND name = "seq"; /* result: { "name": "seq", diff --git a/internal/sqltests/CREATE_TABLE/base.sql b/internal/sqltests/CREATE_TABLE/base.sql index 634766a19..ee3b3809d 100644 --- a/internal/sqltests/CREATE_TABLE/base.sql +++ b/internal/sqltests/CREATE_TABLE/base.sql @@ -1,36 +1,36 @@ -- test: basic -CREATE TABLE test; +CREATE TABLE test(a int); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (...)" + "sql": "CREATE TABLE test (a INTEGER)" } */ -- test: duplicate -CREATE TABLE test; -CREATE TABLE test; +CREATE TABLE test(a int); +CREATE TABLE test(a int); -- error: -- test: if not exists -CREATE TABLE test; -CREATE TABLE IF NOT EXISTS test; +CREATE TABLE test(a int); +CREATE TABLE IF NOT EXISTS test(b int); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (...)" + "sql": "CREATE TABLE test (a INTEGER)" } */ -- test: if not exists, twice -CREATE TABLE IF NOT EXISTS test; -CREATE TABLE IF NOT EXISTS test; +CREATE TABLE IF NOT EXISTS test(a int); +CREATE TABLE IF NOT EXISTS test(a int); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (...)" + "sql": "CREATE TABLE test (a INTEGER)" } */ diff --git a/internal/sqltests/CREATE_TABLE/check.sql b/internal/sqltests/CREATE_TABLE/check.sql index 77da47d94..e086a3c5e 100644 --- a/internal/sqltests/CREATE_TABLE/check.sql +++ b/internal/sqltests/CREATE_TABLE/check.sql @@ -1,23 +1,23 @@ --- test: as field constraint +-- test: as column constraint CREATE TABLE test ( - a CHECK(a > 10 AND a < 20) + a INT CHECK(a > 10 AND a < 20) ); SELECT name, type, sql FROM __chai_catalog WHERE name = "test"; /* result: { name: "test", type: "table", - sql: "CREATE TABLE test (a ANY, CONSTRAINT test_check CHECK (a > 10 AND a < 20))" + sql: "CREATE TABLE test (a INTEGER, CONSTRAINT test_check CHECK (a > 10 AND a < 20))" } */ --- test: as field constraint: undeclared field +-- test: as column constraint: undeclared column CREATE TABLE test ( - a CHECK(b > 10) + a INT CHECK(b > 10) ); -- error: --- test: as field constraint, with other constraints +-- test: as column constraint, with other constraints CREATE TABLE test ( a INT CHECK (a > 10) DEFAULT 100 NOT NULL PRIMARY KEY ); @@ -30,13 +30,13 @@ SELECT name, type, sql FROM __chai_catalog WHERE name = "test"; } */ --- test: as field constraint, no parentheses +-- test: as column constraint, no parentheses CREATE TABLE test ( a INT CHECK a > 10 ); -- error: --- test: as field constraint, incompatible default value +-- test: as column constraint, incompatible default value CREATE TABLE test ( a INT CHECK (a > 10) DEFAULT 0 ); @@ -49,7 +49,7 @@ SELECT name, type, sql FROM __chai_catalog WHERE name = "test"; } */ --- test: as field constraint, reference other fields +-- test: as column constraint, reference other columns CREATE TABLE test ( a INT CHECK (a > 10 AND b < 10), b INT diff --git a/internal/sqltests/CREATE_TABLE/constraints.sql b/internal/sqltests/CREATE_TABLE/constraints.sql index accd31b90..c671859de 100644 --- a/internal/sqltests/CREATE_TABLE/constraints.sql +++ b/internal/sqltests/CREATE_TABLE/constraints.sql @@ -1,13 +1,3 @@ --- test: no constraint -CREATE TABLE test(a, b, c); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ANY, b ANY, c ANY)" -} -*/ - -- test: type CREATE TABLE test(a INTEGER); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; @@ -19,42 +9,42 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; */ -- test: NOT NULL -CREATE TABLE test(a NOT NULL); +CREATE TABLE test(a INT NOT NULL); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (a ANY NOT NULL)" + "sql": "CREATE TABLE test (a INTEGER NOT NULL)" } */ -- test: default -CREATE TABLE test(a DEFAULT 10); +CREATE TABLE test(a INT DEFAULT 10); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (a ANY DEFAULT 10)" + "sql": "CREATE TABLE test (a INTEGER DEFAULT 10)" } */ -- test: unique -CREATE TABLE test(a UNIQUE); +CREATE TABLE test(a INT UNIQUE); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (a ANY, CONSTRAINT test_a_unique UNIQUE (a))" + "sql": "CREATE TABLE test (a INTEGER, CONSTRAINT test_a_unique UNIQUE (a))" } */ -- test: check -CREATE TABLE test(a CHECK(a > 10)); +CREATE TABLE test(a INT CHECK(a > 10)); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (a ANY, CONSTRAINT test_check CHECK (a > 10))" + "sql": "CREATE TABLE test (a INTEGER, CONSTRAINT test_check CHECK (a > 10))" } */ diff --git a/internal/sqltests/CREATE_TABLE/default.sql b/internal/sqltests/CREATE_TABLE/default.sql index 92b792ef5..6070900c4 100644 --- a/internal/sqltests/CREATE_TABLE/default.sql +++ b/internal/sqltests/CREATE_TABLE/default.sql @@ -1,13 +1,3 @@ --- test: basic -CREATE TABLE test(a DEFAULT 10); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ANY DEFAULT 10)" -} -*/ - -- test: same type CREATE TABLE test(a INT DEFAULT 10); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; @@ -58,12 +48,3 @@ CREATE TABLE test(a BLOB DEFAULT 1 AND 1); CREATE TABLE test(a BLOB DEFAULT b); -- error: --- test: nested doc -CREATE TABLE test(a (b INT DEFAULT 10)); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a (b INTEGER DEFAULT 10))" -} -*/ diff --git a/internal/sqltests/CREATE_TABLE/not_null.sql b/internal/sqltests/CREATE_TABLE/not_null.sql index 75cd16227..b94efa21a 100644 --- a/internal/sqltests/CREATE_TABLE/not_null.sql +++ b/internal/sqltests/CREATE_TABLE/not_null.sql @@ -1,15 +1,5 @@ -- test: basic -CREATE TABLE test(a NOT NULL); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ANY NOT NULL)" -} -*/ - --- test: with type -CREATE TABLE test(a INT NOT NULL); +CREATE TABLE test(a INTEGER NOT NULL); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { diff --git a/internal/sqltests/CREATE_TABLE/primary_key.sql b/internal/sqltests/CREATE_TABLE/primary_key.sql index 8943419a6..fe2fa70f4 100644 --- a/internal/sqltests/CREATE_TABLE/primary_key.sql +++ b/internal/sqltests/CREATE_TABLE/primary_key.sql @@ -1,15 +1,5 @@ -- test: basic -CREATE TABLE test(a PRIMARY KEY); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ANY NOT NULL, CONSTRAINT test_pk PRIMARY KEY (a))" -} -*/ - --- test: with type -CREATE TABLE test(a INT PRIMARY KEY); +CREATE TABLE test(a INTEGER PRIMARY KEY); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { @@ -46,7 +36,7 @@ CREATE TABLE test(a INT PRIMARY KEY PRIMARY KEY); CREATE TABLE test(a INT PRIMARY KEY, b INT PRIMARY KEY); -- error: --- test: table constraint: one field +-- test: table constraint: one column CREATE TABLE test(a INT, PRIMARY KEY(a)); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: @@ -56,7 +46,7 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; } */ --- test: table constraint: multiple fields +-- test: table constraint: multiple columns CREATE TABLE test(a INT, b INT, PRIMARY KEY(a, b)); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: @@ -66,36 +56,25 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; } */ --- test: table constraint: nested fields -CREATE TABLE test(a (b INT), PRIMARY KEY(a.b)); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a (b INTEGER NOT NULL), CONSTRAINT test_pk PRIMARY KEY (a.b))" -} -*/ - - --- test: table constraint: multiple fields: with order -CREATE TABLE test(a INT, b INT, c (d INT), PRIMARY KEY(a DESC, b, c.d ASC)); +-- test: table constraint: multiple columns: with order +CREATE TABLE test(a INT, b INT, c INT, PRIMARY KEY(a DESC, b, c ASC)); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (a INTEGER NOT NULL, b INTEGER NOT NULL, c (d INTEGER NOT NULL), CONSTRAINT test_pk PRIMARY KEY (a DESC, b, c.d))" + "sql": "CREATE TABLE test (a INTEGER NOT NULL, b INTEGER NOT NULL, c INTEGER NOT NULL, CONSTRAINT test_pk PRIMARY KEY (a DESC, b, c))" } */ --- test: table constraint: undeclared fields +-- test: table constraint: undeclared columns CREATE TABLE test(a INT, b INT, PRIMARY KEY(a, b, c)); -- error: --- test: table constraint: same field twice +-- test: table constraint: same column twice CREATE TABLE test(a INT, b INT, PRIMARY KEY(a, a)); -- error: --- test: table constraint: same field twice, field constraint + table constraint +-- test: table constraint: same column twice, column constraint + table constraint CREATE TABLE test(a INT PRIMARY KEY, b INT, PRIMARY KEY(a)); -- error: diff --git a/internal/sqltests/CREATE_TABLE/types.sql b/internal/sqltests/CREATE_TABLE/types.sql index ca02246aa..6410e311c 100644 --- a/internal/sqltests/CREATE_TABLE/types.sql +++ b/internal/sqltests/CREATE_TABLE/types.sql @@ -8,6 +8,16 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; } */ +-- test: BIGINT +CREATE TABLE test (a BIGINT); +SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; +/* result: +{ + "name": "test", + "sql": "CREATE TABLE test (a BIGINT)" +} +*/ + -- test: DOUBLE CREATE TABLE test (a DOUBLE); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; @@ -48,36 +58,6 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; } */ --- test: ARRAY -CREATE TABLE test (a ARRAY); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ARRAY)" -} -*/ - --- test: OBJECT -CREATE TABLE test (a OBJECT); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a (...))" -} -*/ - --- test: ANY -CREATE TABLE test (a ANY); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ANY)" -} -*/ - -- test: duplicate type CREATE TABLE test (a INT, a TEXT); -- error: @@ -102,16 +82,6 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; } */ --- test: INT ALIAS: BIGINT -CREATE TABLE test (a BIGINT); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a INTEGER)" -} -*/ - -- test: INT ALIAS: mediumint CREATE TABLE test (a mediumint); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; @@ -132,13 +102,13 @@ SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; } */ --- test: INT ALIAS: INT8 +-- test: BIGINT ALIAS: INT8 CREATE TABLE test (a int8); SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; /* result: { "name": "test", - "sql": "CREATE TABLE test (a INTEGER)" + "sql": "CREATE TABLE test (a BIGINT)" } */ diff --git a/internal/sqltests/CREATE_TABLE/types_document.sql b/internal/sqltests/CREATE_TABLE/types_document.sql deleted file mode 100644 index 3a7a043d6..000000000 --- a/internal/sqltests/CREATE_TABLE/types_document.sql +++ /dev/null @@ -1,29 +0,0 @@ --- test: no keyword -CREATE TABLE test (a (b int)); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a (b INTEGER))" -} -*/ - --- test: with keyword -CREATE TABLE test (a OBJECT (b int)); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a (b INTEGER))" -} -*/ - --- test: with ellipsis -CREATE TABLE test (a OBJECT (b int, ...)); -SELECT name, sql FROM __chai_catalog WHERE type = "table" AND name = "test"; -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a (b INTEGER, ...))" -} -*/ diff --git a/internal/sqltests/CREATE_TABLE/unique.sql b/internal/sqltests/CREATE_TABLE/unique.sql index 385a2b818..e964a24bf 100644 --- a/internal/sqltests/CREATE_TABLE/unique.sql +++ b/internal/sqltests/CREATE_TABLE/unique.sql @@ -1,22 +1,3 @@ --- test: ANY -CREATE TABLE test(a UNIQUE); -SELECT name, sql -FROM __chai_catalog -WHERE - (type = "table" AND name = "test") - OR - (type = "index" AND name = "test_a_idx"); -/* result: -{ - "name": "test", - "sql": "CREATE TABLE test (a ANY, CONSTRAINT test_a_unique UNIQUE (a))" -} -{ - "name": "test_a_idx", - "sql": "CREATE UNIQUE INDEX test_a_idx ON test (a)" -} -*/ - -- test: with type CREATE TABLE test(a INT UNIQUE); SELECT name, sql @@ -43,7 +24,7 @@ FROM __chai_catalog WHERE (type = "table" AND name = "test") OR - (type = "index" AND owner.table_name = "test"); + (type = "index" AND owner_table_name = "test"); /* result: { "name": "test", @@ -59,7 +40,7 @@ WHERE } */ --- test: table constraint: one field +-- test: table constraint: one column CREATE TABLE test(a INT, UNIQUE(a)); SELECT name, sql FROM __chai_catalog @@ -78,7 +59,7 @@ WHERE } */ --- test: table constraint: multiple fields +-- test: table constraint: multiple columns CREATE TABLE test(a INT, b INT, UNIQUE(a, b)); SELECT name, sql FROM __chai_catalog @@ -97,7 +78,7 @@ WHERE } */ --- test: table constraint: multiple fields with order +-- test: table constraint: multiple columns with order CREATE TABLE test(a INT, b INT, c INT, UNIQUE(a DESC, b ASC, c)); SELECT name, sql FROM __chai_catalog @@ -116,30 +97,30 @@ WHERE } */ --- test: table constraint: undeclared field +-- test: table constraint: undeclared column CREATE TABLE test(a INT, UNIQUE(b)); -- error: --- test: table constraint: undeclared fields +-- test: table constraint: undeclared columns CREATE TABLE test(a INT, b INT, UNIQUE(a, b, c)); -- error: --- test: table constraint: same field twice +-- test: table constraint: same column twice CREATE TABLE test(a INT, b INT, UNIQUE(a, a)); -- error: --- test: table constraint: same field twice, field constraint + table constraint +-- test: table constraint: same column twice, column constraint + table constraint CREATE TABLE test(a INT UNIQUE, b INT, UNIQUE(a)); -- error: --- test: table constraint: different fields +-- test: table constraint: different columns CREATE TABLE test(a INT UNIQUE, b INT, UNIQUE(b)); SELECT name, sql FROM __chai_catalog WHERE (type = "table" AND name = "test") OR - (type = "index" AND owner.table_name = "test"); + (type = "index" AND owner_table_name = "test"); /* result: { "name": "test", diff --git a/internal/sqltests/INSERT/check.sql b/internal/sqltests/INSERT/check.sql index cde0df478..a19f75ad7 100644 --- a/internal/sqltests/INSERT/check.sql +++ b/internal/sqltests/INSERT/check.sql @@ -23,10 +23,10 @@ SELECT * FROM test; } */ --- test: non-boolean check constraint, non-numeric result +-- test: non-boolean check constraint CREATE TABLE test (a text CHECK("hello")); INSERT INTO test (a) VALUES ("hello"); --- error: +-- error: row violates check constraint "test_check" -- test: non-boolean check constraint, NULL CREATE TABLE test (a text CHECK(NULL)); @@ -39,46 +39,46 @@ SELECT * FROM test; */ /* -Field types: These tests check the behavior of the check constraint depending -on the type of the field +Column types: These tests check the behavior of the check constraint depending +on the type of the column */ --- test: no type constraint, valid double -CREATE TABLE test (a CHECK(a > 10)); +-- test: valid int +CREATE TABLE test (a INT CHECK(a > 10)); INSERT INTO test (a) VALUES (11); SELECT * FROM test; /* result: { - a: 11.0 + a: 11 } */ --- test: no type constraint, invalid double -CREATE TABLE test (a CHECK(a > 10)); +-- test: invalid int +CREATE TABLE test (a INT CHECK(a > 10)); INSERT INTO test (a) VALUES (1); -- error: row violates check constraint "test_check" --- test: no type constraint, multiple checks, invalid double -CREATE TABLE test (a CHECK(a > 10), CHECK(a < 20)); +-- test: multiple checks, invalid int +CREATE TABLE test (a INT CHECK(a > 10), CHECK(a < 20)); INSERT INTO test (a) VALUES (40); -- error: row violates check constraint "test_check1" --- test: no type constraint, text -CREATE TABLE test (a CHECK(a > 10)); +-- test: text +CREATE TABLE test (a INT CHECK(a > 10)); INSERT INTO test (a) VALUES ("hello"); --- error: row violates check constraint "test_check" +-- error: cannot cast "hello" as integer: strconv.ParseInt: parsing "hello": invalid syntax --- test: no type constraint, null -CREATE TABLE test (a CHECK(a > 10), ...); +-- test: null +CREATE TABLE test (a INT CHECK(a > 10), b int); INSERT INTO test (b) VALUES (10); -SELECT * FROM test; +SELECT b FROM test; /* result: { - b: 10.0 + b: 10 } */ --- test: int type constraint, double +-- test: double CREATE TABLE test (a int CHECK(a > 10)); INSERT INTO test (a) VALUES (15.2); SELECT * FROM test; @@ -87,4 +87,3 @@ SELECT * FROM test; a: 15 } */ - diff --git a/internal/sqltests/INSERT/document.sql b/internal/sqltests/INSERT/document.sql deleted file mode 100644 index 1c9c1e346..000000000 --- a/internal/sqltests/INSERT/document.sql +++ /dev/null @@ -1,62 +0,0 @@ --- test: document -CREATE TABLE test (a TEXT, b DOUBLE, c BOOLEAN); -INSERT INTO test VALUES {a: 'a', b: 2.3, c: 1 = 1}; -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b": 2.3, - "c": true -} -*/ - --- test: document, array -CREATE TABLE test (a ARRAY); -INSERT INTO test VALUES {a: [1, 2, 3]}; -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": [ - 1.0, - 2.0, - 3.0 - ] -} -*/ - --- test: document, strings -CREATE TABLE test (a TEXT, b DOUBLE); -INSERT INTO test VALUES {'a': 'a', b: 2.3}; -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b": 2.3 -} -*/ - --- test: document, double quotes -CREATE TABLE test (a TEXT); -INSERT INTO test VALUES {"a": "b"}; -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "b" -} -*/ - --- test: document, references to other field -CREATE TABLE test (a INT, b INT); -INSERT INTO test VALUES {a: 400, b: a * 4}; -SELECT pk(), * FROM test; -/* result: -{ - "pk()":[1], - "a":400, - "b":1600 -} -*/ diff --git a/internal/sqltests/INSERT/insert_select.sql b/internal/sqltests/INSERT/insert_select.sql index 549405581..3c3630b8f 100644 --- a/internal/sqltests/INSERT/insert_select.sql +++ b/internal/sqltests/INSERT/insert_select.sql @@ -1,52 +1,76 @@ -- setup: -CREATE TABLE foo; -CREATE TABLE bar; +CREATE TABLE foo(a INT, b INT, c INT, d INT, e INT); +CREATE TABLE bar(a INT, b INT); INSERT INTO bar (a, b) VALUES (1, 10); -- test: same table INSERT INTO foo SELECT * FROM foo; -- error: --- test: No fields / No projection +-- test: No columns / No projection INSERT INTO foo SELECT * FROM bar; -SELECT pk(), * FROM foo; +SELECT * FROM foo; /* result: -{"pk()": [1], "a":1.0, "b":10.0} +{ + "a":1, + "b":10, + "c":null, + "d":null, + "e":null +} */ --- test: No fields / Projection +-- test: No columns / Projection INSERT INTO foo SELECT a FROM bar; -SELECT pk(), * FROM foo; +SELECT * FROM foo; /* result: -{"pk()": [1], "a":1.0} +{ + "a":1, + "b":null, + "c":null, + "d":null, + "e":null +} */ --- test: With fields / No Projection +-- test: With columns / No Projection INSERT INTO foo (a, b) SELECT * FROM bar; -SELECT pk(), * FROM foo; +SELECT * FROM foo; /* result: -{"pk()": [1], "a":1.0, "b":10.0} +{ + "a":1, + "b":10, + "c":null, + "d":null, + "e":null +} */ --- test: With fields / Projection +-- test: With columns / Projection INSERT INTO foo (c, d) SELECT a, b FROM bar; -SELECT pk(), * FROM foo; +SELECT * FROM foo; /* result: -{"pk()": [1], "c":1.0, "d":10.0} +{ + "a":null, + "b":null, + "c":1, + "d":10, + "e":null +} */ --- test: Too many fields / No Projection +-- test: Too many columns / No Projection INSERT INTO foo (c) SELECT * FROM bar; -- error: --- test: Too many fields / Projection +-- test: Too many columns / Projection INSERT INTO foo (c, d) SELECT a, b, c FROM bar; -- error: --- test: Too few fields / No Projection +-- test: Too few columns / No Projection INSERT INTO foo (c, d, e) SELECT * FROM bar; -- error: --- test: Too few fields / Projection -INSERT INTO foo (c, d) SELECT a FROM bar`; +-- test: Too few columns / Projection +INSERT INTO foo (c, d) SELECT a FROM bar; -- error: diff --git a/internal/sqltests/INSERT/misc.sql b/internal/sqltests/INSERT/misc.sql index e0daa25ba..1e9888a2b 100644 --- a/internal/sqltests/INSERT/misc.sql +++ b/internal/sqltests/INSERT/misc.sql @@ -1,5 +1,5 @@ -- test: read-only tables -INSERT INTO __chai_catalog VALUES {a: 400, b: a * 4}; +INSERT INTO __chai_catalog (name, namespace) VALUES ('foo', 100); -- error: cannot write to read-only table -- test: insert with primary keys @@ -13,31 +13,13 @@ INSERT INTO testpk (bar, foo) VALUES (1, 2); INSERT INTO testpk (bar, foo) VALUES (1, 2); -- error: --- test: insert with shadowing -CREATE TABLE test (`pk()` INT); -INSERT INTO test (`pk()`) VALUES (10); -SELECT pk() AS pk, `pk()` from test; -/* result: -{ - "pk": [1], - "pk()": 10 -} -*/ - -- test: insert with types constraints CREATE TABLE test_tc( b bool, db double, - i integer, bb blob, byt bytes, - t text, a array, d object + i bigint, bb blob, byt bytes, + t text ); - -INSERT INTO test_tc -VALUES { - i: 10000000000, db: 21.21, b: true, - bb: "YmxvYlZhbHVlCg==", byt: "Ynl0ZXNWYWx1ZQ==", - t: "text", a: [1, "foo", true], d: {"foo": "bar"} -}; - +INSERT INTO test_tc (i, db, b, bb, byt, t) VALUES (10000000000, 21.21, true, "YmxvYlZhbHVlCg==", "Ynl0ZXNWYWx1ZQ==", "text"); SELECT * FROM test_tc; /* result: { @@ -46,15 +28,7 @@ SELECT * FROM test_tc; "i": 10000000000, "bb": CAST("YmxvYlZhbHVlCg==" AS BLOB), "byt": CAST("Ynl0ZXNWYWx1ZQ==" AS BYTES), - "t": "text", - "a": [ - 1.0, - "foo", - true - ], - "d": { - "foo": "bar" - } + "t": "text" } */ @@ -120,33 +94,33 @@ SELECT * FROM test_oc; */ -- test: insert with on conflict do replace, pk -CREATE TABLE test_oc(a INTEGER PRIMARY KEY, ...); +CREATE TABLE test_oc(a INTEGER PRIMARY KEY, b INTEGER, c INTEGER); INSERT INTO test_oc (a, b, c) VALUES (1, 1, 1); INSERT INTO test_oc (a, b, c) VALUES (1, 2, 3) ON CONFLICT DO REPLACE; SELECT * FROM test_oc; /* result: { a: 1, - b: 2.0, - c: 3.0 + b: 2, + c: 3 } */ -- test: insert with on conflict do replace, unique -CREATE TABLE test_oc(a INTEGER UNIQUE, ...); +CREATE TABLE test_oc(a INTEGER UNIQUE, b INTEGER, c INTEGER); INSERT INTO test_oc (a, b, c) VALUES (1, 1, 1); INSERT INTO test_oc (a, b, c) VALUES (1, 2, 3) ON CONFLICT DO REPLACE; SELECT * FROM test_oc; /* result: { a: 1, - b: 2.0, - c: 3.0 + b: 2, + c: 3 } */ -- test: insert with on conflict do replace, not null -CREATE TABLE test_oc(a INTEGER NOT NULL, ...); +CREATE TABLE test_oc(a INTEGER NOT NULL, b INTEGER, c INTEGER); INSERT INTO test_oc (b, c) VALUES (1, 1) ON CONFLICT DO REPLACE; -- error: @@ -168,25 +142,11 @@ SELECT * FROM test_oc; } */ --- test: default on nested fields -CREATE TABLE test_df (a (b TEXT DEFAULT "foo")); -INSERT INTO test_df VALUES {}; -SELECT * FROM test_df; -/* result: -{ -} -*/ - --- test: duplicate field names: root +-- test: duplicate column names: root CREATE TABLE test_df; INSERT INTO test_df(a, a) VALUES (1, 10); -- error: --- test: duplicate field names: nested -CREATE TABLE test_df; -insert into test_df(a) values ({b: 1, b: 10}); --- error: - -- test: inserts must be silent CREATE TABLE test (a int); INSERT INTO test VALUES (1); @@ -197,5 +157,5 @@ INSERT INTO test VALUES (1); CREATE TABLE test (a int); EXPLAIN INSERT INTO test (a) VALUES (1); /* result: -{plan: "rows.Emit({a: 1}) | table.Validate(\"test\") | table.Insert(\"test\") | discard()"} +{plan: "rows.Emit((1)) | table.Validate(\"test\") | table.Insert(\"test\") | discard()"} */ \ No newline at end of file diff --git a/internal/sqltests/INSERT/not_null.sql b/internal/sqltests/INSERT/not_null.sql index 7085381a1..c8cbad47e 100644 --- a/internal/sqltests/INSERT/not_null.sql +++ b/internal/sqltests/INSERT/not_null.sql @@ -13,7 +13,7 @@ CREATE TABLE test (a INT NOT NULL, b INT); INSERT INTO test (a, b) VALUES (NULL, 1); -- error: --- test: with missing field and default +-- test: with missing column and default CREATE TABLE test (a INT NOT NULL DEFAULT 10, b INT); INSERT INTO test (b) VALUES (1); SELECT a, b FROM test; diff --git a/internal/sqltests/INSERT/primary_key.sql b/internal/sqltests/INSERT/primary_key.sql index 44983a5c7..c893bb0ee 100644 --- a/internal/sqltests/INSERT/primary_key.sql +++ b/internal/sqltests/INSERT/primary_key.sql @@ -1,37 +1,3 @@ --- test: Should generate a key by default -CREATE TABLE test (a TEXT); -INSERT INTO test (a) VALUES ("foo"), ("bar"); -SELECT pk(), a FROM test; -/* result: -{ - "pk()": [1], - "a": "foo" -} -{ - "pk()": [2], - "a": "bar" -} -*/ - --- test: Should use the right field if primary key is specified -CREATE TABLE test (a (b TEXT PRIMARY KEY)); -INSERT INTO test (a) VALUES ({b: "foo"}), ({b:"bar"}); -SELECT pk(), a FROM test; -/* result: -{ - "pk()": ["bar"], - "a": { - "b": "bar" - } -} -{ - "pk()": ["foo"], - "a": { - "b": "foo" - } -} -*/ - -- test: Should fail if Pk not found CREATE TABLE test (a PRIMARY KEY, b INT); INSERT INTO test (b) VALUES (1); diff --git a/internal/sqltests/INSERT/types.sql b/internal/sqltests/INSERT/types.sql deleted file mode 100644 index e3902a89b..000000000 --- a/internal/sqltests/INSERT/types.sql +++ /dev/null @@ -1,275 +0,0 @@ --- test: insert with errors, not null without type constraint -CREATE TABLE test_e (a NOT NULL); -INSERT INTO test_e VALUES {}; --- error: - --- test: insert with errors, array / not null with type constraint -CREATE TABLE test_e (a ARRAY NOT NULL); -INSERT INTO test_e VALUES {}; --- error: - --- test: insert with errors, array / not null with non-respected type constraint -CREATE TABLE test_e (a ARRAY NOT NULL); -INSERT INTO test_e VALUES {a: 42}; --- error: - --- test: insert with errors, blob -CREATE TABLE test_e (a BLOB); -INSERT INTO test_e {a: true}; --- error: - --- test: blob / not null with type constraint -CREATE TABLE test_e (a BLOB NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: blob / not null with non-respected type constraint -CREATE TABLE test_e (a BLOB NOT NULL); -INSERT INTO test_e {a: 42}; --- error: - --- test: bool / not null with type constraint -CREATE TABLE test_e (a BOOL NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: bytes -CREATE TABLE test_e (a BYTES); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: bytes / not null with type constraint -CREATE TABLE test_e (a BYTES NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: bytes / not null with non-respected type constraint -CREATE TABLE test_e (a BYTES NOT NULL); -INSERT INTO test_e {a: 42}; --- error: - --- test: document -CREATE TABLE test_e (a OBJECT); -INSERT INTO test_e {"a": "foo"}; --- error: - --- test: document / not null with type constraint -CREATE TABLE test_e (a OBJECT NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: document / not null with non-respected type constraint -CREATE TABLE test_e (a OBJECT NOT NULL); -INSERT INTO test_e {a: false}; --- error: - --- test: double -CREATE TABLE test_e (a DOUBLE); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: double / not null with type constraint -CREATE TABLE test_e (a DOUBLE NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: double / not null with non-respected type constraint -CREATE TABLE test_e (a DOUBLE NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: double precision -CREATE TABLE test_e (a DOUBLE PRECISION); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: double precision / not null with type constraint -CREATE TABLE test_e (a DOUBLE PRECISION NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: double precision / not null with non-respected type constraint -CREATE TABLE test_e (a DOUBLE PRECISION NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: real -CREATE TABLE test_e (a REAL); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: real / not null with type constraint -CREATE TABLE test_e (a REAL NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: real / not null with non-respected type constraint -CREATE TABLE test_e (a REAL NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: integer -CREATE TABLE test_e (a INTEGER); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: integer / not null with type constraint -CREATE TABLE test_e (a INTEGER NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: integer / not null with non-respected type constraint -CREATE TABLE test_e (a INTEGER NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: int2 -CREATE TABLE test_e (a INT2); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: int2 / not null with type constraint -CREATE TABLE test_e (a INT2 NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: int2 / not null with non-respected type constraint -CREATE TABLE test_e (a INT NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: int8 -CREATE TABLE test_e (a INT8); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: int8 / not null with type constraint -CREATE TABLE test_e (a INT8 NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: int8 / not null with non-respected type constraint -CREATE TABLE test_e (a INT8 NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: tinyint -CREATE TABLE test_e (a TINYINT); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: tinyint / not null with type constraint -CREATE TABLE test_e (a TINYINT NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: tinyint / not null with non-respected type constraint -CREATE TABLE test_e (a TINYINT NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: bigint -CREATE TABLE test_e (a BIGINT); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: bigint / not null with type constraint -CREATE TABLE test_e (a BIGINT NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: bigint / not null with non-respected type constraint -CREATE TABLE test_e (a BIGINT NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: smallint -CREATE TABLE test_e (a SMALLINT); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: smallint / not null with type constraint -CREATE TABLE test_e (a SMALLINT NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: smallint / not null with non-respected type constraint -CREATE TABLE test_e (a SMALLINT NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: mediumint -CREATE TABLE test_e (a MEDIUMINT); -INSERT INTO test_e {a: "foo"}; --- error: - --- test: mediumint / not null with type constraint -CREATE TABLE test_e (a MEDIUMINT NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: mediumint / not null with non-respected type constraint -CREATE TABLE test_e (a MEDIUMINT NOT NULL); -INSERT INTO test_e {a: [1,2,3]}; --- error: - --- test: text / not null with type constraint -CREATE TABLE test_e (a TEXT NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: varchar / not null with type constraint -CREATE TABLE test_e (a VARCHAR(255) NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: character / not null with type constraint -CREATE TABLE test_e (a CHARACTER(64) NOT NULL); -INSERT INTO test_e {}; --- error: - --- test: Should fail if the fields cannot be converted to specified field constraints -CREATE TABLE test (a DOUBLE); -INSERT INTO test VALUES ([1]); --- error: - --- test: Conversion -CREATE TABLE test (a INT); -INSERT INTO test (a) VALUES (1.5); -SELECT a FROM test; -/* result: -{ - "a": 1 -} -*/ - --- test: timestamp / no type -CREATE TABLE test_e (a); -INSERT INTO test_e VALUES {a: "2023-01-01T00:00:00Z"}; -SELECT typeof(a) FROM test_e; -/* result: -{ - "typeof(a)": "text" -} -*/ - --- test: timestamp / text -CREATE TABLE test_e (a TEXT); -INSERT INTO test_e VALUES {a: "2023-01-01T00:00:00Z"}; -SELECT typeof(a) FROM test_e; -/* result: -{ - "typeof(a)": "text" -} -*/ - --- test: timestamp / timestamp -CREATE TABLE test_e (a TIMESTAMP); -INSERT INTO test_e VALUES {a: "2023-01-01T00:00:00Z"}; -SELECT typeof(a), a FROM test_e; -/* result: -{ - "typeof(a)": "timestamp", - "a": "2023-01-01T00:00:00Z" -} -*/ \ No newline at end of file diff --git a/internal/sqltests/INSERT/values.sql b/internal/sqltests/INSERT/values.sql index a72e63b4c..03f9ba200 100644 --- a/internal/sqltests/INSERT/values.sql +++ b/internal/sqltests/INSERT/values.sql @@ -1,149 +1,75 @@ --- test: VALUES, with all fields +-- test: VALUES, with all columns CREATE TABLE test (a TEXT, b TEXT, c TEXT); INSERT INTO test (a, b, c) VALUES ('a', 'b', 'c'); -SELECT pk(), * FROM test; +SELECT * FROM test; /* result: { - "pk()": [1], "a": "a", "b": "b", "c": "c" } */ --- test: VALUES, with a few fields +-- test: VALUES, with a few columns CREATE TABLE test (a TEXT, b TEXT, c TEXT); INSERT INTO test (b, a) VALUES ('b', 'a'); -SELECT pk(), * FROM test; +SELECT * FROM test; /* result: { - "pk()": [1], "a": "a", - "b": "b" + "b": "b", + "c": null } */ --- test: VALUES, with too many fields +-- test: VALUES, with too many columns CREATE TABLE test (a TEXT, b TEXT, c TEXT); INSERT INTO test (b, a, c, d) VALUES ('b', 'a', 'c', 'd'); --- error: table has no field d - --- test: variadic, VALUES, with all fields -CREATE TABLE test (a TEXT, b TEXT, c TEXT, ...); -INSERT INTO test (a, b, c) VALUES ('a', 'b', 'c'); -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b": "b", - "c": "c" -} -*/ +-- error: table has no column d --- test: variadic, VALUES, with a few fields -CREATE TABLE test (a TEXT, b TEXT, c TEXT, ...); -INSERT INTO test (b, a, d) VALUES ('b', 'a', 'd'); -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b": "b", - "d": "d" -} -*/ - --- test: VALUES, no fields, all values +-- test: VALUES, no columns, all values CREATE TABLE test (a TEXT, b TEXT, c TEXT); INSERT INTO test VALUES ("a", 'b', 'c'); -SELECT pk(), * FROM test; +SELECT * FROM test; /* result: { - "pk()": [1], "a": "a", "b": "b", "c": "c" } */ --- test: VALUES, no fields, few values +-- test: VALUES, no columns, few values CREATE TABLE test (a TEXT, b TEXT, c TEXT); -INSERT INTO test VALUES ("a", 'b'); -SELECT pk(), * FROM test; +INSERT INTO test VALUES ('a', 'b'); +SELECT * FROM test; /* result: { - "pk()": [1], "a": "a", - "b": "b" -} -*/ - --- test: variadic, VALUES, no fields, few values -CREATE TABLE test (a TEXT, b TEXT, c TEXT, ...); -INSERT INTO test VALUES ("a", 'b'); -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b": "b" + "b": "b", + "c": null } */ --- test: variadic, VALUES, no fields, all values and more -CREATE TABLE test (a TEXT, b TEXT, c TEXT, ...); -INSERT INTO test VALUES ("a", 'b', 'c', 'd', 'e'); --- error: - -- test: VALUES, ident CREATE TABLE test (a TEXT, b TEXT, c TEXT); INSERT INTO test (a) VALUES (a); --- error: field not found +-- error: no table specified -- test: VALUES, ident string CREATE TABLE test (a TEXT, b TEXT, c TEXT); INSERT INTO test (a) VALUES (`a`); --- error: field not found +-- error: no table specified --- test: VALUES, fields ident string +-- test: VALUES, columns ident string CREATE TABLE test (a TEXT, `foo bar` TEXT); INSERT INTO test (a, `foo bar`) VALUES ('a', 'foo bar'); -SELECT pk(), * FROM test; +SELECT * FROM test; /* result: { - "pk()": [1], "a": "a", "foo bar": "foo bar" } */ --- test: VALUES, array -CREATE TABLE test (a TEXT, b TEXT, c ARRAY); -INSERT INTO test (a, b, c) VALUES ("a", 'b', [1, 2, 3]); -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b":"b", - "c": [1.0, 2.0, 3.0] -} -*/ - --- test: VALUES, generic object -CREATE TABLE test (a TEXT, b TEXT, c OBJECT); -INSERT INTO test (a, b, c) VALUES ("a", 'b', {c: 1, d: c + 1}); -SELECT pk(), * FROM test; -/* result: -{ - "pk()": [1], - "a": "a", - "b": "b", - "c": { - "c": 1.0, - "d": 2.0 - } -} -*/ diff --git a/internal/sqltests/SELECT/STRINGS/lower.sql b/internal/sqltests/SELECT/STRINGS/lower.sql index 83430a651..f748bf618 100644 --- a/internal/sqltests/SELECT/STRINGS/lower.sql +++ b/internal/sqltests/SELECT/STRINGS/lower.sql @@ -3,27 +3,18 @@ CREATE TABLE test( a TEXT, b INT, c BOOL, - d DOUBLE, - e ARRAY, - f ( - ... - ) + d DOUBLE ); -INSERT INTO test (a, b, c, d, e, f) VALUES ( +INSERT INTO test (a, b, c, d) VALUES ( "FOO", 42, true, - 42.42, - ["A", "b", "C", "d", "E"], - { - a: "HELLO", - b: "WorlD" - } + 42.42 ); -- test: TEXT value -SELECT strings.LOWER(a) FROM test; +SELECT LOWER(a) FROM test; /* result: { "LOWER(a)": "foo" @@ -32,7 +23,7 @@ SELECT strings.LOWER(a) FROM test; -- test: INT value -SELECT strings.LOWER(b) FROM test; +SELECT LOWER(b) FROM test; /* result: { "LOWER(b)": NULL @@ -41,7 +32,7 @@ SELECT strings.LOWER(b) FROM test; -- test: BOOL value -SELECT strings.LOWER(c) FROM test; +SELECT LOWER(c) FROM test; /* result: { "LOWER(c)": NULL @@ -49,31 +40,15 @@ SELECT strings.LOWER(c) FROM test; */ -- test: DOUBLE value -SELECT strings.LOWER(d) FROM test; +SELECT LOWER(d) FROM test; /* result: { "LOWER(d)": NULL } */ --- test: ARRAY value -SELECT strings.LOWER(e) FROM test; -/* result: -{ - "LOWER(e)": NULL -} -*/ - --- test: OBJECT value -SELECT strings.LOWER(f) FROM test; -/* result: -{ - "LOWER(f)": NULL -} -*/ - -- test: cast INT -SELECT strings.LOWER(CAST(b as TEXT)) FROM test; +SELECT LOWER(CAST(b as TEXT)) FROM test; /* result: { "LOWER(CAST(b AS text))": "42" @@ -81,7 +56,7 @@ SELECT strings.LOWER(CAST(b as TEXT)) FROM test; */ -- test: cast BOOL -SELECT strings.LOWER(CAST(c as TEXT)) FROM test; +SELECT LOWER(CAST(c as TEXT)) FROM test; /* result: { "LOWER(CAST(c AS text))": "true" @@ -89,25 +64,9 @@ SELECT strings.LOWER(CAST(c as TEXT)) FROM test; */ -- test: cast DOUBLE -SELECT strings.LOWER(CAST(d as TEXT)) FROM test; +SELECT LOWER(CAST(d as TEXT)) FROM test; /* result: { "LOWER(CAST(d AS text))": "42.42" } */ - --- test: cast ARRAY -SELECT strings.LOWER(CAST(e as TEXT)) FROM test; -/* result: -{ - "LOWER(CAST(e AS text))": "[\"a\", \"b\", \"c\", \"d\", \"e\"]" -} -*/ - --- test: cast OBJECT -SELECT strings.LOWER(CAST(f as TEXT)) FROM test; -/* result: -{ - "LOWER(CAST(f AS text))": "{\"a\": \"hello\", \"b\": \"world\"}" -} -*/ diff --git a/internal/sqltests/SELECT/STRINGS/ltrim.sql b/internal/sqltests/SELECT/STRINGS/ltrim.sql index e1b8bf761..c52fb1d24 100644 --- a/internal/sqltests/SELECT/STRINGS/ltrim.sql +++ b/internal/sqltests/SELECT/STRINGS/ltrim.sql @@ -6,7 +6,7 @@ CREATE TABLE test( INSERT INTO test (a) VALUES (" hello "), ("!hello!"), (" !hello! "); -- test: LTRIM TEXT default -SELECT strings.LTRIM(a) FROM test; +SELECT LTRIM(a) FROM test; /* result: { "LTRIM(a)": "hello " @@ -21,7 +21,7 @@ SELECT strings.LTRIM(a) FROM test; -- test: LTRIM TEXT with param -SELECT strings.LTRIM(a, "!") FROM test; +SELECT LTRIM(a, "!") FROM test; /* result: { "LTRIM(a, \"!\")": " hello " @@ -35,7 +35,7 @@ SELECT strings.LTRIM(a, "!") FROM test; */ -- test: LTRIM TEXT with multiple char params -SELECT strings.LTRIM(a, " !") FROM test; +SELECT LTRIM(a, " !") FROM test; /* result: { "LTRIM(a, \" !\")": "hello " @@ -50,7 +50,7 @@ SELECT strings.LTRIM(a, " !") FROM test; -- test: LTRIM TEXT with multiple char params -SELECT strings.LTRIM(a, "hel !") FROM test; +SELECT LTRIM(a, "hel !") FROM test; /* result: { "LTRIM(a, \"hel !\")": "o " @@ -65,7 +65,7 @@ SELECT strings.LTRIM(a, "hel !") FROM test; -- test: LTRIM BOOL -SELECT strings.LTRIM(true); +SELECT LTRIM(true); /* result: { "LTRIM(true)": NULL @@ -73,7 +73,7 @@ SELECT strings.LTRIM(true); */ -- test: LTRIM INT -SELECT strings.LTRIM(42); +SELECT LTRIM(42); /* result: { "LTRIM(42)": NULL @@ -81,30 +81,15 @@ SELECT strings.LTRIM(42); */ -- test: LTRIM DOUBLE -SELECT strings.LTRIM(42.42); +SELECT LTRIM(42.42); /* result: { "LTRIM(42.42)": NULL } */ --- test: LTRIM ARRAY -SELECT strings.LTRIM([1, 2]); -/* result: -{ - "LTRIM([1, 2])": NULL -} -*/ --- test: LTRIM OBJECT -SELECT strings.LTRIM({a: 1}); -/* result: -{ - "LTRIM({a: 1})": NULL -} -*/ - -- test: LTRIM STRING wrong param -SELECT strings.LTRIM(" hello ", 42); +SELECT LTRIM(" hello ", 42); /* result: { "LTRIM(\" hello \", 42)": NULL diff --git a/internal/sqltests/SELECT/STRINGS/rtrim.sql b/internal/sqltests/SELECT/STRINGS/rtrim.sql index 6f68ab82b..a54acb6e8 100644 --- a/internal/sqltests/SELECT/STRINGS/rtrim.sql +++ b/internal/sqltests/SELECT/STRINGS/rtrim.sql @@ -6,7 +6,7 @@ CREATE TABLE test( INSERT INTO test (a) VALUES (" hello "), ("!hello!"), (" !hello! "); -- test: RTRIM TEXT default -SELECT strings.RTRIM(a) FROM test; +SELECT RTRIM(a) FROM test; /* result: { "RTRIM(a)": " hello" @@ -21,7 +21,7 @@ SELECT strings.RTRIM(a) FROM test; -- test: RTRIM TEXT with param -SELECT strings.RTRIM(a, "!") FROM test; +SELECT RTRIM(a, "!") FROM test; /* result: { "RTRIM(a, \"!\")": " hello " @@ -35,7 +35,7 @@ SELECT strings.RTRIM(a, "!") FROM test; */ -- test: RTRIM TEXT with multiple char params -SELECT strings.RTRIM(a, " !") FROM test; +SELECT RTRIM(a, " !") FROM test; /* result: { "RTRIM(a, \" !\")": " hello" @@ -50,7 +50,7 @@ SELECT strings.RTRIM(a, " !") FROM test; -- test: RTRIM TEXT with multiple char params -SELECT strings.RTRIM(a, "hel !") FROM test; +SELECT RTRIM(a, "hel !") FROM test; /* result: { "RTRIM(a, \"hel !\")": " hello" @@ -65,7 +65,7 @@ SELECT strings.RTRIM(a, "hel !") FROM test; -- test: RTRIM BOOL -SELECT strings.RTRIM(true); +SELECT RTRIM(true); /* result: { "RTRIM(true)": NULL @@ -73,7 +73,7 @@ SELECT strings.RTRIM(true); */ -- test: RTRIM INT -SELECT strings.RTRIM(42); +SELECT RTRIM(42); /* result: { "RTRIM(42)": NULL @@ -81,30 +81,15 @@ SELECT strings.RTRIM(42); */ -- test: RTRIM DOUBLE -SELECT strings.RTRIM(42.42); +SELECT RTRIM(42.42); /* result: { "RTRIM(42.42)": NULL } */ --- test: RTRIM ARRAY -SELECT strings.RTRIM([1, 2]); -/* result: -{ - "RTRIM([1, 2])": NULL -} -*/ --- test: RTRIM OBJECT -SELECT strings.RTRIM({a: 1}); -/* result: -{ - "RTRIM({a: 1})": NULL -} -*/ - -- test: RTRIM STRING wrong param -SELECT strings.RTRIM(" hello ", 42); +SELECT RTRIM(" hello ", 42); /* result: { "RTRIM(\" hello \", 42)": NULL diff --git a/internal/sqltests/SELECT/STRINGS/trim.sql b/internal/sqltests/SELECT/STRINGS/trim.sql index deab7b333..ca0254d02 100644 --- a/internal/sqltests/SELECT/STRINGS/trim.sql +++ b/internal/sqltests/SELECT/STRINGS/trim.sql @@ -6,7 +6,7 @@ CREATE TABLE test( INSERT INTO test (a) VALUES (" hello "), ("!hello!"), (" !hello! "); -- test: TRIM TEXT default -SELECT strings.TRIM(a) FROM test; +SELECT TRIM(a) FROM test; /* result: { "TRIM(a)": "hello" @@ -21,7 +21,7 @@ SELECT strings.TRIM(a) FROM test; -- test: TRIM TEXT with param -SELECT strings.TRIM(a, "!") FROM test; +SELECT TRIM(a, "!") FROM test; /* result: { "TRIM(a, \"!\")": " hello " @@ -35,7 +35,7 @@ SELECT strings.TRIM(a, "!") FROM test; */ -- test: TRIM TEXT with multiple char params -SELECT strings.TRIM(a, " !") FROM test; +SELECT TRIM(a, " !") FROM test; /* result: { "TRIM(a, \" !\")": "hello" @@ -50,7 +50,7 @@ SELECT strings.TRIM(a, " !") FROM test; -- test: TRIM TEXT with multiple char params -SELECT strings.TRIM(a, "hel !") FROM test; +SELECT TRIM(a, "hel !") FROM test; /* result: { "TRIM(a, \"hel !\")": "o" @@ -65,7 +65,7 @@ SELECT strings.TRIM(a, "hel !") FROM test; -- test: TRIM BOOL -SELECT strings.TRIM(true); +SELECT TRIM(true); /* result: { "TRIM(true)": NULL @@ -73,7 +73,7 @@ SELECT strings.TRIM(true); */ -- test: TRIM INT -SELECT strings.TRIM(42); +SELECT TRIM(42); /* result: { "TRIM(42)": NULL @@ -81,30 +81,15 @@ SELECT strings.TRIM(42); */ -- test: TRIM DOUBLE -SELECT strings.TRIM(42.42); +SELECT TRIM(42.42); /* result: { "TRIM(42.42)": NULL } */ --- test: TRIM ARRAY -SELECT strings.TRIM([1, 2]); -/* result: -{ - "TRIM([1, 2])": NULL -} -*/ --- test: TRIM OBJECT -SELECT strings.TRIM({a: 1}); -/* result: -{ - "TRIM({a: 1})": NULL -} -*/ - -- test: TRIM STRING wrong param -SELECT strings.TRIM(" hello ", 42); +SELECT TRIM(" hello ", 42); /* result: { "TRIM(\" hello \", 42)": NULL diff --git a/internal/sqltests/SELECT/STRINGS/upper.sql b/internal/sqltests/SELECT/STRINGS/upper.sql index a2d762297..3b3d84ce9 100644 --- a/internal/sqltests/SELECT/STRINGS/upper.sql +++ b/internal/sqltests/SELECT/STRINGS/upper.sql @@ -3,27 +3,18 @@ CREATE TABLE test( a TEXT, b INT, c BOOL, - d DOUBLE, - e ARRAY, - f ( - ... - ) + d DOUBLE ); -INSERT INTO test (a, b, c, d, e, f) VALUES ( +INSERT INTO test (a, b, c, d) VALUES ( "foo", 42, true, 42.42, - ["A", "b", "C", "d", "E"], - { - a: "hello", - b: "WorlD" - } ); -- test: TEXT value -SELECT strings.UPPER(a) FROM test; +SELECT UPPER(a) FROM test; /* result: { "UPPER(a)": "FOO" @@ -32,7 +23,7 @@ SELECT strings.UPPER(a) FROM test; -- test: INT value -SELECT strings.UPPER(b) FROM test; +SELECT UPPER(b) FROM test; /* result: { "UPPER(b)": NULL @@ -41,7 +32,7 @@ SELECT strings.UPPER(b) FROM test; -- test: BOOL value -SELECT strings.UPPER(c) FROM test; +SELECT UPPER(c) FROM test; /* result: { "UPPER(c)": NULL @@ -49,31 +40,15 @@ SELECT strings.UPPER(c) FROM test; */ -- test: DOUBLE value -SELECT strings.UPPER(d) FROM test; +SELECT UPPER(d) FROM test; /* result: { "UPPER(d)": NULL } */ --- test: ARRAY value -SELECT strings.UPPER(e) FROM test; -/* result: -{ - "UPPER(e)": NULL -} -*/ - --- test: OBJECT value -SELECT strings.UPPER(f) FROM test; -/* result: -{ - "UPPER(f)": NULL -} -*/ - -- test: cast INT -SELECT strings.UPPER(CAST(b as TEXT)) FROM test; +SELECT UPPER(CAST(b as TEXT)) FROM test; /* result: { "UPPER(CAST(b AS text))": "42" @@ -81,7 +56,7 @@ SELECT strings.UPPER(CAST(b as TEXT)) FROM test; */ -- test: cast BOOL -SELECT strings.UPPER(CAST(c as TEXT)) FROM test; +SELECT UPPER(CAST(c as TEXT)) FROM test; /* result: { "UPPER(CAST(c AS text))": "TRUE" @@ -89,25 +64,10 @@ SELECT strings.UPPER(CAST(c as TEXT)) FROM test; */ -- test: cast DOUBLE -SELECT strings.UPPER(CAST(d as TEXT)) FROM test; +SELECT UPPER(CAST(d as TEXT)) FROM test; /* result: { "UPPER(CAST(d AS text))": "42.42" } */ --- test: cast ARRAY -SELECT strings.UPPER(CAST(e as TEXT)) FROM test; -/* result: -{ - "UPPER(CAST(e AS text))": "[\"A\", \"B\", \"C\", \"D\", \"E\"]" -} -*/ - --- test: cast OBJECT -SELECT strings.UPPER(CAST(f as TEXT)) FROM test; -/* result: -{ - "UPPER(CAST(f AS text))": "{\"A\": \"HELLO\", \"B\": \"WORLD\"}" -} -*/ diff --git a/internal/sqltests/SELECT/WHERE/comp.sql b/internal/sqltests/SELECT/WHERE/comp.sql index fe3161089..785abd629 100644 --- a/internal/sqltests/SELECT/WHERE/comp.sql +++ b/internal/sqltests/SELECT/WHERE/comp.sql @@ -1,8 +1,8 @@ --- This file tests comparison operators with field paths +-- This file tests comparison operators with columns -- of different types. It ensures the behavior (modulo ordering) -- remains the same regardless of index usage. -- It contains one test suite with no index and one test suite --- per field where that particular field is indexed +-- per column where that particular column is indexed -- setup: CREATE TABLE test( @@ -11,52 +11,14 @@ CREATE TABLE test( b double, c boolean, d text, - e blob, - f (a int), -- f document - g ARRAY -- g array + e blob ); INSERT INTO test VALUES - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - }, - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - }, - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - }, - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - }; + (1, 10, 1.0, false, "a", "\xaa"), + (2, 20, 2.0, true, "b", "\xab"), + (3, 30, 3.0, false, "c", "\xac"), + (4, 40, 4.0, true, "d", "\xad"); -- suite: no index @@ -75,12 +37,6 @@ CREATE INDEX ON test(d); -- suite: index on e CREATE INDEX ON test(e); --- suite: index on f -CREATE INDEX ON test(f); - --- suite: index on g -CREATE INDEX ON test(g); - -- test: pk = SELECT * FROM test WHERE id = 1; /* result: @@ -90,9 +46,7 @@ SELECT * FROM test WHERE id = 1; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } */ @@ -105,9 +59,7 @@ SELECT * FROM test WHERE id != 1; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -115,9 +67,7 @@ SELECT * FROM test WHERE id != 1; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -125,9 +75,7 @@ SELECT * FROM test WHERE id != 1; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -140,9 +88,7 @@ SELECT * FROM test WHERE id > 1; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -150,9 +96,7 @@ SELECT * FROM test WHERE id > 1; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -160,9 +104,7 @@ SELECT * FROM test WHERE id > 1; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -175,9 +117,7 @@ SELECT * FROM test WHERE id >= 1; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -185,9 +125,7 @@ SELECT * FROM test WHERE id >= 1; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -195,9 +133,7 @@ SELECT * FROM test WHERE id >= 1; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -205,9 +141,7 @@ SELECT * FROM test WHERE id >= 1; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -220,9 +154,7 @@ SELECT * FROM test WHERE id < 3; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -230,9 +162,7 @@ SELECT * FROM test WHERE id < 3; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } */ @@ -245,9 +175,7 @@ SELECT * FROM test WHERE id <= 3; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -255,9 +183,7 @@ SELECT * FROM test WHERE id <= 3; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -265,9 +191,7 @@ SELECT * FROM test WHERE id <= 3; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -280,9 +204,7 @@ SELECT * FROM test WHERE id IN (1, 3); b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -290,9 +212,7 @@ SELECT * FROM test WHERE id IN (1, 3); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -305,9 +225,7 @@ SELECT * FROM test WHERE id NOT IN (1, 3); b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -315,9 +233,7 @@ SELECT * FROM test WHERE id NOT IN (1, 3); b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -330,9 +246,7 @@ SELECT * FROM test WHERE a = 10; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } */ @@ -345,9 +259,7 @@ SELECT * FROM test WHERE a != 10; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -355,9 +267,7 @@ SELECT * FROM test WHERE a != 10; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -365,9 +275,7 @@ SELECT * FROM test WHERE a != 10; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -380,9 +288,7 @@ SELECT * FROM test WHERE a > 10; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -390,9 +296,7 @@ SELECT * FROM test WHERE a > 10; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -400,9 +304,7 @@ SELECT * FROM test WHERE a > 10; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -415,9 +317,7 @@ SELECT * FROM test WHERE a >= 10; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -425,9 +325,7 @@ SELECT * FROM test WHERE a >= 10; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -435,9 +333,7 @@ SELECT * FROM test WHERE a >= 10; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -445,9 +341,7 @@ SELECT * FROM test WHERE a >= 10; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -460,9 +354,7 @@ SELECT * FROM test WHERE a < 30; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -470,9 +362,7 @@ SELECT * FROM test WHERE a < 30; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } */ @@ -485,9 +375,7 @@ SELECT * FROM test WHERE a <= 30; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -495,9 +383,7 @@ SELECT * FROM test WHERE a <= 30; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -505,9 +391,7 @@ SELECT * FROM test WHERE a <= 30; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -520,9 +404,7 @@ SELECT * FROM test WHERE a IN (10, 30); b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -530,9 +412,7 @@ SELECT * FROM test WHERE a IN (10, 30); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -545,9 +425,7 @@ SELECT * FROM test WHERE a NOT IN (10, 30); b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -555,9 +433,7 @@ SELECT * FROM test WHERE a NOT IN (10, 30); b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -570,9 +446,7 @@ SELECT * FROM test WHERE b = 1.0; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } */ @@ -585,9 +459,7 @@ SELECT * FROM test WHERE b != 1.0; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -595,9 +467,7 @@ SELECT * FROM test WHERE b != 1.0; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -605,9 +475,7 @@ SELECT * FROM test WHERE b != 1.0; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -620,9 +488,7 @@ SELECT * FROM test WHERE b > 1.0; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -630,9 +496,7 @@ SELECT * FROM test WHERE b > 1.0; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -640,9 +504,7 @@ SELECT * FROM test WHERE b > 1.0; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -655,9 +517,7 @@ SELECT * FROM test WHERE b >= 1.0; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -665,9 +525,7 @@ SELECT * FROM test WHERE b >= 1.0; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -675,9 +533,7 @@ SELECT * FROM test WHERE b >= 1.0; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -685,9 +541,7 @@ SELECT * FROM test WHERE b >= 1.0; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -700,9 +554,7 @@ SELECT * FROM test WHERE b < 3.0; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -710,9 +562,7 @@ SELECT * FROM test WHERE b < 3.0; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } */ @@ -725,9 +575,7 @@ SELECT * FROM test WHERE b <= 3.0; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -735,9 +583,7 @@ SELECT * FROM test WHERE b <= 3.0; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -745,9 +591,7 @@ SELECT * FROM test WHERE b <= 3.0; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -760,9 +604,7 @@ SELECT * FROM test WHERE b IN (1.0, 3.0); b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -770,9 +612,7 @@ SELECT * FROM test WHERE b IN (1.0, 3.0); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -785,9 +625,7 @@ SELECT * FROM test WHERE b NOT IN (1.0, 3.0); b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -795,24 +633,20 @@ SELECT * FROM test WHERE b NOT IN (1.0, 3.0); b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ -- test: bool = SELECT * FROM test WHERE c = true; -/* sorted-result: +/* result: { id: 2, a: 20, b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -820,24 +654,20 @@ SELECT * FROM test WHERE c = true; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ -- test: bool != SELECT * FROM test WHERE c != true; -/* sorted-result: +/* result: { id: 1, a: 10, b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -845,24 +675,20 @@ SELECT * FROM test WHERE c != true; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ -- test: bool > SELECT * FROM test WHERE c > false; -/* sorted-result: +/* result: { id: 2, a: 20, b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -870,24 +696,20 @@ SELECT * FROM test WHERE c > false; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ -- test: bool >= -SELECT * FROM test WHERE c >= false; -/* sorted-result: +SELECT * FROM test WHERE c >= false ORDER BY id; +/* result: { id: 1, a: 10, b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -895,9 +717,7 @@ SELECT * FROM test WHERE c >= false; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -905,9 +725,7 @@ SELECT * FROM test WHERE c >= false; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -915,24 +733,20 @@ SELECT * FROM test WHERE c >= false; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ -- test: bool < -SELECT * FROM test WHERE c < true; -/* sorted-result: +SELECT * FROM test WHERE c < true ORDER BY id; +/* result: { id: 1, a: 10, b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -940,24 +754,20 @@ SELECT * FROM test WHERE c < true; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ -- test: bool <= -SELECT * FROM test WHERE c <= true; -/* sorted-result: +SELECT * FROM test WHERE c <= true ORDER BY id; +/* result: { id: 1, a: 10, b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -965,9 +775,7 @@ SELECT * FROM test WHERE c <= true; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -975,9 +783,7 @@ SELECT * FROM test WHERE c <= true; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -985,24 +791,20 @@ SELECT * FROM test WHERE c <= true; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ -- test: bool IN -SELECT * FROM test WHERE c IN (true, false); -/* sorted-result: +SELECT * FROM test WHERE c IN (true, false) ORDER BY id; +/* result: { id: 1, a: 10, b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1010,9 +812,7 @@ SELECT * FROM test WHERE c IN (true, false); b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1020,9 +820,7 @@ SELECT * FROM test WHERE c IN (true, false); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1030,24 +828,20 @@ SELECT * FROM test WHERE c IN (true, false); b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ -- test: bool NOT IN SELECT * FROM test WHERE c NOT IN (true, 3); -/* sorted-result: +/* result: { id: 1, a: 10, b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -1055,9 +849,7 @@ SELECT * FROM test WHERE c NOT IN (true, 3); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -1070,9 +862,7 @@ SELECT * FROM test WHERE d = "a"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } */ @@ -1085,9 +875,7 @@ SELECT * FROM test WHERE d != "a"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1095,9 +883,7 @@ SELECT * FROM test WHERE d != "a"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1105,9 +891,7 @@ SELECT * FROM test WHERE d != "a"; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1120,9 +904,7 @@ SELECT * FROM test WHERE d > "a"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1130,9 +912,7 @@ SELECT * FROM test WHERE d > "a"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1140,9 +920,7 @@ SELECT * FROM test WHERE d > "a"; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1155,9 +933,7 @@ SELECT * FROM test WHERE d >= "a"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1165,9 +941,7 @@ SELECT * FROM test WHERE d >= "a"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1175,9 +949,7 @@ SELECT * FROM test WHERE d >= "a"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1185,9 +957,7 @@ SELECT * FROM test WHERE d >= "a"; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1200,9 +970,7 @@ SELECT * FROM test WHERE d < "c"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1210,9 +978,7 @@ SELECT * FROM test WHERE d < "c"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } */ @@ -1225,9 +991,7 @@ SELECT * FROM test WHERE d <= "c"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1235,9 +999,7 @@ SELECT * FROM test WHERE d <= "c"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1245,9 +1007,7 @@ SELECT * FROM test WHERE d <= "c"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -1260,9 +1020,7 @@ SELECT * FROM test WHERE d IN ("a", "c"); b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -1270,9 +1028,7 @@ SELECT * FROM test WHERE d IN ("a", "c"); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -1285,9 +1041,7 @@ SELECT * FROM test WHERE d NOT IN ("a", "c"); b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -1295,9 +1049,7 @@ SELECT * FROM test WHERE d NOT IN ("a", "c"); b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1310,9 +1062,7 @@ SELECT * FROM test WHERE e = "\xaa"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } */ @@ -1325,9 +1075,7 @@ SELECT * FROM test WHERE e != "\xaa"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1335,9 +1083,7 @@ SELECT * FROM test WHERE e != "\xaa"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1345,9 +1091,7 @@ SELECT * FROM test WHERE e != "\xaa"; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1360,9 +1104,7 @@ SELECT * FROM test WHERE e > "\xaa"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1370,9 +1112,7 @@ SELECT * FROM test WHERE e > "\xaa"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1380,9 +1120,7 @@ SELECT * FROM test WHERE e > "\xaa"; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1395,9 +1133,7 @@ SELECT * FROM test WHERE e >= "\xaa"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1405,9 +1141,7 @@ SELECT * FROM test WHERE e >= "\xaa"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1415,9 +1149,7 @@ SELECT * FROM test WHERE e >= "\xaa"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } { id: 4, @@ -1425,9 +1157,7 @@ SELECT * FROM test WHERE e >= "\xaa"; b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ @@ -1440,9 +1170,7 @@ SELECT * FROM test WHERE e < "\xac"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1450,9 +1178,7 @@ SELECT * FROM test WHERE e < "\xac"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } */ @@ -1465,9 +1191,7 @@ SELECT * FROM test WHERE e <= "\xac"; b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 2, @@ -1475,9 +1199,7 @@ SELECT * FROM test WHERE e <= "\xac"; b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 3, @@ -1485,9 +1207,7 @@ SELECT * FROM test WHERE e <= "\xac"; b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -1500,9 +1220,7 @@ SELECT * FROM test WHERE e IN ("\xaa", "\xac"); b: 1.0, c: false, d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] + e: "\xaa" } { id: 3, @@ -1510,9 +1228,7 @@ SELECT * FROM test WHERE e IN ("\xaa", "\xac"); b: 3.0, c: false, d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] + e: "\xac" } */ @@ -1525,9 +1241,7 @@ SELECT * FROM test WHERE e NOT IN ("\xaa", "\xac"); b: 2.0, c: true, d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] + e: "\xab" } { id: 4, @@ -1535,488 +1249,10 @@ SELECT * FROM test WHERE e NOT IN ("\xaa", "\xac"); b: 4.0, c: true, d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] + e: "\xad" } */ --- test: doc = -SELECT * FROM test WHERE f = {a: 1}; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } -*/ --- test: doc != -SELECT * FROM test WHERE f != {a: 1}; -/* result: - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ --- test: doc > -SELECT * FROM test WHERE f > {a: 1}; -/* result: - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ --- test: doc >= -SELECT * FROM test WHERE f >= {a: 1}; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ - --- test: doc < -SELECT * FROM test WHERE f < {a: 3}; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } -*/ - --- test: doc <= -SELECT * FROM test WHERE f <= {a: 3}; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } -*/ - --- test: doc IN -SELECT * FROM test WHERE f IN ({a: 1}, {a: 3}); -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } -*/ - --- test: doc NOT IN -SELECT * FROM test WHERE f NOT IN ({a: 1}, {a: 3}); -/* result: - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ - --- test: array = -SELECT * FROM test WHERE g = [1]; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } -*/ - --- test: array != -SELECT * FROM test WHERE g != [1]; -/* result: - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ - --- test: array > -SELECT * FROM test WHERE g > [1]; -/* result: - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ - --- test: array >= -SELECT * FROM test WHERE g >= [1]; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ - --- test: array < -SELECT * FROM test WHERE g < [3]; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } -*/ - --- test: array <= -SELECT * FROM test WHERE g <= [3]; -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } -*/ - --- test: array IN -SELECT * FROM test WHERE g IN ([1], [3]); -/* result: - { - id: 1, - a: 10, - b: 1.0, - c: false, - d: "a", - e: "\xaa", - f: {a: 1}, - g: [1.0] - } - { - id: 3, - a: 30, - b: 3.0, - c: false, - d: "c", - e: "\xac", - f: {a: 3}, - g: [3.0] - } -*/ - --- test: array NOT IN -SELECT * FROM test WHERE g NOT IN ([1], [3]); -/* result: - { - id: 2, - a: 20, - b: 2.0, - c: true, - d: "b", - e: "\xab", - f: {a: 2}, - g: [2.0] - } - { - id: 4, - a: 40, - b: 4.0, - c: true, - d: "d", - e: "\xad", - f: {a: 4}, - g: [4.0] - } -*/ \ No newline at end of file diff --git a/internal/sqltests/SELECT/distinct.sql b/internal/sqltests/SELECT/distinct.sql index 0a4155638..987e91cc4 100644 --- a/internal/sqltests/SELECT/distinct.sql +++ b/internal/sqltests/SELECT/distinct.sql @@ -1,12 +1,11 @@ -- setup: -CREATE TABLE test; +CREATE TABLE test(a INT, b TEXT, c bool); INSERT INTO test(a, b, c) VALUES - (1, {d: 1}, [true]), - (1, {d: 2}, [false]), - (1, {d: 2}, []), - (2, {d: 3}, []), - (2, {d: 3}, []), - ([true], 1, 1.5); + (1, 'foo', true), + (1, 'bar', false), + (1, 'bar', NULL), + (2, 'baz', NULL), + (2, 'baz', NULL); -- test: literal SELECT DISTINCT 'a' FROM test; @@ -20,77 +19,62 @@ SELECT DISTINCT 'a' FROM test; SELECT DISTINCT * FROM test; /* result: { - a: 1.0, - b: {d: 1.0}, - c: [true] + a: 1, + b: "bar", + c: null } { - a: 1.0, - b: {d: 2.0}, - c: [] + a: 1, + b: "bar", + c: false } { - a: 1.0, - b: {d: 2.0}, - c: [false] + a: 1, + b: "foo", + c: true } { - a: 2.0, - b: {d: 3.0}, - c: [] -} -{ - a: [true], - b: 1.0, - c: 1.5 + a: 2, + b: "baz", + c: null } */ --- test: field path +-- test: column SELECT DISTINCT a FROM test; /* result: { - a: 1.0, -} -{ - a: 2.0, + a: 1, } { - a: [true], + a: 2, } */ --- test: field path +-- test: column SELECT DISTINCT a FROM test; /* result: { - a: 1.0, + a: 1, } { - a: 2.0, -} -{ - a: [true], + a: 2, } */ --- test: multiple field paths -SELECT DISTINCT a, b.d FROM test; +-- test: multiple columns +SELECT DISTINCT a, b FROM test; /* result: { - a: 1.0, - "b.d": 1.0 -} -{ - a: 1.0, - "b.d": 2.0 + a: 1, + b: "bar" } { - a: 2.0, - "b.d": 3.0 + a: 1, + b: "foo" } { - a: [true], - "b.d": NULL + a: 2, + b: "baz" } */ \ No newline at end of file diff --git a/internal/sqltests/SELECT/len.sql b/internal/sqltests/SELECT/len.sql deleted file mode 100644 index f9696c549..000000000 --- a/internal/sqltests/SELECT/len.sql +++ /dev/null @@ -1,155 +0,0 @@ --- setup: -CREATE TABLE foo( - a TEXT, - b ARRAY, - c ( - ... - ) -); -INSERT INTO foo VALUES ( - "hello", - [1, 2, 3, [4, 5]], - { - a: 1, - b: 2, - c: { - d: 3 - } - } -); - --- test: text field -SELECT len(a) FROM foo; -/* result: -{ - "LEN(a)": 5 -} -*/ - --- test: array field -SELECT len(b) FROM foo; -/* result: -{ - "LEN(b)": 4 -} -*/ - --- test: document -SELECT len(c) FROM foo; -/* result: -{ - "LEN(c)": 3 -} -*/ - --- test: subarray -SELECT len(b[3]) FROM foo; -/* result: -{ - "LEN(b[3])": 2 -} -*/ - --- test: subdocument -SELECT len(c.c) FROM foo; -/* result: -{ - "LEN(c.c)": 1 -} -*/ - --- test: text expr -SELECT len("hello, world!"); -/* result: -{ - "LEN(\"hello, world!\")": 13 -} -*/ - --- test: zero text expr -SELECT len(''); -/* result: -{ - "LEN(\"\")": 0 -} -*/ - --- test: array expr -SELECT len([1, 2, 3, 4, 5]); -/* result: -{ - "LEN([1, 2, 3, 4, 5])": 5 -} -*/ - --- test: empty array expr -SELECT len([]); -/* result: -{ - "LEN([])": 0 -} -*/ - --- test: mixed type array expr -SELECT len([1, 2, 3, [1, 2, 3]]); -/* result: -{ - "LEN([1, 2, 3, [1, 2, 3]])": 4 -} -*/ - --- test: document expr -SELECT len({'a': 1, 'b': 2, 'c': 3}); -/* result: -{ - "LEN({a: 1, b: 2, c: 3})": 3 -} -*/ - --- test: empty document expr -SELECT len({}); -/* result: -{ - "LEN({})": 0 -} -*/ - --- test: integer expr -SELECT len(10); -/* result: -{ - "LEN(10)": NULL -} -*/ - --- test: float expr -SELECT len(1.0); -/* result: -{ - "LEN(1.0)": NULL -} -*/ - --- test: NULL expr -SELECT len(NULL); -/* result: -{ - "LEN(NULL)": NULL -} -*/ - --- test: NULL expr -SELECT len(NULL); -/* result: -{ - "LEN(NULL)": NULL -} -*/ - --- test: blob expr -SELECT len('\x323232') as l; -/* result: -{ - "l": NULL -} -*/ diff --git a/internal/sqltests/SELECT/nullable.sql b/internal/sqltests/SELECT/nullable.sql index 98b899998..c70c75ef6 100644 --- a/internal/sqltests/SELECT/nullable.sql +++ b/internal/sqltests/SELECT/nullable.sql @@ -1,34 +1,27 @@ --- test: document with no constraint -CREATE TABLE test (a OBJECT, c int); -INSERT INTO test (a) VALUES ({ b: 1 }), ({ b: 2 }); +-- test: one nullable column +CREATE TABLE test (a INT); +INSERT INTO test (a) VALUES (null), (null); SELECT * FROM test; /* result: { - a: { b: 1.0 }, + a: null } { - a: { b: 2.0 }, + a: null } */ --- test: one nullable column -CREATE TABLE test (a INT); -INSERT INTO test (a) VALUES (null), (null); -SELECT * FROM test; -/* result: -{} -{} -*/ - -- test: first column null CREATE TABLE test (a INT, b INT); INSERT INTO test (b) VALUES (1), (2); SELECT * FROM test; /* result: { + a: null, b: 1 } { + a: null, b: 2 } */ @@ -39,22 +32,11 @@ INSERT INTO test (a) VALUES (1), (2); SELECT * FROM test; /* result: { - a: 1 + a: 1, + b: null } { - a: 2 + a: 2, + b: null } */ - --- test: after a document -CREATE TABLE test (a OBJECT, b INT); -INSERT INTO test (a) VALUES ({ c: 1 }), ({ c: 2 }); -SELECT * FROM test; -/* result: -{ - a: { c: 1.0 } -} -{ - a: { c: 2.0 } -} -*/ \ No newline at end of file diff --git a/internal/sqltests/SELECT/objects/fields.sql b/internal/sqltests/SELECT/objects/fields.sql deleted file mode 100644 index aa3901c31..000000000 --- a/internal/sqltests/SELECT/objects/fields.sql +++ /dev/null @@ -1,78 +0,0 @@ --- setup: -CREATE TABLE test( - a TEXT, - b INT, - c BOOL, - d DOUBLE, - e ARRAY, - f OBJECT -); - -INSERT INTO test (a, b, c, d, e, f) VALUES ( - "FOO", - 42, - true, - 42.42, - ["A", "b", "C", "d", "E"], - { - a: "HELLO", - b: "WorlD" - } -); - --- test: TEXT value -SELECT objects.fields(a) FROM test; -/* result: -{ - "objects.fields(a)": NULL -} -*/ - --- test: INT value -SELECT objects.fields(b) FROM test; -/* result: -{ - "objects.fields(b)": NULL -} -*/ - - --- test: BOOL value -SELECT objects.fields(c) FROM test; -/* result: -{ - "objects.fields(c)": NULL -} -*/ - --- test: DOUBLE value -SELECT objects.fields(d) FROM test; -/* result: -{ - "objects.fields(d)": NULL -} -*/ - --- test: ARRAY value -SELECT objects.fields(e) FROM test; -/* result: -{ - "objects.fields(e)": NULL -} -*/ - --- test: OBJECT value -SELECT objects.fields(f) FROM test; -/* result: -{ - "objects.fields(f)": ["a", "b"] -} -*/ - --- test: wildcard -SELECT objects.fields(*) FROM test; -/* result: -{ - "objects.fields(*)": ["a", "b", "c", "d", "e", "f"] -} -*/ \ No newline at end of file diff --git a/internal/sqltests/SELECT/order_by.sql b/internal/sqltests/SELECT/order_by.sql index 2de5753a2..e15d7bad6 100644 --- a/internal/sqltests/SELECT/order_by.sql +++ b/internal/sqltests/SELECT/order_by.sql @@ -29,6 +29,7 @@ SELECT b FROM test ORDER BY a; SELECT * FROM test ORDER BY a; /* result: { + a: null, b: 1.0, } { @@ -79,6 +80,7 @@ SELECT * FROM test ORDER BY a DESC; b: 2.0 } { + a: null, b: 1.0 } */ diff --git a/internal/sqltests/SELECT/order_by_desc_index.sql b/internal/sqltests/SELECT/order_by_desc_index.sql index 78f273dfd..e27c50fc0 100644 --- a/internal/sqltests/SELECT/order_by_desc_index.sql +++ b/internal/sqltests/SELECT/order_by_desc_index.sql @@ -128,7 +128,7 @@ SELECT a, b FROM test WHERE a = 100 ORDER BY b DESC; EXPLAIN SELECT a, b FROM test WHERE a = 100 ORDER BY b DESC; /* result: { - plan: "index.Scan(\"test_a_b_idx\", [{\"min\": [100], \"exact\": true}]) | rows.Project(a, b)" + plan: "index.Scan(\"test_a_b_idx\", [{\"min\": (100), \"exact\": true}]) | rows.Project(a, b)" } */ diff --git a/internal/sqltests/SELECT/order_by_desc_pk_composite.sql b/internal/sqltests/SELECT/order_by_desc_pk_composite.sql index 02ed1746e..455bb8fdf 100644 --- a/internal/sqltests/SELECT/order_by_desc_pk_composite.sql +++ b/internal/sqltests/SELECT/order_by_desc_pk_composite.sql @@ -127,7 +127,7 @@ SELECT a, b FROM test WHERE a = 100 ORDER BY b DESC; EXPLAIN SELECT a, b FROM test WHERE a = 100 ORDER BY b DESC; /* result: { - plan: "table.Scan(\"test\", [{\"min\": [100], \"exact\": true}]) | rows.Project(a, b)" + plan: "table.Scan(\"test\", [{\"min\": (100), \"exact\": true}]) | rows.Project(a, b)" } */ diff --git a/internal/sqltests/SELECT/pk.sql b/internal/sqltests/SELECT/pk.sql deleted file mode 100644 index 47c19f51a..000000000 --- a/internal/sqltests/SELECT/pk.sql +++ /dev/null @@ -1,13 +0,0 @@ --- setup: -CREATE TABLE test; -INSERT INTO test(a) VALUES (1), (2), (3), (4), (5); - --- test: wildcard -SELECT pk(), a FROM test; -/* result: -{"pk()": [1], "a": 1.0} -{"pk()": [2], "a": 2.0} -{"pk()": [3], "a": 3.0} -{"pk()": [4], "a": 4.0} -{"pk()": [5], "a": 5.0} -*/ diff --git a/internal/sqltests/SELECT/projection_no_table.sql b/internal/sqltests/SELECT/projection_no_table.sql index 66e07908e..4319ca772 100644 --- a/internal/sqltests/SELECT/projection_no_table.sql +++ b/internal/sqltests/SELECT/projection_no_table.sql @@ -28,12 +28,6 @@ SELECT "'A'"; {`"'A'"`: "'A'"} */ --- test: document -SELECT {a: 1, b: 2 + 1}; -/* result: -{"{a: 1, b: 2 + 1}":{"a":1,"b":3}} -*/ - -- test: aliases SELECT 1 AS A; /* result: @@ -46,13 +40,7 @@ SELECT CAST(1 AS DOUBLE) AS A; {"A": 1.0} */ --- test: pk() -SELECT pk(); -/* result: -{"pk()": null} -*/ - --- test: field +-- test: column SELECT a; -- error: diff --git a/internal/sqltests/SELECT/projection_table.sql b/internal/sqltests/SELECT/projection_table.sql index 82a51b0c0..0eaad22cd 100644 --- a/internal/sqltests/SELECT/projection_table.sql +++ b/internal/sqltests/SELECT/projection_table.sql @@ -1,6 +1,6 @@ -- setup: -CREATE TABLE test(a double, ...); -INSERT INTO test(a, b, c) VALUES (1, {a: 1}, [true]); +CREATE TABLE test(a double, b int, c bool); +INSERT INTO test(a, b, c) VALUES (1, 1, true); -- suite: no index @@ -10,7 +10,7 @@ CREATE INDEX ON test(a); -- test: wildcard SELECT * FROM test; /* result: -{"a": 1.0, "b": {"a": 1.0}, "c": [true]} +{"a": 1.0, "b": 1, "c": true} */ -- test: multiple wildcards @@ -18,43 +18,43 @@ SELECT *, * FROM test; /* result: { "a": 1.0, - "b": {"a": 1.0}, - "c": [true], + "b": 1, + "c": true, "a": 1.0, - "b": {"a": 1.0}, - "c": [true] + "b": 1, + "c": true } */ --- test: field paths +-- test: column paths SELECT a, b, c FROM test; /* result: { "a": 1.0, - "b": {"a": 1.0}, - "c": [true] + "b": 1, + "c": true } */ --- test: field path, wildcards and expressions -SELECT a AS A, b.a + 1, * FROM test; +-- test: column path, wildcards and expressions +SELECT a AS A, b + 1, * FROM test; /* result: { "A": 1.0, - "b.a + 1": 2.0, + "b + 1": 2, "a": 1.0, - "b": {"a": 1.0}, - "c": [true] + "b": 1, + "c": true } */ --- test: wildcard and other field +-- test: wildcard and other column SELECT *, c FROM test; /* result: { "a": 1.0, - "b": {"a": 1.0}, - "c": [true], - "c": [true] + "b": 1, + "c": true, + "c": true } */ diff --git a/internal/sqltests/SELECT/union.sql b/internal/sqltests/SELECT/union.sql index 54bdde846..bac798230 100644 --- a/internal/sqltests/SELECT/union.sql +++ b/internal/sqltests/SELECT/union.sql @@ -1,7 +1,7 @@ -- setup: -CREATE TABLE foo; -CREATE TABLE bar; -CREATE TABLE baz; +CREATE TABLE foo(a DOUBLE, b DOUBLE); +CREATE TABLE bar(a DOUBLE, b DOUBLE); +CREATE TABLE baz(x TEXT, y TEXT); INSERT INTO foo (a,b) VALUES (1.0, 1.0), (2.0, 2.0); INSERT INTO bar (a,b) VALUES (2.0, 2.0), (3.0, 3.0); INSERT INTO baz (x,y) VALUES ("a", "a"), ("b", "b"); diff --git a/internal/sqltests/UPDATE/check.sql b/internal/sqltests/UPDATE/check.sql index f676f5631..bb2c7e954 100644 --- a/internal/sqltests/UPDATE/check.sql +++ b/internal/sqltests/UPDATE/check.sql @@ -1,35 +1,3 @@ --- test: no type constraint, valid double -CREATE TABLE test (a CHECK(a > 10)); -INSERT INTO test (a) VALUES (11); -UPDATE test SET a = 12; -SELECT * FROM test; -/* result: -{ - a: 12.0 -} -*/ - --- test: no type constraint, invalid double -CREATE TABLE test (a CHECK(a > 10)); -INSERT INTO test (a) VALUES (11); -UPDATE test SET a = 1; --- error: row violates check constraint "test_check" - --- test: no type constraint, text -CREATE TABLE test (a CHECK(a > 10)); -INSERT INTO test (a) VALUES (11); -UPDATE test SET a = "hello"; --- error: row violates check constraint "test_check" - --- test: no type constraint, null -CREATE TABLE test (a CHECK(a > 10)); -INSERT INTO test (a) VALUES (11); -UPDATE test UNSET a; -SELECT * FROM test; -/* result: -{} -*/ - -- test: int type constraint, double CREATE TABLE test (a int CHECK(a > 10)); INSERT INTO test (a) VALUES (11); diff --git a/internal/sqltests/UPDATE/pk.sql b/internal/sqltests/UPDATE/pk.sql index d29ba3649..a58609661 100644 --- a/internal/sqltests/UPDATE/pk.sql +++ b/internal/sqltests/UPDATE/pk.sql @@ -3,10 +3,10 @@ CREATE TABLE test (a int primary key, b int); INSERT INTO test (a, b) VALUES (1, 10); UPDATE test SET a = 2, b = 20 WHERE a = 1; INSERT INTO test (a, b) VALUES (1, 10); -SELECT pk(), * FROM test; +SELECT * FROM test; /* result: -{"pk()": [1], a: 1, b: 10} -{"pk()": [2], a: 2, b: 20} +{a: 1, b: 10} +{a: 2, b: 20} */ -- test: set primary key / conflict @@ -20,20 +20,8 @@ CREATE TABLE test (a int, b int, c int, PRIMARY KEY(a, b)); INSERT INTO test (a, b, c) VALUES (1, 10, 100); UPDATE test SET a = 2, b = 20 WHERE a = 1; INSERT INTO test (a, b, c) VALUES (1, 10, 100); -SELECT pk(), * FROM test; +SELECT * FROM test; /* result: -{"pk()": [1, 10], a: 1, b: 10, c: 100} -{"pk()": [2, 20], a: 2, b: 20, c: 100} +{a: 1, b: 10, c: 100} +{a: 2, b: 20, c: 100} */ - --- test: unset primary key -CREATE TABLE test (a int primary key, b int); -INSERT INTO test (a, b) VALUES (1, 10); -UPDATE test UNSET a WHERE a = 1; --- error: cannot unset primary key path - --- test: unset composite primary key -CREATE TABLE test (a int, b int, c int, PRIMARY KEY(a, b)); -INSERT INTO test (a, b, c) VALUES (1, 10, 100); -UPDATE test UNSET b WHERE a = 1; --- error: cannot unset primary key path diff --git a/internal/sqltests/expr/arithmetic.sql b/internal/sqltests/expr/arithmetic.sql index 117d56e10..8d5628656 100644 --- a/internal/sqltests/expr/arithmetic.sql +++ b/internal/sqltests/expr/arithmetic.sql @@ -48,33 +48,33 @@ NULL NULL -- test: divide by zero -> 1 / 0 -NULL +! 1 / 0 +'division by zero' --- test: arithmetic with unexisting field +-- test: arithmetic with unexisting column ! 1 + a -'field not found' +'no table specified' ! 1 - a -'field not found' +'no table specified' ! 1 * a -'field not found' +'no table specified' ! 1 / a -'field not found' +'no table specified' ! 1 % a -'field not found' +'no table specified' ! 1 & a -'field not found' +'no table specified' ! 1 | a -'field not found' +'no table specified' ! 1 ^ a -'field not found' +'no table specified' -- test: division > 1 / 2 @@ -93,23 +93,9 @@ NULL > 1 + true NULL -> 1 + [1] -NULL - -> 1 + {a: 1} -NULL - -> [1] + [1] -NULL - -> {a: 1} + {a: 1} -NULL - > 4.5 + 4.5 9.0 -> 1000000000 * 1000000000 -1000000000000000000 +! 1000000000 * 1000000000 -> 1000000000000000000 * 1000000000000000000 * 1000000000000000000 -1000000000000000000000000000000000000000000000000000000 +! 1000000000000000000 * 1000000000000000000 * 1000000000000000000 diff --git a/internal/sqltests/expr/cast.sql b/internal/sqltests/expr/cast.sql index d05a51875..b8e4282c5 100644 --- a/internal/sqltests/expr/cast.sql +++ b/internal/sqltests/expr/cast.sql @@ -14,12 +14,6 @@ true ! CAST (1 AS BLOB) 'cannot cast integer as blob' -! CAST (1 AS ARRAY) -'cannot cast integer as array' - -! CAST (1 AS OBJECT) -'cannot cast integer as object' - -- test: source(DOUBLE) > CAST (1.1 AS DOUBLE) 1.1 @@ -36,12 +30,6 @@ true ! CAST (1.1 AS BLOB) 'cannot cast double as blob' -! CAST (1.1 AS ARRAY) -'cannot cast double as array' - -! CAST (1.1 AS OBJECT) -'cannot cast double as object' - -- test: source(BOOL) > CAST (true AS BOOL) true @@ -61,12 +49,6 @@ true ! CAST (true AS BLOB) 'cannot cast boolean as blob' -! CAST (true AS ARRAY) -'cannot cast boolean as array' - -! CAST (true AS OBJECT) -'cannot cast boolean as object' - -- test: source(TEXT) > CAST ('a' AS TEXT) 'a' @@ -98,19 +80,6 @@ false > CAST ('YXNkaW5l' AS BLOB) '\x617364696e65' -> CAST ('[]' AS ARRAY) -[] - -> CAST ('[1, true, [], {" a": 1}]' AS ARRAY) -[1, true, [], {" a": 1}] - -! CAST ('[1, true, [], {" a": 1}' AS ARRAY) - -> CAST ('{"a": 1, "b": [1, true, [], {" a": 1}]}' AS OBJECT) -{"a": 1, "b": [1, true, [], {" a": 1}]} - -! CAST ('{"a": 1' AS OBJECT) - -- test: source(BLOB) > CAST ('\xAF' AS BLOB) '\xAF' @@ -123,47 +92,3 @@ false > CAST ('\x617364696e65' AS TEXT) 'YXNkaW5l' - -! CAST ('\xAF' AS ARRAY) -'cannot cast blob as array' - -! CAST ('\xAF' AS OBJECT) -'cannot cast blob as object' - --- test: source(ARRAY) -> CAST ([1] AS ARRAY) -[1] - -! CAST ([1] AS INTEGER) -'cannot cast array as integer' - -! CAST ([1] AS DOUBLE) -'cannot cast array as double' - -> CAST ([1, true, [], {" a": 1}] AS TEXT) -'[1, true, [], {" a": 1}]' - -! CAST ([1] AS BLOB) -'cannot cast array as blob' - -! CAST ([1] AS OBJECT) -'cannot cast array as object' - --- test: source(OBJECT) -> CAST ({a: 1} AS OBJECT) -{a: 1} - -! CAST ({a: 1} AS INTEGER) -'cannot cast object as integer' - -! CAST ({a: 1} AS DOUBLE) -'cannot cast object as double' - -> CAST ({"a": 1, "b": [1, true, [], {" a": 1}]} AS TEXT) -'{"a": 1, "b": [1, true, [], {" a": 1}]}' - -! CAST ({a: 1} AS BLOB) -'cannot cast object as blob' - -! CAST ({a: 1} AS ARRAY) -'cannot cast object as array' \ No newline at end of file diff --git a/internal/sqltests/expr/literal.sql b/internal/sqltests/expr/literal.sql index 88cac04cf..9d152ceab 100644 --- a/internal/sqltests/expr/literal.sql +++ b/internal/sqltests/expr/literal.sql @@ -11,6 +11,19 @@ > typeof(-1) 'integer' +-- test: literals/bigints +> 100000000000 +100000000000 + +> typeof(100000000000) +'bigint' + +> -100000000000 +-100000000000 + +> typeof(-100000000000) +'bigint' + -- test: literals/doubles > 1.0 @@ -82,28 +95,3 @@ false ! '\xhello' 'invalid hexadecimal digit: h' - --- test: literals/arrays - -> [1, true, ['hello'], {a: [1]}] -[1, true, ['hello'], {a: [1]}] - -> typeof([1, true, ['hello'], {a: [1]}]) -'array' - --- test: literals/objects - -> {a: 1} -{a: 1} - -> {"a": 1} -{a: 1} - -> {'a': 1} -{a: 1} - -> {a: 1, b: {c: [1, true, ['hello'], {a: [1]}]}} -{a: 1, b: {c: [1, true, ['hello'], {a: [1]}]}} - -> typeof({a: 1, b: {c: [1, true, ['hello'], {a: [1]}]}}) -'object' \ No newline at end of file diff --git a/internal/sqltests/expr/objects.sql b/internal/sqltests/expr/objects.sql deleted file mode 100644 index 6a07cc613..000000000 --- a/internal/sqltests/expr/objects.sql +++ /dev/null @@ -1,35 +0,0 @@ --- test: objects.fields -> objects.fields({}) -[] - -> objects.fields({a: 1}) -['a'] - -> objects.fields({a: 1, b: {c: 2}}) -['a', 'b'] - - -> objects.fields(NULL) -NULL - -> objects.fields(true) -NULL - -> objects.fields(false) -NULL - -> objects.fields(1) -NULL - -> objects.fields(1.0) -NULL - -> objects.fields('hello') -NULL - -> objects.fields('\xAA') -NULL - -> objects.fields([]) -NULL - diff --git a/internal/sqltests/planning/between.sql b/internal/sqltests/planning/between.sql index 5c77d6bac..d154c07c6 100644 --- a/internal/sqltests/planning/between.sql +++ b/internal/sqltests/planning/between.sql @@ -3,7 +3,7 @@ CREATE TABLE test(a int UNIQUE); EXPLAIN SELECT * FROM test WHERE a BETWEEN 1 AND 2; /* result: { - "plan": 'index.Scan("test_a_idx", [{"min": [1], "max": [2]}])' + "plan": 'index.Scan("test_a_idx", [{"min": (1), "max": (2)}])' } */ @@ -13,7 +13,7 @@ CREATE INDEX on test(a, b); EXPLAIN SELECT * FROM test WHERE a BETWEEN 1 AND 2 AND b BETWEEN 3 AND 4; /* result: { - "plan": 'index.Scan("test_a_b_idx", [{"min": [1], "max": [2]}]) | rows.Filter(b BETWEEN 3 AND 4)' + "plan": 'index.Scan("test_a_b_idx", [{"min": (1), "max": (2)}]) | rows.Filter(b BETWEEN 3 AND 4)' } */ @@ -23,6 +23,6 @@ CREATE INDEX on test(a, b, c, d); EXPLAIN SELECT * FROM test WHERE a = 1 AND b = 10 AND c = 100 AND d BETWEEN 1000 AND 2000 AND e > 10000; /* result: { - "plan": 'index.Scan("test_a_b_c_d_idx", [{"min": [1, 10, 100, 1000], "max": [1, 10, 100, 2000]}]) | rows.Filter(e > 10000)' + "plan": 'index.Scan("test_a_b_c_d_idx", [{"min": (1, 10, 100, 1000), "max": (1, 10, 100, 2000)}]) | rows.Filter(e > 10000)' } */ diff --git a/internal/sqltests/planning/merge.gosave b/internal/sqltests/planning/merge.gosave deleted file mode 100644 index 8b9a72fc9..000000000 --- a/internal/sqltests/planning/merge.gosave +++ /dev/null @@ -1,146 +0,0 @@ -// // MergeFilterNodes merges any two filter nodes that are related to -// // the same path and can constitute a single filter node or a BETWEEN operation. -// // It also detects invalid sequences of filter nodes that can never return any result. -// // Example: -// // objs.Filter(a > 2) | objs.Filter(a < 5) -// // -> objs.Filter(a BETWEEN 2 AND 5) -// // objs.Filter(a > 2) | objs.Filter(a < 5) | objs.Filter(a = 3) -// // -> objs.Filter(a = 3) -// // objs.Filter(a > 2) | objs.Filter(a >= 5) -// // -> objs.Filter(a >= 5) -// func MergeFilterNodes(sctx *StreamContext) error { -// type selected struct { -// path document.Path -// op scanner.Token -// operand types.Value -// exclusive bool -// f *stream.DocsFilterOperator -// } - -// m := make(map[string][]selected) - -// // build a map grouping the filters by their path -// for _, f := range sctx.Filters { -// switch t := f.Expr.(type) { -// case expr.Operator: -// if !operatorIsIndexCompatible(t) { -// continue -// } - -// // check if the filter expression is in the form: -// // 'path OP value' OR 'value OP path' -// path, operand := getPathAndOperandFromOp(t) -// if path == nil { -// continue -// } - -// // check if the operand is a literal value -// lv, ok := operand.(expr.LiteralValue) -// if !ok { -// continue -// } - -// m[path.String()] = append(m[path.String()], selected{ -// path: path, -// op: t.Token(), -// operand: lv.Value, -// exclusive: t.Token() == scanner.GT || t.Token() == scanner.LT, -// f: f, -// }) -// } -// } - -// // merge the filters that are related to the same path -// for _, v := range m { -// if len(v) == 1 { -// continue -// } - -// // ensure the operands are all the same type. -// // the only exception is if the operands are both numbers -// for i := 1; i < len(v); i++ { -// if v[i].operand.Type() != v[0].operand.Type() && !(v[i].operand.Type().IsNumber() && v[i].operand.Type().IsNumber()) { -// // return an empty stream if the operands are not the same type -// sctx.Stream = new(stream.Stream) -// return nil -// } -// } - -// // analyse the filters to determine the lower and upper bounds -// var lower, upper types.Value -// var exclusiveLower, exclusiveUpper bool -// for i, s := range v { -// switch s.op { -// case scanner.GT, scanner.GTE: -// if lower == nil { -// lower = &v[i] -// continue -// } - -// // keep the highest lower bound -// ok, err := types.IsGreaterThan(s.operand, lower.operand) -// if err != nil { -// return err -// } -// if ok { -// // remove the previous lower bound -// sctx.removeFilterNode(lower.f) -// lower = &v[i] -// continue -// } - -// // in case they are equal, and one of them is exclusive, keep that one -// ok, err := types.IsEqual(s.operand, lower.operand) -// if err != nil { -// return err -// } - -// if s.op == scanner.GT && (lower.op == scanner.GTE || lower.op == scanner.BETWEEN) { -// // if they are equal and the operator is GT and the previous lower bound is a GTE or a BETWEEN, -// // we keep the GT -// ok, err := types.IsEqual(s.operand, lower.operand) -// if err != nil { -// return err -// } - -// if ok { -// // remove the previous lower bound -// sctx.removeFilterNode(lower.f) -// lower = &v[i] -// } - -// // remove the filter node -// sctx.removeFilterNode(s.f) -// } else { -// } -// // remove the filter node -// sctx.removeFilterNode(s.f) -// } -// } -// } - -// return nil -// } - -// func getPathAndOperandFromOp(op expr.Operator) (document.Path, expr.Expr) { -// if op.Token() == scanner.BETWEEN { -// xf, xIsPath := op.(*expr.BetweenOperator).X.(expr.Path) -// if !xIsPath { -// return nil, nil -// } - -// return document.Path(xf), expr.LiteralExprList{op.LeftHand(), op.RightHand()} -// } - -// lf, leftIsPath := op.LeftHand().(expr.Path) -// rf, rightIsPath := op.RightHand().(expr.Path) - -// if !leftIsPath && !rightIsPath { -// return nil, nil -// } - -// if leftIsPath { -// return document.Path(lf), op.RightHand() -// } -// return document.Path(rf), op.LeftHand() -// } diff --git a/internal/sqltests/planning/order_by.sql b/internal/sqltests/planning/order_by.sql index 19d48f6a9..8fd4c9b50 100644 --- a/internal/sqltests/planning/order_by.sql +++ b/internal/sqltests/planning/order_by.sql @@ -14,7 +14,7 @@ VALUES (4, 4, 4), (5, 5, 5); --- test: non-indexed field path, ASC +-- test: non-indexed column path, ASC EXPLAIN SELECT * FROM test ORDER BY c; /* result: { @@ -22,7 +22,7 @@ EXPLAIN SELECT * FROM test ORDER BY c; } */ --- test: non-indexed field path, DESC +-- test: non-indexed column path, DESC EXPLAIN SELECT * FROM test ORDER BY c DESC; /* result: { @@ -30,7 +30,7 @@ EXPLAIN SELECT * FROM test ORDER BY c DESC; } */ --- test: indexed field path, ASC +-- test: indexed column path, ASC EXPLAIN SELECT * FROM test ORDER BY a; /* result: { @@ -38,7 +38,7 @@ EXPLAIN SELECT * FROM test ORDER BY a; } */ --- test: indexed field path, DESC +-- test: indexed column path, DESC EXPLAIN SELECT * FROM test ORDER BY a DESC; /* result: { diff --git a/internal/sqltests/planning/order_by_composite.sql b/internal/sqltests/planning/order_by_composite.sql index 9ba2a76df..0bb84fe22 100644 --- a/internal/sqltests/planning/order_by_composite.sql +++ b/internal/sqltests/planning/order_by_composite.sql @@ -12,7 +12,7 @@ VALUES (4, 4, 4), (5, 5, 5); --- test: non-indexed field path, ASC +-- test: non-indexed column path, ASC EXPLAIN SELECT * FROM test ORDER BY c; /* result: { @@ -20,7 +20,7 @@ EXPLAIN SELECT * FROM test ORDER BY c; } */ --- test: non-indexed field path, DESC +-- test: non-indexed column path, DESC EXPLAIN SELECT * FROM test ORDER BY c DESC; /* result: { @@ -28,7 +28,7 @@ EXPLAIN SELECT * FROM test ORDER BY c DESC; } */ --- test: indexed field path, ASC +-- test: indexed column path, ASC EXPLAIN SELECT * FROM test ORDER BY a; /* result: { @@ -36,7 +36,7 @@ EXPLAIN SELECT * FROM test ORDER BY a; } */ --- test: indexed field path, DESC +-- test: indexed column path, DESC EXPLAIN SELECT * FROM test ORDER BY a DESC; /* result: { @@ -44,7 +44,7 @@ EXPLAIN SELECT * FROM test ORDER BY a DESC; } */ --- test: indexed field path in second position, ASC +-- test: indexed column path in second position, ASC EXPLAIN SELECT * FROM test ORDER BY b; /* result: { @@ -52,7 +52,7 @@ EXPLAIN SELECT * FROM test ORDER BY b; } */ --- test: indexed field path in second position, DESC +-- test: indexed column path in second position, DESC EXPLAIN SELECT * FROM test ORDER BY b DESC; /* result: { @@ -64,7 +64,7 @@ EXPLAIN SELECT * FROM test ORDER BY b DESC; EXPLAIN SELECT * FROM test WHERE a > 10 ORDER BY b DESC; /* result: { - "plan": 'index.Scan("test_a_b", [{"min": [10], "exclusive": true}]) | rows.TempTreeSortReverse(b)' + "plan": 'index.Scan("test_a_b", [{"min": (10), "exclusive": true}]) | rows.TempTreeSortReverse(b)' } */ @@ -72,7 +72,7 @@ EXPLAIN SELECT * FROM test WHERE a > 10 ORDER BY b DESC; EXPLAIN SELECT * FROM test WHERE a = 10 ORDER BY b DESC; /* result: { - "plan": 'index.ScanReverse("test_a_b", [{"min": [10], "exact": true}])' + "plan": 'index.ScanReverse("test_a_b", [{"min": (10), "exact": true}])' } */ diff --git a/internal/sqltests/planning/precalculate.sql b/internal/sqltests/planning/precalculate.sql index 29c4b145c..da93735dd 100644 --- a/internal/sqltests/planning/precalculate.sql +++ b/internal/sqltests/planning/precalculate.sql @@ -1,5 +1,5 @@ -- setup: -CREATE table test; +CREATE table test(a int); -- test: precalculate constant EXPLAIN SELECT * FROM test WHERE 3 + 4 > a + 3 % 2; diff --git a/internal/sqltests/planning/where.sql b/internal/sqltests/planning/where.sql index 9680e1ecf..2c71e881d 100644 --- a/internal/sqltests/planning/where.sql +++ b/internal/sqltests/planning/where.sql @@ -18,7 +18,7 @@ VALUES EXPLAIN SELECT * FROM test WHERE a = 10 AND b = 5; /* result: { - "plan": 'index.Scan("test_a", [{"min": [10], "exact": true}]) | rows.Filter(b = 5)' + "plan": 'index.Scan("test_a", [{"min": (10), "exact": true}]) | rows.Filter(b = 5)' } */ @@ -26,7 +26,7 @@ EXPLAIN SELECT * FROM test WHERE a = 10 AND b = 5; EXPLAIN SELECT * FROM test WHERE a > 10 AND b = 5; /* result: { - "plan": 'index.Scan("test_b", [{"min": [5], "exact": true}]) | rows.Filter(a > 10)' + "plan": 'index.Scan("test_b", [{"min": (5), "exact": true}]) | rows.Filter(a > 10)' } */ @@ -34,7 +34,7 @@ EXPLAIN SELECT * FROM test WHERE a > 10 AND b = 5; EXPLAIN SELECT * FROM test WHERE a > 10 AND b > 5; /* result: { - "plan": 'index.Scan("test_a", [{"min": [10], "exclusive": true}]) | rows.Filter(b > 5)' + "plan": 'index.Scan("test_a", [{"min": (10), "exclusive": true}]) | rows.Filter(b > 5)' } */ @@ -42,7 +42,7 @@ EXPLAIN SELECT * FROM test WHERE a > 10 AND b > 5; EXPLAIN SELECT * FROM test WHERE a >= 10 AND b > 5; /* result: { - "plan": 'index.Scan("test_a", [{"min": [10]}]) | rows.Filter(b > 5)' + "plan": 'index.Scan("test_a", [{"min": (10)}]) | rows.Filter(b > 5)' } */ @@ -50,7 +50,7 @@ EXPLAIN SELECT * FROM test WHERE a >= 10 AND b > 5; EXPLAIN SELECT * FROM test WHERE a < 10 AND b > 5; /* result: { - "plan": 'index.Scan("test_a", [{"max": [10], "exclusive": true}]) | rows.Filter(b > 5)' + "plan": 'index.Scan("test_a", [{"max": (10), "exclusive": true}]) | rows.Filter(b > 5)' } */ @@ -58,7 +58,7 @@ EXPLAIN SELECT * FROM test WHERE a < 10 AND b > 5; EXPLAIN SELECT * FROM test WHERE a BETWEEN 4 AND 5 AND b > 5; /* result: { - "plan": 'index.Scan("test_a", [{"min": [4], "max": [5]}]) | rows.Filter(b > 5)' + "plan": 'index.Scan("test_a", [{"min": (4), "max": (5)}]) | rows.Filter(b > 5)' } */ @@ -82,6 +82,6 @@ EXPLAIN SELECT * FROM test WHERE a + 1 < b; EXPLAIN SELECT * FROM test WHERE a IN (1, b + 3); /* result: { - "plan": 'table.Scan("test") | rows.Filter(a IN [1, b + 3])' + "plan": 'table.Scan("test") | rows.Filter(a IN (1, b + 3))' } */ \ No newline at end of file diff --git a/internal/sqltests/planning/where_pk.sql b/internal/sqltests/planning/where_pk.sql index c00f980e5..4cb08d1aa 100644 --- a/internal/sqltests/planning/where_pk.sql +++ b/internal/sqltests/planning/where_pk.sql @@ -14,7 +14,7 @@ VALUES EXPLAIN SELECT * FROM test WHERE a = 10 AND b = 5; /* result: { - "plan": 'table.Scan("test", [{\"min\": [10, 5], \"exact\": true}])' + "plan": 'table.Scan("test", [{\"min\": (10, 5), \"exact\": true}])' } */ @@ -22,7 +22,7 @@ EXPLAIN SELECT * FROM test WHERE a = 10 AND b = 5; EXPLAIN SELECT * FROM test WHERE a > 10 AND b = 5; /* result: { - "plan": 'table.Scan(\"test\", [{"min": [10], "exclusive": true}]) | rows.Filter(b = 5)' + "plan": 'table.Scan(\"test\", [{"min": (10), "exclusive": true}]) | rows.Filter(b = 5)' } */ @@ -30,7 +30,7 @@ EXPLAIN SELECT * FROM test WHERE a > 10 AND b = 5; EXPLAIN SELECT * FROM test WHERE a > 10 AND b > 5; /* result: { - "plan": 'table.Scan("test", [{"min": [10], "exclusive": true}]) | rows.Filter(b > 5)' + "plan": 'table.Scan("test", [{"min": (10), "exclusive": true}]) | rows.Filter(b > 5)' } */ @@ -38,7 +38,7 @@ EXPLAIN SELECT * FROM test WHERE a > 10 AND b > 5; EXPLAIN SELECT * FROM test WHERE a >= 10 AND b > 5; /* result: { - "plan": 'table.Scan("test", [{"min": [10]}]) | rows.Filter(b > 5)' + "plan": 'table.Scan("test", [{"min": (10)}]) | rows.Filter(b > 5)' } */ @@ -46,6 +46,6 @@ EXPLAIN SELECT * FROM test WHERE a >= 10 AND b > 5; EXPLAIN SELECT * FROM test WHERE a < 10 AND b > 5; /* result: { - "plan": 'table.Scan("test", [{"max": [10], "exclusive": true}]) | rows.Filter(b > 5)' + "plan": 'table.Scan("test", [{"max": (10), "exclusive": true}]) | rows.Filter(b > 5)' } */ diff --git a/internal/sqltests/sql_test.go b/internal/sqltests/sql_test.go index e574bee22..883136884 100644 --- a/internal/sqltests/sql_test.go +++ b/internal/sqltests/sql_test.go @@ -133,10 +133,10 @@ func TestSQL(t *testing.T) { } } else { res, err := db.Query(test.Expr) - assert.NoError(t, err) + require.NoError(t, err, "Source: %s:%d", absPath, test.Line) defer res.Close() - testutil.RequireStreamEqf(t, test.Result, res, test.Sorted, "Source: %s:%d", absPath, test.Line) + testutil.RequireStreamEqf(t, test.Result, res, "Source: %s:%d", absPath, test.Line) } }) } @@ -157,7 +157,6 @@ type test struct { Result string ErrorMatch string Fails bool - Sorted bool Line int Only bool } @@ -243,9 +242,6 @@ func parse(r io.Reader, filename string) *testSuite { } case strings.HasPrefix(line, "/* result:"), strings.HasPrefix(line, "/*result:"): readingResult = true - case strings.HasPrefix(line, "/* sorted-result:"): - readingResult = true - curTest.Sorted = true case strings.HasPrefix(line, "-- error:"): error := strings.TrimPrefix(line, "-- error:") error = strings.TrimSpace(error) diff --git a/internal/stream/index/delete.go b/internal/stream/index/delete.go index 747965bce..81e02bd79 100644 --- a/internal/stream/index/delete.go +++ b/internal/stream/index/delete.go @@ -41,7 +41,7 @@ func (op *DeleteOperator) Iterate(in *environment.Environment, fn func(out *envi } return op.Prev.Iterate(in, func(out *environment.Environment) error { - row, ok := out.GetRow() + row, ok := out.GetDatabaseRow() if !ok { return errors.New("missing row") } @@ -51,9 +51,9 @@ func (op *DeleteOperator) Iterate(in *environment.Environment, fn func(out *envi return err } - vs := make([]types.Value, 0, len(info.Paths)) - for _, path := range info.Paths { - v, err := path.GetValueFromObject(old.Object()) + vs := make([]types.Value, 0, len(info.Columns)) + for _, column := range info.Columns { + v, err := old.Get(column) if err != nil { v = types.NewNullValue() } diff --git a/internal/stream/index/insert.go b/internal/stream/index/insert.go index 66c5283e2..0804e4b52 100644 --- a/internal/stream/index/insert.go +++ b/internal/stream/index/insert.go @@ -9,7 +9,7 @@ import ( "github.com/cockroachdb/errors" ) -// InsertOperator reads the input stream and indexes each object. +// InsertOperator reads the input stream and indexes each row. type InsertOperator struct { stream.BaseOperator @@ -41,14 +41,14 @@ func (op *InsertOperator) Iterate(in *environment.Environment, fn func(out *envi } return op.Prev.Iterate(in, func(out *environment.Environment) error { - r, ok := out.GetRow() + r, ok := out.GetDatabaseRow() if !ok { return errors.New("missing row") } - vs := make([]types.Value, 0, len(info.Paths)) - for _, path := range info.Paths { - v, err := path.GetValueFromObject(r.Object()) + vs := make([]types.Value, 0, len(info.Columns)) + for _, column := range info.Columns { + v, err := r.Get(column) if err != nil { v = types.NewNullValue() } diff --git a/internal/stream/index/scan.go b/internal/stream/index/scan.go index 5767eeb29..87a0ff5e3 100644 --- a/internal/stream/index/scan.go +++ b/internal/stream/index/scan.go @@ -76,7 +76,7 @@ func (it *ScanOperator) Iterate(in *environment.Environment, fn func(out *enviro } for _, rng := range ranges { - r, err := rng.ToTreeRange(&table.Info.FieldConstraints, info.Paths) + r, err := rng.ToTreeRange(&table.Info.ColumnConstraints, info.Columns) if err != nil { return err } diff --git a/internal/stream/index/scan_test.go b/internal/stream/index/scan_test.go index 4843301a0..788dcfb3f 100644 --- a/internal/stream/index/scan_test.go +++ b/internal/stream/index/scan_test.go @@ -5,7 +5,7 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/index" "github.com/chaisql/chai/internal/testutil" @@ -27,31 +27,31 @@ func TestIndexScan(t *testing.T) { t.Run("String", func(t *testing.T) { t.Run("idx_test_a", func(t *testing.T) { - require.Equal(t, `index.Scan("idx_test_a", [{"min": [1], "max": [2]}])`, index.Scan("idx_test_a", stream.Range{ - Min: testutil.ExprList(t, `[1]`), Max: testutil.ExprList(t, `[2]`), + require.Equal(t, `index.Scan("idx_test_a", [{"min": (1), "max": (2)}])`, index.Scan("idx_test_a", stream.Range{ + Min: testutil.ExprList(t, `(1)`), Max: testutil.ExprList(t, `(2)`), }).String()) op := index.Scan("idx_test_a", stream.Range{ - Min: testutil.ExprList(t, `[1]`), Max: testutil.ExprList(t, `[2]`), + Min: testutil.ExprList(t, `(1)`), Max: testutil.ExprList(t, `(2)`), }) op.Reverse = true - require.Equal(t, `index.ScanReverse("idx_test_a", [{"min": [1], "max": [2]}])`, op.String()) + require.Equal(t, `index.ScanReverse("idx_test_a", [{"min": (1), "max": (2)}])`, op.String()) }) t.Run("idx_test_a_b", func(t *testing.T) { - require.Equal(t, `index.Scan("idx_test_a_b", [{"min": [1, 1], "max": [2, 2]}])`, index.Scan("idx_test_a_b", stream.Range{ - Min: testutil.ExprList(t, `[1, 1]`), - Max: testutil.ExprList(t, `[2, 2]`), + require.Equal(t, `index.Scan("idx_test_a_b", [{"min": (1, 1), "max": (2, 2)}])`, index.Scan("idx_test_a_b", stream.Range{ + Min: testutil.ExprList(t, `(1, 1)`), + Max: testutil.ExprList(t, `(2, 2)`), }).String()) op := index.Scan("idx_test_a_b", stream.Range{ - Min: testutil.ExprList(t, `[1, 1]`), - Max: testutil.ExprList(t, `[2, 2]`), + Min: testutil.ExprList(t, `(1, 1)`), + Max: testutil.ExprList(t, `(2, 2)`), }) op.Reverse = true - require.Equal(t, `index.ScanReverse("idx_test_a_b", [{"min": [1, 1], "max": [2, 2]}])`, op.String()) + require.Equal(t, `index.ScanReverse("idx_test_a_b", [{"min": (1, 1), "max": (2, 2)}])`, op.String()) }) }) } @@ -60,7 +60,7 @@ func testIndexScan(t *testing.T, getOp func(db *database.Database, tx *database. tests := []struct { name string indexOn string - docsInTable, expected testutil.Objs + docsInTable, expected testutil.Rows ranges stream.Ranges reverse bool fails bool @@ -68,338 +68,338 @@ func testIndexScan(t *testing.T, getOp func(db *database.Database, tx *database. {name: "empty", indexOn: "a"}, { "no range", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": null, "c": null}`, `{"a": 2, "b": null, "c": null}`), nil, false, false, }, { "no range", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 3}`), - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 3}`), + testutil.MakeRows(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 3}`), + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": null}`, `{"a": 2, "b": 3, "c": null}`), nil, false, false, }, { "max:2", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": null, "c": null}`, `{"a": 2, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[2]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Max: testutil.ExprList(t, `(2)`), Columns: []string{"a"}}, }, false, false, }, { "max:1.2", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + nil, stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[1.2]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Max: testutil.ExprList(t, `(1.2)`), Columns: []string{"a"}}, }, false, false, }, { - "max:[2, 2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`), + "max:(2, 2)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": 2, "c": null}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[2, 2]`), Paths: testutil.ParseObjectPaths(t, "a", "b")}, + stream.Range{Max: testutil.ExprList(t, `(2, 2)`), Columns: []string{"a", "b"}}, }, false, false, }, { - "max:[2, 2.2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`), + "max:(2, 2.2)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), + nil, stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[2, 2.2]`), Paths: testutil.ParseObjectPaths(t, "a", "b")}, + stream.Range{Max: testutil.ExprList(t, `(2, 2.2)`), Columns: []string{"a", "b"}}, }, false, false, }, { "max:1", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[1]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Max: testutil.ExprList(t, `(1)`), Columns: []string{"a"}}, }, false, false, }, { - "max:[1, 2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": 2}`), + "max:(1, 2)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": null}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[1, 2]`), Paths: testutil.ParseObjectPaths(t, "a", "b")}, + stream.Range{Max: testutil.ExprList(t, `(1, 2)`), Columns: []string{"a", "b"}}, }, false, false, }, { - "max:[1.1, 2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t), + "max:(1.1, 2)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": null}`, `{"a": 2, "b": 2, "c": null}`), + nil, stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[1.1, 2]`), Paths: testutil.ParseObjectPaths(t, "a", "b")}, + stream.Range{Max: testutil.ExprList(t, `(1.1, 2)`), Columns: []string{"a", "b"}}, }, false, false, }, { "min", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": null, "c": null}`, `{"a": 2, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a"}}, }, false, false, }, { - "min:[1],exclusive", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`), + "min:(1),exclusive", "a", + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}, Exclusive: true}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a"}, Exclusive: true}, }, false, false, }, { - "min:[1],exclusive", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`), + "min:(1),exclusive", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": 2, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: testutil.ParseObjectPaths(t, "a", "b"), Exclusive: true}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a", "b"}, Exclusive: true}, }, false, false, }, { - "min:[2, 1]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`), + "min:(2, 1)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": 2, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[2, 1]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(2, 1)`), + Columns: []string{"a", "b"}, }, }, false, false, }, { - "min:[2, 1.5]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`), + "min:(2, 1.5)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": null}`, `{"a": 2, "b": 2, "c": null}`), + nil, stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[2, 1.5]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(2, 1.5)`), + Columns: []string{"a", "b"}, }, }, false, false, }, { "min/max", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": null, "c": null}`, `{"a": 2, "b": null, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1]`), - Max: testutil.ExprList(t, `[2]`), - Paths: []object.Path{testutil.ParseObjectPath(t, "a")}, + Min: testutil.ExprList(t, `(1)`), + Max: testutil.ExprList(t, `(2)`), + Columns: []string{"a"}, }, }, false, false, }, { - "min:[1, 1], max:[2,2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), + "min:(1, 1), max:[2,2]", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 2}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": null}`, `{"a": 2, "b": 2, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1, 1]`), - Max: testutil.ExprList(t, `[2, 2]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(1, 1)`), + Max: testutil.ExprList(t, `(2, 2)`), + Columns: []string{"a", "b"}, }, }, false, false, }, { - "min:[1, 1], max:[2,2] bis", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 3}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": 3}`, `{"a": 2, "b": 2}`), // [1, 3] < [2, 2] + "min:(1, 1), max:[2,2] bis", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 3}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 3, "c": null}`, `{"a": 2, "b": 2, "c": null}`), // [1, 3] < (2, 2) stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1, 1]`), - Max: testutil.ExprList(t, `[2, 2]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(1, 1)`), + Max: testutil.ExprList(t, `(2, 2)`), + Columns: []string{"a", "b"}, }, }, false, false, }, { "reverse/no range", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": null, "c": null}`, `{"a": 1, "b": null, "c": null}`), nil, true, false, }, { "reverse/max", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": null, "c": null}`, `{"a": 1, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[2]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Max: testutil.ExprList(t, `(2)`), Columns: []string{"a"}}, }, true, false, }, { "reverse/max", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": 2, "c": null}`), stream.Ranges{ stream.Range{ - Max: testutil.ExprList(t, `[2, 2]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Max: testutil.ExprList(t, `(2, 2)`), + Columns: []string{"a", "b"}, }, }, true, false, }, { "reverse/min", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": null, "c": null}`, `{"a": 1, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a"}}, }, true, false, }, { "reverse/min neg", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": -2}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": -2}`), + testutil.MakeRows(t, `{"a": 1, "b": null, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: []object.Path{testutil.ParseObjectPath(t, "a")}}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a"}}, }, true, false, }, { "reverse/min", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": 1}`), + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 1, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1, 1]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(1, 1)`), + Columns: []string{"a", "b"}, }, }, true, false, }, { "reverse/min/max", "a", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": null, "c": null}`, `{"a": 1, "b": null, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1]`), - Max: testutil.ExprList(t, `[2]`), - Paths: []object.Path{testutil.ParseObjectPath(t, "a")}, + Min: testutil.ExprList(t, `(1)`), + Max: testutil.ExprList(t, `(2)`), + Columns: []string{"a"}, }, }, true, false, }, { "reverse/min/max", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 2, "b": 2}`, `{"a": 1, "b": 1}`), + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 2, "b": 2, "c": null}`, `{"a": 1, "b": 1, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1, 1]`), - Max: testutil.ExprList(t, `[2, 2]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(1, 1)`), + Max: testutil.ExprList(t, `(2, 2)`), + Columns: []string{"a", "b"}, }, }, true, false, }, { - "max:[1]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`, `{"a": 1, "b": 9223372036854775807}`), - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 1, "b": 9223372036854775807}`), + "max:(1)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`, `{"a": 1, "b": 9223372036854775807}`), + testutil.MakeRows(t, `{"a": 1, "b": 1, "c": null}`, `{"a": 1, "b": 9223372036854775807, "c": null}`), stream.Ranges{ stream.Range{ - Max: testutil.ExprList(t, `[1]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Max: testutil.ExprList(t, `(1)`), + Columns: []string{"a", "b"}, }, }, false, false, }, { - "reverse max:[1]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`, `{"a": 1, "b": 9223372036854775807}`), - testutil.MakeObjects(t, `{"a": 1, "b": 9223372036854775807}`, `{"a": 1, "b": 1}`), + "reverse max:(1)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`, `{"a": 1, "b": 9223372036854775807}`), + testutil.MakeRows(t, `{"a": 1, "b": 9223372036854775807, "c": null}`, `{"a": 1, "b": 1, "c": null}`), stream.Ranges{ stream.Range{ - Max: testutil.ExprList(t, `[1]`), + Max: testutil.ExprList(t, `(1)`), Exclusive: false, Exact: false, - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Columns: []string{"a", "b"}, }, }, true, false, }, { - "max:[1, 2]", "a, b, c", - testutil.MakeObjects(t, `{"a": 1, "b": 2, "c": 1}`, `{"a": 2, "b": 2, "c": 2}`, `{"a": 1, "b": 2, "c": 9223372036854775807}`), - testutil.MakeObjects(t, `{"a": 1, "b": 2, "c": 1}`, `{"a": 1, "b": 2, "c": 9223372036854775807}`), + "max:(1, 2)", "a, b, c", + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": 1}`, `{"a": 2, "b": 2, "c": 2}`, `{"a": 1, "b": 2, "c": 9223372036854775807}`), + testutil.MakeRows(t, `{"a": 1, "b": 2, "c": 1}`, `{"a": 1, "b": 2, "c": 9223372036854775807}`), stream.Ranges{ stream.Range{ - Max: testutil.ExprList(t, `[1, 2]`), Paths: testutil.ParseObjectPaths(t, "a", "b", "c"), + Max: testutil.ExprList(t, `(1, 2)`), Columns: []string{"a", "b", "c"}, }, }, false, false, }, { - "min:[1]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 1, "b": 1}`), - testutil.MakeObjects(t, `{"a": 1, "b": -2}`, `{"a": 1, "b": 1}`), + "min:(1)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 1, "b": 1}`), + testutil.MakeRows(t, `{"a": 1, "b": -2, "c": null}`, `{"a": 1, "b": 1, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: testutil.ParseObjectPaths(t, "a", "b")}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a", "b"}}, }, false, false, }, { - "min:[1]", "a, b, c", - testutil.MakeObjects(t, `{"a": 1, "b": -2, "c": 0}`, `{"a": -2, "b": 2, "c": 1}`, `{"a": 1, "b": 1, "c": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": -2, "c": 0}`, `{"a": 1, "b": 1, "c": 2}`), + "min:(1)", "a, b, c", + testutil.MakeRows(t, `{"a": 1, "b": -2, "c": 0}`, `{"a": -2, "b": 2, "c": 1}`, `{"a": 1, "b": 1, "c": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": -2, "c": 0}`, `{"a": 1, "b": 1, "c": 2}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: testutil.ParseObjectPaths(t, "a", "b", "c")}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a", "b", "c"}}, }, false, false, }, { - "reverse min:[1]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 1, "b": 1}`), - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 1, "b": -2}`), + "reverse min:(1)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 1, "b": 1}`), + testutil.MakeRows(t, `{"a": 1, "b": 1, "c": null}`, `{"a": 1, "b": -2, "c": null}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Paths: testutil.ParseObjectPaths(t, "a", "b")}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Columns: []string{"a", "b"}}, }, true, false, }, { - "min:[1], max[2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 2, "b": 42}`, `{"a": 3, "b": -1}`), - testutil.MakeObjects(t, `{"a": 1, "b": -2}`, `{"a": 2, "b": 42}`), + "min:(1), max(2)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 2, "b": 42}`, `{"a": 3, "b": -1}`), + testutil.MakeRows(t, `{"a": 1, "b": -2, "c": null}`, `{"a": 2, "b": 42, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1]`), - Max: testutil.ExprList(t, `[2]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(1)`), + Max: testutil.ExprList(t, `(2)`), + Columns: []string{"a", "b"}, }, }, false, false, }, { - "reverse min:[1], max[2]", "a, b", - testutil.MakeObjects(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 2, "b": 42}`, `{"a": 3, "b": -1}`), - testutil.MakeObjects(t, `{"a": 2, "b": 42}`, `{"a": 1, "b": -2}`), + "reverse min:(1), max(2)", "a, b", + testutil.MakeRows(t, `{"a": 1, "b": -2}`, `{"a": -2, "b": 2}`, `{"a": 2, "b": 42}`, `{"a": 3, "b": -1}`), + testutil.MakeRows(t, `{"a": 2, "b": 42, "c": null}`, `{"a": 1, "b": -2, "c": null}`), stream.Ranges{ stream.Range{ - Min: testutil.ExprList(t, `[1]`), - Max: testutil.ExprList(t, `[2]`), - Paths: testutil.ParseObjectPaths(t, "a", "b"), + Min: testutil.ExprList(t, `(1)`), + Max: testutil.ExprList(t, `(2)`), + Columns: []string{"a", "b"}, }, }, true, false, @@ -411,10 +411,13 @@ func testIndexScan(t *testing.T, getOp func(db *database.Database, tx *database. db, tx, cleanup := testutil.NewTestTx(t) defer cleanup() - testutil.MustExec(t, db, tx, "CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER);") + testutil.MustExec(t, db, tx, "CREATE TABLE test (a BIGINT, b BIGINT, c BIGINT);") - for _, doc := range test.docsInTable { - testutil.MustExec(t, db, tx, "INSERT INTO test VALUES ?", environment.Param{Value: doc}) + for _, r := range test.docsInTable { + var a, b, c *int64 + err := row.Scan(r, &a, &b, &c) + require.NoError(t, err) + testutil.MustExec(t, db, tx, "INSERT INTO test VALUES (?, ?, ?)", environment.Param{Value: a}, environment.Param{Value: b}, environment.Param{Value: c}) } op := getOp(db, tx, "idx_test_a", test.indexOn, test.reverse, test.ranges...) @@ -424,19 +427,19 @@ func testIndexScan(t *testing.T, getOp func(db *database.Database, tx *database. env.Params = []environment.Param{{Name: "foo", Value: 1}} var i int - var got testutil.Objs + var got testutil.Rows err := op.Iterate(&env, func(env *environment.Environment) error { r, ok := env.GetRow() require.True(t, ok) - var fb object.FieldBuffer + var fb row.ColumnBuffer - err := fb.Copy(r.Object()) + err := fb.Copy(r) assert.NoError(t, err) got = append(got, &fb) v, err := env.GetParamByName("foo") assert.NoError(t, err) - require.Equal(t, types.NewIntegerValue(1), v) + require.Equal(t, types.NewBigintValue(1), v) i++ return nil }) diff --git a/internal/stream/index/validate.go b/internal/stream/index/validate.go index d43e91dcf..5eafc8dbd 100644 --- a/internal/stream/index/validate.go +++ b/internal/stream/index/validate.go @@ -46,14 +46,14 @@ func (op *ValidateOperator) Iterate(in *environment.Environment, fn func(out *en return errors.New("missing row") } - vs := make([]types.Value, 0, len(info.Paths)) + vs := make([]types.Value, 0, len(info.Columns)) // if the indexes values contain NULL somewhere, // we don't check for unicity. // cf: https://sqlite.org/lang_createindex.html#unique_indexes var hasNull bool - for _, path := range info.Paths { - v, err := path.GetValueFromObject(r.Object()) + for _, column := range info.Columns { + v, err := r.Get(column) if err != nil { hasNull = true v = types.NewNullValue() @@ -72,7 +72,7 @@ func (op *ValidateOperator) Iterate(in *environment.Environment, fn func(out *en if duplicate { return &database.ConstraintViolationError{ Constraint: "UNIQUE", - Paths: info.Paths, + Columns: info.Columns, Key: key, } } diff --git a/internal/stream/on_conflict.go b/internal/stream/on_conflict.go index d2c4581a4..3aec43221 100644 --- a/internal/stream/on_conflict.go +++ b/internal/stream/on_conflict.go @@ -5,6 +5,7 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" + "github.com/cockroachdb/errors" ) // OnConflictOperator handles any conflicts that occur during the iteration. @@ -32,13 +33,13 @@ func (op *OnConflictOperator) Iterate(in *environment.Environment, fn func(out * } newEnv.SetOuter(out) - r, ok := out.GetRow() + r, ok := out.GetDatabaseRow() if !ok { - return fmt.Errorf("missing row") + return errors.New("missing row") } var br database.BasicRow - br.ResetWith(r.TableName(), cerr.Key, r.Object()) + br.ResetWith(r.TableName(), cerr.Key, r) newEnv.SetRow(&br) err = op.OnConflict.Iterate(&newEnv, func(out *environment.Environment) error { return nil }) diff --git a/internal/stream/operator_test.go b/internal/stream/operator_test.go index 203158df5..43f87221b 100644 --- a/internal/stream/operator_test.go +++ b/internal/stream/operator_test.go @@ -9,7 +9,7 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/path" @@ -17,7 +17,6 @@ import ( "github.com/chaisql/chai/internal/stream/table" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -25,31 +24,31 @@ import ( func TestFilter(t *testing.T) { tests := []struct { e expr.Expr - in []expr.Expr - out []types.Object + in []expr.Row + out []row.Row fails bool }{ { parser.MustParseExpr("1"), - testutil.ParseExprs(t, `{"a": 1}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRowExprs(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`), false, }, { parser.MustParseExpr("a > 1"), - testutil.ParseExprs(t, `{"a": 1}`), + testutil.MakeRowExprs(t, `{"a": 1}`), nil, false, }, { parser.MustParseExpr("a >= 1"), - testutil.ParseExprs(t, `{"a": 1}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRowExprs(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`), false, }, { parser.MustParseExpr("null"), - testutil.ParseExprs(t, `{"a": 1}`), + testutil.MakeRowExprs(t, `{"a": 1}`), nil, false, }, @@ -61,7 +60,7 @@ func TestFilter(t *testing.T) { i := 0 err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error { r, _ := out.GetRow() - require.Equal(t, test.out[i], r.Object()) + testutil.RequireRowEqual(t, test.out[i], r) i++ return nil }) @@ -93,10 +92,10 @@ func TestTake(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%d/%d", test.inNumber, test.n), func(t *testing.T) { - var ds []expr.Expr + var ds []expr.Row for i := 0; i < test.inNumber; i++ { - ds = append(ds, testutil.ParseExpr(t, `{"a": `+strconv.Itoa(i)+`}`)) + ds = append(ds, testutil.MakeRowExpr(t, `{"a": `+strconv.Itoa(i)+`}`)) } s := stream.New(rows.Emit(ds...)) @@ -138,10 +137,10 @@ func TestSkip(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%d/%d", test.inNumber, test.n), func(t *testing.T) { - var ds []expr.Expr + var ds []expr.Row for i := 0; i < test.inNumber; i++ { - ds = append(ds, testutil.ParseExpr(t, `{"a": `+strconv.Itoa(i)+`}`)) + ds = append(ds, testutil.MakeRowExpr(t, `{"a": `+strconv.Itoa(i)+`}`)) } s := stream.New(rows.Emit(ds...)) @@ -170,14 +169,14 @@ func TestTableInsert(t *testing.T) { tests := []struct { name string in stream.Operator - out []types.Object + out []row.Row rowid int fails bool }{ { "doc with no key", - rows.Emit(testutil.ParseExpr(t, `{"a": 10}`), testutil.ParseExpr(t, `{"a": 11}`)), - []types.Object{testutil.MakeObject(t, `{"a": 10}`), testutil.MakeObject(t, `{"a": 11}`)}, + rows.Emit(testutil.MakeRowExpr(t, `{"a": 10}`), testutil.MakeRowExpr(t, `{"a": 11}`)), + []row.Row{testutil.MakeRow(t, `{"a": 10}`), testutil.MakeRow(t, `{"a": 11}`)}, 1, false, }, @@ -200,7 +199,7 @@ func TestTableInsert(t *testing.T) { r, ok := out.GetRow() require.True(t, ok) - testutil.RequireObjEqual(t, test.out[i], r.Object()) + testutil.RequireRowEqual(t, test.out[i], r) i++ return nil }) @@ -219,17 +218,17 @@ func TestTableInsert(t *testing.T) { func TestTableReplace(t *testing.T) { tests := []struct { - name string - docsInTable testutil.Objs - op stream.Operator - expected testutil.Objs - fails bool + name string + a, b any + op stream.Operator + expected testutil.Rows + fails bool }{ { - "doc with key", - testutil.MakeObjects(t, `{"a": 1, "b": 1}`), - path.Set(testutil.ParseObjectPath(t, "b"), testutil.ParseExpr(t, "2")), - testutil.MakeObjects(t, `{"a": 1, "b": 2}`), + "row with key", + 1, 1, + path.Set("b", testutil.ParseExpr(t, "2")), + testutil.MakeRows(t, `{"a": 1, "b": 2}`), false, }, } @@ -241,9 +240,7 @@ func TestTableReplace(t *testing.T) { testutil.MustExec(t, db, tx, "CREATE TABLE test (a INTEGER PRIMARY KEY, b INTEGER)") - for _, doc := range test.docsInTable { - testutil.MustExec(t, db, tx, "INSERT INTO test VALUES ?", environment.Param{Value: doc}) - } + testutil.MustExec(t, db, tx, "INSERT INTO test VALUES (?, ?)", environment.Param{Value: test.a}, environment.Param{Value: test.b}) in := environment.Environment{} in.Tx = tx @@ -274,10 +271,10 @@ func TestTableReplace(t *testing.T) { res := testutil.MustQuery(t, db, tx, "SELECT * FROM test") defer res.Close() - var got []types.Object - err = res.Iterate(func(row database.Row) error { - var fb object.FieldBuffer - fb.Copy(row.Object()) + var got []row.Row + err = res.Iterate(func(r database.Row) error { + var fb row.ColumnBuffer + fb.Copy(r) got = append(got, &fb) return nil }) @@ -293,17 +290,17 @@ func TestTableReplace(t *testing.T) { func TestTableDelete(t *testing.T) { tests := []struct { - name string - docsInTable testutil.Objs - op stream.Operator - expected testutil.Objs - fails bool + name string + a []int + op stream.Operator + expected testutil.Rows + fails bool }{ { "doc with key", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`, `{"a": 3}`), + []int{1, 2, 3}, rows.Filter(testutil.ParseExpr(t, `a > 1`)), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`), false, }, } @@ -315,8 +312,8 @@ func TestTableDelete(t *testing.T) { testutil.MustExec(t, db, tx, "CREATE TABLE test (a INTEGER PRIMARY KEY)") - for _, doc := range test.docsInTable { - testutil.MustExec(t, db, tx, "INSERT INTO test VALUES ?", environment.Param{Value: doc}) + for _, a := range test.a { + testutil.MustExec(t, db, tx, "INSERT INTO test VALUES (?)", environment.Param{Value: a}) } var env environment.Environment @@ -336,10 +333,10 @@ func TestTableDelete(t *testing.T) { res := testutil.MustQuery(t, db, tx, "SELECT * FROM test") defer res.Close() - var got []types.Object - err = res.Iterate(func(row database.Row) error { - var fb object.FieldBuffer - fb.Copy(row.Object()) + var got []row.Row + err = res.Iterate(func(r database.Row) error { + var fb row.ColumnBuffer + fb.Copy(r) got = append(got, &fb) return nil }) @@ -352,53 +349,3 @@ func TestTableDelete(t *testing.T) { require.Equal(t, table.Delete("test").String(), "table.Delete('test')") }) } - -func TestPathsRename(t *testing.T) { - tests := []struct { - fieldNames []string - in []expr.Expr - out []types.Object - fails bool - }{ - { - []string{"c", "d"}, - testutil.ParseExprs(t, `{"a": 10, "b": 20}`), - testutil.MakeObjects(t, `{"c": 10, "d": 20}`), - false, - }, - { - []string{"c", "d", "e"}, - testutil.ParseExprs(t, `{"a": 10, "b": 20}`), - nil, - true, - }, - { - []string{"c"}, - testutil.ParseExprs(t, `{"a": 10, "b": 20}`), - nil, - true, - }, - } - - for _, test := range tests { - s := stream.New(rows.Emit(test.in...)).Pipe(path.PathsRename(test.fieldNames...)) - t.Run(s.String(), func(t *testing.T) { - i := 0 - err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error { - r, _ := out.GetRow() - require.Equal(t, test.out[i], r.Object()) - i++ - return nil - }) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } - - t.Run("String", func(t *testing.T) { - require.Equal(t, path.PathsRename("a", "b", "c").String(), "paths.Rename(a, b, c)") - }) -} diff --git a/internal/stream/path/rename.go b/internal/stream/path/rename.go index 11632a5b8..7b1b2a23e 100644 --- a/internal/stream/path/rename.go +++ b/internal/stream/path/rename.go @@ -6,7 +6,7 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -19,7 +19,7 @@ type RenameOperator struct { } // PathsRename iterates over all columns of the incoming row in order and renames them. -// If the number of columns of the incoming row doesn't match the number of expected fields, +// If the number of columns of the incoming row doesn't match the number of expected columns, // it returns an error. func PathsRename(columnNames ...string) *RenameOperator { return &RenameOperator{ @@ -29,12 +29,12 @@ func PathsRename(columnNames ...string) *RenameOperator { // Iterate implements the Operator interface. func (op *RenameOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { - var fb object.FieldBuffer + var cb row.ColumnBuffer var newEnv environment.Environment var br database.BasicRow return op.Prev.Iterate(in, func(out *environment.Environment) error { - fb.Reset() + cb.Reset() r, ok := out.GetRow() if !ok { @@ -45,14 +45,14 @@ func (op *RenameOperator) Iterate(in *environment.Environment, f func(out *envir err := r.Iterate(func(field string, value types.Value) error { // if there are too many columns in the incoming row if i >= len(op.ColumnNames) { - n, err := object.Length(r.Object()) + n, err := row.Length(r) if err != nil { return err } return fmt.Errorf("%d values for %d columns", n, len(op.ColumnNames)) } - fb.Add(op.ColumnNames[i], value) + cb.Add(op.ColumnNames[i], value) i++ return nil }) @@ -62,16 +62,20 @@ func (op *RenameOperator) Iterate(in *environment.Environment, f func(out *envir // if there are too few columns in the incoming row if i < len(op.ColumnNames) { - n, err := object.Length(r.Object()) + n, err := row.Length(r) if err != nil { return err } return fmt.Errorf("%d values for %d columns", n, len(op.ColumnNames)) } - br.ResetWith(r.TableName(), r.Key(), &fb) newEnv.SetOuter(out) - newEnv.SetRow(&br) + if dr, ok := r.(database.Row); ok { + br.ResetWith(dr.TableName(), dr.Key(), &cb) + newEnv.SetRow(&br) + } else { + newEnv.SetRow(&cb) + } return f(&newEnv) }) diff --git a/internal/stream/path/unset_test.go b/internal/stream/path/rename_test.go similarity index 51% rename from internal/stream/path/unset_test.go rename to internal/stream/path/rename_test.go index 71ab44d08..6cada001d 100644 --- a/internal/stream/path/unset_test.go +++ b/internal/stream/path/rename_test.go @@ -5,37 +5,49 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/path" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) -func TestUnset(t *testing.T) { +func TestPathsRename(t *testing.T) { tests := []struct { - path string - in []expr.Expr - out []types.Object - fails bool + fieldNames []string + in []expr.Row + out []row.Row + fails bool }{ { - "a", - testutil.ParseExprs(t, `{"a": 10, "b": 20}`), - testutil.MakeObjects(t, `{"b": 20}`), + []string{"c", "d"}, + testutil.MakeRowExprs(t, `{"a": 10, "b": 20}`), + testutil.MakeRows(t, `{"c": 10, "d": 20}`), false, }, + { + []string{"c", "d", "e"}, + testutil.MakeRowExprs(t, `{"a": 10, "b": 20}`), + nil, + true, + }, + { + []string{"c"}, + testutil.MakeRowExprs(t, `{"a": 10, "b": 20}`), + nil, + true, + }, } for _, test := range tests { - t.Run(test.path, func(t *testing.T) { - s := stream.New(rows.Emit(test.in...)).Pipe(path.Unset(test.path)) + s := stream.New(rows.Emit(test.in...)).Pipe(path.PathsRename(test.fieldNames...)) + t.Run(s.String(), func(t *testing.T) { i := 0 err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error { r, _ := out.GetRow() - require.Equal(t, test.out[i], r.Object()) + testutil.RequireRowEqual(t, test.out[i], r) i++ return nil }) @@ -48,6 +60,6 @@ func TestUnset(t *testing.T) { } t.Run("String", func(t *testing.T) { - require.Equal(t, path.Unset("a").String(), "paths.Unset(a)") + require.Equal(t, path.PathsRename("a", "b", "c").String(), "paths.Rename(a, b, c)") }) } diff --git a/internal/stream/path/set.go b/internal/stream/path/set.go index b2cdf208c..94edf6610 100644 --- a/internal/stream/path/set.go +++ b/internal/stream/path/set.go @@ -6,36 +6,36 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) -// A SetOperator sets the value of a column or nested field in the current row. +// A SetOperator sets the value of a column in the current row. type SetOperator struct { stream.BaseOperator - Path object.Path - Expr expr.Expr + Column string + Expr expr.Expr } -// Set returns a SetOperator that sets the value of a column or nested field in the current row. -func Set(path object.Path, e expr.Expr) *SetOperator { +// Set returns a SetOperator that sets the value of a column in the current row. +func Set(column string, e expr.Expr) *SetOperator { return &SetOperator{ - Path: path, - Expr: e, + Column: column, + Expr: e, } } // Iterate implements the Operator interface. func (op *SetOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { - var fb object.FieldBuffer + var cb row.ColumnBuffer var br database.BasicRow var newEnv environment.Environment return op.Prev.Iterate(in, func(out *environment.Environment) error { v, err := op.Expr.Eval(out) - if err != nil && !errors.Is(err, types.ErrFieldNotFound) { + if err != nil && !errors.Is(err, types.ErrColumnNotFound) { return err } @@ -44,29 +44,32 @@ func (op *SetOperator) Iterate(in *environment.Environment, f func(out *environm return errors.New("missing row") } - fb.Reset() - err = fb.Copy(r.Object()) + cb.Reset() + err = cb.Copy(r) if err != nil { return err } - err = fb.Set(op.Path, v) - if errors.Is(err, types.ErrFieldNotFound) { + err = cb.Set(op.Column, v) + if errors.Is(err, types.ErrColumnNotFound) { return nil } if err != nil { return err } - br.ResetWith(r.TableName(), r.Key(), &fb) - newEnv.SetOuter(out) - newEnv.SetRow(&br) + if dr, ok := r.(database.Row); ok { + br.ResetWith(dr.TableName(), dr.Key(), &cb) + newEnv.SetRow(&br) + } else { + newEnv.SetRow(&cb) + } return f(&newEnv) }) } func (op *SetOperator) String() string { - return fmt.Sprintf("paths.Set(%s, %s)", op.Path, op.Expr) + return fmt.Sprintf("paths.Set(%s, %s)", op.Column, op.Expr) } diff --git a/internal/stream/path/set_test.go b/internal/stream/path/set_test.go index 6d1b81d6c..c26545015 100644 --- a/internal/stream/path/set_test.go +++ b/internal/stream/path/set_test.go @@ -5,50 +5,47 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/path" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) func TestSet(t *testing.T) { tests := []struct { - path string - e expr.Expr - in []expr.Expr - out []types.Object - fails bool + column string + e expr.Expr + in []expr.Row + out []row.Row + fails bool }{ { - "a[0].b", + "a", parser.MustParseExpr(`10`), - testutil.ParseExprs(t, `{"a": [{}]}`), - testutil.MakeObjects(t, `{"a": [{"b": 10}]}`), + testutil.MakeRowExprs(t, `{"a": true}`), + testutil.MakeRows(t, `{"a": 10}`), false, }, { - "a[2]", + "b", parser.MustParseExpr(`10`), - testutil.ParseExprs(t, `{"a": [1]}`, `{"a": [1, 2, 3]}`), - testutil.MakeObjects(t, `{"a": [1, 2, 10]}`), + testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 10}`, `{"a": 2, "b": 10}`), false, }, } for _, test := range tests { - t.Run(test.path, func(t *testing.T) { - p, err := parser.ParsePath(test.path) - assert.NoError(t, err) - s := stream.New(rows.Emit(test.in...)).Pipe(path.Set(p, test.e)) + t.Run(test.column, func(t *testing.T) { + s := stream.New(rows.Emit(test.in...)).Pipe(path.Set(test.column, test.e)) i := 0 - err = s.Iterate(new(environment.Environment), func(out *environment.Environment) error { + err := s.Iterate(new(environment.Environment), func(out *environment.Environment) error { r, _ := out.GetRow() - require.Equal(t, test.out[i], r.Object()) + testutil.RequireRowEqual(t, test.out[i], r) i++ return nil }) @@ -61,6 +58,6 @@ func TestSet(t *testing.T) { } t.Run("String", func(t *testing.T) { - require.Equal(t, path.Set(object.NewPath("a", "b"), parser.MustParseExpr("1")).String(), "paths.Set(a.b, 1)") + require.Equal(t, path.Set("a", parser.MustParseExpr("1")).String(), "paths.Set(a, 1)") }) } diff --git a/internal/stream/path/unset.go b/internal/stream/path/unset.go deleted file mode 100644 index 71e162adb..000000000 --- a/internal/stream/path/unset.go +++ /dev/null @@ -1,70 +0,0 @@ -package path - -import ( - "fmt" - - "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/stream" - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" -) - -// A UnsetOperator unsets the value of a column in the current row. -type UnsetOperator struct { - stream.BaseOperator - Column string -} - -// Unset returns a UnsetOperator that unsets the value of a column in the current row. -func Unset(field string) *UnsetOperator { - return &UnsetOperator{ - Column: field, - } -} - -// Iterate implements the Operator interface. -func (op *UnsetOperator) Iterate(in *environment.Environment, f func(out *environment.Environment) error) error { - var fb object.FieldBuffer - var br database.BasicRow - var newEnv environment.Environment - - return op.Prev.Iterate(in, func(out *environment.Environment) error { - fb.Reset() - - r, ok := out.GetRow() - if !ok { - return errors.New("missing row") - } - - _, err := r.Get(op.Column) - if err != nil { - if !errors.Is(err, types.ErrFieldNotFound) { - return err - } - - return f(out) - } - - err = fb.Copy(r.Object()) - if err != nil { - return err - } - - err = fb.Delete(object.NewPath(op.Column)) - if err != nil { - return err - } - - br.ResetWith(r.TableName(), r.Key(), &fb) - newEnv.SetOuter(out) - newEnv.SetRow(&br) - - return f(&newEnv) - }) -} - -func (op *UnsetOperator) String() string { - return fmt.Sprintf("paths.Unset(%s)", op.Column) -} diff --git a/internal/stream/range.go b/internal/stream/range.go index 0eafd86ac..1d68e4972 100644 --- a/internal/stream/range.go +++ b/internal/stream/range.go @@ -6,14 +6,13 @@ import ( "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" ) // Range represents a range to select values after or before // a given boundary. type Range struct { Min, Max expr.LiteralExprList - Paths []object.Path + Columns []string // Exclude Min and Max from the results. // By default, min and max are inclusive. // Exclusive and Exact cannot be set to true at the same time. @@ -29,23 +28,20 @@ func (r *Range) Eval(env *environment.Environment) (*database.Range, error) { Exclusive: r.Exclusive, Exact: r.Exact, } + var err error if len(r.Min) > 0 { - min, err := r.Min.Eval(env) + rng.Min, err = r.Min.EvalAll(env) if err != nil { return nil, err } - - rng.Min = min.V().(*object.ValueBuffer).Values } if len(r.Max) > 0 { - max, err := r.Max.Eval(env) + rng.Max, err = r.Max.EvalAll(env) if err != nil { return nil, err } - - rng.Max = max.V().(*object.ValueBuffer).Values } return &rng, nil diff --git a/internal/stream/rows/emit.go b/internal/stream/rows/emit.go index 65bc18bbe..cacb7c3d8 100644 --- a/internal/stream/rows/emit.go +++ b/internal/stream/rows/emit.go @@ -1,41 +1,36 @@ package rows import ( - "fmt" "strings" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/stream" - "github.com/chaisql/chai/internal/types" - "github.com/cockroachdb/errors" ) type EmitOperator struct { stream.BaseOperator - Exprs []expr.Expr + Rows []expr.Row } // Emit creates an operator that iterates over the given expressions. -// Each expression must evaluate to an object. -func Emit(exprs ...expr.Expr) *EmitOperator { - return &EmitOperator{Exprs: exprs} +// Each expression must evaluate to an row. +func Emit(rows ...expr.Row) *EmitOperator { + return &EmitOperator{Rows: rows} } func (op *EmitOperator) Iterate(in *environment.Environment, fn func(out *environment.Environment) error) error { var newEnv environment.Environment newEnv.SetOuter(in) - for _, e := range op.Exprs { - v, err := e.Eval(in) + for _, e := range op.Rows { + r, err := e.Eval(in) if err != nil { return err } - if v.Type() != types.TypeObject { - return errors.WithStack(stream.ErrInvalidResult) - } - newEnv.SetRowFromObject(types.AsObject(v)) + newEnv.SetRow(r) + err = fn(&newEnv) if err != nil { return err @@ -49,11 +44,11 @@ func (op *EmitOperator) String() string { var sb strings.Builder sb.WriteString("rows.Emit(") - for i, e := range op.Exprs { + for i, e := range op.Rows { if i > 0 { sb.WriteString(", ") } - sb.WriteString(e.(fmt.Stringer).String()) + sb.WriteString(e.String()) } sb.WriteByte(')') diff --git a/internal/stream/rows/emit_test.go b/internal/stream/rows/emit_test.go deleted file mode 100644 index f1e9aea7d..000000000 --- a/internal/stream/rows/emit_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package rows_test - -import ( - "testing" - - "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/sql/parser" - "github.com/chaisql/chai/internal/stream" - "github.com/chaisql/chai/internal/stream/rows" - "github.com/chaisql/chai/internal/testutil" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" - "github.com/stretchr/testify/require" -) - -func TestRowsEmit(t *testing.T) { - tests := []struct { - e expr.Expr - output types.Object - fails bool - }{ - {parser.MustParseExpr("3 + 4"), nil, true}, - {parser.MustParseExpr("{a: 3 + 4}"), testutil.MakeObject(t, `{"a": 7}`), false}, - } - - for _, test := range tests { - t.Run(test.e.String(), func(t *testing.T) { - s := stream.New(rows.Emit(test.e)) - - err := s.Iterate(new(environment.Environment), func(env *environment.Environment) error { - r, ok := env.GetRow() - require.True(t, ok) - require.Equal(t, r.Object(), test.output) - return nil - }) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } - - t.Run("String", func(t *testing.T) { - require.Equal(t, rows.Emit(parser.MustParseExpr("1 + 1"), parser.MustParseExpr("pk()")).String(), "rows.Emit(1 + 1, pk())") - }) -} diff --git a/internal/stream/rows/group_aggregate.go b/internal/stream/rows/group_aggregate.go index 1b7dc3b46..030d562eb 100644 --- a/internal/stream/rows/group_aggregate.go +++ b/internal/stream/rows/group_aggregate.go @@ -4,11 +4,13 @@ import ( "fmt" "strings" + "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/types" + "github.com/cockroachdb/errors" ) type GroupAggregateOperator struct { @@ -18,7 +20,7 @@ type GroupAggregateOperator struct { } // GroupAggregate consumes the incoming stream and outputs one value per group. -// It assumes the stream is sorted by groupBy. +// It assumes the stream is sorted by the groupBy expression. func GroupAggregate(groupBy expr.Expr, builders ...expr.AggregatorBuilder) *GroupAggregateOperator { return &GroupAggregateOperator{E: groupBy, Builders: builders} } @@ -42,13 +44,17 @@ func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(ou } group, err := op.E.Eval(out) + if errors.Is(err, types.ErrColumnNotFound) { + group = types.NewNullValue() + err = nil + } if err != nil { return err } // handle the first object of the stream if lastGroup == nil { - lastGroup, err = object.CloneValue(group) + lastGroup = group if err != nil { return err } @@ -74,10 +80,7 @@ func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(ou return err } - lastGroup, err = object.CloneValue(group) - if err != nil { - return err - } + lastGroup = group ga = newGroupAggregator(lastGroup, groupExpr, op.Builders) return ga.Aggregate(out) @@ -86,7 +89,7 @@ func (op *GroupAggregateOperator) Iterate(in *environment.Environment, f func(ou return err } - // if s is empty, we create a default group so that aggregators will + // if ga is empty, we create a default group so that aggregators will // return their default initial value. // Ex: For `SELECT COUNT(*) FROM foo`, if `foo` is empty // we want the following result: @@ -155,11 +158,11 @@ func (g *groupAggregator) Aggregate(env *environment.Environment) error { } func (g *groupAggregator) Flush(env *environment.Environment) (*environment.Environment, error) { - fb := object.NewFieldBuffer() + cb := row.NewColumnBuffer() // add the current group to the object if g.groupExpr != "" { - fb.Add(g.groupExpr, g.group) + cb.Add(g.groupExpr, g.group) } for _, agg := range g.aggregators { @@ -167,12 +170,14 @@ func (g *groupAggregator) Flush(env *environment.Environment) (*environment.Envi if err != nil { return nil, err } - fb.Add(agg.String(), v) + cb.Add(agg.String(), v) } var newEnv environment.Environment + var br database.BasicRow + br.ResetWith("", nil, cb) newEnv.SetOuter(env) - newEnv.SetRowFromObject(fb) + newEnv.SetRow(&br) return &newEnv, nil } diff --git a/internal/stream/rows/group_aggregate_test.go b/internal/stream/rows/group_aggregate_test.go index 32960dd16..21420496b 100644 --- a/internal/stream/rows/group_aggregate_test.go +++ b/internal/stream/rows/group_aggregate_test.go @@ -1,13 +1,12 @@ package rows_test import ( - "strconv" "testing" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" @@ -23,32 +22,32 @@ func TestAggregate(t *testing.T) { name string groupBy expr.Expr builders []expr.AggregatorBuilder - in []types.Object - want []types.Object + in []int + want []row.Row fails bool }{ { "fake count", nil, makeAggregatorBuilders("agg"), - []types.Object{testutil.MakeObject(t, `{"a": 10}`)}, - []types.Object{testutil.MakeObject(t, `{"agg": 1}`)}, + []int{10}, + []row.Row{testutil.MakeRow(t, `{"agg": 1}`)}, false, }, { "count", nil, []expr.AggregatorBuilder{functions.NewCount(expr.Wildcard{})}, - []types.Object{testutil.MakeObject(t, `{"a": 10}`)}, - []types.Object{testutil.MakeObject(t, `{"COUNT(*)": 1}`)}, + []int{10}, + []row.Row{testutil.MakeRow(t, `{"COUNT(*)": 1}`)}, false, }, { "count/groupBy", parser.MustParseExpr("a % 2"), []expr.AggregatorBuilder{&functions.Count{Expr: parser.MustParseExpr("a")}, &functions.Avg{Expr: parser.MustParseExpr("a")}}, - generateSeqDocs(t, 10), - []types.Object{testutil.MakeObject(t, `{"a % 2": 0, "COUNT(a)": 5, "AVG(a)": 4.0}`), testutil.MakeObject(t, `{"a % 2": 1, "COUNT(a)": 5, "AVG(a)": 5.0}`)}, + generateSeq(t, 10), + []row.Row{testutil.MakeRow(t, `{"a % 2": 0, "COUNT(a)": 5, "AVG(a)": 4.0}`), testutil.MakeRow(t, `{"a % 2": 1, "COUNT(a)": 5, "AVG(a)": 5.0}`)}, false, }, { @@ -56,15 +55,15 @@ func TestAggregate(t *testing.T) { nil, []expr.AggregatorBuilder{&functions.Count{Expr: parser.MustParseExpr("a")}, &functions.Avg{Expr: parser.MustParseExpr("a")}}, nil, - []types.Object{testutil.MakeObject(t, `{"COUNT(a)": 0, "AVG(a)": 0.0}`)}, + []row.Row{testutil.MakeRow(t, `{"COUNT(a)": 0, "AVG(a)": 0.0}`)}, false, }, { "no aggregator", parser.MustParseExpr("a % 2"), nil, - generateSeqDocs(t, 4), - testutil.MakeObjects(t, `{"a % 2": 0}`, `{"a % 2": 1}`), + generateSeq(t, 4), + testutil.MakeRows(t, `{"a % 2": 0}`, `{"a % 2": 1}`), false, }, } @@ -76,8 +75,8 @@ func TestAggregate(t *testing.T) { testutil.MustExec(t, db, tx, "CREATE TABLE test(a int)") - for _, doc := range test.in { - testutil.MustExec(t, db, tx, "INSERT INTO test VALUES ?", environment.Param{Value: doc}) + for _, val := range test.in { + testutil.MustExec(t, db, tx, "INSERT INTO test VALUES (?)", environment.Param{Value: val}) } var env environment.Environment @@ -91,12 +90,12 @@ func TestAggregate(t *testing.T) { s = s.Pipe(rows.GroupAggregate(test.groupBy, test.builders...)) - var got []types.Object + var got []row.Row err := s.Iterate(&env, func(env *environment.Environment) error { r, ok := env.GetRow() require.True(t, ok) - var fb object.FieldBuffer - fb.Copy(r.Object()) + var fb row.ColumnBuffer + fb.Copy(r) got = append(got, &fb) return nil }) @@ -105,7 +104,7 @@ func TestAggregate(t *testing.T) { } else { assert.NoError(t, err) for i, doc := range test.want { - testutil.RequireObjEqual(t, doc, got[i]) + testutil.RequireRowEqual(t, doc, got[i]) } require.Equal(t, len(test.want), len(got)) @@ -126,7 +125,7 @@ type fakeAggregator struct { } func (f *fakeAggregator) Eval(env *environment.Environment) (types.Value, error) { - return types.NewIntegerValue(f.count), nil + return types.NewBigintValue(f.count), nil } func (f *fakeAggregator) Aggregate(env *environment.Environment) error { @@ -168,12 +167,12 @@ func makeAggregatorBuilders(names ...string) []expr.AggregatorBuilder { return aggs } -func generateSeqDocs(t testing.TB, max int) (docs []types.Object) { +func generateSeq(t testing.TB, max int) (vals []int) { t.Helper() for i := 0; i < max; i++ { - docs = append(docs, testutil.MakeObject(t, `{"a": `+strconv.Itoa(i)+`}`)) + vals = append(vals, i) } - return docs + return vals } diff --git a/internal/stream/rows/project.go b/internal/stream/rows/project.go index a5464e1cb..b1744db91 100644 --- a/internal/stream/rows/project.go +++ b/internal/stream/rows/project.go @@ -6,7 +6,7 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" @@ -38,10 +38,10 @@ func (op *ProjectOperator) Iterate(in *environment.Environment, f func(out *envi } return op.Prev.Iterate(in, func(env *environment.Environment) error { - row, ok := env.GetRow() + r, ok := env.GetDatabaseRow() if ok { - mask.tableName = row.TableName() - mask.key = row.Key() + mask.tableName = r.TableName() + mask.key = r.Key() } mask.Env = env mask.Exprs = op.Exprs @@ -76,19 +76,11 @@ func (m *RowMask) Key() *tree.Key { return m.key } -func (m *RowMask) Object() types.Object { - return m -} - func (m *RowMask) TableName() string { return m.tableName } func (m *RowMask) Get(column string) (v types.Value, err error) { - return m.GetByField(column) -} - -func (m *RowMask) GetByField(field string) (v types.Value, err error) { for _, e := range m.Exprs { if _, ok := e.(expr.Wildcard); ok { r, ok := m.Env.GetRow() @@ -96,23 +88,27 @@ func (m *RowMask) GetByField(field string) (v types.Value, err error) { continue } - v, err = r.Get(field) - if errors.Is(err, types.ErrFieldNotFound) { + v, err = r.Get(column) + if errors.Is(err, types.ErrColumnNotFound) { continue } return } - if ne, ok := e.(*expr.NamedExpr); ok && ne.Name() == field { + if ne, ok := e.(*expr.NamedExpr); ok && ne.Name() == column { return e.Eval(m.Env) } - if e.(fmt.Stringer).String() == field { + if col, ok := e.(expr.Column); ok && col.String() == column { + return e.Eval(m.Env) + } + + if e.(fmt.Stringer).String() == column { return e.Eval(m.Env) } } - err = types.ErrFieldNotFound + err = errors.Wrapf(types.ErrColumnNotFound, "%s not found", column) return } @@ -126,7 +122,7 @@ func (m *RowMask) Iterate(fn func(field string, value types.Value) error) error err := r.Iterate(fn) if err != nil { - return err + return errors.Wrap(err, "wildcard iteration") } continue @@ -153,11 +149,6 @@ func (m *RowMask) Iterate(fn func(field string, value types.Value) error) error return nil } -func (m *RowMask) String() string { - b, _ := types.NewObjectValue(m).MarshalText() - return string(b) -} - -func (d *RowMask) MarshalJSON() ([]byte, error) { - return object.MarshalJSON(d) +func (m *RowMask) MarshalJSON() ([]byte, error) { + return row.MarshalJSON(m) } diff --git a/internal/stream/rows/project_test.go b/internal/stream/rows/project_test.go index a8adcc6c4..5726189e3 100644 --- a/internal/stream/rows/project_test.go +++ b/internal/stream/rows/project_test.go @@ -6,7 +6,7 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/testutil" @@ -19,35 +19,35 @@ func TestProject(t *testing.T) { tests := []struct { name string exprs []expr.Expr - in types.Object + in row.Row out string fails bool }{ { "Constant", []expr.Expr{parser.MustParseExpr("10")}, - testutil.MakeObject(t, `{"a":1,"b":[true]}`), + testutil.MakeRow(t, `{"a":1,"b":true}`), `{"10":10}`, false, }, { "Wildcard", []expr.Expr{expr.Wildcard{}}, - testutil.MakeObject(t, `{"a":1,"b":[true]}`), - `{"a":1,"b":[true]}`, + testutil.MakeRow(t, `{"a":1,"b":true}`), + `{"a":1,"b":true}`, false, }, { "Multiple", []expr.Expr{expr.Wildcard{}, expr.Wildcard{}, parser.MustParseExpr("10")}, - testutil.MakeObject(t, `{"a":1,"b":[true]}`), - `{"a":1,"b":[true],"a":1,"b":[true],"10":10}`, + testutil.MakeRow(t, `{"a":1,"b":true}`), + `{"a":1,"b":true,"a":1,"b":true,"10":10}`, false, }, { "Named", []expr.Expr{&expr.NamedExpr{Expr: parser.MustParseExpr("10"), ExprName: "foo"}}, - testutil.MakeObject(t, `{"a":1,"b":[true]}`), + testutil.MakeRow(t, `{"a":1,"b":true}`), `{"foo":10}`, false, }, @@ -56,13 +56,13 @@ func TestProject(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var inEnv environment.Environment - inEnv.SetRowFromObject(test.in) + inEnv.SetRow(test.in) err := rows.Project(test.exprs...).Iterate(&inEnv, func(out *environment.Environment) error { require.Equal(t, &inEnv, out.GetOuter()) r, ok := out.GetRow() require.True(t, ok) - tt, err := json.Marshal(types.NewObjectValue(r.Object())) + tt, err := json.Marshal(r) require.NoError(t, err) require.JSONEq(t, test.out, string(tt)) @@ -96,7 +96,7 @@ func TestProject(t *testing.T) { rows.Project(parser.MustParseExpr("1 + 1")).Iterate(new(environment.Environment), func(out *environment.Environment) error { r, ok := out.GetRow() require.True(t, ok) - enc, err := object.MarshalJSON(r.Object()) + enc, err := row.MarshalJSON(r) assert.NoError(t, err) require.JSONEq(t, `{"1 + 1": 2}`, string(enc)) return nil diff --git a/internal/stream/rows/skip.go b/internal/stream/rows/skip.go index 8a362d6d8..95202f7c0 100644 --- a/internal/stream/rows/skip.go +++ b/internal/stream/rows/skip.go @@ -5,7 +5,6 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/types" ) @@ -32,7 +31,7 @@ func (op *SkipOperator) Iterate(in *environment.Environment, f func(out *environ return fmt.Errorf("offset expression must evaluate to a number, got %q", v.Type()) } - v, err = object.CastAsInteger(v) + v, err = v.CastAs(types.TypeBigint) if err != nil { return err } diff --git a/internal/stream/rows/take.go b/internal/stream/rows/take.go index b976b5d25..f86d86235 100644 --- a/internal/stream/rows/take.go +++ b/internal/stream/rows/take.go @@ -5,7 +5,6 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" @@ -33,7 +32,7 @@ func (op *TakeOperator) Iterate(in *environment.Environment, f func(out *environ return fmt.Errorf("limit expression must evaluate to a number, got %q", v.Type()) } - v, err = object.CastAsInteger(v) + v, err = v.CastAs(types.TypeBigint) if err != nil { return err } diff --git a/internal/stream/rows/temp_tree_sort.go b/internal/stream/rows/temp_tree_sort.go index 9f76a9025..32ea28b05 100644 --- a/internal/stream/rows/temp_tree_sort.go +++ b/internal/stream/rows/temp_tree_sort.go @@ -4,9 +4,9 @@ import ( "fmt" "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/encoding" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" @@ -47,56 +47,54 @@ func (op *TempTreeSortOperator) Iterate(in *environment.Environment, fn func(out var buf []byte err = op.Prev.Iterate(in, func(out *environment.Environment) error { buf = buf[:0] + // evaluate the sort expression v, err := op.Expr.Eval(out) if err != nil { - return err + if !errors.Is(err, types.ErrColumnNotFound) { + return err + } + + v = nil } - if types.IsNull(v) { + if v == nil { // the expression might be pointing to the original row. v, err = op.Expr.Eval(out.Outer) if err != nil { - // the only valid error here is a missing field. - if !errors.Is(err, types.ErrFieldNotFound) { + // the only valid error here is a missing column. + if !errors.Is(err, types.ErrColumnNotFound) { return err } } } - row, ok := out.GetRow() + r, ok := out.GetDatabaseRow() if !ok { return errors.New("missing row") } - var info *database.TableInfo - if row.TableName() != "" { - info, err = catalog.GetTableInfo(row.TableName()) - if err != nil { - return err - } + // TODO: we should find a way to encode using the table info. - buf, err = info.EncodeObject(in.GetTx(), buf, row.Object()) - if err != nil { - return err - } - } else { - buf, err = encoding.EncodeObject(buf, row.Object()) - if err != nil { - return err - } + buf, err = encodeTempRow(buf, r) + if err != nil { + return errors.Wrap(err, "failed to encode row") } var encKey []byte - key := row.Key() + key := r.Key() if key != nil { + info, err := catalog.GetTableInfo(r.TableName()) + if err != nil { + return err + } encKey, err = info.EncodeKey(key) if err != nil { return err } } - tk := tree.NewKey(v, types.NewTextValue(row.TableName()), types.NewBlobValue(encKey), types.NewIntegerValue(counter)) + tk := tree.NewKey(v, types.NewTextValue(r.TableName()), types.NewBlobValue(encKey), types.NewBigintValue(counter)) counter++ @@ -127,19 +125,9 @@ func (op *TempTreeSortOperator) Iterate(in *environment.Environment, fn func(out key = tree.NewEncodedKey(types.AsByteSlice(kf)) } - var obj types.Object - - if tableName != "" { - info, err := catalog.GetTableInfo(tableName) - if err != nil { - return err - } - obj = database.NewEncodedObject(&info.FieldConstraints, data) - } else { - obj = encoding.DecodeObject(data, false /* intAsDouble */) - } + r := decodeTempRow(data) - br.ResetWith(tableName, key, obj) + br.ResetWith(tableName, key, r) newEnv.SetRow(&br) @@ -154,3 +142,34 @@ func (op *TempTreeSortOperator) String() string { return fmt.Sprintf("rows.TempTreeSort(%s)", op.Expr) } + +func encodeTempRow(buf []byte, r row.Row) ([]byte, error) { + var values []types.Value + err := r.Iterate(func(column string, v types.Value) error { + values = append(values, types.NewTextValue(column)) + values = append(values, types.NewIntegerValue(int32(v.Type()))) + values = append(values, v) + return nil + }) + if err != nil { + return nil, errors.Wrap(err, "failed to iterate row") + } + + return types.EncodeValuesAsKey(buf, values...) +} + +func decodeTempRow(b []byte) row.Row { + cb := row.NewColumnBuffer() + + for len(b) > 0 { + colv, n := types.DecodeValue(b) + b = b[n:] + typev, n := types.DecodeValue(b) + b = b[n:] + v, n := types.Type(types.AsInt32(typev)).Def().Decode(b) + cb.Add(types.AsString(colv), v) + b = b[n:] + } + + return cb +} diff --git a/internal/stream/rows/temp_tree_sort_test.go b/internal/stream/rows/temp_tree_sort_test.go index ce56a5a77..5eb16805b 100644 --- a/internal/stream/rows/temp_tree_sort_test.go +++ b/internal/stream/rows/temp_tree_sort_test.go @@ -3,16 +3,16 @@ package rows_test import ( "testing" + "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/stream/table" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) @@ -20,23 +20,19 @@ func TestTempTreeSort(t *testing.T) { tests := []struct { name string sortExpr expr.Expr - values []types.Object - want []types.Object + values []any + want []row.Row fails bool desc bool }{ { "ASC", parser.MustParseExpr("a"), - []types.Object{ - testutil.MakeObject(t, `{"a": 0}`), - testutil.MakeObject(t, `{"a": null}`), - testutil.MakeObject(t, `{"a": true}`), - }, - []types.Object{ - testutil.MakeObject(t, `{}`), - testutil.MakeObject(t, `{"a": 0}`), - testutil.MakeObject(t, `{"a": 1}`), + []any{0, nil, true}, + []row.Row{ + testutil.MakeRow(t, `{"a": null}`), + testutil.MakeRow(t, `{"a": 0}`), + testutil.MakeRow(t, `{"a": 1}`), }, false, false, @@ -44,15 +40,11 @@ func TestTempTreeSort(t *testing.T) { { "DESC", parser.MustParseExpr("a"), - []types.Object{ - testutil.MakeObject(t, `{"a": 0}`), - testutil.MakeObject(t, `{"a": null}`), - testutil.MakeObject(t, `{"a": true}`), - }, - []types.Object{ - testutil.MakeObject(t, `{"a": 1}`), - testutil.MakeObject(t, `{"a": 0}`), - testutil.MakeObject(t, `{}`), + []any{0, nil, true}, + []row.Row{ + testutil.MakeRow(t, `{"a": 1}`), + testutil.MakeRow(t, `{"a": 0}`), + testutil.MakeRow(t, `{"a": null}`), }, false, true, @@ -66,10 +58,18 @@ func TestTempTreeSort(t *testing.T) { testutil.MustExec(t, db, tx, "CREATE TABLE test(a int)") - for _, doc := range test.values { - testutil.MustExec(t, db, tx, "INSERT INTO test VALUES ?", environment.Param{Value: doc}) + for _, val := range test.values { + testutil.MustExec(t, db, tx, "INSERT INTO test VALUES (?)", environment.Param{Value: val}) } + testutil.MustQuery(t, db, tx, "SELECT * FROM test").Iterate(func(r database.Row) error { + d, err := r.MarshalJSON() + require.NoError(t, err) + + t.Log(string(d)) + return nil + + }) var env environment.Environment env.DB = db env.Tx = tx @@ -81,13 +81,13 @@ func TestTempTreeSort(t *testing.T) { s = s.Pipe(rows.TempTreeSort(test.sortExpr)) } - var got []types.Object + var got []row.Row err := s.Iterate(&env, func(env *environment.Environment) error { r, ok := env.GetRow() require.True(t, ok) - fb := object.NewFieldBuffer() - fb.Copy(r.Object()) + fb := row.NewColumnBuffer() + fb.Copy(r) got = append(got, fb) return nil }) @@ -98,7 +98,7 @@ func TestTempTreeSort(t *testing.T) { assert.NoError(t, err) require.Equal(t, len(got), len(test.want)) for i := range got { - testutil.RequireObjEqual(t, test.want[i], got[i]) + testutil.RequireRowEqual(t, test.want[i], got[i]) } } }) diff --git a/internal/stream/stream.go b/internal/stream/stream.go index bd9c1e970..11a15a10c 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -118,7 +118,7 @@ type DiscardOperator struct { BaseOperator } -// Discard is an operator that doesn't produce any object. +// Discard is an operator that doesn't produce any row. // It iterates over the previous operator and discards all the objects. func Discard() *DiscardOperator { return &DiscardOperator{} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 7bfaa9d8c..cfcdaad40 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -7,20 +7,19 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/rows" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" ) func TestStream(t *testing.T) { s := stream.New(rows.Emit( - testutil.ParseExpr(t, `{"a": 1}`), - testutil.ParseExpr(t, `{"a": 2}`), + testutil.MakeRowExpr(t, `{"a": 1}`), + testutil.MakeRowExpr(t, `{"a": 2}`), )) s = s.Pipe(rows.Filter(parser.MustParseExpr("a > 1"))) @@ -43,40 +42,35 @@ func TestStream(t *testing.T) { func TestUnion(t *testing.T) { tests := []struct { name string - first, second, third []expr.Expr - expected testutil.Objs - fails bool + first, second, third []expr.Row + expected testutil.Rows }{ { - "same docs", - testutil.ParseExprs(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.ParseExprs(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.ParseExprs(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), - false, + "same rows", + testutil.MakeRowExprs(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRowExprs(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRowExprs(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 2, "b": 2}`), }, { - "different docs", - testutil.ParseExprs(t, `{"a": 1, "b": 1}`, `{"a": 1, "b": 2}`), - testutil.ParseExprs(t, `{"a": 2, "b": 1}`, `{"a": 2, "b": 2}`), - testutil.ParseExprs(t, `{"a": 3, "b": 1}`, `{"a": 3, "b": 2}`), - testutil.MakeObjects(t, `{"a": 1, "b": 1}`, `{"a": 1, "b": 2}`, `{"a": 2, "b": 1}`, `{"a": 2, "b": 2}`, `{"a": 3, "b": 1}`, `{"a": 3, "b": 2}`), - false, + "different rows", + testutil.MakeRowExprs(t, `{"a": 1, "b": 1}`, `{"a": 1, "b": 2}`), + testutil.MakeRowExprs(t, `{"a": 2, "b": 1}`, `{"a": 2, "b": 2}`), + testutil.MakeRowExprs(t, `{"a": 3, "b": 1}`, `{"a": 3, "b": 2}`), + testutil.MakeRows(t, `{"a": 1, "b": 1}`, `{"a": 1, "b": 2}`, `{"a": 2, "b": 1}`, `{"a": 2, "b": 2}`, `{"a": 3, "b": 1}`, `{"a": 3, "b": 2}`), }, { "mixed", - testutil.ParseExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), - testutil.ParseExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), - testutil.ParseExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - false, + testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), }, { "only one", - testutil.ParseExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 1}`, `{"a": 2}`), nil, nil, - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - false, + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), }, } @@ -101,57 +95,36 @@ func TestUnion(t *testing.T) { env.Tx = tx env.DB = db - var i int - var got testutil.Objs - err := st.Iterate(&env, func(env *environment.Environment) error { - r, ok := env.GetRow() - require.True(t, ok) - - clone, err := object.CloneValue(types.NewObjectValue(r.Object())) - if err != nil { - return err - } - - got = append(got, types.AsObject(clone)) - i++ - return nil - }) - if test.fails { - assert.Error(t, err) - } else { - assert.NoError(t, err) - require.Equal(t, len(test.expected), i) - test.expected.RequireEqual(t, got) - } + test.expected.RequireEqualStream(t, &env, st) }) } t.Run("String", func(t *testing.T) { st := stream.New(stream.Union( - stream.New(rows.Emit(testutil.ParseExprs(t, `{"a": 1}`, `{"a": 2}`)...)), - stream.New(rows.Emit(testutil.ParseExprs(t, `{"a": 3}`, `{"a": 4}`)...)), - stream.New(rows.Emit(testutil.ParseExprs(t, `{"a": 5}`, `{"a": 6}`)...)), + stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 1}`, `{"a": 2}`)...)), + stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 3}`, `{"a": 4}`)...)), + stream.New(rows.Emit(testutil.MakeRowExprs(t, `{"a": 5}`, `{"a": 6}`)...)), )) - require.Equal(t, `union(rows.Emit({a: 1}, {a: 2}), rows.Emit({a: 3}, {a: 4}), rows.Emit({a: 5}, {a: 6}))`, st.String()) + require.Equal(t, `union(rows.Emit((1), (2)), rows.Emit((3), (4)), rows.Emit((5), (6)))`, st.String()) }) } func TestConcatOperator(t *testing.T) { - in1 := testutil.ParseExprs(t, `{"a": 10}`, `{"a": 11}`) - in2 := testutil.ParseExprs(t, `{"a": 12}`, `{"a": 13}`) + in1 := testutil.MakeRowExprs(t, `{"a": 10}`, `{"a": 11}`) + in2 := testutil.MakeRowExprs(t, `{"a": 12}`, `{"a": 13}`) s1 := stream.New(rows.Emit(in1...)) s2 := stream.New(rows.Emit(in2...)) s := stream.Concat(s1, s2) - var got []types.Object + var got []row.Row s.Iterate(new(environment.Environment), func(env *environment.Environment) error { r, ok := env.GetRow() require.True(t, ok) - var fb object.FieldBuffer - err := fb.Copy(r.Object()) + var fb row.ColumnBuffer + err := fb.Copy(r) if err != nil { return err } @@ -161,8 +134,7 @@ func TestConcatOperator(t *testing.T) { want := append(in1, in2...) for i, w := range want { - v, _ := w.Eval(new(environment.Environment)) - d := types.AsObject(v) - testutil.RequireObjEqual(t, d, got[i]) + r, _ := w.Eval(new(environment.Environment)) + testutil.RequireRowEqual(t, r, got[i]) } } diff --git a/internal/stream/table/delete.go b/internal/stream/table/delete.go index cd2f1247e..308174bc7 100644 --- a/internal/stream/table/delete.go +++ b/internal/stream/table/delete.go @@ -33,7 +33,7 @@ func (op *DeleteOperator) Iterate(in *environment.Environment, f func(out *envir } } - r, ok := out.GetRow() + r, ok := out.GetDatabaseRow() if !ok { return errors.New("missing row") } diff --git a/internal/stream/table/insert.go b/internal/stream/table/insert.go index 610d83345..c001c4149 100644 --- a/internal/stream/table/insert.go +++ b/internal/stream/table/insert.go @@ -41,7 +41,7 @@ func (op *InsertOperator) Iterate(in *environment.Environment, f func(out *envir } } - _, r, err = table.Insert(r.Object()) + _, r, err = table.Insert(r) if err != nil { return err } diff --git a/internal/stream/table/replace.go b/internal/stream/table/replace.go index d27335554..63484c95c 100644 --- a/internal/stream/table/replace.go +++ b/internal/stream/table/replace.go @@ -25,7 +25,7 @@ func (op *ReplaceOperator) Iterate(in *environment.Environment, f func(out *envi var table *database.Table it := func(out *environment.Environment) error { - r, ok := out.GetRow() + r, ok := out.GetDatabaseRow() if !ok { return errors.New("missing row") } @@ -38,7 +38,7 @@ func (op *ReplaceOperator) Iterate(in *environment.Environment, f func(out *envi } } - _, err := table.Replace(r.Key(), r.Object()) + _, err := table.Replace(r.Key(), r) if err != nil { return err } diff --git a/internal/stream/table/table_test.go b/internal/stream/table/table_test.go index ed4c17114..cb91f92aa 100644 --- a/internal/stream/table/table_test.go +++ b/internal/stream/table/table_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/stream" "github.com/chaisql/chai/internal/stream/table" "github.com/chaisql/chai/internal/testutil" @@ -16,114 +16,114 @@ import ( func TestTableScan(t *testing.T) { tests := []struct { name string - docsInTable, expected testutil.Objs + docsInTable, expected testutil.Rows ranges stream.Ranges reverse bool fails bool }{ { "no-range", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), nil, false, false, }, { "no-range:reverse", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2}`, `{"a": 1}`), nil, true, false, }, { "max:2", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[2]`)}, + stream.Range{Max: testutil.ExprList(t, `(2)`)}, }, false, false, }, { "max:1", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[1]`)}, + stream.Range{Max: testutil.ExprList(t, `(1)`)}, }, false, false, }, { "max:1.1", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + nil, stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[1.1]`)}, + stream.Range{Max: testutil.ExprList(t, `(1.1)`)}, }, false, false, }, { "min", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`)}, + stream.Range{Min: testutil.ExprList(t, `(1)`)}, }, false, false, }, { "min:0.5", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + nil, stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[0.5]`)}, + stream.Range{Min: testutil.ExprList(t, `(0.5)`)}, }, false, false, }, { "min/max", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Max: testutil.ExprList(t, `[2]`)}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Max: testutil.ExprList(t, `(2)`)}, }, false, false, }, { "min/max:0.5/1.5", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + nil, stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[0.5]`), Max: testutil.ExprList(t, `[1.5]`)}, + stream.Range{Min: testutil.ExprList(t, `(0.5)`), Max: testutil.ExprList(t, `(1.5)`)}, }, false, false, }, { "reverse/max", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2}`, `{"a": 1}`), stream.Ranges{ - stream.Range{Max: testutil.ExprList(t, `[2]`)}, + stream.Range{Max: testutil.ExprList(t, `(2)`)}, }, true, false, }, { "reverse/min", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2}`, `{"a": 1}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`)}, + stream.Range{Min: testutil.ExprList(t, `(1)`)}, }, true, false, }, { "reverse/min/max", - testutil.MakeObjects(t, `{"a": 1}`, `{"a": 2}`), - testutil.MakeObjects(t, `{"a": 2}`, `{"a": 1}`), + testutil.MakeRows(t, `{"a": 1}`, `{"a": 2}`), + testutil.MakeRows(t, `{"a": 2}`, `{"a": 1}`), stream.Ranges{ - stream.Range{Min: testutil.ExprList(t, `[1]`), Max: testutil.ExprList(t, `[2]`)}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Max: testutil.ExprList(t, `(2)`)}, }, true, false, }, @@ -136,8 +136,11 @@ func TestTableScan(t *testing.T) { testutil.MustExec(t, db, tx, "CREATE TABLE test (a INTEGER NOT NULL PRIMARY KEY)") - for _, doc := range test.docsInTable { - testutil.MustExec(t, db, tx, "INSERT INTO test VALUES ?", environment.Param{Value: doc}) + for _, r := range test.docsInTable { + v, err := r.Get("a") + require.NoError(t, err) + + testutil.MustExec(t, db, tx, "INSERT INTO test VALUES (?)", environment.Param{Value: types.AsInt64(v)}) } op := table.Scan("test", test.ranges...) @@ -147,19 +150,19 @@ func TestTableScan(t *testing.T) { env.Params = []environment.Param{{Name: "foo", Value: 1}} var i int - var got testutil.Objs + var got testutil.Rows err := op.Iterate(&env, func(env *environment.Environment) error { r, ok := env.GetRow() require.True(t, ok) - var fb object.FieldBuffer + var fb row.ColumnBuffer - err := fb.Copy(r.Object()) + err := fb.Copy(r) assert.NoError(t, err) got = append(got, &fb) v, err := env.GetParamByName("foo") assert.NoError(t, err) - require.Equal(t, types.NewIntegerValue(1), v) + require.Equal(t, types.NewBigintValue(1), v) i++ return nil }) @@ -174,17 +177,17 @@ func TestTableScan(t *testing.T) { } t.Run("String", func(t *testing.T) { - require.Equal(t, `table.Scan("test", [{"min": [1], "max": [2]}])`, table.Scan("test", stream.Range{ - Min: testutil.ExprList(t, `[1]`), Max: testutil.ExprList(t, `[2]`), + require.Equal(t, `table.Scan("test", [{"min": (1), "max": (2)}])`, table.Scan("test", stream.Range{ + Min: testutil.ExprList(t, `(1)`), Max: testutil.ExprList(t, `(2)`), }).String()) op := table.Scan("test", - stream.Range{Min: testutil.ExprList(t, `[1]`), Max: testutil.ExprList(t, `[2]`), Exclusive: true}, - stream.Range{Min: testutil.ExprList(t, `[10]`), Exact: true}, - stream.Range{Min: testutil.ExprList(t, `[100]`)}, + stream.Range{Min: testutil.ExprList(t, `(1)`), Max: testutil.ExprList(t, `(2)`), Exclusive: true}, + stream.Range{Min: testutil.ExprList(t, `(10)`), Exact: true}, + stream.Range{Min: testutil.ExprList(t, `(100)`)}, ) op.Reverse = true - require.Equal(t, `table.ScanReverse("test", [{"min": [1], "max": [2], "exclusive": true}, {"min": [10], "exact": true}, {"min": [100]}])`, op.String()) + require.Equal(t, `table.ScanReverse("test", [{"min": (1), "max": (2), "exclusive": true}, {"min": (10), "exact": true}, {"min": (100)}])`, op.String()) }) } diff --git a/internal/stream/table/validate.go b/internal/stream/table/validate.go index 0e90b6122..57eaf17fb 100644 --- a/internal/stream/table/validate.go +++ b/internal/stream/table/validate.go @@ -9,7 +9,7 @@ import ( "github.com/cockroachdb/errors" ) -// ValidateOperator validates and converts incoming rows against table and field constraints. +// ValidateOperator validates and converts incoming rows against table and column constraints. type ValidateOperator struct { stream.BaseOperator @@ -38,7 +38,7 @@ func (op *ValidateOperator) Iterate(in *environment.Environment, fn func(out *en var newEnv environment.Environment var br database.BasicRow - var eo database.EncodedObject + var eo database.EncodedRow return op.Prev.Iterate(in, func(out *environment.Environment) error { buf = buf[:0] newEnv.SetOuter(out) @@ -49,16 +49,21 @@ func (op *ValidateOperator) Iterate(in *environment.Environment, fn func(out *en } // generate default values, validate and encode row - buf, err = info.EncodeObject(tx, buf, row.Object()) + buf, err = info.EncodeRow(tx, buf, row) if err != nil { return err } // use the encoded row as the new row - eo.ResetWith(&info.FieldConstraints, buf) - - br.ResetWith(row.TableName(), row.Key(), &eo) - newEnv.SetRow(&br) + eo.ResetWith(&info.ColumnConstraints, buf) + + if dRow, ok := row.(database.Row); ok { + br.ResetWith(op.tableName, dRow.Key(), &eo) + newEnv.SetRow(&br) + } else { + br.ResetWith(op.tableName, nil, &eo) + newEnv.SetRow(&br) + } // validate CHECK constraints if any err := info.TableConstraints.ValidateRow(tx, newEnv.Row) diff --git a/internal/stream/union.go b/internal/stream/union.go index cca56c7df..24d8e5d2f 100644 --- a/internal/stream/union.go +++ b/internal/stream/union.go @@ -5,9 +5,8 @@ import ( "strings" "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/encoding" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" ) @@ -31,7 +30,7 @@ func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *envir defer func() { if cleanup != nil { e := cleanup() - if err != nil { + if err == nil { err = e } } @@ -39,15 +38,13 @@ func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *envir // iterate over all the streams and insert each key in the temporary table // to deduplicate them - fb := object.NewFieldBuffer() var buf []byte for _, s := range it.Streams { err := s.Iterate(in, func(out *environment.Environment) error { - fb.Reset() buf = buf[:0] - row, ok := out.GetRow() + r, ok := out.GetRow() if !ok { return errors.New("missing row") } @@ -62,27 +59,30 @@ func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *envir } } - key := tree.NewKey(types.NewObjectValue(row.Object())) + var tableName string + var encKey []byte - if row.Key() != nil { + if dr, ok := r.(database.Row); ok { // encode the row key and table name as the value - info, err := in.GetTx().Catalog.GetTableInfo(row.TableName()) - if err != nil { - return err - } - encKey, err := info.EncodeKey(row.Key()) + tableName = dr.TableName() + + info, err := in.GetTx().Catalog.GetTableInfo(tableName) if err != nil { return err } - fb.Add("key", types.NewBlobValue(encKey)) - fb.Add("table", types.NewTextValue(row.TableName())) - buf, err = encoding.EncodeObject(buf, fb) + encKey, err = info.EncodeKey(dr.Key()) if err != nil { return err } } + key := tree.NewKey(row.Flatten(r)...) + buf, err = types.EncodeValuesAsKey(buf, types.NewBlobValue(encKey), types.NewTextValue(tableName)) + if err != nil { + return err + } + err = temp.Put(key, buf) if err == nil || errors.Is(err, database.ErrIndexDuplicateValue) { return nil @@ -102,11 +102,9 @@ func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *envir var newEnv environment.Environment newEnv.SetOuter(in) - var vb object.ValueBuffer var basicRow database.BasicRow // iterate over the temporary index return temp.IterateOnRange(nil, false, func(key *tree.Key, value []byte) error { - vb.Reset() kv, err := key.Decode() if err != nil { return err @@ -115,20 +113,12 @@ func (it *UnionOperator) Iterate(in *environment.Environment, fn func(out *envir var tableName string var pk *tree.Key - obj := types.AsObject(kv[0]) + obj := row.Unflatten(kv) if len(value) > 1 { - ser := encoding.DecodeObject(value, false) - pkf, err := ser.GetByField("key") - if err != nil { - return err - } - pk = tree.NewEncodedKey(types.AsByteSlice(pkf)) - tf, err := ser.GetByField("table") - if err != nil { - return err - } - tableName = types.AsString(tf) + ser := types.DecodeValues(value) + pk = tree.NewEncodedKey(types.AsByteSlice(ser[0])) + tableName = types.AsString(ser[1]) } basicRow.ResetWith(tableName, pk, obj) diff --git a/internal/testutil/expr.go b/internal/testutil/expr.go index a0f5604c3..da0eabb4d 100644 --- a/internal/testutil/expr.go +++ b/internal/testutil/expr.go @@ -12,7 +12,6 @@ import ( "github.com/chaisql/chai/internal/environment" "github.com/chaisql/chai/internal/expr" "github.com/chaisql/chai/internal/expr/functions" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/sql/parser" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/testutil/genexprtests" @@ -20,14 +19,9 @@ import ( "github.com/stretchr/testify/require" ) -// BlobValue creates a literal value of type Blob. -func BlobValue(v []byte) expr.LiteralValue { - return expr.LiteralValue{Value: types.NewBlobValue(v)} -} - -// TextValue creates a literal value of type Text. -func TextValue(v string) expr.LiteralValue { - return expr.LiteralValue{Value: types.NewTextValue(v)} +// NullValue creates a literal value of type Null. +func NullValue() expr.LiteralValue { + return expr.LiteralValue{Value: types.NewNullValue()} } // BoolValue creates a literal value of type Bool. @@ -36,28 +30,28 @@ func BoolValue(v bool) expr.LiteralValue { } // IntegerValue creates a literal value of type Integer. -func IntegerValue(v int64) expr.LiteralValue { +func IntegerValue(v int32) expr.LiteralValue { return expr.LiteralValue{Value: types.NewIntegerValue(v)} } +// BigintValue creates a literal value of type Bigint. +func BigintValue(v int64) expr.LiteralValue { + return expr.LiteralValue{Value: types.NewBigintValue(v)} +} + // DoubleValue creates a literal value of type Double. func DoubleValue(v float64) expr.LiteralValue { return expr.LiteralValue{Value: types.NewDoubleValue(v)} } -// NullValue creates a literal value of type Null. -func NullValue() expr.LiteralValue { - return expr.LiteralValue{Value: types.NewNullValue()} -} - -// ObjectValue creates a literal value of type Object. -func ObjectValue(d types.Object) expr.LiteralValue { - return expr.LiteralValue{Value: types.NewObjectValue(d)} +// TextValue creates a literal value of type Text. +func TextValue(v string) expr.LiteralValue { + return expr.LiteralValue{Value: types.NewTextValue(v)} } -// ArrayValue creates a literal value of type Array. -func ArrayValue(a types.Array) expr.LiteralValue { - return expr.LiteralValue{Value: types.NewArrayValue(a)} +// BlobValue creates a literal value of type Blob. +func BlobValue(v []byte) expr.LiteralValue { + return expr.LiteralValue{Value: types.NewBlobValue(v)} } func ExprList(t testing.TB, s string) expr.LiteralExprList { @@ -65,32 +59,16 @@ func ExprList(t testing.TB, s string) expr.LiteralExprList { e, err := parser.ParseExpr(s) assert.NoError(t, err) - require.IsType(t, e, expr.LiteralExprList{}) - - return e.(expr.LiteralExprList) -} - -func ParsePath(t testing.TB, p string) expr.Path { - t.Helper() - - return expr.Path(ParseObjectPath(t, p)) -} - -func ParseObjectPath(t testing.TB, p string) object.Path { - t.Helper() - - vp, err := parser.ParsePath(p) - assert.NoError(t, err) - return vp -} - -func ParseObjectPaths(t testing.TB, str ...string) []object.Path { - var paths []object.Path - for _, s := range str { - paths = append(paths, ParseObjectPath(t, s)) + switch e := e.(type) { + case expr.LiteralExprList: + return e + case expr.Parentheses: + return expr.LiteralExprList{e.E} + default: + t.Fatalf("unexpected expression type: %T", e) } - return paths + return nil } func ParseNamedExpr(t testing.TB, s string, name ...string) expr.Expr { @@ -128,6 +106,24 @@ func ParseExprs(t testing.TB, s ...string) []expr.Expr { return ex } +func ParseExprList(t testing.TB, s string) expr.LiteralExprList { + t.Helper() + + e, err := parser.ParseExpr(s) + assert.NoError(t, err) + + switch e := e.(type) { + case expr.LiteralExprList: + return e + case expr.Parentheses: + return expr.LiteralExprList{e.E} + default: + t.Fatalf("unexpected expression type: %T", e) + } + + return e.(expr.LiteralExprList) +} + func TestExpr(t testing.TB, exprStr string, env *environment.Environment, want types.Value, fails bool) { t.Helper() e, err := parser.NewParser(strings.NewReader(exprStr)).ParseExpr() @@ -143,8 +139,7 @@ func TestExpr(t testing.TB, exprStr string, env *environment.Environment, want t func FunctionExpr(t testing.TB, name string, args ...expr.Expr) expr.Expr { t.Helper() - n := strings.Split(name, ".") - def, err := functions.DefaultPackages().GetFunc(n[0], n[1]) + def, err := functions.GetFunc(name) assert.NoError(t, err) require.NotNil(t, def) expr, err := def.Function(args...) @@ -208,7 +203,7 @@ func ExprRunner(t *testing.T, testfile string) { } else { // eval it, it should return an error _, err = e.Eval(env) - require.NotNilf(t, err, "expected expr to return an error at %s:%\n`%s`, got nil", testfile, stmt.ExprLine, stmt.Expr) + require.NotNilf(t, err, "expected expr to return an error at %s:%d\n`%s`, got nil", testfile, stmt.ExprLine, stmt.Expr) require.Regexpf(t, regexp.MustCompile(regexp.QuoteMeta(stmt.Res)), err.Error(), "expected error message to match at %s:%d", testfile, stmt.ResLine) } }) diff --git a/internal/testutil/object.go b/internal/testutil/object.go deleted file mode 100644 index 38138ef4a..000000000 --- a/internal/testutil/object.go +++ /dev/null @@ -1,193 +0,0 @@ -package testutil - -import ( - "bufio" - "encoding/json" - "fmt" - "io" - "os" - "strings" - "testing" - - "github.com/chaisql/chai/internal/database" - "github.com/chaisql/chai/internal/object" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" - "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/require" -) - -// MakeValue turns v into a types.Value. -func MakeValue(t testing.TB, v any) types.Value { - t.Helper() - - vv, err := object.NewValue(v) - assert.NoError(t, err) - return vv -} - -func MakeArrayValue(t testing.TB, vs ...any) types.Value { - t.Helper() - - vvs := []types.Value{} - for _, v := range vs { - vvs = append(vvs, MakeValue(t, v)) - } - - vb := object.NewValueBuffer(vvs...) - - return types.NewArrayValue(vb) -} - -// MakeObject creates an object from a json string. -func MakeObject(t testing.TB, jsonDoc string) types.Object { - t.Helper() - - var fb object.FieldBuffer - - err := fb.UnmarshalJSON([]byte(jsonDoc)) - assert.NoError(t, err) - - return &fb -} - -// MakeObjects creates a slice of objects from json strings. -func MakeObjects(t testing.TB, jsonDocs ...string) (docs Objs) { - for _, jsonDoc := range jsonDocs { - docs = append(docs, MakeObject(t, jsonDoc)) - } - return -} - -// MakeArray creates an array from a json string. -func MakeArray(t testing.TB, jsonArray string) types.Array { - t.Helper() - - var vb object.ValueBuffer - - err := vb.UnmarshalJSON([]byte(jsonArray)) - assert.NoError(t, err) - - return &vb -} - -func MakeValueBuffer(t testing.TB, jsonArray string) *object.ValueBuffer { - t.Helper() - - var vb object.ValueBuffer - - err := vb.UnmarshalJSON([]byte(jsonArray)) - assert.NoError(t, err) - - return &vb -} - -type Objs []types.Object - -func (o Objs) RequireEqual(t testing.TB, others Objs) { - t.Helper() - - require.Equal(t, len(o), len(others), fmt.Sprintf("expected len %d, got %d", len(o), len(others))) - - for i, d := range o { - RequireObjEqual(t, d, others[i]) - } -} - -// Dump a json representation of v to os.Stdout. -func Dump(t testing.TB, v interface{}) { - t.Helper() - - enc := json.NewEncoder(os.Stdout) - enc.SetIndent("", " ") - err := enc.Encode(v) - assert.NoError(t, err) -} - -func RequireJSONEq(t testing.TB, o any, expected string) { - t.Helper() - - data, err := json.Marshal(o) - assert.NoError(t, err) - require.JSONEq(t, expected, string(data)) -} - -// IteratorToJSONArray encodes all the objects of an iterator to a JSON array. -func IteratorToJSONArray(w io.Writer, s database.RowIterator) error { - buf := bufio.NewWriter(w) - - buf.WriteByte('[') - - first := true - err := s.Iterate(func(r database.Row) error { - if !first { - buf.WriteString(", ") - } else { - first = false - } - - data, err := r.MarshalJSON() - if err != nil { - return err - } - - _, err = buf.Write(data) - return err - }) - if err != nil { - return err - } - - buf.WriteByte(']') - return buf.Flush() -} - -func RequireObjEqual(t testing.TB, want, got types.Object) { - t.Helper() - - tWant, err := types.MarshalTextIndent(types.NewObjectValue(object.WithSortedFields(want)), "\n", " ") - require.NoError(t, err) - tGot, err := types.MarshalTextIndent(types.NewObjectValue(object.WithSortedFields(got)), "\n", " ") - require.NoError(t, err) - - if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" { - require.Failf(t, "mismatched objects, (-want, +got)", "%s", diff) - } -} - -func RequireArrayEqual(t testing.TB, want, got types.Array) { - t.Helper() - - tWant, err := types.MarshalTextIndent(types.NewArrayValue(want), "\n", " ") - require.NoError(t, err) - tGot, err := types.MarshalTextIndent(types.NewArrayValue(got), "\n", " ") - require.NoError(t, err) - - if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" { - require.Failf(t, "mismatched arrays, (-want, +got)", "%s", diff) - } -} - -func RequireValueEqual(t testing.TB, want, got types.Value, msg string, args ...any) { - t.Helper() - - tWant, err := types.MarshalTextIndent(want, "\n", " ") - require.NoError(t, err) - tGot, err := types.MarshalTextIndent(got, "\n", " ") - require.NoError(t, err) - - if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" { - require.Failf(t, "mismatched values, (-want, +got)", "%s\n%s", diff, fmt.Sprintf(msg, args...)) - } -} - -func CloneObject(t testing.TB, d types.Object) *object.FieldBuffer { - t.Helper() - - var newFb object.FieldBuffer - - err := newFb.Copy(d) - assert.NoError(t, err) - - return &newFb -} diff --git a/internal/testutil/row.go b/internal/testutil/row.go new file mode 100644 index 000000000..f675a2dd7 --- /dev/null +++ b/internal/testutil/row.go @@ -0,0 +1,182 @@ +package testutil + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "os" + "strings" + "testing" + + "github.com/chaisql/chai/internal/database" + "github.com/chaisql/chai/internal/environment" + "github.com/chaisql/chai/internal/expr" + "github.com/chaisql/chai/internal/row" + "github.com/chaisql/chai/internal/stream" + "github.com/chaisql/chai/internal/testutil/assert" + "github.com/chaisql/chai/internal/types" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" +) + +func MakeRow(t testing.TB, s string) row.Row { + var cb row.ColumnBuffer + + err := json.Unmarshal([]byte(s), &cb) + assert.NoError(t, err) + return &cb +} + +func MakeRows(t testing.TB, s ...string) []row.Row { + var rows []row.Row + for _, v := range s { + rows = append(rows, MakeRow(t, v)) + } + return rows +} + +func MakeRowExpr(t testing.TB, s string) expr.Row { + r := MakeRow(t, s) + var er expr.Row + + r.Iterate(func(column string, value types.Value) error { + er.Columns = append(er.Columns, column) + er.Exprs = append(er.Exprs, expr.LiteralValue{Value: value}) + return nil + }) + + return er +} + +func MakeRowExprs(t testing.TB, s ...string) []expr.Row { + var rows []expr.Row + for _, v := range s { + rows = append(rows, MakeRowExpr(t, v)) + } + return rows +} + +// MakeValue turns v into a types.Value. +func MakeValue(t testing.TB, v any) types.Value { + t.Helper() + + vv, err := row.NewValue(v) + assert.NoError(t, err) + return vv +} + +type Rows []row.Row + +func (r Rows) RequireEqual(t testing.TB, others Rows) { + t.Helper() + + require.Equal(t, len(r), len(others), fmt.Sprintf("expected len %d, got %d", len(r), len(others))) + + for i := range r { + RequireRowEqual(t, r[i], others[i]) + } +} + +func (r Rows) RequireEqualStream(t testing.TB, env *environment.Environment, st *stream.Stream) { + t.Helper() + + var i int + + err := st.Iterate(env, func(env *environment.Environment) error { + rr, ok := env.GetRow() + require.True(t, ok) + + RequireRowEqual(t, r[i], rr) + i++ + return nil + }) + assert.NoError(t, err) + + require.Equal(t, len(r), i) +} + +// Dump a json representation of v to os.Stdout. +func Dump(t testing.TB, v interface{}) { + t.Helper() + + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + err := enc.Encode(v) + assert.NoError(t, err) +} + +func RequireJSONEq(t testing.TB, o any, expected string) { + t.Helper() + + data, err := json.Marshal(o) + assert.NoError(t, err) + require.JSONEq(t, expected, string(data)) +} + +// IteratorToJSONArray encodes all the objects of an iterator to a JSON array. +func IteratorToJSONArray(w io.Writer, s database.RowIterator) error { + buf := bufio.NewWriter(w) + + buf.WriteByte('[') + + first := true + err := s.Iterate(func(r database.Row) error { + if !first { + buf.WriteString(", ") + } else { + first = false + } + + data, err := r.MarshalJSON() + if err != nil { + return err + } + + _, err = buf.Write(data) + return err + }) + if err != nil { + return err + } + + buf.WriteByte(']') + return buf.Flush() +} + +func RequireRowEqual(t testing.TB, want, got row.Row) { + t.Helper() + + tWant, err := json.MarshalIndent(want, "", " ") + require.NoError(t, err) + tGot, err := json.MarshalIndent(got, "", " ") + require.NoError(t, err) + + if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" { + require.Failf(t, "mismatched objects, (-want, +got)", "%s", diff) + } +} + +func RequireValueEqual(t testing.TB, want, got types.Value, msg string, args ...any) { + t.Helper() + + tWant, err := json.MarshalIndent(want, "", " ") + require.NoError(t, err) + tGot, err := json.MarshalIndent(got, "", " ") + require.NoError(t, err) + + if diff := cmp.Diff(string(tWant), string(tGot), cmp.Comparer(strings.EqualFold)); diff != "" { + require.Failf(t, "mismatched values, (-want, +got)", "%s\n%s", diff, fmt.Sprintf(msg, args...)) + } +} + +func CloneRow(t testing.TB, r row.Row) *row.ColumnBuffer { + t.Helper() + + var newFb row.ColumnBuffer + + err := newFb.Copy(r) + assert.NoError(t, err) + + return &newFb +} diff --git a/internal/testutil/stream.go b/internal/testutil/stream.go index 17b99704f..848ed4137 100644 --- a/internal/testutil/stream.go +++ b/internal/testutil/stream.go @@ -2,16 +2,15 @@ package testutil import ( "errors" - "sort" "strings" "testing" "github.com/chaisql/chai" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/expr" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/sql/parser" - "github.com/chaisql/chai/internal/testutil/assert" - "github.com/chaisql/chai/internal/types" + "github.com/chaisql/chai/internal/sql/scanner" "github.com/stretchr/testify/require" ) @@ -20,13 +19,69 @@ type ResultStream struct { env *environment.Environment } -func (ds *ResultStream) Next() (types.Value, error) { - exp, err := ds.Parser.ParseObject() - if err != nil { +func (ds *ResultStream) Next() (row.Row, error) { + return ds.ParseObject() +} + +func (p *ResultStream) ParseObject() (row.Row, error) { + // Parse { token. + if err := p.Parser.ParseTokens(scanner.LBRACKET); err != nil { return nil, err } - return exp.Eval(ds.env) + var cb row.ColumnBuffer + + // Parse kv pairs. + for { + column, e, err := p.parseKV() + if err != nil { + p.Unscan() + break + } + + v, err := e.Eval(p.env) + if err != nil { + return nil, err + } + + cb.Add(column, v) + + if tok, _, _ := p.ScanIgnoreWhitespace(); tok != scanner.COMMA { + p.Unscan() + break + } + } + + // Parse required } token. + if err := p.ParseTokens(scanner.RBRACKET); err != nil { + return nil, err + } + + return &cb, nil +} + +// parseKV parses a key-value pair in the form IDENT : Expr. +func (p *ResultStream) parseKV() (string, expr.Expr, error) { + var k string + + tok, _, lit := p.ScanIgnoreWhitespace() + if tok == scanner.IDENT || tok == scanner.STRING { + k = lit + } else { + return "", nil, errors.New("expected IDENT or STRING") + } + + if err := p.ParseTokens(scanner.COLON); err != nil { + p.Unscan() + return "", nil, err + } + + e, err := p.ParseExpr() + if err != nil { + return "", nil, err + } + + return k, e, nil } func ParseResultStream(stream string) *ResultStream { @@ -36,19 +91,20 @@ func ParseResultStream(stream string) *ResultStream { return &ResultStream{p, env} } -func RequireStreamEq(t *testing.T, raw string, res *chai.Result, sorted bool) { +func RequireStreamEq(t *testing.T, raw string, res *chai.Result) { t.Helper() - RequireStreamEqf(t, raw, res, sorted, "") + RequireStreamEqf(t, raw, res, "") } -func RequireStreamEqf(t *testing.T, raw string, res *chai.Result, sorted bool, msg string, args ...any) { +func RequireStreamEqf(t *testing.T, raw string, res *chai.Result, msg string, args ...any) { + errMsg := append([]any{msg}, args...) t.Helper() - objs := ParseResultStream(raw) + rows := ParseResultStream(raw) - want := object.NewValueBuffer() + var want []row.Row for { - v, err := objs.Next() + v, err := rows.Next() if err != nil { if perr, ok := err.(*parser.ParseError); ok { if perr.Found == "EOF" { @@ -60,63 +116,48 @@ func RequireStreamEqf(t *testing.T, raw string, res *chai.Result, sorted bool, m } } } - require.NoError(t, err, append([]any{msg}, args...)...) + require.NoError(t, err, errMsg...) - v, err = object.CloneValue(v) - require.NoError(t, err, append([]any{msg}, args...)...) - want.Append(v) + want = append(want, v) } - got := object.NewValueBuffer() + var got []row.Row err := res.Iterate(func(r *chai.Row) error { - var fb object.FieldBuffer - err := fb.Copy(r.Object()) - assert.NoError(t, err) + var cb row.ColumnBuffer + err := r.StructScan(&cb) + require.NoError(t, err, errMsg...) - got.Append(types.NewObjectValue(&fb)) + got = append(got, &cb) return nil }) - assert.NoError(t, err) + require.NoError(t, err, errMsg...) + + var expected strings.Builder + for i := range want { + data, err := row.MarshalTextIndent(want[i], "\n", " ") + require.NoError(t, err, errMsg...) + if i > 0 { + expected.WriteString("\n") + } - if sorted { - swant := sortableValueBuffer(*want) - sgot := sortableValueBuffer(*got) - sort.Sort(&swant) - sort.Sort(&sgot) + expected.WriteString(string(data)) } - expected, err := types.MarshalTextIndent(types.NewArrayValue(want), "\n", " ") - assert.NoError(t, err) + var actual strings.Builder + for i := range got { + data, err := row.MarshalTextIndent(got[i], "\n", " ") + require.NoError(t, err, errMsg...) + if i > 0 { + actual.WriteString("\n") + } - actual, err := types.MarshalTextIndent(types.NewArrayValue(got), "\n", " ") - assert.NoError(t, err) + actual.WriteString(string(data)) + } if msg != "" { - require.Equal(t, string(expected), string(actual), append([]any{msg}, args...)...) + require.Equal(t, expected.String(), actual.String(), errMsg...) } else { - require.Equal(t, string(expected), string(actual)) - } -} - -type sortableValueBuffer object.ValueBuffer - -func (vb *sortableValueBuffer) Len() int { - return len(vb.Values) -} - -func (vb *sortableValueBuffer) Swap(i, j int) { - vb.Values[i], vb.Values[j] = vb.Values[j], vb.Values[i] -} - -func (vb *sortableValueBuffer) Less(i, j int) (ok bool) { - it, jt := vb.Values[i].Type(), vb.Values[j].Type() - if it == jt || (it.IsNumber() && jt.IsNumber()) { - // TODO(asdine) make the types package work with static objects - // to avoid having to deal with errors? - ok, _ = vb.Values[i].LT(vb.Values[j]) - return + require.Equal(t, expected.String(), actual.String()) } - - return it < jt } diff --git a/internal/tree/key.go b/internal/tree/key.go index 13ef842bd..1cbee85a8 100644 --- a/internal/tree/key.go +++ b/internal/tree/key.go @@ -1,8 +1,9 @@ package tree import ( + "strings" + "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -38,7 +39,7 @@ func (k *Key) Encode(ns Namespace, order SortOrder) ([]byte, error) { for i, v := range k.values { // extract the sort order - buf, err = encoding.EncodeValue(buf, v, order.IsDesc(i)) + buf, err = types.EncodeValueAsKey(buf, v, order.IsDesc(i)) if err != nil { return nil, err } @@ -65,7 +66,7 @@ func (key *Key) Decode() ([]types.Value, error) { b = b[n:] for { - v, n := encoding.DecodeValue(b, false /* intAsDouble */) + v, n := types.DecodeValue(b) b = b[n:] values = append(values, v) @@ -83,5 +84,14 @@ func (k *Key) String() string { } values, _ := k.Decode() - return types.NewArrayValue(object.NewValueBuffer(values...)).String() + var sb strings.Builder + sb.WriteString("(") + for i, v := range values { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(v.String()) + } + sb.WriteString(")") + return sb.String() } diff --git a/internal/tree/tree.go b/internal/tree/tree.go index 2a302e6e0..861b7eb42 100644 --- a/internal/tree/tree.go +++ b/internal/tree/tree.go @@ -2,12 +2,9 @@ package tree import ( "fmt" - "math" - "time" "github.com/chaisql/chai/internal/encoding" "github.com/chaisql/chai/internal/engine" - "github.com/chaisql/chai/internal/types" "github.com/cockroachdb/errors" ) @@ -128,7 +125,16 @@ func (t *Tree) Get(key *Key) ([]byte, error) { return nil, err } - return t.Session.Get(k) + v, err := t.Session.Get(k) + if err != nil { + return nil, err + } + + if len(v) == 0 || v[0] == 0 { + return nil, nil + } + + return v, nil } // Exists returns true if the key exists in the tree. @@ -208,6 +214,10 @@ func (t *Tree) IterateOnRange(rng *Range, reverse bool, fn func(*Key, []byte) er if err != nil { return err } + if len(v) == 0 || v[0] == 0 { + v = nil + } + err = fn(&k, v) if err != nil { return err @@ -285,10 +295,10 @@ func (t *Tree) buildMinKeyForType(max *Key, desc bool) ([]byte, error) { if len(max.values) == 1 { buf := encoding.EncodeInt(nil, int64(t.Namespace)) if desc { - return append(buf, byte(t.NewMinTypeForTypeDesc(max.values[0].Type()))), nil + return append(buf, max.values[0].Type().MinEnctypeDesc()), nil } - return append(buf, byte(t.NewMinTypeForType(max.values[0].Type()))), nil + return append(buf, max.values[0].Type().MinEnctype()), nil } buf, err := NewKey(max.values[:len(max.values)-1]...).Encode(t.Namespace, t.Order) @@ -297,10 +307,10 @@ func (t *Tree) buildMinKeyForType(max *Key, desc bool) ([]byte, error) { } i := len(max.values) - 1 if desc { - return append(buf, byte(t.NewMinTypeForTypeDesc(max.values[i].Type()))), nil + return append(buf, max.values[i].Type().MinEnctypeDesc()), nil } - return append(buf, byte(t.NewMinTypeForType(max.values[i].Type()))), nil + return append(buf, max.values[i].Type().MinEnctype()), nil } func (t *Tree) buildMaxKeyForType(min *Key, desc bool) ([]byte, error) { @@ -311,9 +321,9 @@ func (t *Tree) buildMaxKeyForType(min *Key, desc bool) ([]byte, error) { if len(min.values) == 1 { buf := encoding.EncodeInt(nil, int64(t.Namespace)) if desc { - return append(buf, byte(t.NewMaxTypeForTypeDesc(min.values[0].Type()))), nil + return append(buf, min.values[0].Type().MaxEnctypeDesc()), nil } - return append(buf, byte(t.NewMaxTypeForType(min.values[0].Type()))), nil + return append(buf, min.values[0].Type().MaxEnctype()), nil } buf, err := NewKey(min.values[:len(min.values)-1]...).Encode(t.Namespace, t.Order) @@ -322,10 +332,10 @@ func (t *Tree) buildMaxKeyForType(min *Key, desc bool) ([]byte, error) { } i := len(min.values) - 1 if desc { - return append(buf, byte(t.NewMaxTypeForTypeDesc(min.values[i].Type()))), nil + return append(buf, min.values[i].Type().MaxEnctypeDesc()), nil } - return append(buf, byte(t.NewMaxTypeForType(min.values[i].Type()))), nil + return append(buf, min.values[i].Type().MaxEnctype()), nil } func (t *Tree) buildLastKey() []byte { @@ -359,156 +369,6 @@ func (t *Tree) buildEndKeyExclusive(key *Key, desc bool) ([]byte, error) { return key.Encode(t.Namespace, t.Order) } -func (t *Tree) NewMinValueForType(tp types.Type) types.Value { - switch tp { - case types.TypeNull: - return types.NewNullValue() - case types.TypeBoolean: - return types.NewBooleanValue(false) - case types.TypeInteger: - return types.NewIntegerValue(math.MinInt64) - case types.TypeDouble: - return types.NewDoubleValue(-math.MaxFloat64) - case types.TypeTimestamp: - return types.NewTimestampValue(time.Time{}) - case types.TypeText: - return types.NewTextValue("") - case types.TypeBlob: - return types.NewBlobValue(nil) - case types.TypeArray: - return types.NewArrayValue(nil) - case types.TypeObject: - return types.NewObjectValue(nil) - default: - panic(fmt.Sprintf("unsupported type %v", t)) - } -} - -func (t *Tree) NewMinTypeForType(tp types.Type) byte { - switch tp { - case types.TypeNull: - return encoding.NullValue - case types.TypeBoolean: - return encoding.FalseValue - case types.TypeInteger: - return encoding.Int64Value - case types.TypeDouble: - return encoding.Float64Value - case types.TypeTimestamp: - return encoding.Int64Value - case types.TypeText: - return encoding.TextValue - case types.TypeBlob: - return encoding.BlobValue - case types.TypeArray: - return encoding.ArrayValue - case types.TypeObject: - return encoding.ObjectValue - default: - panic(fmt.Sprintf("unsupported type %v", t)) - } -} - -func (t *Tree) NewMinTypeForTypeDesc(tp types.Type) byte { - switch tp { - case types.TypeNull: - return encoding.DESC_NullValue - case types.TypeBoolean: - return encoding.DESC_TrueValue - case types.TypeInteger: - return encoding.DESC_Uint64Value - case types.TypeDouble: - return encoding.DESC_Float64Value - case types.TypeTimestamp: - return encoding.DESC_Uint64Value - case types.TypeText: - return encoding.DESC_TextValue - case types.TypeBlob: - return encoding.DESC_BlobValue - case types.TypeArray: - return encoding.DESC_ArrayValue - case types.TypeObject: - return encoding.DESC_ObjectValue - default: - panic(fmt.Sprintf("unsupported type %v", t)) - } -} - -func (t *Tree) NewMaxTypeForTypeDesc(tp types.Type) byte { - switch tp { - case types.TypeNull: - return encoding.DESC_NullValue + 1 - case types.TypeBoolean: - return encoding.DESC_FalseValue + 1 - case types.TypeInteger: - return encoding.DESC_Int64Value + 1 - case types.TypeDouble: - return encoding.DESC_Float64Value + 1 - case types.TypeTimestamp: - return encoding.DESC_Int64Value + 1 - case types.TypeText: - return encoding.DESC_TextValue + 1 - case types.TypeBlob: - return encoding.DESC_BlobValue + 1 - case types.TypeArray: - return encoding.DESC_ArrayValue + 1 - case types.TypeObject: - return encoding.DESC_ObjectValue + 1 - default: - panic(fmt.Sprintf("unsupported type %v", t)) - } -} - -func (t *Tree) NewMinValueForTypeDesc(tp types.Type) types.Value { - switch tp { - case types.TypeNull: - return types.NewNullValue() - case types.TypeBoolean: - return types.NewBooleanValue(true) - case types.TypeInteger: - return types.NewIntegerValue(math.MaxInt64) - case types.TypeDouble: - return types.NewDoubleValue(math.MaxFloat64) - case types.TypeTimestamp: - return types.NewIntegerValue(math.MaxInt64) - case types.TypeText: - return types.NewTextValue("") - case types.TypeBlob: - return types.NewBlobValue(nil) - case types.TypeArray: - return types.NewArrayValue(nil) - case types.TypeObject: - return types.NewObjectValue(nil) - default: - panic(fmt.Sprintf("unsupported type %v", t)) - } -} - -func (t *Tree) NewMaxTypeForType(tp types.Type) byte { - switch tp { - case types.TypeNull: - return encoding.NullValue + 1 - case types.TypeBoolean: - return encoding.TrueValue + 1 - case types.TypeInteger: - return encoding.Uint64Value + 1 - case types.TypeDouble: - return encoding.Float64Value + 1 - case types.TypeTimestamp: - return encoding.Uint64Value + 1 - case types.TypeText: - return encoding.TextValue + 1 - case types.TypeBlob: - return encoding.BlobValue + 1 - case types.TypeArray: - return encoding.ArrayValue + 1 - case types.TypeObject: - return encoding.ObjectValue + 1 - default: - panic(fmt.Sprintf("unsupported type %v", t)) - } -} - // A Range of keys to iterate on. // By default, Min and Max are inclusive. // If Exclusive is true, Min and Max are excluded diff --git a/internal/tree/tree_test.go b/internal/tree/tree_test.go index 4bf6f9598..44b12c58b 100644 --- a/internal/tree/tree_test.go +++ b/internal/tree/tree_test.go @@ -5,9 +5,8 @@ import ( "sort" "testing" - "github.com/chaisql/chai/internal/database" "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/tree" @@ -29,7 +28,7 @@ var key2 = func() *tree.Key { ) }() -var doc = object.NewFromMap(map[string]bool{ +var doc = row.NewFromMap(map[string]bool{ "a": true, }) @@ -37,7 +36,7 @@ func TestTreeGet(t *testing.T) { tests := []struct { name string key *tree.Key - d types.Object + r row.Row Fails bool }{ {"existing", key1, doc, false}, @@ -46,8 +45,6 @@ func TestTreeGet(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var ti database.TableInfo - ti.FieldConstraints.AllowExtraFields = true tree := testutil.NewTestTree(t, 10) err := tree.Put(key1, []byte{1}) @@ -76,9 +73,6 @@ func TestTreeDelete(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var ti database.TableInfo - ti.FieldConstraints.AllowExtraFields = true - tree := testutil.NewTestTree(t, 10) err := tree.Put(key1, []byte{1}) @@ -124,9 +118,9 @@ func TestTreeIterateOnRange(t *testing.T) { var keys []*tree.Key // keys: [bool, bool, int] * 100 - var c int64 // for unicity - for i := int64(0); i < 10; i++ { - for j := int64(0); j < 10; j++ { + var c int32 // for unicity + for i := int32(0); i < 10; i++ { + for j := int32(0); j < 10; j++ { keys = append(keys, tree.NewKey( types.NewBooleanValue(i%2 == 0), types.NewBooleanValue(j%2 == 0), @@ -137,7 +131,7 @@ func TestTreeIterateOnRange(t *testing.T) { } // keys: [int, text, double] * 1000 - for i := int64(0); i < 10; i++ { + for i := int32(0); i < 10; i++ { for j := 0; j < 10; j++ { for k := 0; k < 10; k++ { keys = append(keys, tree.NewKey( @@ -150,7 +144,7 @@ func TestTreeIterateOnRange(t *testing.T) { } // keys: [double, double] * 100 - for i := int64(0); i < 10; i++ { + for i := int32(0); i < 10; i++ { for j := 0; j < 10; j++ { keys = append(keys, tree.NewKey( types.NewDoubleValue(float64(i)), @@ -160,7 +154,7 @@ func TestTreeIterateOnRange(t *testing.T) { } // keys: [text, text] * 100 - for i := int64(0); i < 10; i++ { + for i := int32(0); i < 10; i++ { for j := 0; j < 10; j++ { keys = append(keys, tree.NewKey( types.NewTextValue(fmt.Sprintf("bar%d", i)), @@ -170,7 +164,7 @@ func TestTreeIterateOnRange(t *testing.T) { } // keys: [blob, blob] * 100 - for i := int64(0); i < 10; i++ { + for i := int32(0); i < 10; i++ { for j := 0; j < 10; j++ { keys = append(keys, tree.NewKey( types.NewBlobValue([]byte(fmt.Sprintf("bar%d", i))), @@ -179,26 +173,6 @@ func TestTreeIterateOnRange(t *testing.T) { } } - // keys: [array, array] * 100 - for i := int64(0); i < 10; i++ { - for j := int64(0); j < 10; j++ { - keys = append(keys, tree.NewKey( - types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(i))), - types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(j))), - )) - } - } - - // keys: [doc, doc] * 100 - for i := int64(0); i < 10; i++ { - for j := int64(0); j < 10; j++ { - keys = append(keys, tree.NewKey( - types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(i))), - types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(j))), - )) - } - } - for _, reversed := range []bool{false, true} { tests := []struct { name string @@ -208,7 +182,7 @@ func TestTreeIterateOnRange(t *testing.T) { order tree.SortOrder }{ // all - {"all", nil, nil, false, 0, 1600, 0}, + {"all", nil, nil, false, 0, 1400, 0}, // arity: 1 {"= 3", tree.NewKey(types.NewIntegerValue(3)), tree.NewKey(types.NewIntegerValue(3)), false, 400, 500, 0}, @@ -220,10 +194,10 @@ func TestTreeIterateOnRange(t *testing.T) { {"> 3 AND < 7", tree.NewKey(types.NewIntegerValue(3)), tree.NewKey(types.NewIntegerValue(7)), true, 500, 800, 0}, // arity 1, order desc - {"= 3 desc", tree.NewKey(types.NewIntegerValue(3)), tree.NewKey(types.NewIntegerValue(3)), false, 1100, 1200, tree.SortOrder(0).SetDesc(0)}, - {">= 3 desc", tree.NewKey(types.NewIntegerValue(3)), nil, false, 500, 1200, tree.SortOrder(0).SetDesc(0)}, - {"> 3 desc", tree.NewKey(types.NewIntegerValue(3)), nil, true, 500, 1100, tree.SortOrder(0).SetDesc(0)}, - {"<= 3 desc", nil, tree.NewKey(types.NewIntegerValue(3)), false, 1100, 1500, tree.SortOrder(0).SetDesc(0)}, + {"= 3 desc", tree.NewKey(types.NewIntegerValue(3)), tree.NewKey(types.NewIntegerValue(3)), false, 900, 1000, tree.SortOrder(0).SetDesc(0)}, + {">= 3 desc", tree.NewKey(types.NewIntegerValue(3)), nil, false, 300, 1000, tree.SortOrder(0).SetDesc(0)}, + {"> 3 desc", tree.NewKey(types.NewIntegerValue(3)), nil, true, 300, 900, tree.SortOrder(0).SetDesc(0)}, + {"<= 3 desc", nil, tree.NewKey(types.NewIntegerValue(3)), false, 900, 1300, tree.SortOrder(0).SetDesc(0)}, {"= 12 desc", tree.NewKey(types.NewIntegerValue(12)), tree.NewKey(types.NewIntegerValue(12)), false, 0, 0, tree.SortOrder(0).SetDesc(0)}, // arity 2 @@ -235,13 +209,13 @@ func TestTreeIterateOnRange(t *testing.T) { {"= 3 AND >= foo1 AND <= foo3", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo3")), false, 410, 440, 0}, // arity 2 desc - {"= 3 AND = foo1 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), false, 1180, 1190, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, - {"= 3 AND >= foo1 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), nil, false, 1100, 1190, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, - {"= 3 AND > foo1 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), nil, true, 1100, 1180, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, - {"= 3 AND <= foo1 desc", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), false, 1180, 1200, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, - {"= 3 AND < foo1 desc", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), true, 1190, 1200, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, - {"= 3 AND >= foo1 AND <= foo3 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo3")), false, 1160, 1190, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, - {"= 3 AND > foo1 AND < foo3 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo3")), true, 1170, 1180, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND = foo1 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), false, 980, 990, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND >= foo1 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), nil, false, 900, 990, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND > foo1 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), nil, true, 900, 980, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND <= foo1 desc", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), false, 980, 1000, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND < foo1 desc", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), true, 990, 1000, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND >= foo1 AND <= foo3 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo3")), false, 960, 990, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, + {"= 3 AND > foo1 AND < foo3 desc", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1")), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo3")), true, 970, 980, tree.SortOrder(0).SetDesc(0).SetDesc(1)}, // arity 3 {"= 3 AND = foo1 AND = 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), false, 415, 416, 0}, @@ -251,11 +225,11 @@ func TestTreeIterateOnRange(t *testing.T) { {"= 3 AND = foo1 AND < 5.0", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), true, 410, 415, 0}, // arity 3 desc - {"= 3 AND = foo1 AND = 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), false, 1184, 1185, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, - {"= 3 AND = foo1 AND >= 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), nil, false, 1180, 1185, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, - {"= 3 AND = foo1 AND > 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), nil, true, 1180, 1184, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, - {"= 3 AND = foo1 AND <= 5.0", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), false, 1184, 1190, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, - {"= 3 AND = foo1 AND < 5.0", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), true, 1185, 1190, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, + {"= 3 AND = foo1 AND = 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), false, 984, 985, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, + {"= 3 AND = foo1 AND >= 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), nil, false, 980, 985, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, + {"= 3 AND = foo1 AND > 5.0", tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), nil, true, 980, 984, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, + {"= 3 AND = foo1 AND <= 5.0", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), false, 984, 990, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, + {"= 3 AND = foo1 AND < 5.0", nil, tree.NewKey(types.NewIntegerValue(3), types.NewTextValue("foo1"), types.NewDoubleValue(5)), true, 985, 990, tree.SortOrder(0).SetDesc(0).SetDesc(1).SetDesc(2)}, // other types @@ -269,13 +243,13 @@ func TestTreeIterateOnRange(t *testing.T) { {"< true", nil, tree.NewKey(types.NewBooleanValue(true)), true, 0, 50, 0}, // bool desc - {"= false desc", tree.NewKey(types.NewBooleanValue(false)), tree.NewKey(types.NewBooleanValue(false)), false, 1550, 1600, tree.SortOrder(0).SetDesc(0)}, - {"= true desc", tree.NewKey(types.NewBooleanValue(true)), tree.NewKey(types.NewBooleanValue(true)), false, 1500, 1550, tree.SortOrder(0).SetDesc(0)}, - {">= false desc", tree.NewKey(types.NewBooleanValue(false)), nil, false, 1500, 1600, tree.SortOrder(0).SetDesc(0)}, - {"> false desc", tree.NewKey(types.NewBooleanValue(false)), nil, true, 1500, 1550, tree.SortOrder(0).SetDesc(0)}, - {"<= false desc", nil, tree.NewKey(types.NewBooleanValue(false)), false, 1550, 1600, tree.SortOrder(0).SetDesc(0)}, + {"= false desc", tree.NewKey(types.NewBooleanValue(false)), tree.NewKey(types.NewBooleanValue(false)), false, 1350, 1400, tree.SortOrder(0).SetDesc(0)}, + {"= true desc", tree.NewKey(types.NewBooleanValue(true)), tree.NewKey(types.NewBooleanValue(true)), false, 1300, 1350, tree.SortOrder(0).SetDesc(0)}, + {">= false desc", tree.NewKey(types.NewBooleanValue(false)), nil, false, 1300, 1400, tree.SortOrder(0).SetDesc(0)}, + {"> false desc", tree.NewKey(types.NewBooleanValue(false)), nil, true, 1300, 1350, tree.SortOrder(0).SetDesc(0)}, + {"<= false desc", nil, tree.NewKey(types.NewBooleanValue(false)), false, 1350, 1400, tree.SortOrder(0).SetDesc(0)}, {"< false desc", nil, tree.NewKey(types.NewBooleanValue(false)), true, 0, 0, tree.SortOrder(0).SetDesc(0)}, - {"< true desc", nil, tree.NewKey(types.NewBooleanValue(true)), true, 1550, 1600, tree.SortOrder(0).SetDesc(0)}, + {"< true desc", nil, tree.NewKey(types.NewBooleanValue(true)), true, 1350, 1400, tree.SortOrder(0).SetDesc(0)}, // double {"= 3.0", tree.NewKey(types.NewDoubleValue(3)), tree.NewKey(types.NewDoubleValue(3)), false, 1130, 1140, 0}, @@ -285,11 +259,11 @@ func TestTreeIterateOnRange(t *testing.T) { {"< 3.0", nil, tree.NewKey(types.NewDoubleValue(3)), true, 1100, 1130, 0}, // double desc - {"= 3.0 desc", tree.NewKey(types.NewDoubleValue(3)), tree.NewKey(types.NewDoubleValue(3)), false, 460, 470, tree.SortOrder(0).SetDesc(0)}, - {">= 3.0 desc", tree.NewKey(types.NewDoubleValue(3)), nil, false, 400, 470, tree.SortOrder(0).SetDesc(0)}, - {"> 3.0 desc", tree.NewKey(types.NewDoubleValue(3)), nil, true, 400, 460, tree.SortOrder(0).SetDesc(0)}, - {"<= 3.0 desc", nil, tree.NewKey(types.NewDoubleValue(3)), false, 460, 500, tree.SortOrder(0).SetDesc(0)}, - {"< 3.0 desc", nil, tree.NewKey(types.NewDoubleValue(3)), true, 470, 500, tree.SortOrder(0).SetDesc(0)}, + {"= 3.0 desc", tree.NewKey(types.NewDoubleValue(3)), tree.NewKey(types.NewDoubleValue(3)), false, 260, 270, tree.SortOrder(0).SetDesc(0)}, + {">= 3.0 desc", tree.NewKey(types.NewDoubleValue(3)), nil, false, 200, 270, tree.SortOrder(0).SetDesc(0)}, + {"> 3.0 desc", tree.NewKey(types.NewDoubleValue(3)), nil, true, 200, 260, tree.SortOrder(0).SetDesc(0)}, + {"<= 3.0 desc", nil, tree.NewKey(types.NewDoubleValue(3)), false, 260, 300, tree.SortOrder(0).SetDesc(0)}, + {"< 3.0 desc", nil, tree.NewKey(types.NewDoubleValue(3)), true, 270, 300, tree.SortOrder(0).SetDesc(0)}, // text {"= bar3", tree.NewKey(types.NewTextValue("bar3")), tree.NewKey(types.NewTextValue("bar3")), false, 1230, 1240, 0}, @@ -299,11 +273,11 @@ func TestTreeIterateOnRange(t *testing.T) { {"< bar3", nil, tree.NewKey(types.NewTextValue("bar3")), true, 1200, 1230, 0}, // text desc - {"= bar3 desc", tree.NewKey(types.NewTextValue("bar3")), tree.NewKey(types.NewTextValue("bar3")), false, 360, 370, tree.SortOrder(0).SetDesc(0)}, - {">= bar3 desc", tree.NewKey(types.NewTextValue("bar3")), nil, false, 300, 370, tree.SortOrder(0).SetDesc(0)}, - {"> bar3 desc", tree.NewKey(types.NewTextValue("bar3")), nil, true, 300, 360, tree.SortOrder(0).SetDesc(0)}, - {"<= bar3 desc", nil, tree.NewKey(types.NewTextValue("bar3")), false, 360, 400, tree.SortOrder(0).SetDesc(0)}, - {"< bar3 desc", nil, tree.NewKey(types.NewTextValue("bar3")), true, 370, 400, tree.SortOrder(0).SetDesc(0)}, + {"= bar3 desc", tree.NewKey(types.NewTextValue("bar3")), tree.NewKey(types.NewTextValue("bar3")), false, 160, 170, tree.SortOrder(0).SetDesc(0)}, + {">= bar3 desc", tree.NewKey(types.NewTextValue("bar3")), nil, false, 100, 170, tree.SortOrder(0).SetDesc(0)}, + {"> bar3 desc", tree.NewKey(types.NewTextValue("bar3")), nil, true, 100, 160, tree.SortOrder(0).SetDesc(0)}, + {"<= bar3 desc", nil, tree.NewKey(types.NewTextValue("bar3")), false, 160, 200, tree.SortOrder(0).SetDesc(0)}, + {"< bar3 desc", nil, tree.NewKey(types.NewTextValue("bar3")), true, 170, 200, tree.SortOrder(0).SetDesc(0)}, // blob {"= bar3", tree.NewKey(types.NewBlobValue([]byte("bar3"))), tree.NewKey(types.NewBlobValue([]byte("bar3"))), false, 1330, 1340, 0}, @@ -313,39 +287,11 @@ func TestTreeIterateOnRange(t *testing.T) { {"< bar3", nil, tree.NewKey(types.NewBlobValue([]byte("bar3"))), true, 1300, 1330, 0}, // blob desc - {"= bar3 desc", tree.NewKey(types.NewBlobValue([]byte("bar3"))), tree.NewKey(types.NewBlobValue([]byte("bar3"))), false, 260, 270, tree.SortOrder(0).SetDesc(0)}, - {">= bar3 desc", tree.NewKey(types.NewBlobValue([]byte("bar3"))), nil, false, 200, 270, tree.SortOrder(0).SetDesc(0)}, - {"> bar3 desc", tree.NewKey(types.NewBlobValue([]byte("bar3"))), nil, true, 200, 260, tree.SortOrder(0).SetDesc(0)}, - {"<= bar3 desc", nil, tree.NewKey(types.NewBlobValue([]byte("bar3"))), false, 260, 300, tree.SortOrder(0).SetDesc(0)}, - {"< bar3 desc", nil, tree.NewKey(types.NewBlobValue([]byte("bar3"))), true, 270, 300, tree.SortOrder(0).SetDesc(0)}, - - // array - {"= [3]", tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), false, 1430, 1440, 0}, - {">= [3]", tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), nil, false, 1430, 1500, 0}, - {"> [3]", tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), nil, true, 1440, 1500, 0}, - {"<= [3]", nil, tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), false, 1400, 1440, 0}, - {"< [3]", nil, tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), true, 1400, 1430, 0}, - - // array desc - {"= [3] desc", tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), false, 160, 170, tree.SortOrder(0).SetDesc(0)}, - {">= [3] desc", tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), nil, false, 100, 170, tree.SortOrder(0).SetDesc(0)}, - {"> [3] desc", tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), nil, true, 100, 160, tree.SortOrder(0).SetDesc(0)}, - {"<= [3] desc", nil, tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), false, 160, 200, tree.SortOrder(0).SetDesc(0)}, - {"< [3] desc", nil, tree.NewKey(types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(3)))), true, 170, 200, tree.SortOrder(0).SetDesc(0)}, - - // object - {"= {foo: 3}", tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), false, 1530, 1540, 0}, - {">= {foo: 3}", tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), nil, false, 1530, 1600, 0}, - {"> {foo: 3}", tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), nil, true, 1540, 1600, 0}, - {"<= {foo: 3}", nil, tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), false, 1500, 1540, 0}, - {"< {foo: 3}", nil, tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), true, 1500, 1530, 0}, - - // object desc - {"= {foo: 3} desc", tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), false, 60, 70, tree.SortOrder(0).SetDesc(0)}, - {">= {foo: 3} desc", tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), nil, false, 0, 70, tree.SortOrder(0).SetDesc(0)}, - {"> {foo: 3} desc", tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), nil, true, 0, 60, tree.SortOrder(0).SetDesc(0)}, - {"<= {foo: 3} desc", nil, tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), false, 60, 100, tree.SortOrder(0).SetDesc(0)}, - {"< {foo: 3} desc", nil, tree.NewKey(types.NewObjectValue(object.NewFieldBuffer().Add("foo", types.NewIntegerValue(3)))), true, 70, 100, tree.SortOrder(0).SetDesc(0)}, + {"= bar3 desc", tree.NewKey(types.NewBlobValue([]byte("bar3"))), tree.NewKey(types.NewBlobValue([]byte("bar3"))), false, 60, 70, tree.SortOrder(0).SetDesc(0)}, + {">= bar3 desc", tree.NewKey(types.NewBlobValue([]byte("bar3"))), nil, false, 0, 70, tree.SortOrder(0).SetDesc(0)}, + {"> bar3 desc", tree.NewKey(types.NewBlobValue([]byte("bar3"))), nil, true, 0, 60, tree.SortOrder(0).SetDesc(0)}, + {"<= bar3 desc", nil, tree.NewKey(types.NewBlobValue([]byte("bar3"))), false, 60, 100, tree.SortOrder(0).SetDesc(0)}, + {"< bar3 desc", nil, tree.NewKey(types.NewBlobValue([]byte("bar3"))), true, 70, 100, tree.SortOrder(0).SetDesc(0)}, } for _, test := range tests { diff --git a/internal/types/array.go b/internal/types/array.go deleted file mode 100644 index 49c918a61..000000000 --- a/internal/types/array.go +++ /dev/null @@ -1,106 +0,0 @@ -package types - -import "github.com/cockroachdb/errors" - -var _ Value = NewArrayValue(nil) - -type ArrayValue struct { - a Array -} - -// NewArrayValue returns a SQL ARRAY value. -func NewArrayValue(x Array) *ArrayValue { - return &ArrayValue{ - a: x, - } -} - -func (v *ArrayValue) V() any { - return v.a -} - -func (v *ArrayValue) Type() Type { - return TypeArray -} - -func (v *ArrayValue) IsZero() (bool, error) { - // The zero value of an array is an empty array. - // Thus, if GetByIndex(0) returns the ErrValueNotFound - // it means that the array is empty. - _, err := v.a.GetByIndex(0) - if errors.Is(err, ErrValueNotFound) { - return true, nil - } - return false, err -} - -func (v *ArrayValue) String() string { - data, _ := v.MarshalText() - return string(data) -} - -func (v *ArrayValue) MarshalText() ([]byte, error) { - return MarshalTextIndent(v, "", "") -} - -func (v *ArrayValue) MarshalJSON() ([]byte, error) { - return jsonArray{Array: v.a}.MarshalJSON() -} - -func (v *ArrayValue) EQ(other Value) (bool, error) { - t := other.Type() - if t != TypeArray { - return false, nil - } - - return compareArrays(operatorEq, v.a, AsArray(other)) -} - -func (v *ArrayValue) GT(other Value) (bool, error) { - t := other.Type() - if t != TypeArray { - return false, nil - } - - return compareArrays(operatorGt, v.a, AsArray(other)) -} - -func (v *ArrayValue) GTE(other Value) (bool, error) { - t := other.Type() - if t != TypeArray { - return false, nil - } - - return compareArrays(operatorGte, v.a, AsArray(other)) -} - -func (v *ArrayValue) LT(other Value) (bool, error) { - t := other.Type() - if t != TypeArray { - return false, nil - } - - return compareArrays(operatorLt, v.a, AsArray(other)) -} - -func (v *ArrayValue) LTE(other Value) (bool, error) { - t := other.Type() - if t != TypeArray { - return false, nil - } - - return compareArrays(operatorLte, v.a, AsArray(other)) -} - -func (v *ArrayValue) Between(a, b Value) (bool, error) { - if a.Type() != TypeArray || b.Type() != TypeArray { - return false, nil - } - - ok, err := a.LTE(v) - if err != nil || !ok { - return false, err - } - - return b.GTE(v) -} diff --git a/internal/types/bigint.go b/internal/types/bigint.go new file mode 100644 index 000000000..cde6f2012 --- /dev/null +++ b/internal/types/bigint.go @@ -0,0 +1,312 @@ +package types + +import ( + "math" + "strconv" + + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" +) + +var _ TypeDefinition = BigintTypeDef{} + +type BigintTypeDef struct{} + +func (BigintTypeDef) New(v any) Value { + return NewBigintValue(v.(int64)) +} + +func (BigintTypeDef) Type() Type { + return TypeBigint +} + +func (BigintTypeDef) Decode(src []byte) (Value, int) { + x, n := encoding.DecodeInt(src) + return NewBigintValue(x), n +} + +func (BigintTypeDef) IsComparableWith(other Type) bool { + return other == TypeBigint || other == TypeInteger || other == TypeDouble +} + +func (BigintTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeBigint || other == TypeInteger +} + +var _ Numeric = NewBigintValue(0) +var _ Integral = NewBigintValue(0) +var _ Value = NewBigintValue(0) + +type BigintValue int64 + +// NewBigintValue returns a SQL BIGINT value. +func NewBigintValue(x int64) BigintValue { + return BigintValue(x) +} + +func (v BigintValue) V() any { + return int64(v) +} + +func (v BigintValue) Type() Type { + return TypeBigint +} + +func (v BigintValue) TypeDef() TypeDefinition { + return BigintTypeDef{} +} + +func (v BigintValue) IsZero() (bool, error) { + return v == 0, nil +} + +func (v BigintValue) String() string { + return strconv.FormatInt(int64(v), 10) +} + +func (v BigintValue) MarshalText() ([]byte, error) { + return []byte(strconv.FormatInt(int64(v), 10)), nil +} + +func (v BigintValue) MarshalJSON() ([]byte, error) { + return v.MarshalText() +} + +func (v BigintValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeInt(dst, int64(v)), nil +} + +func (v BigintValue) EncodeAsKey(dst []byte) ([]byte, error) { + return v.Encode(dst) +} + +func (v BigintValue) CastAs(target Type) (Value, error) { + switch target { + case TypeBigint: + return v, nil + case TypeInteger: + if int64(v) > math.MaxInt32 || int64(v) < math.MinInt32 { + return nil, errors.Errorf("integer out of range") + } + return NewIntegerValue(int32(v)), nil + case TypeDouble: + return NewDoubleValue(float64(v)), nil + case TypeText: + return NewTextValue(v.String()), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + +func (v BigintValue) EQ(other Value) (bool, error) { + t := other.Type() + switch t { + case TypeBigint, TypeInteger: + return int64(v) == AsInt64(other), nil + case TypeDouble: + return float64(int64(v)) == AsFloat64(other), nil + default: + return false, nil + } +} + +func (v BigintValue) GT(other Value) (bool, error) { + t := other.Type() + switch t { + case TypeBigint, TypeInteger: + return int64(v) > AsInt64(other), nil + case TypeDouble: + return float64(int64(v)) > AsFloat64(other), nil + default: + return false, nil + } +} + +func (v BigintValue) GTE(other Value) (bool, error) { + t := other.Type() + switch t { + case TypeBigint, TypeInteger: + return int64(v) >= AsInt64(other), nil + case TypeDouble: + return float64(int64(v)) >= AsFloat64(other), nil + default: + return false, nil + } +} + +func (v BigintValue) LT(other Value) (bool, error) { + t := other.Type() + switch t { + case TypeBigint, TypeInteger: + return int64(v) < AsInt64(other), nil + case TypeDouble: + return float64(int64(v)) <= AsFloat64(other), nil + default: + return false, nil + } +} + +func (v BigintValue) LTE(other Value) (bool, error) { + t := other.Type() + switch t { + case TypeBigint, TypeInteger: + return int64(v) <= AsInt64(other), nil + case TypeDouble: + return float64(int64(v)) <= AsFloat64(other), nil + default: + return false, nil + } +} + +func (v BigintValue) Between(a, b Value) (bool, error) { + if !a.Type().IsNumber() || !b.Type().IsNumber() { + return false, nil + } + + ok, err := a.LTE(v) + if err != nil || !ok { + return false, err + } + + return b.GTE(v) +} + +func (v BigintValue) Add(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + xa := int64(v) + xb := AsInt64(other) + if isAddOverflow(xa, xb, math.MinInt64, math.MaxInt64) { + return nil, errors.New("bigint out of range") + } + xr := xa + xb + return NewBigintValue(xr), nil + case TypeDouble: + return NewDoubleValue(float64(int64(v)) + AsFloat64(other)), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) Sub(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + xa := int64(v) + xb := AsInt64(other) + if isSubOverflow(xa, xb, math.MinInt64, math.MaxInt64) { + return nil, errors.New("bigint out of range") + } + xr := xa - xb + return NewBigintValue(xr), nil + case TypeDouble: + return NewDoubleValue(float64(int64(v)) - AsFloat64(other)), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) Mul(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + xa := int64(v) + xb := AsInt64(other) + if xa == 0 || xb == 0 { + return NewBigintValue(0), nil + } + if isMulOverflow(xa, xb, math.MinInt64, math.MaxInt64) { + return nil, errors.New("bigint out of range") + } + xr := xa * xb + return NewBigintValue(xr), nil + case TypeDouble: + return NewDoubleValue(float64(int64(v)) * AsFloat64(other)), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) Div(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + xa := int64(v) + xb := AsInt64(other) + if xb == 0 { + return NewNullValue(), nil + } + + return NewBigintValue(xa / xb), nil + case TypeDouble: + xa := float64(AsInt64(v)) + xb := AsFloat64(other) + if xb == 0 { + return NewNullValue(), nil + } + + return NewDoubleValue(xa / xb), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) Mod(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + xa := int64(v) + xb := AsInt64(other) + if xb == 0 { + return NewNullValue(), nil + } + + return NewBigintValue(xa % xb), nil + case TypeDouble: + xa := float64(AsInt64(v)) + xb := AsFloat64(other) + mod := math.Mod(xa, xb) + if math.IsNaN(mod) { + return NewNullValue(), nil + } + + return NewDoubleValue(mod), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) BitwiseAnd(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + return NewBigintValue(int64(v) & AsInt64(other)), nil + case TypeDouble: + xa := int64(v) + xb := int64(AsFloat64(other)) + return NewBigintValue(xa & xb), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) BitwiseOr(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + return NewBigintValue(int64(v) | AsInt64(other)), nil + case TypeDouble: + xa := int64(v) + xb := int64(AsFloat64(other)) + return NewBigintValue(xa | xb), nil + } + + return NewNullValue(), nil +} + +func (v BigintValue) BitwiseXor(other Numeric) (Value, error) { + switch other.Type() { + case TypeBigint, TypeInteger: + return NewBigintValue(int64(v) ^ AsInt64(other)), nil + case TypeDouble: + xa := int64(v) + xb := int64(AsFloat64(other)) + return NewBigintValue(xa ^ xb), nil + } + + return NewNullValue(), nil +} diff --git a/internal/types/blob.go b/internal/types/blob.go index 64ec15596..7ed91282a 100644 --- a/internal/types/blob.go +++ b/internal/types/blob.go @@ -4,8 +4,36 @@ import ( "bytes" "encoding/base64" "encoding/hex" + + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" ) +var _ TypeDefinition = BlobTypeDef{} + +type BlobTypeDef struct{} + +func (BlobTypeDef) New(v any) Value { + return NewBlobValue(v.([]byte)) +} + +func (BlobTypeDef) Type() Type { + return TypeBlob +} + +func (BlobTypeDef) Decode(src []byte) (Value, int) { + x, n := encoding.DecodeBlob(src) + return NewBlobValue(x), n +} + +func (BlobTypeDef) IsComparableWith(other Type) bool { + return other == TypeBlob +} + +func (BlobTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeBlob +} + var _ Value = NewBlobValue(nil) type BlobValue []byte @@ -23,6 +51,10 @@ func (v BlobValue) Type() Type { return TypeBlob } +func (v BlobValue) TypeDef() TypeDefinition { + return BlobTypeDef{} +} + func (v BlobValue) IsZero() (bool, error) { return v == nil, nil } @@ -48,6 +80,25 @@ func (v BlobValue) MarshalJSON() ([]byte, error) { return dst, nil } +func (v BlobValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeBlob(dst, v), nil +} + +func (v BlobValue) EncodeAsKey(dst []byte) ([]byte, error) { + return encoding.EncodeBlob(dst, v), nil +} + +func (v BlobValue) CastAs(target Type) (Value, error) { + switch target { + case TypeBlob: + return v, nil + case TypeText: + return NewTextValue(base64.StdEncoding.EncodeToString([]byte(v))), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + func (v BlobValue) EQ(other Value) (bool, error) { if other.Type() != TypeBlob { return false, nil diff --git a/internal/types/boolean.go b/internal/types/boolean.go index acdb17b09..66ee9d22e 100644 --- a/internal/types/boolean.go +++ b/internal/types/boolean.go @@ -1,6 +1,36 @@ package types -import "strconv" +import ( + "strconv" + + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" +) + +var _ TypeDefinition = BooleanTypeDef{} + +type BooleanTypeDef struct{} + +func (BooleanTypeDef) New(v any) Value { + return NewBooleanValue(v.(bool)) +} + +func (BooleanTypeDef) Type() Type { + return TypeBoolean +} + +func (t BooleanTypeDef) Decode(src []byte) (Value, int) { + b := encoding.DecodeBoolean(src) + return NewBooleanValue(b), 1 +} + +func (BooleanTypeDef) IsComparableWith(other Type) bool { + return other == TypeBoolean +} + +func (BooleanTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeBoolean +} var _ Value = NewBooleanValue(false) @@ -19,6 +49,10 @@ func (v BooleanValue) Type() Type { return TypeBoolean } +func (v BooleanValue) TypeDef() TypeDefinition { + return BooleanTypeDef{} +} + func (v BooleanValue) IsZero() (bool, error) { return !bool(v), nil } @@ -35,6 +69,35 @@ func (v BooleanValue) MarshalJSON() ([]byte, error) { return v.MarshalText() } +func (v BooleanValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeBoolean(dst, bool(v)), nil +} + +func (v BooleanValue) EncodeAsKey(dst []byte) ([]byte, error) { + return v.Encode(dst) +} + +func (v BooleanValue) CastAs(target Type) (Value, error) { + switch target { + case TypeBoolean: + return v, nil + case TypeInteger: + if bool(v) { + return NewIntegerValue(1), nil + } + + return NewIntegerValue(0), nil + case TypeText: + return NewTextValue(v.String()), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + +func (v BooleanValue) ConvertToIndexedType(t Type) (Value, error) { + return v, nil +} + func (v BooleanValue) EQ(other Value) (bool, error) { if other.Type() != TypeBoolean { return false, nil diff --git a/internal/object/cast_test.go b/internal/types/cast_test.go similarity index 69% rename from internal/object/cast_test.go rename to internal/types/cast_test.go index d61072bc8..73bec105a 100644 --- a/internal/object/cast_test.go +++ b/internal/types/cast_test.go @@ -1,4 +1,4 @@ -package object +package types_test import ( "math" @@ -23,12 +23,6 @@ func TestCastAs(t *testing.T) { tsV := types.NewTimestampValue(now) textV := types.NewTextValue("foo") blobV := types.NewBlobValue([]byte("asdine")) - arrayV := types.NewArrayValue(NewValueBuffer(). - Append(types.NewTextValue("bar")). - Append(integerV)) - docV := types.NewObjectValue(NewFieldBuffer(). - Add("a", integerV). - Add("b", textV)) check := func(t *testing.T, targetType types.Type, tests []test) { t.Helper() @@ -37,7 +31,7 @@ func TestCastAs(t *testing.T) { t.Run(test.v.String(), func(t *testing.T) { t.Helper() - got, err := CastAs(test.v, targetType) + got, err := test.v.CastAs(targetType) if test.fails { assert.Error(t, err) } else { @@ -58,8 +52,6 @@ func TestCastAs(t *testing.T) { {types.NewTextValue("true"), boolV, false}, {types.NewTextValue("false"), types.NewBooleanValue(false), false}, {blobV, nil, true}, - {arrayV, nil, true}, - {docV, nil, true}, }) }) @@ -73,8 +65,6 @@ func TestCastAs(t *testing.T) { {types.NewTextValue("10"), integerV, false}, {types.NewTextValue("10.5"), integerV, false}, {blobV, nil, true}, - {arrayV, nil, true}, - {docV, nil, true}, {types.NewDoubleValue(math.MaxInt64 + 1), nil, true}, }) }) @@ -88,8 +78,6 @@ func TestCastAs(t *testing.T) { {types.NewTextValue("10"), types.NewDoubleValue(10), false}, {types.NewTextValue("10.5"), doubleV, false}, {blobV, nil, true}, - {arrayV, nil, true}, - {docV, nil, true}, }) }) @@ -100,8 +88,6 @@ func TestCastAs(t *testing.T) { {doubleV, nil, true}, {types.NewTextValue(now.Format(time.RFC3339Nano)), tsV, false}, {blobV, nil, true}, - {arrayV, nil, true}, - {docV, nil, true}, }) }) @@ -112,10 +98,6 @@ func TestCastAs(t *testing.T) { {doubleV, types.NewTextValue("10.5"), false}, {textV, textV, false}, {blobV, types.NewTextValue(`YXNkaW5l`), false}, - {arrayV, types.NewTextValue(`["bar", 10]`), false}, - {docV, - types.NewTextValue(`{"a": 10, "b": "foo"}`), - false}, }) }) @@ -127,34 +109,6 @@ func TestCastAs(t *testing.T) { {types.NewTextValue("YXNkaW5l"), types.NewBlobValue([]byte{0x61, 0x73, 0x64, 0x69, 0x6e, 0x65}), false}, {types.NewTextValue("not base64"), nil, true}, {blobV, blobV, false}, - {arrayV, nil, true}, - {docV, nil, true}, - }) - }) - - t.Run("array", func(t *testing.T) { - check(t, types.TypeArray, []test{ - {boolV, nil, true}, - {integerV, nil, true}, - {doubleV, nil, true}, - {types.NewTextValue(`["bar", 10]`), arrayV, false}, - {types.NewTextValue("abc"), nil, true}, - {blobV, nil, true}, - {arrayV, arrayV, false}, - {docV, nil, true}, - }) - }) - - t.Run("object", func(t *testing.T) { - check(t, types.TypeObject, []test{ - {boolV, nil, true}, - {integerV, nil, true}, - {doubleV, nil, true}, - {types.NewTextValue(`{"a": 10, "b": "foo"}`), docV, false}, - {types.NewTextValue("abc"), nil, true}, - {blobV, nil, true}, - {arrayV, nil, true}, - {docV, docV, false}, }) }) } diff --git a/internal/types/comparable.go b/internal/types/comparable.go deleted file mode 100644 index 96f6c965c..000000000 --- a/internal/types/comparable.go +++ /dev/null @@ -1,222 +0,0 @@ -package types - -import ( - "strings" -) - -type Comparable interface { - EQ(other Value) (bool, error) - GT(other Value) (bool, error) - GTE(other Value) (bool, error) - LT(other Value) (bool, error) - LTE(other Value) (bool, error) - Between(a, b Value) (bool, error) -} - -type operator uint8 - -const ( - operatorEq operator = iota + 1 - operatorGt - operatorGte - operatorLt - operatorLte -) - -func compareArrays(op operator, l Array, r Array) (bool, error) { - var i, j int - - for { - lv, lerr := l.GetByIndex(i) - rv, rerr := r.GetByIndex(j) - if lerr == nil { - i++ - } - if rerr == nil { - j++ - } - if lerr != nil || rerr != nil { - break - } - - if lv.Type().IsComparableWith(rv.Type()) { - isEq, err := lv.EQ(rv) - if err != nil { - return false, err - } - if !isEq { - switch op { - case operatorEq: - return false, nil - case operatorGt: - return lv.GT(rv) - case operatorGte: - return lv.GTE(rv) - case operatorLt: - return lv.LT(rv) - case operatorLte: - return lv.LTE(rv) - } - } - } else { - switch op { - case operatorEq: - return false, nil - case operatorGt, operatorGte: - return lv.Type() > rv.Type(), nil - case operatorLt, operatorLte: - return lv.Type() < rv.Type(), nil - } - } - } - - switch { - case i > j: - switch op { - case operatorEq, operatorLt, operatorLte: - return false, nil - default: - return true, nil - } - case i < j: - switch op { - case operatorEq, operatorGt, operatorGte: - return false, nil - default: - return true, nil - } - default: - switch op { - case operatorEq, operatorGte, operatorLte: - return true, nil - default: - return false, nil - } - } -} - -func compareObjects(op operator, l, r Object) (bool, error) { - lf, err := Fields(l) - if err != nil { - return false, err - } - rf, err := Fields(r) - if err != nil { - return false, err - } - - if len(lf) == 0 && len(rf) > 0 { - switch op { - case operatorEq: - return false, nil - case operatorGt: - return false, nil - case operatorGte: - return false, nil - case operatorLt: - return true, nil - case operatorLte: - return true, nil - } - } - - if len(rf) == 0 && len(lf) > 0 { - switch op { - case operatorEq: - return false, nil - case operatorGt: - return true, nil - case operatorGte: - return true, nil - case operatorLt: - return false, nil - case operatorLte: - return false, nil - } - } - - var i, j int - - for i < len(lf) && j < len(rf) { - if cmp := strings.Compare(lf[i], rf[j]); cmp != 0 { - switch op { - case operatorEq: - return false, nil - case operatorGt: - return cmp > 0, nil - case operatorGte: - return cmp >= 0, nil - case operatorLt: - return cmp < 0, nil - case operatorLte: - return cmp <= 0, nil - } - } - - lv, lerr := l.GetByField(lf[i]) - rv, rerr := r.GetByField(rf[j]) - if lerr == nil { - i++ - } - if rerr == nil { - j++ - } - if lerr != nil || rerr != nil { - break - } - - if lv.Type().IsComparableWith(rv.Type()) { - isEq, err := lv.EQ(rv) - if err != nil { - return false, err - } - if !isEq { - switch op { - case operatorEq: - return false, nil - case operatorGt: - return lv.GT(rv) - case operatorGte: - return lv.GTE(rv) - case operatorLt: - return lv.LT(rv) - case operatorLte: - return lv.LTE(rv) - } - } - } else { - switch op { - case operatorEq: - return false, nil - case operatorGt, operatorGte: - return lv.Type() > rv.Type(), nil - case operatorLt, operatorLte: - return lv.Type() < rv.Type(), nil - } - } - } - - switch { - case i > j: - switch op { - case operatorEq, operatorLt, operatorLte: - return false, nil - default: - return true, nil - } - case i < j: - switch op { - case operatorEq, operatorGt, operatorGte: - return false, nil - default: - return true, nil - } - default: - switch op { - case operatorEq, operatorGte, operatorLte: - return true, nil - default: - return false, nil - } - } -} diff --git a/internal/types/comparable_test.go b/internal/types/comparable_test.go index 3b42d9f76..30d415e09 100644 --- a/internal/types/comparable_test.go +++ b/internal/types/comparable_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/chaisql/chai/internal/object" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/types" "github.com/golang-module/carbon/v2" @@ -14,13 +13,21 @@ import ( ) func jsonToInteger(t testing.TB, x string) types.Value { - var i int64 + var i int32 err := json.Unmarshal([]byte(x), &i) assert.NoError(t, err) return types.NewIntegerValue(i) } +func jsonToBigint(t testing.TB, x string) types.Value { + var i int64 + err := json.Unmarshal([]byte(x), &i) + assert.NoError(t, err) + + return types.NewBigintValue(i) +} + func jsonToDouble(t testing.TB, x string) types.Value { var f float64 err := json.Unmarshal([]byte(x), &f) @@ -55,22 +62,6 @@ func toBlob(t testing.TB, x string) types.Value { return types.NewBlobValue([]byte(x)) } -func jsonToArray(t testing.TB, x string) types.Value { - var vb object.ValueBuffer - err := json.Unmarshal([]byte(x), &vb) - assert.NoError(t, err) - - return types.NewArrayValue(&vb) -} - -func jsonToObject(t testing.TB, x string) types.Value { - var fb object.FieldBuffer - err := json.Unmarshal([]byte(x), &fb) - assert.NoError(t, err) - - return types.NewObjectValue(&fb) -} - var now = time.Now().Format(time.RFC3339Nano) var nowPlusOne = time.Now().Add(time.Second).Format(time.RFC3339Nano) @@ -117,6 +108,24 @@ func TestCompare(t *testing.T) { {"<=", "1", "2", true, jsonToInteger}, {"<=", "2", "2", true, jsonToInteger}, + // bigint + {"=", "2", "1", false, jsonToBigint}, + {"=", "2", "2", true, jsonToBigint}, + {"!=", "2", "1", true, jsonToBigint}, + {"!=", "2", "2", false, jsonToBigint}, + {">", "2", "1", true, jsonToBigint}, + {">", "1", "2", false, jsonToBigint}, + {">", "2", "2", false, jsonToBigint}, + {">=", "2", "1", true, jsonToBigint}, + {">=", "1", "2", false, jsonToBigint}, + {">=", "2", "2", true, jsonToBigint}, + {"<", "2", "1", false, jsonToBigint}, + {"<", "1", "2", true, jsonToBigint}, + {"<", "2", "2", false, jsonToBigint}, + {"<=", "2", "1", false, jsonToBigint}, + {"<=", "1", "2", true, jsonToBigint}, + {"<=", "2", "2", true, jsonToBigint}, + // double {"=", "2", "1", false, jsonToDouble}, {"=", "2", "2", true, jsonToDouble}, @@ -189,76 +198,6 @@ func TestCompare(t *testing.T) { {"<=", "b", "a", false, toBlob}, {"<=", "a", "b", true, toBlob}, {"<=", "b", "b", true, toBlob}, - - // array - {"=", `[]`, `[]`, true, jsonToArray}, - {"=", `[1]`, `[1]`, true, jsonToArray}, - {"=", `[1]`, `[]`, false, jsonToArray}, - {"=", `[1.0, 2]`, `[1, 2]`, true, jsonToArray}, - {"=", `[1,2,3]`, `[1,2,3]`, true, jsonToArray}, - {"!=", `[1]`, `[5]`, true, jsonToArray}, - {"!=", `[1]`, `[1, 1]`, true, jsonToArray}, - {"!=", `[1,2,3]`, `[1,2,3]`, false, jsonToArray}, - {"!=", `[1]`, `[]`, true, jsonToArray}, - {">", `[2]`, `[1]`, true, jsonToArray}, - {">", `[2]`, `[1, 1000]`, true, jsonToArray}, - {">", `[1]`, `[1, 1000]`, false, jsonToArray}, - {">", `[1, 2]`, `[1, 1000]`, false, jsonToArray}, - {">", `[1, 10]`, `[1, true]`, true, jsonToArray}, - {">", `[1, true]`, `[1, 10]`, false, jsonToArray}, - {">", `[2, 1000]`, `[1]`, true, jsonToArray}, - {">", `[2, 1000]`, `[2]`, true, jsonToArray}, - {">", `[1,2,3]`, `[1,2,3]`, false, jsonToArray}, - {">", `[1,2,3]`, `[]`, true, jsonToArray}, - {">=", `[2]`, `[1]`, true, jsonToArray}, - {">=", `[2]`, `[2]`, true, jsonToArray}, - {">=", `[2]`, `[1, 1000]`, true, jsonToArray}, - {">=", `[1]`, `[1, 1000]`, false, jsonToArray}, - {">=", `[1, 2]`, `[1, 2]`, true, jsonToArray}, - {">=", `[1, 2]`, `[1, 1000]`, false, jsonToArray}, - {">=", `[1, 10]`, `[1, true]`, true, jsonToArray}, - {">=", `[1, true]`, `[1, 10]`, false, jsonToArray}, - {">=", `[2, 1000]`, `[1]`, true, jsonToArray}, - {">=", `[2, 1000]`, `[2]`, true, jsonToArray}, - {">=", `[1,2,3]`, `[1,2,3]`, true, jsonToArray}, - {">=", `[1,2,3]`, `[]`, true, jsonToArray}, - {"<", `[1]`, `[2]`, true, jsonToArray}, - {"<", `[1,2,3]`, `[1,2]`, false, jsonToArray}, - {"<", `[1,2,3]`, `[1,2,3]`, false, jsonToArray}, - {"<", `[1,2]`, `[1,2,3]`, true, jsonToArray}, - {"<", `[1, 1000]`, `[2]`, true, jsonToArray}, - {"<", `[2]`, `[2, 1000]`, true, jsonToArray}, - {"<", `[1,2,3]`, `[]`, false, jsonToArray}, - {"<", `[]`, `[1,2,3]`, true, jsonToArray}, - {"<", `[1, 10]`, `[1, true]`, false, jsonToArray}, - {"<", `[1, true]`, `[1, 10]`, true, jsonToArray}, - {"<=", `[1]`, `[2]`, true, jsonToArray}, - {"<=", `[1, 1000]`, `[2]`, true, jsonToArray}, - {"<=", `[1,2,3]`, `[1,2]`, false, jsonToArray}, - {">=", `[2]`, `[1]`, true, jsonToArray}, - {">=", `[2]`, `[2]`, true, jsonToArray}, - {">=", `[2]`, `[1, 1000]`, true, jsonToArray}, - {">=", `[2, 1000]`, `[1]`, true, jsonToArray}, - {"<=", `[1,2,3]`, `[1,2,3]`, true, jsonToArray}, - {"<=", `[]`, `[]`, true, jsonToArray}, - {"<=", `[]`, `[1,2,3]`, true, jsonToArray}, - - // object - {"=", `{}`, `{}`, true, jsonToObject}, - {"=", `{"a": 1}`, `{"a": 1}`, true, jsonToObject}, - {"=", `{"a": 1.0}`, `{"a": 1}`, true, jsonToObject}, - {"=", `{"a": 1, "b": 2}`, `{"b": 2, "a": 1}`, true, jsonToObject}, - {"=", `{"a": 1, "b": {"a": 1}}`, `{"b": {"a": 1}, "a": 1}`, true, jsonToObject}, - {">", `{"a": 2}`, `{"a": 1}`, true, jsonToObject}, - {">", `{"b": 1}`, `{"a": 1}`, true, jsonToObject}, - {">", `{"a": 1}`, `{"a": 1}`, false, jsonToObject}, - {">", `{"a": 1}`, `{"a": true}`, true, jsonToObject}, - {"<", `{"a": 1}`, `{"a": 2}`, true, jsonToObject}, - {"<", `{"a": 1}`, `{"b": 1}`, true, jsonToObject}, - {"<", `{"a": 1}`, `{"a": 1}`, false, jsonToObject}, - {"<", `{"a": 1}`, `{"a": true}`, false, jsonToObject}, - {">=", `{"a": 1}`, `{"a": 1}`, true, jsonToObject}, - {"<=", `{"a": 1}`, `{"a": 1}`, true, jsonToObject}, } for _, test := range tests { diff --git a/internal/types/double.go b/internal/types/double.go index cf9a289d7..c877187e9 100644 --- a/internal/types/double.go +++ b/internal/types/double.go @@ -3,8 +3,36 @@ package types import ( "math" "strconv" + + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" ) +var _ TypeDefinition = DoubleTypeDef{} + +type DoubleTypeDef struct{} + +func (DoubleTypeDef) New(v any) Value { + return NewDoubleValue(v.(float64)) +} + +func (DoubleTypeDef) Type() Type { + return TypeDouble +} + +func (DoubleTypeDef) Decode(src []byte) (Value, int) { + x, n := encoding.DecodeFloat(src) + return NewDoubleValue(x), n +} + +func (DoubleTypeDef) IsComparableWith(other Type) bool { + return other == TypeDouble || other == TypeInteger || other == TypeBigint +} + +func (DoubleTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeDouble +} + var _ Numeric = NewDoubleValue(0) type DoubleValue float64 @@ -22,6 +50,10 @@ func (v DoubleValue) Type() Type { return TypeDouble } +func (v DoubleValue) TypeDef() TypeDefinition { + return DoubleTypeDef{} +} + func (v DoubleValue) IsZero() (bool, error) { return v == 0, nil } @@ -66,12 +98,47 @@ func (v DoubleValue) MarshalJSON() ([]byte, error) { return strconv.AppendFloat(nil, AsFloat64(v), fmt, prec, 64), nil } +func (v DoubleValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeFloat(dst, float64(v)), nil +} + +func (v DoubleValue) EncodeAsKey(dst []byte) ([]byte, error) { + return encoding.EncodeFloat64(dst, float64(v)), nil +} + +func (v DoubleValue) CastAs(target Type) (Value, error) { + switch target { + case TypeDouble: + return v, nil + case TypeInteger: + f := float64(v) + if f > 0 && (int32(f) < 0 || f >= math.MaxInt32) { + return nil, errors.New("integer out of range") + } + return NewIntegerValue(int32(v)), nil + case TypeBigint: + f := float64(v) + if f > 0 && (int64(f) < 0 || f >= math.MaxInt64) { + return nil, errors.New("integer out of range") + } + return NewBigintValue(int64(v)), nil + case TypeText: + enc, err := v.MarshalJSON() + if err != nil { + return nil, err + } + return NewTextValue(string(enc)), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + func (v DoubleValue) EQ(other Value) (bool, error) { t := other.Type() switch t { case TypeDouble: return float64(v) == AsFloat64(other), nil - case TypeInteger: + case TypeInteger, TypeBigint: return float64(v) == float64(AsInt64(other)), nil default: return false, nil @@ -83,7 +150,7 @@ func (v DoubleValue) GT(other Value) (bool, error) { switch t { case TypeDouble: return float64(v) > AsFloat64(other), nil - case TypeInteger: + case TypeInteger, TypeBigint: return float64(v) > float64(AsInt64(other)), nil default: return false, nil @@ -95,7 +162,7 @@ func (v DoubleValue) GTE(other Value) (bool, error) { switch t { case TypeDouble: return float64(v) >= AsFloat64(other), nil - case TypeInteger: + case TypeInteger, TypeBigint: return float64(v) >= float64(AsInt64(other)), nil default: return false, nil @@ -107,7 +174,7 @@ func (v DoubleValue) LT(other Value) (bool, error) { switch t { case TypeDouble: return float64(v) < AsFloat64(other), nil - case TypeInteger: + case TypeInteger, TypeBigint: return float64(v) < float64(AsInt64(other)), nil default: return false, nil @@ -119,7 +186,7 @@ func (v DoubleValue) LTE(other Value) (bool, error) { switch t { case TypeDouble: return float64(v) <= AsFloat64(other), nil - case TypeInteger: + case TypeInteger, TypeBigint: return float64(v) <= float64(AsInt64(other)), nil default: return false, nil @@ -141,7 +208,7 @@ func (v DoubleValue) Between(a, b Value) (bool, error) { func (v DoubleValue) Add(other Numeric) (Value, error) { switch other.Type() { - case TypeInteger: + case TypeInteger, TypeBigint: return NewDoubleValue(float64(v) + float64(AsInt64(other))), nil case TypeDouble: return NewDoubleValue(float64(v) + AsFloat64(other)), nil @@ -152,7 +219,7 @@ func (v DoubleValue) Add(other Numeric) (Value, error) { func (v DoubleValue) Sub(other Numeric) (Value, error) { switch other.Type() { - case TypeInteger: + case TypeInteger, TypeBigint: return NewDoubleValue(float64(v) - float64(AsInt64(other))), nil case TypeDouble: return NewDoubleValue(float64(v) - AsFloat64(other)), nil @@ -163,7 +230,7 @@ func (v DoubleValue) Sub(other Numeric) (Value, error) { func (v DoubleValue) Mul(other Numeric) (Value, error) { switch other.Type() { - case TypeInteger: + case TypeInteger, TypeBigint: return NewDoubleValue(float64(v) * float64(AsInt64(other))), nil case TypeDouble: return NewDoubleValue(float64(v) * AsFloat64(other)), nil @@ -174,7 +241,7 @@ func (v DoubleValue) Mul(other Numeric) (Value, error) { func (v DoubleValue) Div(other Numeric) (Value, error) { switch other.Type() { - case TypeInteger: + case TypeInteger, TypeBigint: xb := float64(AsInt64(other)) if xb == 0 { return NewNullValue(), nil @@ -195,7 +262,7 @@ func (v DoubleValue) Div(other Numeric) (Value, error) { func (v DoubleValue) Mod(other Numeric) (Value, error) { switch other.Type() { - case TypeInteger: + case TypeInteger, TypeBigint: xb := float64(AsInt64(other)) xr := math.Mod(float64(v), xb) if math.IsNaN(xr) { @@ -215,42 +282,3 @@ func (v DoubleValue) Mod(other Numeric) (Value, error) { return NewNullValue(), nil } - -func (v DoubleValue) BitwiseAnd(other Numeric) (Value, error) { - switch other.Type() { - case TypeInteger: - return NewIntegerValue(int64(v) & AsInt64(other)), nil - case TypeDouble: - xa := int64(v) - xb := int64(AsFloat64(other)) - return NewIntegerValue(xa & xb), nil - } - - return NewNullValue(), nil -} - -func (v DoubleValue) BitwiseOr(other Numeric) (Value, error) { - switch other.Type() { - case TypeInteger: - return NewIntegerValue(int64(v) | AsInt64(other)), nil - case TypeDouble: - xa := int64(v) - xb := int64(AsFloat64(other)) - return NewIntegerValue(xa | xb), nil - } - - return NewNullValue(), nil -} - -func (v DoubleValue) BitwiseXor(other Numeric) (Value, error) { - switch other.Type() { - case TypeInteger: - return NewIntegerValue(int64(v) ^ AsInt64(other)), nil - case TypeDouble: - xa := int64(v) - xb := int64(AsFloat64(other)) - return NewIntegerValue(xa ^ xb), nil - } - - return NewNullValue(), nil -} diff --git a/internal/types/encoding.go b/internal/types/encoding.go new file mode 100644 index 000000000..8a71eb000 --- /dev/null +++ b/internal/types/encoding.go @@ -0,0 +1,88 @@ +package types + +import ( + "github.com/chaisql/chai/internal/encoding" +) + +var encodedTypeToTypeDefs = map[byte]TypeDefinition{ + encoding.NullValue: NullTypeDef{}, + encoding.FalseValue: BooleanTypeDef{}, + encoding.TrueValue: BooleanTypeDef{}, + encoding.Int8Value: IntegerTypeDef{}, + encoding.Int16Value: IntegerTypeDef{}, + encoding.Int32Value: IntegerTypeDef{}, + encoding.Int64Value: BigintTypeDef{}, + encoding.Uint8Value: IntegerTypeDef{}, + encoding.Uint16Value: IntegerTypeDef{}, + encoding.Uint32Value: IntegerTypeDef{}, + encoding.Uint64Value: BigintTypeDef{}, + encoding.Float64Value: DoubleTypeDef{}, + encoding.TextValue: TextTypeDef{}, + encoding.BlobValue: BlobTypeDef{}, +} + +func DecodeValue(b []byte) (v Value, n int) { + t := b[0] + // deal with descending values + if t > 128 { + t = 255 - t + } + + if t >= encoding.IntSmallValue && t < encoding.Uint8Value { + return IntegerTypeDef{}.Decode(b) + } + + return encodedTypeToTypeDefs[t].Decode(b) +} + +func DecodeValues(b []byte) []Value { + var values []Value + + for len(b) > 0 { + v, n := DecodeValue(b) + values = append(values, v) + b = b[n:] + } + + return values +} + +func EncodeValuesAsKey(dst []byte, values ...Value) ([]byte, error) { + var err error + + for _, v := range values { + dst, err = EncodeValueAsKey(dst, v, false) + if err != nil { + return nil, err + } + } + + return dst, nil +} + +func EncodeValueAsKey(dst []byte, v Value, desc bool) ([]byte, error) { + newDst, err := v.EncodeAsKey(dst) + if err != nil { + return nil, err + } + + if desc { + newDst, _ = Desc(newDst, len(newDst)-len(dst)) + } + + return newDst, nil +} + +// Desc changes the type of the encoded value to its descending counterpart. +// It is meant to be used in combination with one of the Encode* functions. +// +// var buf []byte +// buf, n = encoding.Desc(encoding.EncodeInt(buf, 10)) +func Desc(dst []byte, n int) ([]byte, int) { + if n == 0 { + return dst, 0 + } + + dst[len(dst)-n] = 255 - dst[len(dst)-n] + return dst, n +} diff --git a/internal/encoding/encoding_test.go b/internal/types/encoding_test.go similarity index 70% rename from internal/encoding/encoding_test.go rename to internal/types/encoding_test.go index 1ef1a21e6..33b55461d 100644 --- a/internal/encoding/encoding_test.go +++ b/internal/types/encoding_test.go @@ -1,4 +1,4 @@ -package encoding_test +package types_test import ( "fmt" @@ -6,7 +6,7 @@ import ( "testing" "github.com/chaisql/chai/internal/encoding" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/tree" "github.com/chaisql/chai/internal/types" "github.com/stretchr/testify/require" @@ -71,41 +71,26 @@ func TestOrdering(t *testing.T) { } func TestEncodeDecode(t *testing.T) { - userMapDoc := object.NewFromMap(map[string]any{ + userMapDoc := row.NewFromMap(map[string]any{ "age": 10, "name": "john", }) - addressMapDoc := object.NewFromMap(map[string]any{ - "city": "Ajaccio", - "country": "France", - }) - - complexArray := object.NewValueBuffer(). - Append(types.NewBooleanValue(true)). - Append(types.NewIntegerValue(-40)). - Append(types.NewDoubleValue(-3.14)). - Append(types.NewDoubleValue(3)). - Append(types.NewBlobValue([]byte("blob"))). - Append(types.NewTextValue("hello")). - Append(types.NewObjectValue(addressMapDoc)). - Append(types.NewArrayValue(object.NewValueBuffer().Append(types.NewIntegerValue(11)))) - tests := []struct { name string - d types.Object + r row.Row expected string fails bool }{ { "empty doc", - object.NewFieldBuffer(), + row.NewColumnBuffer(), `{}`, false, }, { - "object.FieldBuffer", - object.NewFieldBuffer(). + "row.ColumnBuffer", + row.NewColumnBuffer(). Add("age", types.NewIntegerValue(10)). Add("name", types.NewTextValue("john")), `{"age": 10, "name": "john"}`, @@ -118,36 +103,26 @@ func TestEncodeDecode(t *testing.T) { false, }, { - "duplicate field name", - object.NewFieldBuffer(). + "duplicate column name", + row.NewColumnBuffer(). Add("age", types.NewIntegerValue(10)). Add("age", types.NewIntegerValue(10)), - ``, - true, - }, - { - "Nested types.Object", - object.NewFieldBuffer(). - Add("age", types.NewIntegerValue(10)). - Add("name", types.NewTextValue("john")). - Add("address", types.NewObjectValue(addressMapDoc)). - Add("array", types.NewArrayValue(complexArray)), - `{"age": 10, "name": "john", "address": {"city": "Ajaccio", "country": "France"}, "array": [true, -40, -3.14, 3, "YmxvYg==", "hello", {"city": "Ajaccio", "country": "France"}, [11]]}`, + `{"age": 10}`, false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - buf, err := encoding.EncodeValue(nil, types.NewObjectValue(test.d), false) + buf, err := types.EncodeValuesAsKey(nil, row.Flatten(test.r)...) if test.fails { require.Error(t, err) return } require.NoError(t, err) - v, _ := encoding.DecodeValue(buf, false) + r := row.Unflatten(types.DecodeValues(buf)) - data, err := v.MarshalJSON() + data, err := r.MarshalJSON() require.NoError(t, err) require.JSONEq(t, test.expected, string(data)) }) diff --git a/internal/types/integer.go b/internal/types/integer.go index dc3e05c5c..91f2349c9 100644 --- a/internal/types/integer.go +++ b/internal/types/integer.go @@ -3,25 +3,63 @@ package types import ( "math" "strconv" + + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" ) +var _ TypeDefinition = IntegerTypeDef{} + +type IntegerTypeDef struct{} + +func (IntegerTypeDef) New(v any) Value { + return NewIntegerValue(v.(int32)) +} + +func (IntegerTypeDef) Type() Type { + return TypeInteger +} + +func (IntegerTypeDef) Decode(src []byte) (Value, int) { + x, n := encoding.DecodeInt(src) + if x < math.MinInt32 || x > math.MaxInt32 { + panic(errors.New("integer out of range")) + } + + return NewIntegerValue(int32(x)), n +} + +func (IntegerTypeDef) IsComparableWith(other Type) bool { + return other == TypeInteger || other == TypeBigint || other == TypeDouble +} + +func (IntegerTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeInteger || other == TypeBigint +} + var _ Numeric = NewIntegerValue(0) +var _ Integral = NewIntegerValue(0) +var _ Value = NewIntegerValue(0) -type IntegerValue int64 +type IntegerValue int32 // NewIntegerValue returns a SQL INTEGER value. -func NewIntegerValue(x int64) IntegerValue { +func NewIntegerValue(x int32) IntegerValue { return IntegerValue(x) } func (v IntegerValue) V() any { - return int64(v) + return int32(v) } func (v IntegerValue) Type() Type { return TypeInteger } +func (v IntegerValue) TypeDef() TypeDefinition { + return IntegerTypeDef{} +} + func (v IntegerValue) IsZero() (bool, error) { return v == 0, nil } @@ -38,13 +76,40 @@ func (v IntegerValue) MarshalJSON() ([]byte, error) { return v.MarshalText() } +func (v IntegerValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeInt(dst, int64(v)), nil +} + +func (v IntegerValue) EncodeAsKey(dst []byte) ([]byte, error) { + return v.Encode(dst) +} + +func (v IntegerValue) CastAs(target Type) (Value, error) { + switch target { + case TypeInteger: + return v, nil + case TypeBoolean: + return NewBooleanValue(int32(v) != 0), nil + case TypeBigint: + return NewBigintValue(int64(v)), nil + case TypeDouble: + return NewDoubleValue(float64(v)), nil + case TypeText: + return NewTextValue(v.String()), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + func (v IntegerValue) EQ(other Value) (bool, error) { t := other.Type() switch t { case TypeInteger: + return int32(v) == AsInt32(other), nil + case TypeBigint: return int64(v) == AsInt64(other), nil case TypeDouble: - return float64(int64(v)) == AsFloat64(other), nil + return float64(int32(v)) == AsFloat64(other), nil default: return false, nil } @@ -54,9 +119,11 @@ func (v IntegerValue) GT(other Value) (bool, error) { t := other.Type() switch t { case TypeInteger: + return int32(v) > AsInt32(other), nil + case TypeBigint: return int64(v) > AsInt64(other), nil case TypeDouble: - return float64(int64(v)) > AsFloat64(other), nil + return float64(int32(v)) > AsFloat64(other), nil default: return false, nil } @@ -66,9 +133,11 @@ func (v IntegerValue) GTE(other Value) (bool, error) { t := other.Type() switch t { case TypeInteger: + return int32(v) >= AsInt32(other), nil + case TypeBigint: return int64(v) >= AsInt64(other), nil case TypeDouble: - return float64(int64(v)) >= AsFloat64(other), nil + return float64(int32(v)) >= AsFloat64(other), nil default: return false, nil } @@ -78,9 +147,11 @@ func (v IntegerValue) LT(other Value) (bool, error) { t := other.Type() switch t { case TypeInteger: + return int32(v) < AsInt32(other), nil + case TypeBigint: return int64(v) < AsInt64(other), nil case TypeDouble: - return float64(int64(v)) <= AsFloat64(other), nil + return float64(int32(v)) <= AsFloat64(other), nil default: return false, nil } @@ -90,9 +161,11 @@ func (v IntegerValue) LTE(other Value) (bool, error) { t := other.Type() switch t { case TypeInteger: + return int32(v) <= AsInt32(other), nil + case TypeBigint: return int64(v) <= AsInt64(other), nil case TypeDouble: - return float64(int64(v)) <= AsFloat64(other), nil + return float64(int32(v)) <= AsFloat64(other), nil default: return false, nil } @@ -114,17 +187,25 @@ func (v IntegerValue) Between(a, b Value) (bool, error) { func (v IntegerValue) Add(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: + xa := int32(v) + xb := AsInt32(other) + if isAddOverflow(xa, xb, math.MinInt32, math.MaxInt32) { + return nil, errors.New("integer out of range") + } + + xr := xa + xb + return NewIntegerValue(xr), nil + case TypeBigint: xa := int64(v) xb := AsInt64(other) - xr := xa + xb - // if there is an integer overflow - // convert to float - if (xr > xa) != (xb > 0) { - return NewDoubleValue(float64(xa) + float64(xb)), nil + if isAddOverflow(xa, xb, math.MinInt64, math.MaxInt64) { + return nil, errors.New("bigint out of range") } - return NewIntegerValue(xr), nil + + xr := xa + xb + return NewBigintValue(xr), nil case TypeDouble: - return NewDoubleValue(float64(int64(v)) + AsFloat64(other)), nil + return NewDoubleValue(float64(int32(v)) + AsFloat64(other)), nil } return NewNullValue(), nil @@ -133,17 +214,24 @@ func (v IntegerValue) Add(other Numeric) (Value, error) { func (v IntegerValue) Sub(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: + xa := int32(v) + xb := AsInt32(other) + if isSubOverflow(xa, xb, math.MinInt32, math.MaxInt32) { + return nil, errors.New("integer out of range") + } + + xr := xa - xb + return NewIntegerValue(xr), nil + case TypeBigint: xa := int64(v) xb := AsInt64(other) - xr := xa - xb - // if there is an integer overflow - // convert to float - if (xr < xa) != (xb > 0) { - return NewDoubleValue(float64(xa) - float64(xb)), nil + if isSubOverflow(xa, xb, math.MinInt64, math.MaxInt64) { + return nil, errors.New("bigint out of range") } - return NewIntegerValue(xr), nil + xr := xa - xb + return NewBigintValue(xr), nil case TypeDouble: - return NewDoubleValue(float64(int64(v)) - AsFloat64(other)), nil + return NewDoubleValue(float64(int32(v)) - AsFloat64(other)), nil } return NewNullValue(), nil @@ -152,24 +240,25 @@ func (v IntegerValue) Sub(other Numeric) (Value, error) { func (v IntegerValue) Mul(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: - xa := int64(v) - xb := AsInt64(other) - if xa == 0 || xb == 0 { - return NewIntegerValue(0), nil + xa := int32(v) + xb := AsInt32(other) + if isMulOverflow(xa, xb, math.MinInt32, math.MaxInt32) { + return nil, errors.New("integer out of range") } xr := xa * xb - // if there is no integer overflow - // return an int, otherwise - // convert to float - if (xr < 0) == ((xa < 0) != (xb < 0)) { - if xr/xb == xa { - return NewIntegerValue(xr), nil - } + + return NewIntegerValue(xr), nil + case TypeBigint: + xa := int64(v) + xb := AsInt64(other) + if isMulOverflow(xa, xb, math.MinInt64, math.MaxInt64) { + return nil, errors.New("integer out of range") } - return NewDoubleValue(float64(xa) * float64(xb)), nil + xr := xa * xb + return NewBigintValue(xr), nil case TypeDouble: - return NewDoubleValue(float64(int64(v)) * AsFloat64(other)), nil + return NewDoubleValue(float64(int32(v)) * AsFloat64(other)), nil } return NewNullValue(), nil @@ -178,13 +267,21 @@ func (v IntegerValue) Mul(other Numeric) (Value, error) { func (v IntegerValue) Div(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: + xa := int32(v) + xb := AsInt32(other) + if xb == 0 { + return nil, errors.New("division by zero") + } + + return NewIntegerValue(xa / xb), nil + case TypeBigint: xa := int64(v) xb := AsInt64(other) if xb == 0 { - return NewNullValue(), nil + return nil, errors.New("division by zero") } - return NewIntegerValue(xa / xb), nil + return NewBigintValue(xa / xb), nil case TypeDouble: xa := float64(AsInt64(v)) xb := AsFloat64(other) @@ -201,13 +298,21 @@ func (v IntegerValue) Div(other Numeric) (Value, error) { func (v IntegerValue) Mod(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: + xa := int32(v) + xb := AsInt32(other) + if xb == 0 { + return NewNullValue(), nil + } + + return NewIntegerValue(xa % xb), nil + case TypeBigint: xa := int64(v) xb := AsInt64(other) if xb == 0 { return NewNullValue(), nil } - return NewIntegerValue(xa % xb), nil + return NewBigintValue(xa % xb), nil case TypeDouble: xa := float64(AsInt64(v)) xb := AsFloat64(other) @@ -225,10 +330,12 @@ func (v IntegerValue) Mod(other Numeric) (Value, error) { func (v IntegerValue) BitwiseAnd(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: - return NewIntegerValue(int64(v) & AsInt64(other)), nil + return NewIntegerValue(int32(v) & AsInt32(other)), nil + case TypeBigint: + return NewBigintValue(int64(v) & AsInt64(other)), nil case TypeDouble: - xa := int64(v) - xb := int64(AsFloat64(other)) + xa := int32(v) + xb := int32(AsFloat64(other)) return NewIntegerValue(xa & xb), nil } @@ -238,10 +345,12 @@ func (v IntegerValue) BitwiseAnd(other Numeric) (Value, error) { func (v IntegerValue) BitwiseOr(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: - return NewIntegerValue(int64(v) | AsInt64(other)), nil + return NewIntegerValue(int32(v) | AsInt32(other)), nil + case TypeBigint: + return NewBigintValue(int64(v) | AsInt64(other)), nil case TypeDouble: - xa := int64(v) - xb := int64(AsFloat64(other)) + xa := int32(v) + xb := int32(AsFloat64(other)) return NewIntegerValue(xa | xb), nil } @@ -251,10 +360,12 @@ func (v IntegerValue) BitwiseOr(other Numeric) (Value, error) { func (v IntegerValue) BitwiseXor(other Numeric) (Value, error) { switch other.Type() { case TypeInteger: - return NewIntegerValue(int64(v) ^ AsInt64(other)), nil + return NewIntegerValue(int32(v) ^ AsInt32(other)), nil + case TypeBigint: + return NewBigintValue(int64(v) ^ AsInt64(other)), nil case TypeDouble: - xa := int64(v) - xb := int64(AsFloat64(other)) + xa := int32(v) + xb := int32(AsFloat64(other)) return NewIntegerValue(xa ^ xb), nil } diff --git a/internal/types/null.go b/internal/types/null.go index 97f616383..598837752 100644 --- a/internal/types/null.go +++ b/internal/types/null.go @@ -1,10 +1,43 @@ package types +import ( + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" +) + +var _ TypeDefinition = NullTypeDef{} + +type NullTypeDef struct{} + +func (NullTypeDef) New(v any) Value { + return NewNullValue() +} + +func (NullTypeDef) Type() Type { + return TypeNull +} + +func (NullTypeDef) Decode(src []byte) (Value, int) { + if src[0] != encoding.NullValue && src[0] != encoding.DESC_NullValue { + panic(errors.New("invalid encoded null value")) + } + + return NewNullValue(), 1 +} + +func (NullTypeDef) IsComparableWith(other Type) bool { + return other == TypeNull +} + +func (NullTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeNull +} + var _ Value = NewNullValue() type NullValue struct{} -// NewNullValue returns a SQL BOOLEAN value. +// NewNullValue returns a SQL NULL value. func NewNullValue() NullValue { return NullValue{} } @@ -17,6 +50,10 @@ func (v NullValue) Type() Type { return TypeNull } +func (v NullValue) TypeDef() TypeDefinition { + return NullTypeDef{} +} + func (v NullValue) IsZero() (bool, error) { return false, nil } @@ -33,6 +70,18 @@ func (v NullValue) MarshalJSON() ([]byte, error) { return []byte("null"), nil } +func (v NullValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeNull(dst), nil +} + +func (v NullValue) EncodeAsKey(dst []byte) ([]byte, error) { + return v.Encode(dst) +} + +func (v NullValue) CastAs(target Type) (Value, error) { + return v, nil +} + func (v NullValue) EQ(other Value) (bool, error) { return other.Type() == TypeNull, nil } diff --git a/internal/types/numeric.go b/internal/types/numeric.go index 56a3fff3d..a4ec1519b 100644 --- a/internal/types/numeric.go +++ b/internal/types/numeric.go @@ -20,6 +20,9 @@ type Numeric interface { // Only numeric values and booleans can be calculated together. // If both v and u are integers, the result will be an integer. Mod(other Numeric) (Value, error) +} + +type Integral interface { // BitwiseAnd calculates v & u and returns the result. // Only numeric values and booleans can be calculated together. // If both v and u are integers, the result will be an integer. @@ -33,3 +36,45 @@ type Numeric interface { // If both v and u are integers, the result will be an integer. BitwiseXor(other Numeric) (Value, error) } + +func isMulOverflow[T int32 | int64](left, right, min, max T) bool { + if right > 0 { + if left > max/right { + return true + } + } else { + if left < min/right { + return true + } + } + + return false +} + +func isAddOverflow[T int32 | int64](left, right, min, max T) bool { + if right > 0 { + if left > max-right { + return true + } + } else { + if left < min-right { + return true + } + } + + return false +} + +func isSubOverflow[T int32 | int64](left, right, min, max T) bool { + if right > 0 { + if left < min+right { + return true + } + } else { + if left > max+right { + return true + } + } + + return false +} diff --git a/internal/types/object.go b/internal/types/object.go deleted file mode 100644 index d4b19bd0c..000000000 --- a/internal/types/object.go +++ /dev/null @@ -1,134 +0,0 @@ -package types - -import ( - "sort" - - "github.com/cockroachdb/errors" -) - -var _ Value = NewObjectValue(nil) - -type ObjectValue struct { - o Object -} - -// NewObjectValue returns a SQL INTEGER value. -func NewObjectValue(x Object) *ObjectValue { - return &ObjectValue{ - o: x, - } -} - -func (o *ObjectValue) V() any { - return o.o -} - -func (o *ObjectValue) Type() Type { - return TypeObject -} - -func (v *ObjectValue) IsZero() (bool, error) { - err := v.o.Iterate(func(_ string, _ Value) error { - // We return an error in the first iteration to stop it. - return errors.WithStack(errStop) - }) - if err == nil { - // If err is nil, it means that we didn't iterate, - // thus the object is empty. - return true, nil - } - if errors.Is(err, errStop) { - // If err is errStop, it means that we iterate - // at least once, thus the object is not empty. - return false, nil - } - // An unexpecting error occurs, let's return it! - return false, err -} - -func (o *ObjectValue) String() string { - data, _ := o.MarshalText() - return string(data) -} - -func (o *ObjectValue) MarshalText() ([]byte, error) { - return MarshalTextIndent(o, "", "") -} - -func (o *ObjectValue) MarshalJSON() ([]byte, error) { - return jsonObject{Object: o.o}.MarshalJSON() -} - -func (v *ObjectValue) EQ(other Value) (bool, error) { - t := other.Type() - if t != TypeObject { - return false, nil - } - - return compareObjects(operatorEq, v.o, AsObject(other)) -} - -func (v *ObjectValue) GT(other Value) (bool, error) { - t := other.Type() - if t != TypeObject { - return false, nil - } - - return compareObjects(operatorGt, v.o, AsObject(other)) -} - -func (v *ObjectValue) GTE(other Value) (bool, error) { - t := other.Type() - if t != TypeObject { - return false, nil - } - - return compareObjects(operatorGte, v.o, AsObject(other)) -} - -func (v *ObjectValue) LT(other Value) (bool, error) { - t := other.Type() - if t != TypeObject { - return false, nil - } - - return compareObjects(operatorLt, v.o, AsObject(other)) -} - -func (v *ObjectValue) LTE(other Value) (bool, error) { - t := other.Type() - if t != TypeObject { - return false, nil - } - - return compareObjects(operatorLte, v.o, AsObject(other)) -} - -func (v *ObjectValue) Between(a, b Value) (bool, error) { - if a.Type() != TypeObject || b.Type() != TypeObject { - return false, nil - } - - ok, err := a.LTE(v) - if err != nil || !ok { - return false, err - } - - return b.GTE(v) -} - -// Fields returns a list of all the fields at the root of the object -// sorted lexicographically. -func Fields(o Object) ([]string, error) { - var fields []string - err := o.Iterate(func(f string, _ Value) error { - fields = append(fields, f) - return nil - }) - if err != nil { - return nil, err - } - - sort.Strings(fields) - return fields, nil -} diff --git a/internal/types/text.go b/internal/types/text.go index cf293468f..008ac6351 100644 --- a/internal/types/text.go +++ b/internal/types/text.go @@ -1,10 +1,40 @@ package types import ( + "encoding/base64" + "fmt" "strconv" "strings" + + "github.com/chaisql/chai/internal/encoding" + "github.com/cockroachdb/errors" ) +var _ TypeDefinition = TextTypeDef{} + +type TextTypeDef struct{} + +func (TextTypeDef) New(v any) Value { + return NewTextValue(v.(string)) +} + +func (TextTypeDef) Type() Type { + return TypeText +} + +func (TextTypeDef) Decode(src []byte) (Value, int) { + x, n := encoding.DecodeText(src) + return NewTextValue(x), n +} + +func (TextTypeDef) IsComparableWith(other Type) bool { + return other == TypeNull || other == TypeText || other == TypeBoolean || other == TypeInteger || other == TypeBigint || other == TypeDouble || other == TypeTimestamp || other == TypeBlob +} + +func (t TextTypeDef) IsIndexComparableWith(other Type) bool { + return t.IsComparableWith(other) +} + var _ Value = NewTextValue("") type TextValue string @@ -22,6 +52,10 @@ func (v TextValue) Type() Type { return TypeText } +func (v TextValue) TypeDef() TypeDefinition { + return TextTypeDef{} +} + func (v TextValue) IsZero() (bool, error) { return v == "", nil } @@ -38,6 +72,71 @@ func (v TextValue) MarshalJSON() ([]byte, error) { return v.MarshalText() } +func (v TextValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeText(dst, string(v)), nil +} + +func (v TextValue) EncodeAsKey(dst []byte) ([]byte, error) { + return v.Encode(dst) +} + +func (v TextValue) CastAs(target Type) (Value, error) { + switch target { + case TypeText: + return v, nil + case TypeBoolean: + b, err := strconv.ParseBool(string(v)) + if err != nil { + return nil, errors.Errorf(`cannot cast %q as bool: %w`, v.V(), err) + } + return NewBooleanValue(b), nil + case TypeInteger: + i, err := strconv.ParseInt(string(v), 10, 32) + if err != nil { + intErr := err + f, err := strconv.ParseFloat(string(v), 64) + if err != nil { + return nil, errors.Errorf(`cannot cast %q as integer: %w`, v.V(), intErr) + } + i = int64(f) + } + return NewIntegerValue(int32(i)), nil + case TypeBigint: + i, err := strconv.ParseInt(string(v), 10, 64) + if err != nil { + intErr := err + f, err := strconv.ParseFloat(string(v), 64) + if err != nil { + return nil, fmt.Errorf(`cannot cast %q as bigint: %w`, v.V(), intErr) + } + i = int64(f) + } + return NewBigintValue(i), nil + case TypeDouble: + f, err := strconv.ParseFloat(string(v), 64) + if err != nil { + return nil, fmt.Errorf(`cannot cast %q as double: %w`, v.V(), err) + } + return NewDoubleValue(f), nil + case TypeTimestamp: + t, err := ParseTimestamp(string(v)) + if err != nil { + return nil, fmt.Errorf(`cannot cast %q as timestamp: %w`, v.V(), err) + } + return NewTimestampValue(t), nil + case TypeBlob: + s := string(v) + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, err + } + + return NewBlobValue(b), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + func (v TextValue) EQ(other Value) (bool, error) { t := other.Type() switch t { diff --git a/internal/types/timestamp.go b/internal/types/timestamp.go index 950a71d53..5ed2d1076 100644 --- a/internal/types/timestamp.go +++ b/internal/types/timestamp.go @@ -5,10 +5,36 @@ import ( "strconv" "time" + "github.com/chaisql/chai/internal/encoding" "github.com/cockroachdb/errors" "github.com/golang-module/carbon/v2" ) +var _ TypeDefinition = TimestampTypeDef{} + +type TimestampTypeDef struct{} + +func (TimestampTypeDef) New(v any) Value { + return NewTimestampValue(v.(time.Time)) +} + +func (TimestampTypeDef) Type() Type { + return TypeTimestamp +} + +func (t TimestampTypeDef) Decode(src []byte) (Value, int) { + ts, n := encoding.DecodeTimestamp(src) + return NewTimestampValue(ts), n +} + +func (TimestampTypeDef) IsComparableWith(other Type) bool { + return other == TypeTimestamp || other == TypeText +} + +func (TimestampTypeDef) IsIndexComparableWith(other Type) bool { + return other == TypeTimestamp +} + var _ Value = NewTimestampValue(time.Time{}) var ( @@ -32,6 +58,10 @@ func (v TimestampValue) Type() Type { return TypeTimestamp } +func (v TimestampValue) TypeDef() TypeDefinition { + return TimestampTypeDef{} +} + func (v TimestampValue) IsZero() (bool, error) { return time.Time(v).IsZero(), nil } @@ -48,6 +78,25 @@ func (v TimestampValue) MarshalJSON() ([]byte, error) { return v.MarshalText() } +func (v TimestampValue) Encode(dst []byte) ([]byte, error) { + return encoding.EncodeTimestamp(dst, time.Time(v)), nil +} + +func (v TimestampValue) EncodeAsKey(dst []byte) ([]byte, error) { + return v.Encode(dst) +} + +func (v TimestampValue) CastAs(target Type) (Value, error) { + switch target { + case TypeTimestamp: + return v, nil + case TypeText: + return NewTextValue(v.String()), nil + } + + return nil, errors.Errorf("cannot cast %s as %s", v.Type(), target) +} + func (v TimestampValue) EQ(other Value) (bool, error) { t := other.Type() switch t { diff --git a/internal/types/types.go b/internal/types/types.go index 695b51f15..536cb3803 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1,18 +1,16 @@ package types import ( + "fmt" + + "github.com/chaisql/chai/internal/encoding" "github.com/cockroachdb/errors" ) var ( - // ErrFieldNotFound must be returned by object implementations, when calling the GetByField method and - // the field wasn't found in the object. - ErrFieldNotFound = errors.New("field not found") - // ErrValueNotFound must be returned by Array implementations, when calling the GetByIndex method and - // the index wasn't found in the array. - ErrValueNotFound = errors.New("value not found") - - errStop = errors.New("stop") + // ErrColumnNotFound must be returned by row implementations, when calling the Get method and + // the column doesn't exist. + ErrColumnNotFound = errors.New("column not found") ) // Type represents a type supported by the database. @@ -25,14 +23,36 @@ const ( TypeNull TypeBoolean TypeInteger + TypeBigint TypeDouble TypeTimestamp TypeText TypeBlob - TypeArray - TypeObject ) +func (t Type) Def() TypeDefinition { + switch t { + case TypeNull: + return NullTypeDef{} + case TypeBoolean: + return BooleanTypeDef{} + case TypeInteger: + return IntegerTypeDef{} + case TypeBigint: + return BigintTypeDef{} + case TypeDouble: + return DoubleTypeDef{} + case TypeTimestamp: + return TimestampTypeDef{} + case TypeText: + return TextTypeDef{} + case TypeBlob: + return BlobTypeDef{} + } + + return nil +} + func (t Type) String() string { switch t { case TypeNull: @@ -41,6 +61,8 @@ func (t Type) String() string { return "boolean" case TypeInteger: return "integer" + case TypeBigint: + return "bigint" case TypeDouble: return "double" case TypeTimestamp: @@ -49,18 +71,108 @@ func (t Type) String() string { return "blob" case TypeText: return "text" - case TypeArray: - return "array" - case TypeObject: - return "object" } - return "any" + panic(fmt.Sprintf("unsupported type %#v", t)) +} + +func (t Type) MinEnctype() byte { + switch t { + case TypeNull: + return encoding.NullValue + case TypeBoolean: + return encoding.FalseValue + case TypeInteger: + return encoding.Int32Value + case TypeBigint: + return encoding.Int64Value + case TypeDouble: + return encoding.Float64Value + case TypeTimestamp: + return encoding.Int64Value + case TypeText: + return encoding.TextValue + case TypeBlob: + return encoding.BlobValue + default: + panic(fmt.Sprintf("unsupported type %v", t)) + } +} + +func (t Type) MinEnctypeDesc() byte { + switch t { + case TypeNull: + return encoding.DESC_NullValue + case TypeBoolean: + return encoding.DESC_TrueValue + case TypeInteger: + return encoding.DESC_Uint32Value + case TypeBigint: + return encoding.DESC_Uint64Value + case TypeDouble: + return encoding.DESC_Float64Value + case TypeTimestamp: + return encoding.DESC_Uint64Value + case TypeText: + return encoding.DESC_TextValue + case TypeBlob: + return encoding.DESC_BlobValue + default: + panic(fmt.Sprintf("unsupported type %v", t)) + } +} + +func (t Type) MaxEnctype() byte { + switch t { + case TypeNull: + return encoding.NullValue + 1 + case TypeBoolean: + return encoding.TrueValue + 1 + case TypeInteger: + return encoding.Uint32Value + 1 + case TypeBigint: + return encoding.Uint64Value + 1 + case TypeDouble: + return encoding.Float64Value + 1 + case TypeTimestamp: + return encoding.Uint64Value + 1 + case TypeText: + return encoding.TextValue + 1 + case TypeBlob: + return encoding.BlobValue + 1 + default: + panic(fmt.Sprintf("unsupported type %v", t)) + } +} + +func (t Type) MaxEnctypeDesc() byte { + switch t { + case TypeNull: + return encoding.DESC_NullValue + 1 + case TypeBoolean: + return encoding.DESC_FalseValue + 1 + case TypeInteger: + return encoding.DESC_Int64Value + 1 + case TypeDouble: + return encoding.DESC_Float64Value + 1 + case TypeTimestamp: + return encoding.DESC_Int64Value + 1 + case TypeText: + return encoding.DESC_TextValue + 1 + case TypeBlob: + return encoding.DESC_BlobValue + 1 + default: + panic(fmt.Sprintf("unsupported type %v", t)) + } } // IsNumber returns true if t is either an integer or a float. func (t Type) IsNumber() bool { - return t == TypeInteger || t == TypeDouble + return t == TypeInteger || t == TypeBigint || t == TypeDouble +} + +func (t Type) IsInteger() bool { + return t == TypeInteger || t == TypeBigint } // IsTimestampCompatible returns true if t is either a timestamp, an integer, or a text. @@ -98,31 +210,25 @@ type Value interface { String() string MarshalJSON() ([]byte, error) MarshalText() ([]byte, error) + TypeDef() TypeDefinition + Encode(dst []byte) ([]byte, error) + EncodeAsKey(dst []byte) ([]byte, error) + CastAs(t Type) (Value, error) } -// A Object represents a group of key value pairs. -type Object interface { - // Iterate goes through all the fields of the object and calls the given function by passing each one of them. - // If the given function returns an error, the iteration stops. - Iterate(fn func(field string, value Value) error) error - // GetByField returns a value by field name. - // Must return ErrFieldNotFound if the field doesn't exist. - GetByField(field string) (Value, error) - - // MarshalJSON implements the json.Marshaler interface. - // It returns a JSON representation of the object. - MarshalJSON() ([]byte, error) +type TypeDefinition interface { + New(v any) Value + Type() Type + Decode(src []byte) (Value, int) + IsComparableWith(other Type) bool + IsIndexComparableWith(other Type) bool } -// An Array contains a set of values. -type Array interface { - // Iterate goes through all the values of the array and calls the given function by passing each one of them. - // If the given function returns an error, the iteration stops. - Iterate(fn func(i int, value Value) error) error - // GetByIndex returns a value by index of the array. - GetByIndex(i int) (Value, error) - - // MarshalJSON implements the json.Marshaler interface. - // It returns a JSON representation of the array. - MarshalJSON() ([]byte, error) +type Comparable interface { + EQ(other Value) (bool, error) + GT(other Value) (bool, error) + GTE(other Value) (bool, error) + LT(other Value) (bool, error) + LTE(other Value) (bool, error) + Between(a, b Value) (bool, error) } diff --git a/internal/types/value.go b/internal/types/value.go index 5b84d3e17..89f76b61e 100644 --- a/internal/types/value.go +++ b/internal/types/value.go @@ -1,33 +1,43 @@ package types import ( - "bytes" - "encoding/hex" "fmt" "math" - "strconv" - "strings" "time" - - "github.com/chaisql/chai/internal/stringutil" ) func AsBool(v Value) bool { - bv, ok := v.(BooleanValue) - if !ok { - return v.V().(bool) + return v.V().(bool) +} + +func AsInt32(v Value) int32 { + iv, ok := v.(IntegerValue) + if ok { + return int32(iv) } - return bool(bv) + if bv, ok := v.(BigintValue); ok { + if bv < math.MinInt32 || bv > math.MaxInt32 { + panic(fmt.Errorf("value %d out of range for int32", bv)) + } + return int32(bv) + } + + return v.V().(int32) } func AsInt64(v Value) int64 { + biv, ok := v.(BigintValue) + if ok { + return int64(biv) + } + iv, ok := v.(IntegerValue) - if !ok { - return v.V().(int64) + if ok { + return int64(iv) } - return int64(iv) + return v.V().(int64) } func AsFloat64(v Value) float64 { @@ -66,29 +76,6 @@ func AsByteSlice(v Value) []byte { return bv } -func AsArray(v Value) Array { - av, ok := v.(*ArrayValue) - if !ok { - return v.V().(Array) - } - - return av.a -} - -func AsObject(v Value) Object { - ov, ok := v.(*ObjectValue) - if !ok { - return v.V().(Object) - } - - return ov.o -} - -func Is[T any](v Value) (T, bool) { - x, ok := v.V().(T) - return x, ok -} - func IsNull(v Value) bool { return v == nil || v.Type() == TypeNull } @@ -102,190 +89,3 @@ func IsTruthy(v Value) (bool, error) { b, err := v.IsZero() return !b, err } - -func MarshalTextIndent(v Value, prefix, indent string) ([]byte, error) { - var buf bytes.Buffer - - err := marshalText(&buf, v, prefix, indent, 0) - if err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -func marshalText(dst *bytes.Buffer, v Value, prefix, indent string, depth int) error { - if v.V() == nil { - dst.WriteString("NULL") - return nil - } - - switch v.Type() { - case TypeNull: - dst.WriteString("NULL") - return nil - case TypeBoolean: - dst.WriteString(strconv.FormatBool(AsBool(v))) - return nil - case TypeInteger: - dst.WriteString(strconv.FormatInt(AsInt64(v), 10)) - return nil - case TypeDouble: - f := AsFloat64(v) - abs := math.Abs(f) - fmt := byte('f') - if abs != 0 { - if abs < 1e-6 || abs >= 1e15 { - fmt = 'e' - } - } - - // By default the precision is -1 to use the smallest number of digits. - // See https://pkg.go.dev/strconv#FormatFloat - prec := -1 - // if the number is round, add .0 - if float64(int64(f)) == f { - prec = 1 - } - dst.WriteString(strconv.FormatFloat(AsFloat64(v), fmt, prec, 64)) - return nil - case TypeTimestamp: - dst.WriteString(strconv.Quote(AsTime(v).Format(time.RFC3339Nano))) - return nil - case TypeText: - dst.WriteString(strconv.Quote(AsString(v))) - return nil - case TypeBlob: - src := AsByteSlice(v) - dst.WriteString("\"\\x") - hex.NewEncoder(dst).Write(src) - dst.WriteByte('"') - return nil - case TypeArray: - var nonempty bool - dst.WriteByte('[') - err := AsArray(v).Iterate(func(i int, value Value) error { - nonempty = true - if i > 0 { - dst.WriteByte(',') - if prefix == "" { - dst.WriteByte(' ') - } - } - newline(dst, prefix, indent, depth+1) - - return marshalText(dst, value, prefix, indent, depth+1) - }) - if err != nil { - return err - } - if nonempty && prefix != "" { - newline(dst, prefix, indent, depth) - } - dst.WriteByte(']') - return nil - case TypeObject: - dst.WriteByte('{') - var i int - err := AsObject(v).Iterate(func(field string, value Value) error { - if i > 0 { - dst.WriteByte(',') - if prefix == "" { - dst.WriteByte(' ') - } - } - newline(dst, prefix, indent, depth+1) - i++ - - var ident string - if strings.HasPrefix(field, "\"") { - ident = stringutil.NormalizeIdentifier(field, '`') - } else { - ident = stringutil.NormalizeIdentifier(field, '"') - } - dst.WriteString(ident) - dst.WriteString(": ") - - return marshalText(dst, value, prefix, indent, depth+1) - }) - if err != nil { - return err - } - newline(dst, prefix, indent, depth) - dst.WriteRune('}') - return nil - default: - return fmt.Errorf("unexpected type: %d", v.Type()) - } -} - -func newline(dst *bytes.Buffer, prefix, indent string, depth int) { - dst.WriteString(prefix) - for i := 0; i < depth; i++ { - dst.WriteString(indent) - } -} - -type jsonArray struct { - Array -} - -func (j jsonArray) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - - buf.WriteRune('[') - err := j.Array.Iterate(func(i int, v Value) error { - if i > 0 { - buf.WriteString(", ") - } - - data, err := v.MarshalJSON() - if err != nil { - return err - } - - _, err = buf.Write(data) - return err - }) - if err != nil { - return nil, err - } - buf.WriteRune(']') - - return buf.Bytes(), nil -} - -type jsonObject struct { - Object -} - -func (j jsonObject) MarshalJSON() ([]byte, error) { - var buf bytes.Buffer - - buf.WriteByte('{') - - var notFirst bool - err := j.Object.Iterate(func(f string, v Value) error { - if notFirst { - buf.WriteString(", ") - } - notFirst = true - - buf.WriteString(strconv.Quote(f)) - buf.WriteString(": ") - - data, err := v.MarshalJSON() - if err != nil { - return err - } - _, err = buf.Write(data) - return err - }) - if err != nil { - return nil, err - } - - buf.WriteByte('}') - - return buf.Bytes(), nil -} diff --git a/internal/types/value_test.go b/internal/types/value_test.go index 747f91900..8a68dfd4e 100644 --- a/internal/types/value_test.go +++ b/internal/types/value_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/chaisql/chai/internal/environment" - "github.com/chaisql/chai/internal/object" + "github.com/chaisql/chai/internal/row" "github.com/chaisql/chai/internal/testutil" "github.com/chaisql/chai/internal/testutil/assert" "github.com/chaisql/chai/internal/types" @@ -24,25 +24,18 @@ func TestValueMarshalText(t *testing.T) { {"bytes", []byte("bar"), `"\x626172"`}, {"string", "bar", `"bar"`}, {"bool", true, "true"}, - {"int", int64(10), "10"}, + {"int", int32(10), "10"}, {"float64", 10.0, "10.0"}, {"float64", 10.1, "10.1"}, {"float64", math.MaxFloat64, "1.7976931348623157e+308"}, {"time", now, `"` + now.UTC().Format(time.RFC3339Nano) + `"`}, {"null", nil, "NULL"}, - {"object", object.NewFieldBuffer(). - Add("a", types.NewIntegerValue(10)). - Add("b c", types.NewTextValue("foo")). - Add(`"d e"`, types.NewTextValue("foo")), - "{a: 10, \"b c\": \"foo\", `\"d e\"`: \"foo\"}", - }, - {"array", object.NewValueBuffer(types.NewIntegerValue(10), types.NewTextValue("foo")), `[10, "foo"]`}, {"time", now, `"` + now.UTC().Format(time.RFC3339Nano) + `"`}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - v, err := object.NewValue(test.value) + v, err := row.NewValue(test.value) assert.NoError(t, err) data, err := v.MarshalText() assert.NoError(t, err) @@ -57,60 +50,6 @@ func TestValueMarshalText(t *testing.T) { } } -func TestMarshalTextIndent(t *testing.T) { - now := time.Now() - - tests := []struct { - name string - value interface{} - expected string - }{ - {"bytes", []byte("bar"), `"\x626172"`}, - {"string", "bar", `"bar"`}, - {"bool", true, "true"}, - {"int", int64(10), "10"}, - {"float64", 10.0, "10.0"}, - {"float64", 10.1, "10.1"}, - {"time", now, `"` + now.UTC().Format(time.RFC3339Nano) + `"`}, - {"float64", math.MaxFloat64, "1.7976931348623157e+308"}, - {"null", nil, "NULL"}, - {"object", - object.NewFieldBuffer().Add("a", types.NewIntegerValue(10)).Add("b c", types.NewTextValue("foo")).Add("d", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10), types.NewTextValue("foo")))), - `{ - a: 10, - "b c": "foo", - d: [ - 10, - "foo" - ] -}`}, - {"array", - object.NewValueBuffer(types.NewIntegerValue(10), types.NewTextValue("foo")), - `[ - 10, - "foo" -]`, - }, - {"time", now, `"` + now.UTC().Format(time.RFC3339Nano) + `"`}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - v, err := object.NewValue(test.value) - assert.NoError(t, err) - data, err := types.MarshalTextIndent(v, "\n", " ") - assert.NoError(t, err) - require.Equal(t, test.expected, string(data)) - if test.name != "time" { - e := testutil.ParseExpr(t, string(data)) - got, err := e.Eval(&environment.Environment{}) - assert.NoError(t, err) - require.Equal(t, test.value, got.V()) - } - }) - } -} - func TestValueMarshalJSON(t *testing.T) { now := time.Now() tests := []struct { @@ -127,8 +66,6 @@ func TestValueMarshalJSON(t *testing.T) { {"time", types.NewTimestampValue(now), `"` + now.UTC().Format(time.RFC3339Nano) + `"`}, {"double with no decimal", types.NewDoubleValue(10), "10"}, {"big double", types.NewDoubleValue(1e15), "1e+15"}, - {"object", types.NewObjectValue(object.NewFieldBuffer().Add("a", types.NewIntegerValue(10))), "{\"a\": 10}"}, - {"array", types.NewArrayValue(object.NewValueBuffer(types.NewIntegerValue(10))), "[10]"}, } for _, test := range tests {