Skip to content

Commit

Permalink
impl cast array
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei committed Dec 22, 2022
1 parent 4adce4c commit 74a2864
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 7 deletions.
58 changes: 55 additions & 3 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func (c *castAsArrayFunctionClass) verifyArgs(args []Expression) error {
}

if args[0].GetType().EvalType() != types.ETJson {
return types.ErrInvalidJSONData.GenWithStackByArgs("1", "cast_as_array")
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "cast_as_array")
}

return nil
Expand Down Expand Up @@ -467,9 +467,61 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS
return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAST-ing Non-JSON Array type to array")
}

// TODO: impl the cast(... as ... array) function
arrayVals := make([]any, 0, len(b.args))
f := convertJSON2Tp(b.tp.ArrayType())
originVal := b.ctx.GetSessionVars().StmtCtx.OverflowAsWarning
b.ctx.GetSessionVars().StmtCtx.OverflowAsWarning = false
defer func() {
b.ctx.GetSessionVars().StmtCtx.OverflowAsWarning = originVal
}()
for i := 0; i < val.GetElemCount(); i++ {
item, err := f(b, val.ArrayGetElem(i))
if err != nil {
return types.BinaryJSON{}, false, err
}
arrayVals = append(arrayVals, item)
}
return types.CreateBinaryJSON(arrayVals), false, nil
}

return types.BinaryJSON{}, false, nil
func convertJSON2Tp(tp *types.FieldType) func(*castJSONAsArrayFunctionSig, types.BinaryJSON) (any, error) {
switch tp.EvalType() {
case types.ETString:
return func(b *castJSONAsArrayFunctionSig, item types.BinaryJSON) (any, error) {
if item.TypeCode != types.JSONTypeCodeString {
return nil, errIncorrectArgs
}
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, b.ctx.GetSessionVars().StmtCtx, false)
}
default:
return func(b *castJSONAsArrayFunctionSig, item types.BinaryJSON) (any, error) {
switch tp.EvalType() {
case types.ETInt:
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, errIncorrectArgs
}
case types.ETReal, types.ETDecimal:
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 && item.TypeCode != types.JSONTypeCodeFloat64 {
return nil, errIncorrectArgs
}
case types.ETDatetime:
if item.TypeCode != types.JSONTypeCodeDatetime {
return nil, errIncorrectArgs
}
case types.ETTimestamp:
if item.TypeCode != types.JSONTypeCodeTimestamp {
return nil, errIncorrectArgs
}
case types.ETDuration:
if item.TypeCode != types.JSONTypeCodeDate {
return nil, errIncorrectArgs
}
}
d := types.NewJSONDatum(item)
to, err := d.ConvertTo(b.ctx.GetSessionVars().StmtCtx, tp)
return to.GetValue(), err
}
}
}

type castAsJSONFunctionClass struct {
Expand Down
70 changes: 70 additions & 0 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1619,3 +1619,73 @@ func TestCastBinaryStringAsJSONSig(t *testing.T) {
require.Equal(t, tt.resultStr, res.String())
}
}

func TestCastArrayFunc(t *testing.T) {
ctx := createContext(t)
tbl := []struct {
input interface{}
expected interface{}
tp *types.FieldType
success bool
buildFuncSuccess bool
}{
{
[]interface{}{int64(-1), int64(2), int64(3)},
[]interface{}{int64(-1), int64(2), int64(3)},
types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetCharset(charset.CharsetBin).SetCollate(charset.CollationBin).SetArray(true).BuildP(),
true,
true,
},
{
[]interface{}{int64(-1), int64(2), int64(3)},
nil,
types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetCharset(charset.CharsetUTF8MB4).SetCollate(charset.CollationUTF8MB4).SetArray(true).BuildP(),
false,
true,
},
{
[]interface{}{"1"},
nil,
types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetCharset(charset.CharsetBin).SetCollate(charset.CharsetBin).SetArray(true).BuildP(),
false,
true,
},
{
[]interface{}{"1", "2"},
nil,
types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).SetCharset(charset.CharsetBin).SetCollate(charset.CharsetBin).SetArray(true).BuildP(),
false,
true,
},
{
[]interface{}{int64(-1), 2.1, int64(3)},
[]interface{}{int64(-1), 2.1, int64(3)},
types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).SetCharset(charset.CharsetBin).SetCollate(charset.CharsetBin).SetArray(true).BuildP(),
false,
true,
},
}
for _, tt := range tbl {
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp)
if tt.buildFuncSuccess {
require.NoError(t, err, tt.input)
} else {
require.Error(t, err, tt.input)
continue
}

val, isNull, err := f.EvalJSON(ctx, chunk.Row{})
if tt.success {
require.NoError(t, err, tt.input)
if tt.expected == nil {
require.True(t, isNull, tt.input)
} else {
j1 := types.CreateBinaryJSON(tt.expected)
cmp := types.CompareBinaryJSON(j1, val)
require.Equal(t, 0, cmp, tt.input)
}
} else {
require.Error(t, err, tt.input)
}
}
}
22 changes: 22 additions & 0 deletions expression/multi_valued_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ func TestMultiValuedIndexDDL(t *testing.T) {
tk.MustExec("drop table t")
tk.MustExec("create table t(a json, b int, index idx3(b, (cast(a as signed array))));")
}

func TestMultiValuedIndexDML(t *testing.T) {
store := testkit.CreateMockStore(t)

tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")
tk.MustExec("create table t(a json, index idx((cast(a as unsigned array))));")

tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustGetErrCode("insert into t values ('[-1]')", errno.ErrDataOutOfRange)
tk.MustGetErrCode(`insert into t values ('["1"]')`, errno.ErrWrongArguments)
tk.MustGetErrCode(`insert into t values ('["a"]')`, errno.ErrWrongArguments)
tk.MustGetErrCode(`insert into t values ('[1.2]')`, errno.ErrWrongArguments)
tk.MustGetErrCode(`insert into t values ('[1.0]')`, errno.ErrWrongArguments)

tk.MustExec("set @@sql_mode=''")
tk.MustGetErrCode("insert into t values ('[-1]')", errno.ErrDataOutOfRange)
tk.MustGetErrCode(`insert into t values ('["1"]')`, errno.ErrWrongArguments)
tk.MustGetErrCode(`insert into t values ('["a"]')`, errno.ErrWrongArguments)
tk.MustGetErrCode(`insert into t values ('[1.2]')`, errno.ErrWrongArguments)
tk.MustGetErrCode(`insert into t values ('[1.0]')`, errno.ErrWrongArguments)
}
8 changes: 4 additions & 4 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1184,10 +1184,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
er.disableFoldCounter--
}
case *ast.FuncCastExpr:
if v.Tp.IsArray() && !er.b.allowBuildCastArray {
er.err = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
return retNode, false
}
//if v.Tp.IsArray() && !er.b.allowBuildCastArray {
// er.err = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
// return retNode, false
//}
arg := er.ctxStack[len(er.ctxStack)-1]
er.err = expression.CheckArgsNotMultiColumnRow(arg)
if er.err != nil {
Expand Down

0 comments on commit 74a2864

Please sign in to comment.