Skip to content

Commit

Permalink
support require clauses in grant stmt (#666) (#698)
Browse files Browse the repository at this point in the history
  • Loading branch information
lysu authored Dec 24, 2019
1 parent 7df8c2c commit 93f4d5e
Show file tree
Hide file tree
Showing 5 changed files with 3,644 additions and 3,616 deletions.
52 changes: 30 additions & 22 deletions ast/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -976,12 +976,12 @@ const (
Subject
)

type TslOption struct {
type TLSOption struct {
Type int
Value string
}

func (t *TslOption) Restore(ctx *RestoreCtx) error {
func (t *TLSOption) Restore(ctx *RestoreCtx) error {
switch t.Type {
case TslNone:
ctx.WriteKeyWord("NONE")
Expand All @@ -996,10 +996,10 @@ func (t *TslOption) Restore(ctx *RestoreCtx) error {
ctx.WriteKeyWord("ISSUER ")
ctx.WriteString(t.Value)
case Subject:
ctx.WriteKeyWord("CIPHER")
ctx.WriteKeyWord("SUBJECT ")
ctx.WriteString(t.Value)
default:
return errors.Errorf("Unsupported TslOption.Type %d", t.Type)
return errors.Errorf("Unsupported TLSOption.Type %d", t.Type)
}
return nil
}
Expand Down Expand Up @@ -1077,7 +1077,7 @@ type CreateUserStmt struct {
IsCreateRole bool
IfNotExists bool
Specs []*UserSpec
TslOptions []*TslOption
TLSOptions []*TLSOption
ResourceOptions []*ResourceOption
PasswordOrLockOptions []*PasswordOrLockOption
}
Expand All @@ -1101,19 +1101,16 @@ func (n *CreateUserStmt) Restore(ctx *RestoreCtx) error {
}
}

tslOptionLen := len(n.TslOptions)

if tslOptionLen != 0 {
if len(n.TLSOptions) != 0 {
ctx.WriteKeyWord(" REQUIRE ")
}

// Restore `tslOptions` reversely to keep order the same with original sql
for i := tslOptionLen; i > 0; i-- {
if i != tslOptionLen {
for i, option := range n.TLSOptions {
if i != 0 {
ctx.WriteKeyWord(" AND ")
}
if err := n.TslOptions[i-1].Restore(ctx); err != nil {
return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.TslOptions[%d]", i)
if err := option.Restore(ctx); err != nil {
return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.TLSOptions[%d]", i)
}
}

Expand Down Expand Up @@ -1166,7 +1163,7 @@ type AlterUserStmt struct {
IfExists bool
CurrentAuth *AuthOption
Specs []*UserSpec
TslOptions []*TslOption
TLSOptions []*TLSOption
ResourceOptions []*ResourceOption
PasswordOrLockOptions []*PasswordOrLockOption
}
Expand All @@ -1193,19 +1190,16 @@ func (n *AlterUserStmt) Restore(ctx *RestoreCtx) error {
}
}

tslOptionLen := len(n.TslOptions)

if tslOptionLen != 0 {
if len(n.TLSOptions) != 0 {
ctx.WriteKeyWord(" REQUIRE ")
}

// Restore `tslOptions` reversely to keep order the same with original sql
for i := tslOptionLen; i > 0; i-- {
if i != tslOptionLen {
for i, option := range n.TLSOptions {
if i != 0 {
ctx.WriteKeyWord(" AND ")
}
if err := n.TslOptions[i-1].Restore(ctx); err != nil {
return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.TslOptions[%d]", i)
if err := option.Restore(ctx); err != nil {
return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.TLSOptions[%d]", i)
}
}

Expand Down Expand Up @@ -1869,6 +1863,7 @@ type GrantStmt struct {
ObjectType ObjectTypeType
Level *GrantLevel
Users []*UserSpec
TLSOptions []*TLSOption
WithGrant bool
}

Expand Down Expand Up @@ -1904,6 +1899,19 @@ func (n *GrantStmt) Restore(ctx *RestoreCtx) error {
return errors.Annotatef(err, "An error occurred while restore GrantStmt.Users[%d]", i)
}
}
if n.TLSOptions != nil {
if len(n.TLSOptions) != 0 {
ctx.WriteKeyWord(" REQUIRE ")
}
for i, option := range n.TLSOptions {
if i != 0 {
ctx.WriteKeyWord(" AND ")
}
if err := option.Restore(ctx); err != nil {
return errors.Annotatef(err, "An error occurred while restore GrantStmt.TLSOptions[%d]", i)
}
}
}
if n.WithGrant {
ctx.WriteKeyWord(" WITH GRANT OPTION")
}
Expand Down
2 changes: 2 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ const (
const (
// SystemDB is the name of system database.
SystemDB = "mysql"
// GlobalPrivTable is the table in system db contains global scope privilege info.
GlobalPrivTable = "global_priv"
// UserTable is the table in system db contains user info.
UserTable = "User"
// DBTable is the table in system db contains db scope privilege info.
Expand Down
Loading

0 comments on commit 93f4d5e

Please sign in to comment.