Skip to content

Commit

Permalink
Support for Parameterized type
Browse files Browse the repository at this point in the history
* Separate AnyType. This will be helpful in match method
* Added support for ParameterizedFixedChar/VarChar/FixedBinary/Decimal
* Added parser support for Parameterized/PrecisionTimestamp/PrecisionTimestampTz
  • Loading branch information
anshuldata committed Aug 30, 2024
1 parent 58e4ba0 commit ff46f15
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 46 deletions.
15 changes: 15 additions & 0 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
package extensions

import (
"errors"
"fmt"
"reflect"
"strings"

substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/parser"
)

Expand Down Expand Up @@ -57,6 +59,7 @@ type TypeVariation struct {

type Argument interface {
toTypeString() string
ArgType() (types.Type, error)
}

type EnumArg struct {
Expand All @@ -69,6 +72,10 @@ func (EnumArg) toTypeString() string {
return "req"
}

func (EnumArg) ArgType() (types.Type, error) {
return nil, errors.New("unimplemented")

Check warning on line 76 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L75-L76

Added lines #L75 - L76 were not covered by tests
}

type ValueArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -80,6 +87,10 @@ func (v ValueArg) toTypeString() string {
return v.Value.Expr.(*parser.Type).ShortType()
}

func (v ValueArg) ArgType() (types.Type, error) {
return v.Value.Expr.(*parser.Type).Type()

Check warning on line 91 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L90-L91

Added lines #L90 - L91 were not covered by tests
}

type TypeArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -88,6 +99,10 @@ type TypeArg struct {

func (TypeArg) toTypeString() string { return "type" }

func (TypeArg) ArgType() (types.Type, error) {
return nil, errors.New("unimplemented")

Check warning on line 103 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L102-L103

Added lines #L102 - L103 were not covered by tests
}

type ArgumentList []Argument

func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error {
Expand Down
58 changes: 58 additions & 0 deletions types/any_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

// AnyType to represent AnyType, this type is to indicate "any" type of argument
// This type is not used in function invocation. It is only used in function definition
type AnyType struct {
Name string
Nullability Nullability
}

func (*AnyType) isRootRef() {}
func (m *AnyType) WithNullability(nullability Nullability) Type {
m.Nullability = nullability
return m

Check warning on line 19 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L16-L19

Added lines #L16 - L19 were not covered by tests
}
func (m *AnyType) GetType() Type { return m }

Check warning on line 21 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L21

Added line #L21 was not covered by tests
func (m *AnyType) GetNullability() Nullability {
return m.Nullability
}
func (*AnyType) GetTypeVariationReference() uint32 {
panic("not allowed")

Check warning on line 26 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L25-L26

Added lines #L25 - L26 were not covered by tests
}
func (*AnyType) Equals(rhs Type) bool {

Check warning on line 28 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L28

Added line #L28 was not covered by tests
// equal to every other type
return true

Check warning on line 30 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L30

Added line #L30 was not covered by tests
}

func (*AnyType) ToProtoFuncArg() *proto.FunctionArgument {
panic("not allowed")

Check warning on line 34 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L33-L34

Added lines #L33 - L34 were not covered by tests
}

func (*AnyType) ToProto() *proto.Type {
panic("not allowed")

Check warning on line 38 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L37-L38

Added lines #L37 - L38 were not covered by tests
}

func (t *AnyType) ShortString() string { return t.Name }
func (t *AnyType) String() string {
return fmt.Sprintf("%s%s", t.Name, strNullable(t))
}

// Below methods are for parser Def interface

func (*AnyType) Optional() bool {
panic("not allowed")

Check warning on line 49 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L48-L49

Added lines #L48 - L49 were not covered by tests
}

func (m *AnyType) ShortType() string {
return "any"
}

func (m *AnyType) Type() (Type, error) {
return m, nil

Check warning on line 57 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L56-L57

Added lines #L56 - L57 were not covered by tests
}
33 changes: 33 additions & 0 deletions types/any_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestAnyType(t *testing.T) {
for _, td := range []struct {
testName string
argName string
nullability types.Nullability
expectedString string
}{
{"any", "any", types.NullabilityNullable, "any?"},
{"anyrequired", "any", types.NullabilityRequired, "any"},
{"anyOtherName", "any1", types.NullabilityNullable, "any1?"},
{"T name", "T", types.NullabilityNullable, "T?"},
} {
t.Run(td.testName, func(t *testing.T) {
arg := &types.AnyType{
Name: td.argName,
Nullability: td.nullability,
}
require.Equal(t, td.expectedString, arg.String())
require.Equal(t, td.nullability, arg.GetNullability())
require.Equal(t, td.argName, arg.ShortString())
require.Equal(t, "any", arg.ShortType())
})
}
}
55 changes: 55 additions & 0 deletions types/parameterized_decimal_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

type ParameterizedDecimal struct {
Nullability Nullability
TypeVariationRef uint32
Precision IntegerParam
Scale IntegerParam
}

func (*ParameterizedDecimal) isRootRef() {}
func (m *ParameterizedDecimal) WithNullability(n Nullability) Type {
m.Nullability = n
return m

Check warning on line 19 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L16-L19

Added lines #L16 - L19 were not covered by tests
}

func (m *ParameterizedDecimal) GetType() Type { return m }

Check warning on line 22 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L22

Added line #L22 was not covered by tests
func (m *ParameterizedDecimal) GetNullability() Nullability { return m.Nullability }
func (m *ParameterizedDecimal) GetTypeVariationReference() uint32 {
return m.TypeVariationRef

Check warning on line 25 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L24-L25

Added lines #L24 - L25 were not covered by tests
}
func (m *ParameterizedDecimal) Equals(rhs Type) bool {
if o, ok := rhs.(*ParameterizedDecimal); ok {
return *o == *m
}
return false

Check warning on line 31 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L31

Added line #L31 was not covered by tests
}

func (*ParameterizedDecimal) ToProtoFuncArg() *proto.FunctionArgument {

Check warning on line 34 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L34

Added line #L34 was not covered by tests
// parameterized type are never on wire so to proto is not supported
panic("not supported")

Check warning on line 36 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L36

Added line #L36 was not covered by tests
}

func (m *ParameterizedDecimal) ShortString() string {
t := &DecimalType{}
return t.ShortString()
}

func (m *ParameterizedDecimal) String() string {
return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString())
}

func (m *ParameterizedDecimal) ParameterString() string {
return fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String())
}

func (m *ParameterizedDecimal) BaseString() string {
t := &DecimalType{}
return t.BaseString()
}
114 changes: 114 additions & 0 deletions types/parameterized_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

type IntegerParam struct {
Name string
}

func (m IntegerParam) Equals(o IntegerParam) bool {
return m == o

Check warning on line 14 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L13-L14

Added lines #L13 - L14 were not covered by tests
}

func (p IntegerParam) ToProto() *proto.ParameterizedType_IntegerParameter {
panic("not implemented")

Check warning on line 18 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L17-L18

Added lines #L17 - L18 were not covered by tests
}

func (m *IntegerParam) String() string {
return m.Name
}

type ParameterizedSingleIntegerType interface {
Type
WithIntegerOption(param IntegerParam) ParameterizedSingleIntegerType
}

type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct {
Nullability Nullability
TypeVariationRef uint32
IntegerOption IntegerParam
}

func (m *ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType {
m.IntegerOption = integerOption
return m
}

func (*ParameterizedTypeSingleIntegerParam[T]) isRootRef() {}

Check warning on line 41 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L41

Added line #L41 was not covered by tests
func (m *ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type {
m.Nullability = n
return m
}

func (m *ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m }

Check warning on line 47 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L47

Added line #L47 was not covered by tests
func (m *ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability }
func (m *ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 {
return m.TypeVariationRef

Check warning on line 50 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L49-L50

Added lines #L49 - L50 were not covered by tests
}
func (m *ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool {
if o, ok := rhs.(*ParameterizedTypeSingleIntegerParam[T]); ok {
return *o == *m
}
return false

Check warning on line 56 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L56

Added line #L56 was not covered by tests
}

func (*ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument {

Check warning on line 59 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L59

Added line #L59 was not covered by tests
// parameterized type are never on wire so to proto is not supported
panic("not supported")

Check warning on line 61 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L61

Added line #L61 was not covered by tests
}

func (m *ParameterizedTypeSingleIntegerParam[T]) ShortString() string {
switch any(m).(type) {
case *ParameterizedVarCharType:
t := &VarCharType{}
return t.ShortString()
case *ParameterizedFixedCharType:
t := &FixedCharType{}
return t.ShortString()
case *ParameterizedFixedBinaryType:
t := &FixedBinaryType{}
return t.ShortString()
case *ParameterizedPrecisionTimestampType:
t := &PrecisionTimestampType{}
return t.ShortString()
case *ParameterizedPrecisionTimestampTzType:
t := &PrecisionTimestampTzType{}
return t.ShortString()
default:
panic("unknown type")

Check warning on line 82 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L81-L82

Added lines #L81 - L82 were not covered by tests
}
}

func (m *ParameterizedTypeSingleIntegerParam[T]) String() string {
return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString())
}

func (m *ParameterizedTypeSingleIntegerParam[T]) ParameterString() string {
return fmt.Sprintf("<%s>", m.IntegerOption.String())
}

func (m *ParameterizedTypeSingleIntegerParam[T]) BaseString() string {
switch any(m).(type) {
case *ParameterizedVarCharType:
t := &VarCharType{}
return t.BaseString()
case *ParameterizedFixedCharType:
t := &FixedCharType{}
return t.BaseString()
case *ParameterizedFixedBinaryType:
t := &FixedBinaryType{}
return t.BaseString()
case *ParameterizedPrecisionTimestampType:
t := &PrecisionTimestampType{}
return t.BaseString()
case *ParameterizedPrecisionTimestampTzType:
t := &PrecisionTimestampTzType{}
return t.BaseString()
default:
panic("unknown type")

Check warning on line 112 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L111-L112

Added lines #L111 - L112 were not covered by tests
}
}
66 changes: 66 additions & 0 deletions types/parameterized_types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestParameterizedVarCharType(t *testing.T) {
for _, td := range []struct {
name string
typ types.ParameterizedSingleIntegerType
nullability types.Nullability
integerOption types.IntegerParam
expectedString string
expectedBaseString string
expectedShortString string
}{
{"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "varchar?<L1>", "varchar", "vchar"},
{"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "varchar<L1>", "varchar", "vchar"},
{"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "char?<L1>", "char", "fchar"},
{"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "char<L1>", "char", "fchar"},
{"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "fixedbinary?<L1>", "fixedbinary", "fbin"},
{"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "fixedbinary<L1>", "fixedbinary", "fbin"},
{"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp?<L1>", "precision_timestamp", "prets"},
{"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp<L1>", "precision_timestamp", "prets"},
{"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz?<L1>", "precision_timestamp_tz", "pretstz"},
{"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz<L1>", "precision_timestamp_tz", "pretstz"},
} {
t.Run(td.name, func(t *testing.T) {
pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability)
require.Equal(t, td.expectedString, pt.String())
parameterizeType, ok := pt.(types.ParameterizedType)
require.True(t, ok)
require.Equal(t, td.expectedBaseString, parameterizeType.BaseString())
require.Equal(t, td.expectedShortString, pt.ShortString())
require.True(t, pt.Equals(pt))
})
}
}

func TestParameterizedDecimalType(t *testing.T) {
for _, td := range []struct {
name string
precision string
scale string
nullability types.Nullability
expectedString string
expectedBaseString string
expectedShortString string
}{
{"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?<P,S>", "decimal", "dec"},
{"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal<P,S>", "decimal", "dec"},
} {
t.Run(td.name, func(t *testing.T) {
precision := types.IntegerParam{Name: td.precision}
scale := types.IntegerParam{Name: td.scale}
pt := &types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability}
require.Equal(t, td.expectedString, pt.String())
require.Equal(t, td.expectedBaseString, pt.BaseString())
require.Equal(t, td.expectedShortString, pt.ShortString())
require.True(t, pt.Equals(pt))
})
}
}
Loading

0 comments on commit ff46f15

Please sign in to comment.