From 2aa476d2f92d5599543e5db4e8f6aa14d3cb9369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=8D=81KoNekoD?= Date: Mon, 2 Dec 2024 00:04:08 +0300 Subject: [PATCH] [patch-1] chore: use cast-safe option --- named_args.go | 15 ++++- named_args_test.go | 152 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 1 deletion(-) diff --git a/named_args.go b/named_args.go index f2463034a..e837c6291 100644 --- a/named_args.go +++ b/named_args.go @@ -107,7 +107,20 @@ func rawState(l *sqlLexer) stateFn { return singleQuoteState case '"': return doubleQuoteState - case '@', ':': + case ':': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + prevRune := rune(0) + if l.pos > 1 { + prevRune, _ = utf8.DecodeRuneInString(l.src[l.pos-2:]) + } + if nextRune != ':' && prevRune != ':' && (isLetter(nextRune) || nextRune == '_') { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return namedArgState + } + case '@': nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) if isLetter(nextRune) || nextRune == '_' { if l.pos-l.start > 0 { diff --git a/named_args_test.go b/named_args_test.go index 8cab2f4d2..c742e8233 100644 --- a/named_args_test.go +++ b/named_args_test.go @@ -160,3 +160,155 @@ func TestStrictNamedArgsRewriteQuery(t *testing.T) { } } } + +func TestNamedArgsRewriteQuery2(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + args []any + namedArgs pgx.NamedArgs + expectedSQL string + expectedArgs []any + }{ + { + sql: "select * from users where id = :id", + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: "select * from users where id = $1", + expectedArgs: []any{int32(42)}, + }, + { + sql: "select * from t where foo < :abc and baz = :def and bar < :abc", + namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)}, + expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1", + expectedArgs: []any{int32(42), int32(1)}, + }, + { + sql: "select :a::int, :b::text", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "select $1::int, $2::text", + expectedArgs: []any{int32(42), "foo"}, + }, + { + sql: "select :Abc::int, :b_4::text, :_c::int", + namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)}, + expectedSQL: "select $1::int, $2::text, $3::int", + expectedArgs: []any{int32(42), "foo", int32(1)}, + }, + { + sql: "at end :", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "at end :", + expectedArgs: []any{}, + }, + { + sql: "ignores without valid character after : foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "ignores without valid character after : foo bar", + expectedArgs: []any{}, + }, + { + sql: "name cannot start with number :1 foo bar", + namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, + expectedSQL: "name cannot start with number :1 foo bar", + expectedArgs: []any{}, + }, + { + sql: `select *, ':foo' as ":bar" from users where id = :id`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select *, ':foo' as ":bar" from users where id = $1`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * -- :foo + from users -- :single line comments + where id = :id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * -- :foo + from users -- :single line comments + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: `select * /* :multi line + :comment + */ + /* /* with :nesting */ */ + from users + where id = :id;`, + namedArgs: pgx.NamedArgs{"id": int32(42)}, + expectedSQL: `select * /* :multi line + :comment + */ + /* /* with :nesting */ */ + from users + where id = $1;`, + expectedArgs: []any{int32(42)}, + }, + { + sql: "extra provided argument", + namedArgs: pgx.NamedArgs{"extra": int32(1)}, + expectedSQL: "extra provided argument", + expectedArgs: []any{}, + }, + { + sql: ":missing argument", + namedArgs: pgx.NamedArgs{}, + expectedSQL: "$1 argument", + expectedArgs: []any{nil}, + }, + + // test comments and quotes + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + require.NoError(t, err) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } +} + +func TestStrictNamedArgsRewriteQuery2(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + namedArgs pgx.StrictNamedArgs + expectedSQL string + expectedArgs []any + isExpectedError bool + }{ + { + sql: "no arguments", + namedArgs: pgx.StrictNamedArgs{}, + expectedSQL: "no arguments", + expectedArgs: []any{}, + isExpectedError: false, + }, + { + sql: ":all :matches", + namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)}, + expectedSQL: "$1 $2", + expectedArgs: []any{int32(1), int32(2)}, + isExpectedError: false, + }, + { + sql: "extra provided argument", + namedArgs: pgx.StrictNamedArgs{"extra": int32(1)}, + isExpectedError: true, + }, + { + sql: ":missing argument", + namedArgs: pgx.StrictNamedArgs{}, + isExpectedError: true, + }, + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil) + if tt.isExpectedError { + assert.Errorf(t, err, "%d", i) + } else { + require.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } + } +}