Skip to content

Commit

Permalink
chore: expression & goqu improvements (#181)
Browse files Browse the repository at this point in the history
# Description

1. Rename `GoquExpressionToSQL` to `ParseGoquExpression`
2. Introduce `Expressions#Literal` 
3. `QueryCondition` returns an Expression instead of `string`
4. Register goqu dialects

## Security

- [x] The code changed/added as part of this pull request won't create
any security issues with how the software is being used.
  • Loading branch information
atzoum authored Sep 13, 2024
1 parent 4356576 commit 201389a
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 46 deletions.
11 changes: 7 additions & 4 deletions sqlconnect/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,16 @@ type (
// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema.
ParseRelationRef(identifier string) (RelationRef, error)

// QueryCondition returns a dialect-specific query condition sql string for the provided identifier, operator and value(s).
// QueryCondition returns a dialect-specific query expression for the provided identifier, operator and value(s).
//
// E.g. QueryCondition("age", "gt", 18) returns "age > 18"
//
// Each operator has a different number of arguments, e.g. [eq] requires one argument, [in] requires at least one argument, etc.
// See [op] package for the list of supported operators
QueryCondition(identifier, operator string, args ...any) (sql string, err error)
QueryCondition(identifier, operator string, args ...any) (Expression, error)

// GoquExpressionToSQL converts an Expression to a SQL string
GoquExpressionToSQL(expression GoquExpression) (sql string, err error)
// ParseGoquExpression converts a goqu Expression to an Expression
ParseGoquExpression(goquExpression GoquExpression) (Expression, error)

// Expressions returns the dialect-specific expressions
Expressions() Expressions
Expand All @@ -150,6 +150,9 @@ type (
// The value can either be a string literal (column, timestamp, function etc.) or a [time.Time] value.
// Values are cast to [DATE].
DateAdd(dateValue any, interval int, unit string) (Expression, error)

// Literal creates a literal sql expression
Literal(sql string, args ...any) (Expression, error)
}

// Expression represents a dialect-specific expression.
Expand Down
90 changes: 49 additions & 41 deletions sqlconnect/internal/base/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,115 +30,117 @@ type Expressions struct {
DateAdd func(date any, interval int, unit string) goqu.Expression
}

func (gq *GoquDialect) QueryCondition(identifier, operator string, args ...any) (sql string, err error) {
func (gq *GoquDialect) QueryCondition(identifier, operator string, args ...any) (sqlconnect.Expression, error) {
args = lo.Map(args, func(a any, _ int) any {
if s, ok := a.(sqlconnect.Expression); ok { // unwrap sqlconnect.Expression
return s.GoquExpression()
}
return a
})
var expr goqu.Expression
var goquExpression goqu.Expression
switch op.Operator(strings.ToLower(operator)) {
case op.Eq:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Eq(args[0])
goquExpression = goqu.C(identifier).Eq(args[0])
case op.Neq:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Neq(args[0])
goquExpression = goqu.C(identifier).Neq(args[0])
case op.In:
if len(args) == 0 {
return "", fmt.Errorf("%s operator requires at least one argument", operator)
return nil, fmt.Errorf("%s operator requires at least one argument", operator)
}
expr = goqu.C(identifier).In(args...)
goquExpression = goqu.C(identifier).In(args...)
case op.Nin:
if len(args) == 0 {
return "", fmt.Errorf("%s operator requires at least one argument", operator)
return nil, fmt.Errorf("%s operator requires at least one argument", operator)
}
expr = goqu.C(identifier).NotIn(args...)
goquExpression = goqu.C(identifier).NotIn(args...)
case op.Like:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Like(args[0])
goquExpression = goqu.C(identifier).Like(args[0])
case op.NLike:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).NotLike(args[0])
goquExpression = goqu.C(identifier).NotLike(args[0])
case op.Nnull:
if len(args) != 0 {
return "", fmt.Errorf("%s operator requires no arguments, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires no arguments, got %d", operator, len(args))
}
expr = goqu.C(identifier).IsNotNull()
goquExpression = goqu.C(identifier).IsNotNull()
case op.Null:
if len(args) != 0 {
return "", fmt.Errorf("%s operator requires no arguments, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires no arguments, got %d", operator, len(args))
}
expr = goqu.C(identifier).IsNull()
goquExpression = goqu.C(identifier).IsNull()
case op.Gt:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Gt(args[0])
goquExpression = goqu.C(identifier).Gt(args[0])
case op.Gte:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Gte(args[0])
goquExpression = goqu.C(identifier).Gte(args[0])
case op.Lt:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Lt(args[0])
goquExpression = goqu.C(identifier).Lt(args[0])
case op.Lte:
if len(args) != 1 {
return "", fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly one argument, got %d", operator, len(args))
}
expr = goqu.C(identifier).Lte(args[0])
goquExpression = goqu.C(identifier).Lte(args[0])
case op.Btw:
if len(args) != 2 {
return "", fmt.Errorf("%s operator requires exactly two arguments, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly two arguments, got %d", operator, len(args))
}
expr = goqu.C(identifier).Between(exp.NewRangeVal(args[0], args[1]))
goquExpression = goqu.C(identifier).Between(exp.NewRangeVal(args[0], args[1]))
case op.Nbtw:
if len(args) != 2 {
return "", fmt.Errorf("%s operator requires exactly two arguments, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly two arguments, got %d", operator, len(args))
}
expr = goqu.C(identifier).NotBetween(exp.NewRangeVal(args[0], args[1]))
goquExpression = goqu.C(identifier).NotBetween(exp.NewRangeVal(args[0], args[1]))
case op.Inlast:
if len(args) != 2 {
return "", fmt.Errorf("%s operator requires exactly two arguments, got %d", operator, len(args))
return nil, fmt.Errorf("%s operator requires exactly two arguments, got %d", operator, len(args))
}
var (
interval int
unit string
ok bool
)
if interval, ok = args[0].(int); !ok {
return "", fmt.Errorf("nbfinterval operator requires first argument to be an integer")
return nil, fmt.Errorf("nbfinterval operator requires first argument to be an integer")
}
if unit, ok = args[1].(string); !ok {
return "", fmt.Errorf("nbfinterval operator requires second argument to be a string")
return nil, fmt.Errorf("nbfinterval operator requires second argument to be a string")
}
dateAddExpr, err := gq.DateAdd("CURRENT_DATE", -interval, unit)
if err != nil {
return "", err
return nil, err
}
expr = goqu.C(identifier).Gte(dateAddExpr.GoquExpression())
goquExpression = goqu.C(identifier).Gte(dateAddExpr.GoquExpression())
default:
return "", fmt.Errorf("unsupported operator: %s", operator)
return nil, fmt.Errorf("unsupported operator: %s", operator)
}

return gq.GoquExpressionToSQL(expr)
return gq.ParseGoquExpression(goquExpression)
}

func (gq *GoquDialect) GoquExpressionToSQL(expression sqlconnect.GoquExpression) (sql string, err error) {
sql, _, err = sqlgen.GenerateExpressionSQL(gq.esg, false, expression)
return
func (gq *GoquDialect) ParseGoquExpression(goquExpression sqlconnect.GoquExpression) (sqlconnect.Expression, error) {
sql, _, err := sqlgen.GenerateExpressionSQL(gq.esg, false, goquExpression)
if err != nil {
return nil, err
}
return &expression{Expression: goquExpression, sql: sql}, nil
}

func (gq *GoquDialect) Expressions() sqlconnect.Expressions {
Expand Down Expand Up @@ -191,6 +193,12 @@ func (gq *GoquDialect) DateAdd(timeValue any, interval int, unit string) (sqlcon
return &expression{Expression: goquExpression, sql: sql}, err
}

func (gq *GoquDialect) Literal(sql string, args ...any) (sqlconnect.Expression, error) {
goquExpression := goqu.L(sql, args...)
sql, _, err := sqlgen.GenerateExpressionSQL(gq.esg, false, goquExpression)
return &expression{Expression: goquExpression, sql: sql}, err
}

type expression struct {
goqu.Expression
sql string
Expand Down
4 changes: 4 additions & 0 deletions sqlconnect/internal/databricks/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func init() {
goqu.RegisterDialect(DatabaseType, GoquDialectOptions())
}

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
getQueryCondition := func(t *testing.T, col, op string, val ...any) string {
sql, err := db.QueryCondition(col, op, val...)
require.NoError(t, err, "it should be able to generate a query condition")
return sql
return sql.String()
}

getTimestampAddExpression := func(t *testing.T, timeValue any, interval int, unit string) any {
Expand Down
4 changes: 4 additions & 0 deletions sqlconnect/internal/mysql/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func init() {
goqu.RegisterDialect(DatabaseType, GoquDialectOptions())
}

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
Expand Down
4 changes: 4 additions & 0 deletions sqlconnect/internal/postgres/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func init() {
goqu.RegisterDialect(DatabaseType, GoquDialectOptions())
}

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
Expand Down
4 changes: 4 additions & 0 deletions sqlconnect/internal/redshift/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func init() {
goqu.RegisterDialect(DatabaseType, GoquDialectOptions())
}

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
Expand Down
4 changes: 4 additions & 0 deletions sqlconnect/internal/snowflake/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func init() {
goqu.RegisterDialect(DatabaseType, GoquDialectOptions())
}

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
Expand Down
4 changes: 4 additions & 0 deletions sqlconnect/internal/trino/goqu_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base"
)

func init() {
goqu.RegisterDialect(DatabaseType, GoquDialectOptions())
}

func GoquDialectOptions() *sqlgen.SQLDialectOptions {
o := sqlgen.DefaultDialectOptions()
o.QuoteIdentifiers = false
Expand Down

0 comments on commit 201389a

Please sign in to comment.