diff --git a/pkg/gofr/container/datasources.go b/pkg/gofr/container/datasources.go index 52c274221..bfa0e16d9 100644 --- a/pkg/gofr/container/datasources.go +++ b/pkg/gofr/container/datasources.go @@ -3,7 +3,6 @@ package container import ( "context" "database/sql" - "database/sql/driver" "github.com/redis/go-redis/v9" @@ -12,7 +11,6 @@ import ( ) type DB interface { - Driver() driver.Driver Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) @@ -23,6 +21,7 @@ type DB interface { Begin() (*gofrSQL.Tx, error) Select(ctx context.Context, data interface{}, query string, args ...interface{}) HealthCheck() *datasource.Health + Dialect() string } type Redis interface { diff --git a/pkg/gofr/container/mock_datasources.go b/pkg/gofr/container/mock_datasources.go index ba7a40b66..5d9b1fdb5 100644 --- a/pkg/gofr/container/mock_datasources.go +++ b/pkg/gofr/container/mock_datasources.go @@ -12,7 +12,6 @@ package container import ( context "context" sql "database/sql" - driver "database/sql/driver" reflect "reflect" time "time" @@ -60,18 +59,18 @@ func (mr *MockDBMockRecorder) Begin() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockDB)(nil).Begin)) } -// Driver mocks base method. -func (m *MockDB) Driver() driver.Driver { +// Dialect mocks base method. +func (m *MockDB) Dialect() string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Driver") - ret0, _ := ret[0].(driver.Driver) + ret := m.ctrl.Call(m, "Dialect") + ret0, _ := ret[0].(string) return ret0 } -// Driver indicates an expected call of Driver. -func (mr *MockDBMockRecorder) Driver() *gomock.Call { +// Dialect indicates an expected call of Dialect. +func (mr *MockDBMockRecorder) Dialect() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Driver", reflect.TypeOf((*MockDB)(nil).Driver)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dialect", reflect.TypeOf((*MockDB)(nil).Dialect)) } // Exec mocks base method. diff --git a/pkg/gofr/datasource/sql/db.go b/pkg/gofr/datasource/sql/db.go index 963c0a5cb..369ad2702 100644 --- a/pkg/gofr/datasource/sql/db.go +++ b/pkg/gofr/datasource/sql/db.go @@ -70,6 +70,10 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { return d.DB.Query(query, args...) } +func (d *DB) Dialect() string { + return d.config.Dialect +} + func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { defer d.logQuery(time.Now(), "QueryRow", query, args...) return d.DB.QueryRow(query, args...) diff --git a/pkg/gofr/migration/sql.go b/pkg/gofr/migration/sql.go index e6e3c3526..69d2e63da 100644 --- a/pkg/gofr/migration/sql.go +++ b/pkg/gofr/migration/sql.go @@ -5,9 +5,6 @@ import ( "database/sql" "time" - "github.com/go-sql-driver/mysql" - "github.com/lib/pq" - "gofr.dev/pkg/gofr/container" gofrSql "gofr.dev/pkg/gofr/datasource/sql" ) @@ -21,8 +18,6 @@ const ( constraint primary_key primary key (version, method) );` - checkSQLGoFrMigrationsTable = `SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'gofr_migrations');` - getLastSQLGoFrMigration = `SELECT COALESCE(MAX(version), 0) FROM gofr_migrations;` insertGoFrMigrationRowMySQL = `INSERT INTO gofr_migrations (version, method, start_time,duration) VALUES (?, ?, ?, ?);` @@ -90,35 +85,8 @@ func (s sqlMigratorObject) apply(m Migrator) Migrator { } func (d sqlMigrator) checkAndCreateMigrationTable(c *container.Container) error { - // this can be replaced with having switch case only in the exists variable - but we have chosen to differentiate based - // on driver because if new dialect comes will follow the same, also this complete has to be refactored as mentioned in RUN. - switch c.SQL.Driver().(type) { - case *mysql.MySQLDriver: - var exists int - - err := c.SQL.QueryRow(checkSQLGoFrMigrationsTable).Scan(&exists) - if err != nil { - return err - } - - if exists != 1 { - if _, err := c.SQL.Exec(createSQLGoFrMigrationsTable); err != nil { - return err - } - } - case *pq.Driver: - var exists bool - - err := c.SQL.QueryRow(checkSQLGoFrMigrationsTable).Scan(&exists) - if err != nil { - return err - } - - if !exists { - if _, err := c.SQL.Exec(createSQLGoFrMigrationsTable); err != nil { - return err - } - } + if _, err := c.SQL.Exec(createSQLGoFrMigrationsTable); err != nil { + return err } return d.Migrator.checkAndCreateMigrationTable(c) @@ -144,14 +112,14 @@ func (d sqlMigrator) getLastMigration(c *container.Container) int64 { } func (d sqlMigrator) commitMigration(c *container.Container, data migrationData) error { - switch c.SQL.Driver().(type) { - case *mysql.MySQLDriver: + switch c.SQL.Dialect() { + case "mysql": err := insertMigrationRecord(data.SQLTx, insertGoFrMigrationRowMySQL, data.MigrationNumber, data.StartTime) if err != nil { return err } - case *pq.Driver: + case "postgres": err := insertMigrationRecord(data.SQLTx, insertGoFrMigrationRowPostgres, data.MigrationNumber, data.StartTime) if err != nil { return err diff --git a/pkg/gofr/migration/sql_test.go b/pkg/gofr/migration/sql_test.go new file mode 100644 index 000000000..df0112de3 --- /dev/null +++ b/pkg/gofr/migration/sql_test.go @@ -0,0 +1,266 @@ +package migration + +import ( + "context" + "database/sql" + "errors" + "reflect" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "go.uber.org/mock/gomock" + + "gofr.dev/pkg/gofr/container" +) + +func TestNewMysql(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + + sqlDB := newMysql(mockDB) + + if sqlDB.db != mockDB { + t.Errorf("newMysql should wrap the provided db, got: %v", sqlDB.db) + } +} + +func TestQuery(t *testing.T) { + t.Run("successful query", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + expectedRows := &sql.Rows{} + + mockDB.EXPECT().Query("SELECT * FROM users", []interface{}{}).Return(expectedRows, nil) + sqlDB := newMysql(mockDB) + + rows, err := sqlDB.Query("SELECT * FROM users", []interface{}{}) + if rows.Err() != nil { + t.Errorf("unexpected row error: %v", rows.Err()) + } + + if err != nil { + t.Errorf("Query should return no error, got: %v", err) + } + + if rows != expectedRows { + t.Errorf("Query should return the expected rows, got: %v", rows) + } + }) + + t.Run("query error", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + expectedErr := sql.ErrNoRows + + mockDB.EXPECT().Query("SELECT * FROM unknown_table", []interface{}{}).Return(nil, expectedErr) + sqlDB := newMysql(mockDB) + + rows, err := sqlDB.Query("SELECT * FROM unknown_table", []interface{}{}) + if rows != nil { + t.Errorf("unexpected rows error: %v", rows.Err()) + } + + if err == nil { + t.Errorf("Query should return an error") + } + + if !errors.Is(err, expectedErr) { + t.Errorf("Query should return the expected error, got: %v", err) + } + }) +} + +func TestQueryRow(t *testing.T) { + t.Run("successful query row", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + expectedRow := &sql.Row{} + + mockDB.EXPECT().QueryRow("SELECT * FROM users WHERE id = ?", 1).Return(expectedRow) + sqlDB := newMysql(mockDB) + + row := sqlDB.QueryRow("SELECT * FROM users WHERE id = ?", 1) + + if row != expectedRow { + t.Errorf("QueryRow should return the expected row, got: %v", row) + } + }) +} + +func TestQueryRowContext(t *testing.T) { + ctx := context.Background() + + t.Run("successful query row context", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + expectedRow := &sql.Row{} + mockDB.EXPECT().QueryRowContext(ctx, "SELECT * FROM users WHERE id = ?", 1).Return(expectedRow) + sqlDB := newMysql(mockDB) + + row := sqlDB.QueryRowContext(ctx, "SELECT * FROM users WHERE id = ?", 1) + + if row != expectedRow { + t.Errorf("QueryRowContext should return the expected row, got: %v", row) + } + }) +} + +func TestExec(t *testing.T) { + t.Run("successful exec", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + expectedResult := sqlmock.NewResult(10, 1) + + mockDB.EXPECT().Exec("DELETE FROM users WHERE id = ?", 1).Return(expectedResult, nil) + sqlDB := newMysql(mockDB) + + result, err := sqlDB.Exec("DELETE FROM users WHERE id = ?", 1) + + if err != nil { + t.Errorf("Exec should return no error, got: %v", err) + } + + if !reflect.DeepEqual(result, expectedResult) { + t.Errorf("Exec should return the expected result, got: %v", result) + } + }) + + t.Run("exec error", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + + expectedErr := sql.ErrNoRows + mockDB.EXPECT().Exec("UPDATE unknown_table SET name = ?", "John").Return(nil, expectedErr) + sqlDB := newMysql(mockDB) + + _, err := sqlDB.Exec("UPDATE unknown_table SET name = ?", "John") + + if err == nil { + t.Errorf("Exec should return an error") + } + + if !errors.Is(err, expectedErr) { + t.Errorf("Exec should return the expected error, got: %v", err) + } + }) +} + +func TestExecContext(t *testing.T) { + ctx := context.Background() + + t.Run("successful exec context", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + + expectedResult := sqlmock.NewResult(10, 1) + mockDB.EXPECT().ExecContext(ctx, "DELETE FROM users WHERE id = ?", 1).Return(expectedResult, nil) + sqlDB := newMysql(mockDB) + + result, err := sqlDB.ExecContext(ctx, "DELETE FROM users WHERE id = ?", 1) + + if err != nil { + t.Errorf("ExecContext should return no error, got: %v", err) + } + + if !reflect.DeepEqual(result, expectedResult) { + t.Errorf("ExecContext should return the expected result, got: %v", result) + } + }) +} + +func TestCheckAndCreateMigrationTableSuccess(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + mockMigrator := NewMockMigrator(ctrl) + mockContainer, mocks := container.NewMockContainer(t) + + mockMigrator.EXPECT().checkAndCreateMigrationTable(mockContainer) + mocks.SQL.EXPECT().Exec(createSQLGoFrMigrationsTable).Return(nil, nil) + + migrator := sqlMigrator{ + db: mockDB, + Migrator: mockMigrator, + } + + err := migrator.checkAndCreateMigrationTable(mockContainer) + + if err != nil { + t.Errorf("checkAndCreateMigrationTable should return no error, got: %v", err) + } +} + +func TestCheckAndCreateMigrationTableExecError(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + mockMigrator := NewMockMigrator(ctrl) + mockContainer, mocks := container.NewMockContainer(t) + expectedErr := sql.ErrNoRows + + mocks.SQL.EXPECT().Exec(createSQLGoFrMigrationsTable).Return(nil, expectedErr) + + migrator := sqlMigrator{ + db: mockDB, + Migrator: mockMigrator, + } + + err := migrator.checkAndCreateMigrationTable(mockContainer) + + if err == nil { + t.Errorf("checkAndCreateMigrationTable should return an error") + } + + if !errors.Is(err, expectedErr) { + t.Errorf("checkAndCreateMigrationTable should return the expected error, got: %v", err) + } +} + +func TestBeginTransactionSuccess(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + mockMigrator := NewMockMigrator(ctrl) + mockContainer, mocks := container.NewMockContainer(t) + expectedMigrationData := migrationData{} + + mocks.SQL.EXPECT().Begin() + mockMigrator.EXPECT().beginTransaction(mockContainer) + + migrator := sqlMigrator{ + db: mockDB, + Migrator: mockMigrator, + } + data := migrator.beginTransaction(mockContainer) + + if data != expectedMigrationData { + t.Errorf("beginTransaction should return data from Migrator, got: %v", data) + } +} + +var ( + errBeginTx = errors.New("failed to begin transaction") +) + +func TestBeginTransactionDBError(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := container.NewMockDB(ctrl) + mockMigrator := NewMockMigrator(ctrl) + mockContainer, mocks := container.NewMockContainer(t) + + mocks.SQL.EXPECT().Begin().Return(nil, errBeginTx) + + migrator := sqlMigrator{ + db: mockDB, + Migrator: mockMigrator, + } + data := migrator.beginTransaction(mockContainer) + + if data.SQLTx != nil { + t.Errorf("beginTransaction should not return a transaction on DB error") + } +} + +func TestRollbackNoTransaction(t *testing.T) { + mockContainer, _ := container.NewMockContainer(t) + + migrator := sqlMigrator{} + migrator.rollback(mockContainer, migrationData{}) +}