Skip to content

Commit

Permalink
use constants for dialectors
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancoLiberali committed Sep 7, 2023
1 parent 420e5cf commit 28ccb7e
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 65 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ jobs:
map: |
{
"postgresql": {
"dialector": "postgresql"
"dialector": "postgres"
},
"cockroachdb": {
"dialector": "postgresql"
"dialector": "postgres"
},
"mysql": {
"dialector": "mysql"
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ sqlserver:
docker compose -f "docker/sqlserver/docker-compose.yml" up -d --build

test_integration_postgresql: postgresql
DB=postgresql gotestsum --format testname ./testintegration
DB=postgres gotestsum --format testname ./testintegration

test_integration_cockroachdb: cockroachdb
DB=postgresql gotestsum --format testname ./testintegration -tags=cockroachdb
DB=postgres gotestsum --format testname ./testintegration -tags=cockroachdb

test_integration_mysql: mysql
DB=mysql gotestsum --format testname ./testintegration -tags=mysql
Expand Down
19 changes: 16 additions & 3 deletions orm/query/gorm_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func (query *GormQuery) Order(field IFieldIdentifier, descending bool, joinNumbe
return err
}

switch query.GormDB.Dialector.Name() {
case "postgres":
switch query.Dialector() {
case Postgres:
// postgres supports only order by selected fields
query.AddSelect(table, field)
query.GormDB = query.GormDB.Order(
Expand All @@ -42,7 +42,7 @@ func (query *GormQuery) Order(field IFieldIdentifier, descending bool, joinNumbe
)

return nil
case "sqlserver", "sqlite", "mysql":
case SQLServer, SQLite, MySQL:
query.GormDB = query.GormDB.Order(
clause.OrderByColumn{
Column: clause.Column{
Expand Down Expand Up @@ -173,6 +173,19 @@ func (query GormQuery) ColumnName(table Table, fieldName string) string {
return query.GormDB.NamingStrategy.ColumnName(table.Name, fieldName)
}

type Dialector string

const (
Postgres Dialector = "postgres"
MySQL Dialector = "mysql"
SQLite Dialector = "sqlite"
SQLServer Dialector = "sqlserver"
)

func (query GormQuery) Dialector() Dialector {
return Dialector(query.GormDB.Dialector.Name())
}

func NewGormQuery(db *gorm.DB, initialModel model.Model, initialTable Table) *GormQuery {
query := &GormQuery{
GormDB: db.Select(initialTable.Name + ".*"),
Expand Down
73 changes: 37 additions & 36 deletions testintegration/operators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/ditrit/badaas/orm/mysql"
"github.com/ditrit/badaas/orm/operator"
"github.com/ditrit/badaas/orm/psql"
"github.com/ditrit/badaas/orm/query"
"github.com/ditrit/badaas/orm/sqlite"
"github.com/ditrit/badaas/testintegration/conditions"
"github.com/ditrit/badaas/testintegration/models"
Expand Down Expand Up @@ -250,12 +251,12 @@ func (ts *OperatorsIntTestSuite) TestIsTrue() {
var entities []*models.Product

switch getDBDialector() {
case postgreSQL, mySQL, sqLite:
case query.Postgres, query.MySQL, query.SQLite:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.BoolIs().True(),
).Find()
case sqlServer:
case query.SQLServer:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.BoolIs().Eq(true),
Expand All @@ -277,12 +278,12 @@ func (ts *OperatorsIntTestSuite) TestIsFalse() {
var entities []*models.Product

switch getDBDialector() {
case postgreSQL, mySQL, sqLite:
case query.Postgres, query.MySQL, query.SQLite:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.BoolIs().False(),
).Find()
case sqlServer:
case query.SQLServer:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.BoolIs().Eq(false),
Expand Down Expand Up @@ -310,12 +311,12 @@ func (ts *OperatorsIntTestSuite) TestIsNotTrue() {
var entities []*models.Product

switch getDBDialector() {
case postgreSQL, mySQL, sqLite:
case query.Postgres, query.MySQL, query.SQLite:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().NotTrue(),
).Find()
case sqlServer:
case query.SQLServer:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().Distinct(true),
Expand Down Expand Up @@ -343,12 +344,12 @@ func (ts *OperatorsIntTestSuite) TestIsNotFalse() {
var entities []*models.Product

switch getDBDialector() {
case postgreSQL, mySQL, sqLite:
case query.Postgres, query.MySQL, query.SQLite:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().NotFalse(),
).Find()
case sqlServer:
case query.SQLServer:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().Distinct(false),
Expand Down Expand Up @@ -376,12 +377,12 @@ func (ts *OperatorsIntTestSuite) TestIsUnknown() {
var entities []*models.Product

switch getDBDialector() {
case postgreSQL, mySQL:
case query.Postgres, query.MySQL:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().Unknown(),
).Find()
case sqlServer, sqLite:
case query.SQLServer, query.SQLite:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().Null(),
Expand Down Expand Up @@ -409,12 +410,12 @@ func (ts *OperatorsIntTestSuite) TestIsNotUnknown() {
var entities []*models.Product

switch getDBDialector() {
case postgreSQL, mySQL:
case query.Postgres, query.MySQL:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().NotUnknown(),
).Find()
case sqlServer, sqLite:
case query.SQLServer, query.SQLite:
entities, err = orm.NewQuery[models.Product](
ts.db,
conditions.Product.NullBoolIs().NotNull(),
Expand All @@ -428,7 +429,7 @@ func (ts *OperatorsIntTestSuite) TestIsNotUnknown() {

func (ts *OperatorsIntTestSuite) TestIsDistinct() {
switch getDBDialector() {
case postgreSQL, sqlServer, sqLite:
case query.Postgres, query.SQLServer, query.SQLite:
match1 := ts.createProduct("match", 3, 0, false, nil)
match2 := ts.createProduct("match", 4, 0, false, nil)
ts.createProduct("not_match", 2, 0, false, nil)
Expand All @@ -440,14 +441,14 @@ func (ts *OperatorsIntTestSuite) TestIsDistinct() {
ts.Nil(err)

EqualList(&ts.Suite, []*models.Product{match1, match2}, entities)
case mySQL:
case query.MySQL:
log.Println("IsDistinct not compatible")
}
}

func (ts *OperatorsIntTestSuite) TestIsNotDistinct() {
switch getDBDialector() {
case postgreSQL, sqlServer, sqLite:
case query.Postgres, query.SQLServer, query.SQLite:
match := ts.createProduct("match", 3, 0, false, nil)
ts.createProduct("not_match", 4, 0, false, nil)
ts.createProduct("not_match", 2, 0, false, nil)
Expand All @@ -459,7 +460,7 @@ func (ts *OperatorsIntTestSuite) TestIsNotDistinct() {
ts.Nil(err)

EqualList(&ts.Suite, []*models.Product{match}, entities)
case mySQL:
case query.MySQL:
log.Println("IsNotDistinct not compatible")
}
}
Expand Down Expand Up @@ -532,9 +533,9 @@ func (ts *OperatorsIntTestSuite) TestLikeEscape() {

func (ts *OperatorsIntTestSuite) TestLikeOnNumeric() {
switch getDBDialector() {
case postgreSQL, sqlServer, sqLite:
case query.Postgres, query.SQLServer, query.SQLite:
log.Println("Like with numeric not compatible")
case mySQL:
case query.MySQL:
match1 := ts.createProduct("", 10, 0, false, nil)
match2 := ts.createProduct("", 100, 0, false, nil)

Expand All @@ -555,9 +556,9 @@ func (ts *OperatorsIntTestSuite) TestLikeOnNumeric() {

func (ts *OperatorsIntTestSuite) TestILike() {
switch getDBDialector() {
case mySQL, sqlServer, sqLite:
case query.MySQL, query.SQLServer, query.SQLite:
log.Println("ILike not compatible")
case postgreSQL:
case query.Postgres:
match1 := ts.createProduct("basd", 0, 0, false, nil)
match2 := ts.createProduct("cape", 0, 0, false, nil)
match3 := ts.createProduct("bAsd", 0, 0, false, nil)
Expand All @@ -579,9 +580,9 @@ func (ts *OperatorsIntTestSuite) TestILike() {

func (ts *OperatorsIntTestSuite) TestSimilarTo() {
switch getDBDialector() {
case mySQL, sqlServer, sqLite:
case query.MySQL, query.SQLServer, query.SQLite:
log.Println("SimilarTo not compatible")
case postgreSQL:
case query.Postgres:
match1 := ts.createProduct("abc", 0, 0, false, nil)
match2 := ts.createProduct("aabcc", 0, 0, false, nil)

Expand Down Expand Up @@ -611,11 +612,11 @@ func (ts *OperatorsIntTestSuite) TestPosixRegexCaseSensitive() {
var posixRegexOperator operator.Operator[string]

switch getDBDialector() {
case sqlServer, mySQL:
case query.SQLServer, query.MySQL:
log.Println("PosixRegex not compatible")
case postgreSQL:
case query.Postgres:
posixRegexOperator = psql.POSIXMatch("^a(b|x)")
case sqLite:
case query.SQLite:
posixRegexOperator = sqlite.Glob("a[bx]")
}

Expand Down Expand Up @@ -643,11 +644,11 @@ func (ts *OperatorsIntTestSuite) TestPosixRegexCaseInsensitive() {
var posixRegexOperator operator.Operator[string]

switch getDBDialector() {
case sqlServer, sqLite:
case query.SQLServer, query.SQLite:
log.Println("PosixRegex Case Insensitive not compatible")
case mySQL:
case query.MySQL:
posixRegexOperator = mysql.RegexP("^a(b|x)")
case postgreSQL:
case query.Postgres:
posixRegexOperator = psql.POSIXIMatch("^a(b|x)")
}

Expand Down Expand Up @@ -744,7 +745,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchConvertibl

func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvertible() {
switch getDBDialector() {
case sqLite:
case query.SQLite:
// comparisons between types are allowed and matches nothing if not convertible
ts.createProduct("", 0, 0, false, nil)
ts.createProduct("", 0, 2, false, nil)
Expand All @@ -757,7 +758,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvert
ts.Nil(err)

EqualList(&ts.Suite, []*models.Product{}, entities)
case mySQL:
case query.MySQL:
// comparisons between types are allowed but matches 0s if not convertible
match := ts.createProduct("", 0, 0, false, nil)
ts.createProduct("", 0, 2, false, nil)
Expand All @@ -770,14 +771,14 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvert
ts.Nil(err)

EqualList(&ts.Suite, []*models.Product{match}, entities)
case sqlServer:
case query.SQLServer:
// returns an error
_, err := orm.NewQuery[models.Product](
ts.db,
conditions.Product.FloatIs().Unsafe().Eq("not_convertible_to_float"),
).Find()
ts.ErrorContains(err, "mssql: Error converting data type nvarchar to float.")
case postgreSQL:
case query.Postgres:
// returns an error
_, err := orm.NewQuery[models.Product](
ts.db,
Expand All @@ -789,7 +790,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseTypesNotMatchNotConvert

func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch() {
switch getDBDialector() {
case sqLite:
case query.SQLite:
// comparisons between fields with different types are allowed
match1 := ts.createProduct("0", 0, 0, false, nil)
match2 := ts.createProduct("1", 0, 1, false, nil)
Expand All @@ -803,7 +804,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch(
ts.Nil(err)

EqualList(&ts.Suite, []*models.Product{match1, match2}, entities)
case mySQL:
case query.MySQL:
// comparisons between fields with different types are allowed but matches 0s on not convertible
match1 := ts.createProduct("0", 1, 0, false, nil)
match2 := ts.createProduct("1", 2, 1, false, nil)
Expand All @@ -817,7 +818,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch(
ts.Nil(err)

EqualList(&ts.Suite, []*models.Product{match1, match2, match3}, entities)
case sqlServer:
case query.SQLServer:
// comparisons between fields with different types are allowed and returns error only if at least one is not convertible
match1 := ts.createProduct("0", 1, 0, false, nil)
match2 := ts.createProduct("1", 2, 1, false, nil)
Expand All @@ -838,7 +839,7 @@ func (ts *OperatorsIntTestSuite) TestUnsafeOperatorInCaseFieldWithTypesNotMatch(
conditions.Product.FloatIs().Unsafe().Eq(conditions.Product.String),
).Find()
ts.ErrorContains(err, "mssql: Error converting data type nvarchar to float.")
case postgreSQL:
case query.Postgres:
// returns an error
_, err := orm.NewQuery[models.Product](
ts.db,
Expand Down
24 changes: 8 additions & 16 deletions testintegration/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/ditrit/badaas/orm"
"github.com/ditrit/badaas/orm/logger"
"github.com/ditrit/badaas/orm/query"
"github.com/ditrit/badaas/persistence/database"
"github.com/ditrit/badaas/persistence/gormfx"
)
Expand All @@ -33,15 +34,6 @@ const (
dbName = "badaas_db"
)

type dbDialector string

const (
postgreSQL dbDialector = "postgresql"
mySQL dbDialector = "mysql"
sqLite dbDialector = "sqlite"
sqlServer dbDialector = "sqlserver"
)

func TestBaDaaSORM(t *testing.T) {
tGlobal = t

Expand Down Expand Up @@ -84,13 +76,13 @@ func NewDBConnection() (*gorm.DB, error) {
var dialector gorm.Dialector

switch getDBDialector() {
case postgreSQL:
case query.Postgres:
dialector = postgres.Open(orm.CreatePostgreSQLDSN(host, username, password, sslMode, dbName, port))
case mySQL:
dialector = mysql.Open(orm.CreateMySQLDSN(host, username, password, dbName, port))
case sqLite:
case query.SQLite:
dialector = sqlite.Open(orm.CreateSQLiteDSN(host))
case sqlServer:
case query.MySQL:
dialector = mysql.Open(orm.CreateMySQLDSN(host, username, password, dbName, port))
case query.SQLServer:
dialector = sqlserver.Open(orm.CreateSQLServerDSN(host, username, password, dbName, port))
default:
return nil, fmt.Errorf("unknown db %s", getDBDialector())
Expand All @@ -103,6 +95,6 @@ func NewDBConnection() (*gorm.DB, error) {
)
}

func getDBDialector() dbDialector {
return dbDialector(os.Getenv(dbTypeEnvKey))
func getDBDialector() query.Dialector {
return query.Dialector(os.Getenv(dbTypeEnvKey))
}
Loading

0 comments on commit 28ccb7e

Please sign in to comment.