Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dryrun migration should run select #251

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,18 @@ type Migrator struct {
migrator.Migrator
}

// select querys ignore dryrun
func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) {
queryTx := m.DB
if m.DB.DryRun {
queryTx = m.DB.Session(&gorm.Session{})
queryTx.DryRun = false
}
return queryTx.Raw(sql, values...)
}

func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
return
}

Expand Down Expand Up @@ -87,7 +97,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw(
return m.queryRaw(
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
).Scan(&count).Error
})
Expand Down Expand Up @@ -155,7 +165,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {

func (m Migrator) GetTables() (tableList []string, err error) {
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
}

func (m Migrator) CreateTable(values ...interface{}) (err error) {
Expand Down Expand Up @@ -189,7 +199,7 @@ func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
})
return count > 0
}
Expand Down Expand Up @@ -241,7 +251,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
}

currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw(
return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
Expand All @@ -266,7 +276,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
m.DB.Raw(checkSQL, values...).Scan(&description)
m.queryRaw(checkSQL, values...).Scan(&description)

comment := strings.Trim(field.Comment, "'")
comment = strings.Trim(comment, `"`)
Expand Down Expand Up @@ -414,7 +424,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
}
currentSchema, curTable := m.CurrentSchema(stmt, table)

return m.DB.Raw(
return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
Expand All @@ -429,7 +439,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
columns, err = m.DB.Raw(
columns, err = m.queryRaw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
)
Expand Down Expand Up @@ -503,7 +513,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,

// check primary, unique field
{
columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
if err != nil {
return err
}
Expand All @@ -515,7 +525,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
}
columnTypeRows.Close()

columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
}
Expand All @@ -542,7 +552,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,

// check column type
{
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns
Expand Down Expand Up @@ -700,7 +710,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {

err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0)
scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error
scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
if scanErr != nil {
return scanErr
}
Expand Down
Loading