Skip to content

Commit

Permalink
chore: sync API with Bun
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 1, 2023
1 parent 48aa57b commit 12a9b1e
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 75 deletions.
1 change: 1 addition & 0 deletions ch/ch.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

type (
Safe = chschema.Safe
Name = chschema.Name
Ident = chschema.Ident
CHModel = chschema.CHModel
AfterScanRowHook = chschema.AfterScanRowHook
Expand Down
8 changes: 4 additions & 4 deletions ch/chschema/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ func NewFormatter() Formatter {
return Formatter{}
}

func (f Formatter) AppendIdent(b []byte, ident string) []byte {
return AppendIdent(b, ident)
func (f Formatter) AppendName(b []byte, ident string) []byte {
return AppendName(b, ident)
}

func (f Formatter) AppendFQN(b []byte, ident string) []byte {
return AppendFQN(b, ident)
func (f Formatter) AppendIdent(b []byte, ident string) []byte {
return AppendIdent(b, ident)
}

func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
Expand Down
78 changes: 40 additions & 38 deletions ch/chschema/sqlfmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,60 +28,40 @@ func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {

//------------------------------------------------------------------------------

// FQN represents a fully qualified SQL name, for example, table or column name.
type FQN string
// Name represents a SQL identifier, for example, table or column name.
type Name string

var _ QueryAppender = (*FQN)(nil)
var _ QueryAppender = (*Name)(nil)

func (s FQN) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendFQN(b, string(s)), nil
func (s Name) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendName(b, string(s)), nil
}

func AppendFQN(b []byte, field string) []byte {
return appendFQN(b, internal.Bytes(field))
func AppendName(b []byte, field string) []byte {
return appendName(b, internal.Bytes(field))
}

func appendFQN(b, src []byte) []byte {
func appendName(b, src []byte) []byte {
const quote = '"'

var quoted bool
loop:
b = append(b, quote)
for _, c := range src {
switch c {
case '*':
if !quoted {
b = append(b, '*')
continue loop
}
case '.':
if quoted {
b = append(b, quote)
quoted = false
}
b = append(b, '.')
continue loop
}

if !quoted {
b = append(b, quote)
quoted = true
}
if c == quote {
b = append(b, quote, quote)
} else {
b = append(b, c)
}
}
if quoted {
b = append(b, quote)
}
b = append(b, quote)
return b
}

// Ident represents a SQL identifier, for example, table or column name.
//------------------------------------------------------------------------------

// Ident represents a fully qualified SQL name, for example, table or column name.
type Ident string

var _ QueryAppender = (*Ident)(nil)
var _ QueryAppender = (*Name)(nil)

func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendIdent(b, string(s)), nil
Expand All @@ -94,15 +74,37 @@ func AppendIdent(b []byte, field string) []byte {
func appendIdent(b, src []byte) []byte {
const quote = '"'

b = append(b, quote)
var quoted bool
loop:
for _, c := range src {
switch c {
case '*':
if !quoted {
b = append(b, '*')
continue loop
}
case '.':
if quoted {
b = append(b, quote)
quoted = false
}
b = append(b, '.')
continue loop
}

if !quoted {
b = append(b, quote)
quoted = true
}
if c == quote {
b = append(b, quote, quote)
} else {
b = append(b, c)
}
}
b = append(b, quote)
if quoted {
b = append(b, quote)
}
return b
}

Expand All @@ -127,7 +129,7 @@ func SafeQuery(query string, args []any) QueryWithArgs {
}
}

func UnsafeIdent(ident string) QueryWithArgs {
func UnsafeName(ident string) QueryWithArgs {
return QueryWithArgs{Query: ident}
}

Expand All @@ -141,7 +143,7 @@ func (q QueryWithArgs) IsEmpty() bool {

func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
if q.Args == nil {
return fmter.AppendIdent(b, q.Query), nil
return fmter.AppendName(b, q.Query), nil
}
return fmter.AppendQuery(b, q.Query, q.Args...), nil
}
Expand Down
2 changes: 1 addition & 1 deletion ch/chschema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func (t *Table) AppendNamedArg(
}

func quoteTableName(s string) Safe {
return Safe(appendFQN(nil, internal.Bytes(s)))
return Safe(appendIdent(nil, internal.Bytes(s)))
}

func quoteColumnName(s string) Safe {
Expand Down
2 changes: 1 addition & 1 deletion ch/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (db *DB) autoCreateDatabase() {
tmp := newDB(conf)
defer tmp.Close()

if _, err := tmp.Exec(query, Ident(db.conf.Database), Ident(db.conf.Cluster)); err != nil {
if _, err := tmp.Exec(query, Name(db.conf.Database), Name(db.conf.Cluster)); err != nil {
internal.Logger.Printf("create database %q failed: %s", db.conf.Database, err)
}
}
Expand Down
4 changes: 2 additions & 2 deletions ch/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestAutoCreateDatabase(t *testing.T) {
db := ch.Connect()
defer db.Close()

_, err := db.Exec("DROP DATABASE IF EXISTS ?", ch.Ident(dbName))
_, err := db.Exec("DROP DATABASE IF EXISTS ?", ch.Name(dbName))
require.NoError(t, err)
}

Expand Down Expand Up @@ -160,7 +160,7 @@ func TestPlaceholder(t *testing.T) {
params := struct {
A int
B int
Alias ch.Ident
Alias ch.Name
}{
A: 1,
B: 2,
Expand Down
2 changes: 1 addition & 1 deletion ch/query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (q *baseQuery) addColumn(column chschema.QueryWithArgs) {
func (q *baseQuery) excludeColumn(columns []string) {
if q.columns == nil {
for _, f := range q.table.Fields {
q.columns = append(q.columns, chschema.UnsafeIdent(f.CHName))
q.columns = append(q.columns, chschema.UnsafeName(f.CHName))
}
}

Expand Down
6 changes: 3 additions & 3 deletions ch/query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (q *InsertQuery) Model(model any) *InsertQuery {

func (q *InsertQuery) Table(tables ...string) *InsertQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
q.addTable(chschema.UnsafeName(table))
}
return q
}
Expand All @@ -44,7 +44,7 @@ func (q *InsertQuery) TableExpr(query string, args ...any) *InsertQuery {
}

func (q *InsertQuery) ModelTable(table string) *InsertQuery {
q.modelTableName = chschema.UnsafeIdent(table)
q.modelTableName = chschema.UnsafeName(table)
return q
}

Expand All @@ -62,7 +62,7 @@ func (q *InsertQuery) Setting(query string, args ...any) *InsertQuery {

func (q *InsertQuery) Column(columns ...string) *InsertQuery {
for _, column := range columns {
q.addColumn(chschema.UnsafeIdent(column))
q.addColumn(chschema.UnsafeName(column))
}
return q
}
Expand Down
18 changes: 9 additions & 9 deletions ch/query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (q *SelectQuery) DistinctOn(query string, args ...any) *SelectQuery {

func (q *SelectQuery) Table(tables ...string) *SelectQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
q.addTable(chschema.UnsafeName(table))
}
return q
}
Expand All @@ -128,7 +128,7 @@ func (q *SelectQuery) TableExpr(query string, args ...any) *SelectQuery {
}

func (q *SelectQuery) ModelTable(table string) *SelectQuery {
q.modelTableName = chschema.UnsafeIdent(table)
q.modelTableName = chschema.UnsafeName(table)
return q
}

Expand All @@ -146,7 +146,7 @@ func (q *SelectQuery) Sample(query string, args ...any) *SelectQuery {

func (q *SelectQuery) Column(columns ...string) *SelectQuery {
for _, column := range columns {
q.addColumn(chschema.UnsafeIdent(column))
q.addColumn(chschema.UnsafeName(column))
}
return q
}
Expand Down Expand Up @@ -262,7 +262,7 @@ func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery)

func (q *SelectQuery) Group(columns ...string) *SelectQuery {
for _, column := range columns {
q.group = append(q.group, chschema.UnsafeIdent(column))
q.group = append(q.group, chschema.UnsafeName(column))
}
return q
}
Expand All @@ -285,7 +285,7 @@ func (q *SelectQuery) Order(orders ...string) *SelectQuery {

index := strings.IndexByte(order, ' ')
if index == -1 {
q.order = append(q.order, chschema.UnsafeIdent(order))
q.order = append(q.order, chschema.UnsafeName(order))
continue
}

Expand All @@ -296,11 +296,11 @@ func (q *SelectQuery) Order(orders ...string) *SelectQuery {
case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST",
"ASC NULLS LAST", "DESC NULLS LAST":
q.order = append(q.order, chschema.SafeQuery("? ?", []any{
Ident(field),
Name(field),
Safe(sort),
}))
default:
q.order = append(q.order, chschema.UnsafeIdent(order))
q.order = append(q.order, chschema.UnsafeName(order))
}
}
return q
Expand Down Expand Up @@ -522,7 +522,7 @@ func (q *SelectQuery) appendWith(fmter chschema.Formatter, b []byte) (_ []byte,
}

if with.cte {
b = chschema.AppendIdent(b, with.name)
b = chschema.AppendName(b, with.name)
b = append(b, " AS "...)
b = append(b, "("...)
}
Expand All @@ -536,7 +536,7 @@ func (q *SelectQuery) appendWith(fmter chschema.Formatter, b []byte) (_ []byte,
b = append(b, ")"...)
} else {
b = append(b, " AS "...)
b = chschema.AppendIdent(b, with.name)
b = chschema.AppendName(b, with.name)
}
}
b = append(b, ' ')
Expand Down
8 changes: 4 additions & 4 deletions ch/query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (q *CreateTableQuery) Apply(fn func(*CreateTableQuery) *CreateTableQuery) *

func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
q.addTable(chschema.UnsafeName(table))
}
return q
}
Expand All @@ -54,7 +54,7 @@ func (q *CreateTableQuery) TableExpr(query string, args ...any) *CreateTableQuer
}

func (q *CreateTableQuery) ModelTable(table string) *CreateTableQuery {
q.modelTableName = chschema.UnsafeIdent(table)
q.modelTableName = chschema.UnsafeName(table)
return q
}

Expand All @@ -64,7 +64,7 @@ func (q *CreateTableQuery) ModelTableExpr(query string, args ...any) *CreateTabl
}

func (q *CreateTableQuery) As(table string) *CreateTableQuery {
q.as = chschema.UnsafeIdent(table)
q.as = chschema.UnsafeName(table)
return q
}

Expand All @@ -81,7 +81,7 @@ func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
}

func (q *CreateTableQuery) OnCluster(cluster string) *CreateTableQuery {
q.onCluster = chschema.UnsafeIdent(cluster)
q.onCluster = chschema.UnsafeName(cluster)
return q
}

Expand Down
4 changes: 2 additions & 2 deletions ch/query_table_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (q *DropTableQuery) Model(model any) *DropTableQuery {

func (q *DropTableQuery) Table(tables ...string) *DropTableQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
q.addTable(chschema.UnsafeName(table))
}
return q
}
Expand All @@ -58,7 +58,7 @@ func (q *DropTableQuery) IfExists() *DropTableQuery {
}

func (q *DropTableQuery) OnCluster(cluster string) *DropTableQuery {
q.onCluster = chschema.UnsafeIdent(cluster)
q.onCluster = chschema.UnsafeName(cluster)
return q
}

Expand Down
2 changes: 1 addition & 1 deletion ch/query_table_truncate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (q *TruncateTableQuery) Model(model any) *TruncateTableQuery {

func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery {
for _, table := range tables {
q.addTable(chschema.UnsafeIdent(table))
q.addTable(chschema.UnsafeName(table))
}
return q
}
Expand Down
2 changes: 1 addition & 1 deletion ch/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestQuery(t *testing.T) {
Table("my-table_dist").
As("my-table").
Engine("Distributed(?, currentDatabase(), ?, rand())",
ch.Ident("my-cluster"), ch.Ident("my-table")).
ch.Name("my-cluster"), ch.Name("my-table")).
OnCluster("my-cluster").
IfNotExists()
},
Expand Down
Loading

0 comments on commit 12a9b1e

Please sign in to comment.