Skip to content

Commit

Permalink
Merge pull request #1627 from dolthub/zachmu/signal
Browse files Browse the repository at this point in the history
Made signal statements work with user vars
  • Loading branch information
zachmu authored Mar 2, 2023
2 parents e08b02d + 9e8e516 commit 5d3d1d6
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 32 deletions.
25 changes: 14 additions & 11 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,26 @@ func TestSingleScript(t *testing.T) {
t.Skip()
var scripts = []queries.ScriptTest{
{
Name: "create table as select distinct",
Name: "trigger with signal and user var",
SetUpScript: []string{
"CREATE TABLE t1 (a int, b varchar(10));",
"insert into t1 values (1, 'a'), (2, 'b'), (2, 'b'), (3, 'c');",
"create table t1 (id int primary key)",
"create table t2 (id int primary key)",
`
create trigger trigger1 before insert on t1
for each row
begin
set @myvar = concat('bro', 'ken');
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = @myvar;
end;`,
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "create table t2 as select distinct b, a from t1;",
Expected: []sql.Row{{types.OkResult{RowsAffected: 3}}},
Query: "insert into t1 values (1)",
ExpectedErrStr: "broken (errno 1644) (sqlstate 45000)",
},
{
Query: "select * from t2 order by a;",
Expected: []sql.Row{
{"a", 1},
{"b", 2},
{"c", 3},
},
Query: "select id from t1",
Expected: []sql.Row{},
},
},
},
Expand Down
30 changes: 29 additions & 1 deletion enginetest/queries/trigger_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1320,13 +1320,41 @@ begin
if
(select target_id from sn where id = NEW.upstream_edge_id) <> (select source_id from sn where id = NEW.downstream_edge_id)
then
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'broken';
set @myvar = concat('bro', 'ken');
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = @myvar;
end if;
end;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into rn values (1,1,1)",
},
{
Query: "select id from rn",
Expected: []sql.Row{{1}},
},
},
},
{
Name: "trigger with signal and user var",
SetUpScript: []string{
"create table t1 (id int primary key)",
"create table t2 (id int primary key)",
`
create trigger trigger1 before insert on t1
for each row
begin
set @myvar = concat('bro', 'ken');
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = @myvar;
end;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into t1 values (1)",
ExpectedErrStr: "broken (errno 1644) (sqlstate 45000)",
},
{
Query: "select id from t1",
Expected: []sql.Row{},
},
},
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/dolthub/go-mysql-server
require (
github.com/cespare/xxhash v1.1.0
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20230223032306-95d4b04eabad
github.com/dolthub/vitess v0.0.0-20230301224006-436948ebe944
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.6.0
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474/go.mod h1:kMz7uXOXq4qRriCEyZ/LUeTqraLJCjf0WVZcUi6TxUY=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20230216234925-189ffe819e56 h1:dHuKfUwaDUe847BVN3Wo+4GUGUNdlhuUif4RWkvG3Go=
github.com/dolthub/vitess v0.0.0-20230216234925-189ffe819e56/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs=
github.com/dolthub/vitess v0.0.0-20230223032306-95d4b04eabad h1:9FPQtKoqyREEsHfGKNU2DImktOusXTXklLtvTxtIuZ0=
github.com/dolthub/vitess v0.0.0-20230223032306-95d4b04eabad/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs=
github.com/dolthub/vitess v0.0.0-20230301224006-436948ebe944 h1:Rlccv6h7kWyJLxc8IiWwjLqwTlNkOvCFbtJzFu2kEcA=
github.com/dolthub/vitess v0.0.0-20230301224006-436948ebe944/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
43 changes: 30 additions & 13 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -1608,24 +1608,41 @@ func convertSignal(ctx *sql.Context, s *sqlparser.Signal) (sql.Node, error) {
}

if si.ConditionItemName == plan.SignalConditionItemName_MysqlErrno {
number, err := strconv.ParseUint(string(info.Value.Val), 10, 16)
if err != nil || number == 0 {
// We use our own error instead
return nil, fmt.Errorf("invalid value '%s' for signal condition information item MYSQL_ERRNO", string(info.Value.Val))
switch v := info.Value.(type) {
case *sqlparser.SQLVal:
number, err := strconv.ParseUint(string(v.Val), 10, 16)
if err != nil || number == 0 {
// We use our own error instead
return nil, fmt.Errorf("invalid value '%s' for signal condition information item MYSQL_ERRNO", string(v.Val))
}
si.IntValue = int64(number)
default:
return nil, fmt.Errorf("invalid value '%v' for signal condition information item MYSQL_ERRNO", info.Value)
}
si.IntValue = int64(number)
} else if si.ConditionItemName == plan.SignalConditionItemName_MessageText {
val := string(info.Value.Val)
if len(val) > 128 {
return nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128")
switch v := info.Value.(type) {
case *sqlparser.SQLVal:
val := string(v.Val)
if len(val) > 128 {
return nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128")
}
si.StrValue = val
case *sqlparser.ColName:
si.ExprVal = expression.NewUnresolvedColumn(v.Name.String())
default:
return nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", info.Value)
}
si.StrValue = val
} else {
val := string(info.Value.Val)
if len(val) > 64 {
return nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(si.ConditionItemName)))
switch v := info.Value.(type) {
case *sqlparser.SQLVal:
val := string(v.Val)
if len(val) > 64 {
return nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(si.ConditionItemName)))
}
si.StrValue = val
default:
return nil, fmt.Errorf("invalid value '%v' for signal condition information item '%s''", info.Value, strings.ToUpper(string(si.ConditionItemName)))
}
si.StrValue = val
}
signalInfo[si.ConditionItemName] = si
}
Expand Down
39 changes: 39 additions & 0 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5013,6 +5013,45 @@ func TestParseCreateTrigger(t *testing.T) {
time.Unix(0, 0),
"``@``",
),
`create trigger signal_with_user_var
BEFORE DELETE ON FOO FOR EACH ROW
BEGIN
SET @message_text = CONCAT('ouch', 'oof');
SIGNAL SQLSTATE '45000'
SET MESSAGE_TEXT = @message_text;
END`: plan.NewCreateTrigger(sql.UnresolvedDatabase(""),
"signal_with_user_var", "before", "delete",
nil,
plan.NewUnresolvedTable("FOO", ""),
plan.NewBeginEndBlock("", plan.NewBlock([]sql.Node{
plan.NewSet([]sql.Expression{
expression.NewSetField(
expression.NewUserVar("message_text"),
expression.NewUnresolvedFunction("concat", false, nil, expression.NewLiteral("ouch", types.LongText), expression.NewLiteral("oof", types.LongText)),
),
}),
plan.NewSignal("45000", map[plan.SignalConditionItemName]plan.SignalInfo{
plan.SignalConditionItemName_MessageText: {
ConditionItemName: plan.SignalConditionItemName_MessageText,
ExprVal: expression.NewUnresolvedColumn("@message_text"),
},
}),
},
)),
`create trigger signal_with_user_var
BEFORE DELETE ON FOO FOR EACH ROW
BEGIN
SET @message_text = CONCAT('ouch', 'oof');
SIGNAL SQLSTATE '45000'
SET MESSAGE_TEXT = @message_text;
END`,
`BEGIN
SET @message_text = CONCAT('ouch', 'oof');
SIGNAL SQLSTATE '45000'
SET MESSAGE_TEXT = @message_text;
END`,
time.Unix(0, 0),
"``@``"),
}

var queriesInOrder []string
Expand Down
111 changes: 109 additions & 2 deletions sql/plan/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package plan

import (
"fmt"
"sort"
"strings"

"github.com/dolthub/vitess/go/mysql"
Expand Down Expand Up @@ -61,6 +62,7 @@ type SignalInfo struct {
ConditionItemName SignalConditionItemName
IntValue int64
StrValue string
ExprVal sql.Expression
}

// Signal represents the SIGNAL statement with a set SQLSTATE.
Expand All @@ -77,6 +79,7 @@ type SignalName struct {

var _ sql.Node = (*Signal)(nil)
var _ sql.Node = (*SignalName)(nil)
var _ sql.Expressioner = (*Signal)(nil)

// NewSignal returns a *Signal node.
func NewSignal(sqlstate string, info map[SignalConditionItemName]SignalInfo) *Signal {
Expand Down Expand Up @@ -129,6 +132,11 @@ func NewSignalName(name string, info map[SignalConditionItemName]SignalInfo) *Si

// Resolved implements the sql.Node interface.
func (s *Signal) Resolved() bool {
for _, e := range s.Expressions() {
if !e.Resolved() {
return false
}
}
return true
}

Expand All @@ -152,6 +160,26 @@ func (s *Signal) String() string {
return fmt.Sprintf("SIGNAL SQLSTATE '%s'%s", s.SqlStateValue, infoStr)
}

// DebugString implements the sql.DebugStringer interface.
func (s *Signal) DebugString() string {
infoStr := ""
if len(s.Info) > 0 {
infoStr = " SET"
i := 0
for _, k := range SignalItems {
// enforce deterministic ordering
if info, ok := s.Info[k]; ok {
if i > 0 {
infoStr += ","
}
infoStr += " " + info.DebugString()
i++
}
}
}
return fmt.Sprintf("SIGNAL SQLSTATE '%s'%s", s.SqlStateValue, infoStr)
}

// Schema implements the sql.Node interface.
func (s *Signal) Schema() sql.Schema {
return nil
Expand All @@ -167,6 +195,58 @@ func (s *Signal) WithChildren(children ...sql.Node) (sql.Node, error) {
return NillaryWithChildren(s, children...)
}

func (s *Signal) Expressions() []sql.Expression {
items := s.signalItemsWithExpressions()

var exprs []sql.Expression
for _, itemInfo := range items {
exprs = append(exprs, itemInfo.ExprVal)
}

return exprs
}

// signalItemsWithExpressions returns the subset of the Info map entries that have an expression value, sorted by
// item name
func (s *Signal) signalItemsWithExpressions() []SignalInfo {
var items []SignalInfo

for _, itemInfo := range s.Info {
if itemInfo.ExprVal != nil {
items = append(items, itemInfo)
}
}

// Very important to have a consistent sort order between here and the WithExpressions call
sort.Slice(items, func(i, j int) bool {
return items[i].ConditionItemName < items[j].ConditionItemName
})

return items
}

func (s Signal) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
itemsWithExprs := s.signalItemsWithExpressions()
if len(itemsWithExprs) != len(exprs) {
return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(itemsWithExprs))
}

mapCopy := make(map[SignalConditionItemName]SignalInfo)
for k, v := range s.Info {
mapCopy[k] = v
}

for i := range exprs {
// transfer the expression to the new info map
newInfo := itemsWithExprs[i]
newInfo.ExprVal = exprs[i]
mapCopy[itemsWithExprs[i].ConditionItemName] = newInfo
}

s.Info = mapCopy
return &s, nil
}

// CheckPrivileges implements the interface sql.Node.
func (s *Signal) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return true
Expand All @@ -188,10 +268,25 @@ func (s *Signal) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
//TODO: implement warnings
return nil, fmt.Errorf("warnings not yet implemented")
} else {

messageItem := s.Info[SignalConditionItemName_MessageText]
strValue := messageItem.StrValue
if messageItem.ExprVal != nil {
exprResult, err := messageItem.ExprVal.Eval(ctx, nil)
if err != nil {
return nil, err
}
s, ok := exprResult.(string)
if !ok {
return nil, fmt.Errorf("message text expression did not evaluate to a string")
}
strValue = s
}

return nil, mysql.NewSQLError(
int(s.Info[SignalConditionItemName_MysqlErrno].IntValue),
s.SqlStateValue,
s.Info[SignalConditionItemName_MessageText].StrValue,
strValue,
)
}
}
Expand Down Expand Up @@ -245,7 +340,19 @@ func (s *SignalName) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error)

func (s SignalInfo) String() string {
itemName := strings.ToUpper(string(s.ConditionItemName))
if s.ConditionItemName == SignalConditionItemName_MysqlErrno {
if s.ExprVal != nil {
return fmt.Sprintf("%s = %s", itemName, s.ExprVal.String())
} else if s.ConditionItemName == SignalConditionItemName_MysqlErrno {
return fmt.Sprintf("%s = %d", itemName, s.IntValue)
}
return fmt.Sprintf("%s = %s", itemName, s.StrValue)
}

func (s SignalInfo) DebugString() string {
itemName := strings.ToUpper(string(s.ConditionItemName))
if s.ExprVal != nil {
return fmt.Sprintf("%s = %s", itemName, sql.DebugString(s.ExprVal))
} else if s.ConditionItemName == SignalConditionItemName_MysqlErrno {
return fmt.Sprintf("%s = %d", itemName, s.IntValue)
}
return fmt.Sprintf("%s = %s", itemName, s.StrValue)
Expand Down

0 comments on commit 5d3d1d6

Please sign in to comment.