Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Support union and none type in flyteidl #401

Merged
merged 4 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions clients/go/coreutils/extract_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) {
return scalarValue.Generic, nil
case *core.Scalar_StructuredDataset:
return scalarValue.StructuredDataset.Uri, nil
case *core.Scalar_Union:
// extract the value of the union but not the actual union object
extractedVal, err := ExtractFromLiteral(scalarValue.Union.Value)
if err != nil {
return nil, err
}
return extractedVal, nil
case *core.Scalar_NoneType:
return nil, nil
default:
return nil, fmt.Errorf("unsupported literal scalar type %T", scalarValue)
}
Expand Down
40 changes: 39 additions & 1 deletion clients/go/coreutils/extract_literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestFetchLiteral(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, p.GetScalar())
_, err = ExtractFromLiteral(p)
assert.NotNil(t, err)
assert.Nil(t, err)
})

t.Run("Generic", func(t *testing.T) {
Expand Down Expand Up @@ -199,4 +199,42 @@ func TestFetchLiteral(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union", func(t *testing.T) {
literalVal := int64(1)
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union with None", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, nil)

assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Nil(t, extractedLiteralVal)
})
}
48 changes: 48 additions & 0 deletions clients/go/coreutils/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,28 @@ func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) {
return MakeLiteralForType(typ, nil)
case *core.LiteralType_Schema:
return MakeLiteralForType(typ, nil)
case *core.LiteralType_UnionType:
if len(t.UnionType.Variants) == 0 {
return nil, errors.Errorf("Union type must have at least one variant")
}
// For union types, we just return the default for the first variant
val, err := MakeDefaultLiteralForType(t.UnionType.Variants[0])
if err != nil {
return nil, errors.Errorf("Failed to create default literal for first union type variant [%v]", t.UnionType.Variants[0])
}
res := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Type: t.UnionType.Variants[0],
Value: val,
},
},
},
},
}
return res, nil
}

return nil, fmt.Errorf("failed to convert to a known Literal. Input Type [%v] not supported", typ.String())
Expand Down Expand Up @@ -588,6 +610,32 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro
}
return MakePrimitiveLiteral(newV)

case *core.LiteralType_UnionType:
// Try different types in the variants, return the first one matched
found := false
for _, subType := range newT.UnionType.Variants {
lv, err := MakeLiteralForType(subType, v)
if err == nil {
l = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: lv,
Type: subType,
},
},
},
},
}
found = true
break
}
}
if !found {
return nil, fmt.Errorf("incorrect union value [%s], supported values %+v", v, newT.UnionType.Variants)
}

default:
return nil, fmt.Errorf("unsupported type %s", t.String())
}
Expand Down
47 changes: 47 additions & 0 deletions clients/go/coreutils/literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,23 @@ func TestMakeDefaultLiteralForType(t *testing.T) {
Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "x"}}}}}}
assert.Equal(t, expected, l)
})

t.Run("union", func(t *testing.T) {
l, err := MakeDefaultLiteralForType(
&core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
},
)
assert.NoError(t, err)
assert.Equal(t, "*core.Union", reflect.TypeOf(l.GetScalar().GetUnion()).String())
})
}

func TestMustMakeDefaultLiteralForType(t *testing.T) {
Expand Down Expand Up @@ -715,4 +732,34 @@ func TestMakeLiteralForType(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, expectedVal, actualVal)
})

t.Run("Union", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
expectedLV := &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_FloatValue{FloatValue: 0.1}}}}}},
},
},
}}}
lv, err := MakeLiteralForType(literalType, float64(0.1))
assert.NoError(t, err)
assert.Equal(t, expectedLV, lv)
expectedVal, err := ExtractFromLiteral(expectedLV)
assert.NoError(t, err)
actualVal, err := ExtractFromLiteral(lv)
assert.NoError(t, err)
assert.Equal(t, expectedVal, actualVal)
})
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ require (
k8s.io/klog/v2 v2.5.0 // indirect
)

// These 2 versions were wrongly published.
// These 2 versions were wrongly published.
retract (
v1.4.0
v1.4.2
v1.4.0
)