Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

infoschema: fix inserting into a temporary table panics after the database is dropped #29263

Merged
merged 3 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions ddl/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2972,6 +2972,89 @@ func (s *testIntegrationSuite3) TestCreateTemporaryTable(c *C) {
tk.MustExec(updateSafePoint)
}

func (s *testIntegrationSuite3) TestAccessLocalTmpTableAfterDropDB(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("create database if not exists tmpdb")
tk.MustExec("create temporary table tmpdb.tmp(id int)")
tk.MustExec("drop database tmpdb")

tests := []struct {
sql string
errcode int
result []string
queryResult []string
}{
{
sql: "insert into tmpdb.tmp values(1)",
result: []string{"1"},
},
{
sql: "select * from tmpdb.tmp t1 join tmpdb.tmp t2 where t1.id=t2.id",
queryResult: []string{"1 1"},
},
{
sql: "select (select id from tmpdb.tmp) id1, t1.id id2 from (select * from tmpdb.tmp) t1 where t1.id=1",
queryResult: []string{"1 1"},
},
{
sql: "update tmpdb.tmp set id=2 where id=1",
result: []string{"2"},
},
{
sql: "delete from tmpdb.tmp where id=2",
result: []string{},
},
{
sql: "insert into tmpdb.tmp select 1 from dual",
result: []string{"1"},
},
{
sql: "update tmpdb.tmp t1, tmpdb.tmp t2 set t1.id=2 where t1.id=t2.id",
result: []string{"2"},
},
{
sql: "delete t1 from tmpdb.tmp t1 join tmpdb.tmp t2 where t1.id=t2.id",
result: []string{},
},
{
sql: "admin check table tmpdb.tmp",
errcode: errno.ErrOptOnTemporaryTable,
},
{
sql: "alter table tmpdb.tmp add column name char(10)",
errcode: errno.ErrUnsupportedDDLOperation,
},
}

executeTests := func() {
tk.MustExec("truncate table tmpdb.tmp")
for _, test := range tests {
switch {
case test.errcode != 0:
tk.MustGetErrCode(test.sql, test.errcode)
case test.queryResult != nil:
tk.MustQuery(test.sql).Check(testkit.Rows(test.queryResult...))
case test.result != nil:
tk.MustExec(test.sql)
tk.MustQuery("select * from tmpdb.tmp").Check(testkit.Rows(test.result...))
default:
tk.MustExec(test.sql)
}
}
}

executeTests()

// Create the database again.
tk.MustExec("create database tmpdb")
executeTests()

// Create another table in the database and drop the database again.
tk.MustExec("create temporary table tmpdb.tmp2(id int)")
tk.MustExec("drop database tmpdb")
executeTests()
}

func (s *testSerialDBSuite) TestPlacementOnTemporaryTable(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
2 changes: 1 addition & 1 deletion executor/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error {
return err
}

return e.tempTableDDL.CreateLocalTemporaryTable(dbInfo.Name, tbInfo)
return e.tempTableDDL.CreateLocalTemporaryTable(dbInfo, tbInfo)
}

func (e *DDLExec) executeCreateView(s *ast.CreateViewStmt) error {
Expand Down
47 changes: 24 additions & 23 deletions infoschema/infoschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,20 +457,19 @@ func GetBundle(h InfoSchema, ids []int64) *placement.Bundle {
return &placement.Bundle{ID: placement.GroupID(id), Rules: newRules}
}

type schemaLocalTempSchemaTables struct {
tables map[string]table.Table
}

// LocalTemporaryTables store local temporary tables
type LocalTemporaryTables struct {
schemaMap map[string]*schemaLocalTempSchemaTables
// Local temporary tables can be accessed after the db is dropped, so there needs a way to retain the DBInfo.
// schemaTables.dbInfo will only be used when the db is dropped and it may be stale after the db is created again.
// But it's fine because we only need its name.
schemaMap map[string]*schemaTables
idx2table map[int64]table.Table
}

// NewLocalTemporaryTables creates a new NewLocalTemporaryTables object
func NewLocalTemporaryTables() *LocalTemporaryTables {
return &LocalTemporaryTables{
schemaMap: make(map[string]*schemaLocalTempSchemaTables),
schemaMap: make(map[string]*schemaTables),
idx2table: make(map[int64]table.Table),
}
}
Expand Down Expand Up @@ -498,8 +497,8 @@ func (is *LocalTemporaryTables) TableByID(id int64) (tbl table.Table, ok bool) {
}

// AddTable add a table
func (is *LocalTemporaryTables) AddTable(schema model.CIStr, tbl table.Table) error {
schemaTables := is.ensureSchema(schema)
func (is *LocalTemporaryTables) AddTable(db *model.DBInfo, tbl table.Table) error {
schemaTables := is.ensureSchema(db)

tblMeta := tbl.Meta()
if _, ok := schemaTables.tables[tblMeta.Name.L]; ok {
Expand Down Expand Up @@ -530,37 +529,40 @@ func (is *LocalTemporaryTables) RemoveTable(schema, table model.CIStr) (exist bo

delete(tbls.tables, table.L)
delete(is.idx2table, oldTable.Meta().ID)
if len(tbls.tables) == 0 {
delete(is.schemaMap, schema.L)
}
return true
}

// SchemaByTable get a table's schema name
func (is *LocalTemporaryTables) SchemaByTable(tableInfo *model.TableInfo) (string, bool) {
func (is *LocalTemporaryTables) SchemaByTable(tableInfo *model.TableInfo) (*model.DBInfo, bool) {
if tableInfo == nil {
return "", false
return nil, false
}

for schema, v := range is.schemaMap {
for _, v := range is.schemaMap {
if tbl, ok := v.tables[tableInfo.Name.L]; ok {
if tbl.Meta().ID == tableInfo.ID {
return schema, true
return v.dbInfo, true
}
}
}

return "", false
return nil, false
}

func (is *LocalTemporaryTables) ensureSchema(schema model.CIStr) *schemaLocalTempSchemaTables {
if tbls, ok := is.schemaMap[schema.L]; ok {
func (is *LocalTemporaryTables) ensureSchema(db *model.DBInfo) *schemaTables {
if tbls, ok := is.schemaMap[db.Name.L]; ok {
return tbls
}

tbls := &schemaLocalTempSchemaTables{tables: make(map[string]table.Table)}
is.schemaMap[schema.L] = tbls
tbls := &schemaTables{dbInfo: db, tables: make(map[string]table.Table)}
is.schemaMap[db.Name.L] = tbls
return tbls
}

func (is *LocalTemporaryTables) schemaTables(schema model.CIStr) *schemaLocalTempSchemaTables {
func (is *LocalTemporaryTables) schemaTables(schema model.CIStr) *schemaTables {
if is.schemaMap == nil {
return nil
}
Expand All @@ -574,8 +576,7 @@ func (is *LocalTemporaryTables) schemaTables(schema model.CIStr) *schemaLocalTem

// TemporaryTableAttachedInfoSchema implements InfoSchema
// Local temporary table has a loose relationship with database.
// So when a database is dropped, its temporary tables still exist and can be return by TableByName/TableByID.
// However SchemaByTable will return nil if database is dropped.
// So when a database is dropped, its temporary tables still exist and can be returned by TableByName/TableByID.
type TemporaryTableAttachedInfoSchema struct {
InfoSchema
LocalTemporaryTables *LocalTemporaryTables
Expand All @@ -599,14 +600,14 @@ func (ts *TemporaryTableAttachedInfoSchema) TableByID(id int64) (table.Table, bo
return ts.InfoSchema.TableByID(id)
}

// SchemaByTable implements InfoSchema.SchemaByTable
// SchemaByTable implements InfoSchema.SchemaByTable, it returns a stale DBInfo even if it's dropped.
func (ts *TemporaryTableAttachedInfoSchema) SchemaByTable(tableInfo *model.TableInfo) (*model.DBInfo, bool) {
if tableInfo == nil {
return nil, false
}

if schemaName, ok := ts.LocalTemporaryTables.SchemaByTable(tableInfo); ok {
return ts.SchemaByName(model.NewCIStr(schemaName))
if db, ok := ts.LocalTemporaryTables.SchemaByTable(tableInfo); ok {
return db, true
}

return ts.InfoSchema.SchemaByTable(tableInfo)
Expand Down
44 changes: 27 additions & 17 deletions infoschema/infoschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,15 @@ func TestLocalTemporaryTables(t *testing.T) {
}
}

assertSchemaByTable := func(sc *infoschema.LocalTemporaryTables, schema model.CIStr, tb *model.TableInfo) {
assertSchemaByTable := func(sc *infoschema.LocalTemporaryTables, db *model.DBInfo, tb *model.TableInfo) {
got, ok := sc.SchemaByTable(tb)
if tb == nil {
require.True(t, schema.L == "")
require.Equal(t, "", got)
if db == nil {
require.Nil(t, got)
require.False(t, ok)
} else {
require.Equal(t, schema.L != "", ok)
require.Equal(t, got, schema.L)
require.NotNil(t, got)
require.Equal(t, db.Name.L, got.Name.L)
require.True(t, ok)
}
}

Expand Down Expand Up @@ -513,7 +513,7 @@ func TestLocalTemporaryTables(t *testing.T) {
}

for _, p := range prepareTables {
err = sc.AddTable(p.db.Name, p.tb)
err = sc.AddTable(p.db, p.tb)
require.NoError(t, err)
}

Expand Down Expand Up @@ -541,20 +541,20 @@ func TestLocalTemporaryTables(t *testing.T) {
)

assertTableByID(sc, p.tb.Meta().ID, p.db, p.tb)
assertSchemaByTable(sc, p.db.Name, p.tb.Meta())
assertSchemaByTable(sc, p.db, p.tb.Meta())
}

// test add dup table
err = sc.AddTable(db1.Name, tb11)
err = sc.AddTable(db1, tb11)
require.True(t, infoschema.ErrTableExists.Equal(err))
err = sc.AddTable(db1b.Name, tb15)
err = sc.AddTable(db1b, tb15)
require.True(t, infoschema.ErrTableExists.Equal(err))
err = sc.AddTable(db1b.Name, tb11)
err = sc.AddTable(db1b, tb11)
require.True(t, infoschema.ErrTableExists.Equal(err))
db1c := createNewSchemaInfo("db1")
err = sc.AddTable(db1c.Name, createNewTable(db1c.ID, "tb1", model.TempTableLocal))
err = sc.AddTable(db1c, createNewTable(db1c.ID, "tb1", model.TempTableLocal))
require.True(t, infoschema.ErrTableExists.Equal(err))
err = sc.AddTable(db1b.Name, tb11)
err = sc.AddTable(db1b, tb11)
require.True(t, infoschema.ErrTableExists.Equal(err))

// failed add has no effect
Expand Down Expand Up @@ -585,22 +585,23 @@ func TestLocalTemporaryTables(t *testing.T) {
}

// test non exist table schemaByTable
assertSchemaByTable(sc, model.NewCIStr(""), tb11.Meta())
assertSchemaByTable(sc, model.NewCIStr(""), tb22.Meta())
assertSchemaByTable(sc, model.NewCIStr(""), nil)
assertSchemaByTable(sc, nil, tb11.Meta())
assertSchemaByTable(sc, nil, tb22.Meta())
assertSchemaByTable(sc, nil, nil)

// test TemporaryTableAttachedInfoSchema
dbTest := createNewSchemaInfo("test")
tmpTbTestA := createNewTable(dbTest.ID, "tba", model.TempTableLocal)
normalTbTestA := createNewTable(dbTest.ID, "tba", model.TempTableNone)
normalTbTestB := createNewTable(dbTest.ID, "tbb", model.TempTableNone)
normalTbTestC := createNewTable(db1.ID, "tbc", model.TempTableNone)

is := &infoschema.TemporaryTableAttachedInfoSchema{
InfoSchema: infoschema.MockInfoSchema([]*model.TableInfo{normalTbTestA.Meta(), normalTbTestB.Meta()}),
LocalTemporaryTables: sc,
}

err = sc.AddTable(dbTest.Name, tmpTbTestA)
err = sc.AddTable(dbTest, tmpTbTestA)
require.NoError(t, err)

// test TableByName
Expand Down Expand Up @@ -641,7 +642,16 @@ func TestLocalTemporaryTables(t *testing.T) {
info, ok = is.SchemaByTable(tmpTbTestA.Meta())
require.True(t, ok)
require.Equal(t, dbTest.Name.L, info.Name.L)
// SchemaByTable also returns DBInfo when the schema is not in the infoSchema but the table is an existing tmp table.
info, ok = is.SchemaByTable(tb12.Meta())
require.True(t, ok)
require.Equal(t, db1.Name.L, info.Name.L)
// SchemaByTable returns nil when the schema is not in the infoSchema and the table is an non-existing normal table.
info, ok = is.SchemaByTable(normalTbTestC.Meta())
require.False(t, ok)
require.Nil(t, info)
// SchemaByTable returns nil when the schema is not in the infoSchema and the table is an non-existing tmp table.
info, ok = is.SchemaByTable(tb22.Meta())
require.False(t, ok)
require.Nil(t, info)
}
2 changes: 1 addition & 1 deletion planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) {
}

tableInfo := table.Meta()
dbInfo, _ := p.ensureInfoSchema().SchemaByName(tn.Schema)
dbInfo, _ := p.ensureInfoSchema().SchemaByTable(tableInfo)
// tableName should be checked as sequence object.
if p.flag&inSequenceFunction > 0 {
if !tableInfo.IsSequence() {
Expand Down
9 changes: 5 additions & 4 deletions table/temptable/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (

// TemporaryTableDDL is an interface providing ddl operations for temporary table
type TemporaryTableDDL interface {
CreateLocalTemporaryTable(schema model.CIStr, info *model.TableInfo) error
CreateLocalTemporaryTable(db *model.DBInfo, info *model.TableInfo) error
DropLocalTemporaryTable(schema model.CIStr, tblName model.CIStr) error
TruncateLocalTemporaryTable(schema model.CIStr, tblName model.CIStr) error
}
Expand All @@ -45,7 +45,7 @@ type temporaryTableDDL struct {
sctx sessionctx.Context
}

func (d *temporaryTableDDL) CreateLocalTemporaryTable(schema model.CIStr, info *model.TableInfo) error {
func (d *temporaryTableDDL) CreateLocalTemporaryTable(db *model.DBInfo, info *model.TableInfo) error {
if _, err := ensureSessionData(d.sctx); err != nil {
return err
}
Expand All @@ -55,7 +55,7 @@ func (d *temporaryTableDDL) CreateLocalTemporaryTable(schema model.CIStr, info *
return err
}

return ensureLocalTemporaryTables(d.sctx).AddTable(schema, tbl)
return ensureLocalTemporaryTables(d.sctx).AddTable(db, tbl)
}

func (d *temporaryTableDDL) DropLocalTemporaryTable(schema model.CIStr, tblName model.CIStr) error {
Expand All @@ -81,8 +81,9 @@ func (d *temporaryTableDDL) TruncateLocalTemporaryTable(schema model.CIStr, tblN
}

localTempTables := getLocalTemporaryTables(d.sctx)
db, _ := localTempTables.SchemaByTable(oldTblInfo)
localTempTables.RemoveTable(schema, tblName)
if err = localTempTables.AddTable(schema, newTbl); err != nil {
if err = localTempTables.AddTable(db, newTbl); err != nil {
return err
}

Expand Down
Loading