Skip to content

Commit

Permalink
SNOW-1524314 Hide structured types behind context (#1175)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus authored Jul 15, 2024
1 parent f2164e5 commit 9730225
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 75 deletions.
33 changes: 27 additions & 6 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const numberDefaultPrecision = 38

type timezoneType int

var errNativeArrowWithoutProperContext = errors.New("structured types must be enabled to use with native arrow")

const (
// TimestampNTZType denotes a NTZ timezoneType for array binds
TimestampNTZType timezoneType = iota
Expand Down Expand Up @@ -134,6 +136,7 @@ func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType {

// snowflakeTypeToGo translates Snowflake data type to Go data type.
func snowflakeTypeToGo(ctx context.Context, dbtype snowflakeType, scale int64, fields []fieldMetadata) reflect.Type {
structuredTypesEnabled := structuredTypesEnabled(ctx)
switch dbtype {
case fixedType:
if scale == 0 {
Expand All @@ -151,12 +154,12 @@ func snowflakeTypeToGo(ctx context.Context, dbtype snowflakeType, scale int64, f
case booleanType:
return reflect.TypeOf(true)
case objectType:
if len(fields) > 0 {
if len(fields) > 0 && structuredTypesEnabled {
return reflect.TypeOf(ObjectType{})
}
return reflect.TypeOf("")
case arrayType:
if len(fields) == 0 {
if len(fields) == 0 || !structuredTypesEnabled {
return reflect.TypeOf("")
}
if len(fields) != 1 {
Expand Down Expand Up @@ -188,6 +191,9 @@ func snowflakeTypeToGo(ctx context.Context, dbtype snowflakeType, scale int64, f
}
return nil
case mapType:
if !structuredTypesEnabled {
return reflect.TypeOf("")
}
switch getSnowflakeType(fields[0].Type) {
case textType:
return snowflakeTypeToGoForMaps[string](ctx, fields[1])
Expand Down Expand Up @@ -898,10 +904,11 @@ func stringToValue(ctx context.Context, dest *driver.Value, srcColumnMeta execRe
*dest = nil
return nil
}
structuredTypesEnabled := structuredTypesEnabled(ctx)
logger.Debugf("snowflake data type: %v, raw value: %v", srcColumnMeta.Type, *srcValue)
switch srcColumnMeta.Type {
case "object":
if len(srcColumnMeta.Fields) == 0 {
if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled {
// semistructured type without schema
*dest = *srcValue
return nil
Expand Down Expand Up @@ -991,7 +998,7 @@ func stringToValue(ctx context.Context, dest *driver.Value, srcColumnMeta execRe
*dest = b
return nil
case "array":
if len(srcColumnMeta.Fields) == 0 {
if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled {
*dest = *srcValue
return nil
}
Expand Down Expand Up @@ -1019,6 +1026,10 @@ func stringToValue(ctx context.Context, dest *driver.Value, srcColumnMeta execRe
}

func jsonToMap(ctx context.Context, keyMetadata, valueMetadata fieldMetadata, srcValue string, params map[string]*string) (snowflakeValue, error) {
structuredTypesEnabled := structuredTypesEnabled(ctx)
if !structuredTypesEnabled {
return srcValue, nil
}
switch keyMetadata.Type {
case "text":
var m map[string]any
Expand Down Expand Up @@ -1459,6 +1470,7 @@ func arrowToValues(
}

func arrowToValue(ctx context.Context, rowIdx int, srcColumnMeta fieldMetadata, srcValue arrow.Array, loc *time.Location, higherPrecision bool, params map[string]*string, snowflakeType snowflakeType) (snowflakeValue, error) {
structuredTypesEnabled := structuredTypesEnabled(ctx)
switch snowflakeType {
case fixedType:
// Snowflake data types that are fixed-point numbers will fall into this category
Expand Down Expand Up @@ -1489,7 +1501,7 @@ func arrowToValue(ctx context.Context, rowIdx int, srcColumnMeta fieldMetadata,
}
return nil, nil
case arrayType:
if len(srcColumnMeta.Fields) == 0 {
if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled {
// semistructured type without schema
strings := srcValue.(*array.String)
if !srcValue.IsNull(rowIdx) {
Expand All @@ -1511,9 +1523,12 @@ func arrowToValue(ctx context.Context, rowIdx int, srcColumnMeta fieldMetadata,
}
return nil, nil
}
if !structuredTypesEnabled {
return nil, errNativeArrowWithoutProperContext
}
return buildListFromNativeArrow(ctx, rowIdx, srcColumnMeta.Fields[0], srcValue, loc, higherPrecision, params)
case objectType:
if len(srcColumnMeta.Fields) == 0 {
if len(srcColumnMeta.Fields) == 0 || !structuredTypesEnabled {
// semistructured type without schema
strings := srcValue.(*array.String)
if !srcValue.IsNull(rowIdx) {
Expand All @@ -1536,6 +1551,9 @@ func arrowToValue(ctx context.Context, rowIdx int, srcColumnMeta fieldMetadata,
return nil, nil
}
// structured objects as native arrow
if !structuredTypesEnabled {
return nil, errNativeArrowWithoutProperContext
}
if srcValue.IsNull(rowIdx) {
return nil, nil
}
Expand All @@ -1553,6 +1571,9 @@ func arrowToValue(ctx context.Context, rowIdx int, srcColumnMeta fieldMetadata,
}
} else {
// structured map as native arrow
if !structuredTypesEnabled {
return nil, errNativeArrowWithoutProperContext
}
return buildMapFromNativeArrow(ctx, rowIdx, srcColumnMeta.Fields[0], srcColumnMeta.Fields[1], srcValue, loc, higherPrecision, params)
}
case binaryType:
Expand Down
37 changes: 22 additions & 15 deletions converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,26 +156,33 @@ func TestSnowflakeTypeToGo(t *testing.T) {
{in: timestampNtzType, scale: 0, out: reflect.TypeOf(time.Now()), ctx: context.Background()},
{in: timestampTzType, scale: 0, out: reflect.TypeOf(time.Now()), ctx: context.Background()},
{in: objectType, scale: 0, out: reflect.TypeOf(""), ctx: context.Background()},
{in: objectType, scale: 0, fields: []fieldMetadata{{}}, out: reflect.TypeOf(ObjectType{}), ctx: context.Background()},
{in: objectType, scale: 0, fields: []fieldMetadata{}, out: reflect.TypeOf(""), ctx: context.Background()},
{in: objectType, scale: 0, fields: []fieldMetadata{{}}, out: reflect.TypeOf(""), ctx: context.Background()},
{in: objectType, scale: 0, fields: []fieldMetadata{{}}, out: reflect.TypeOf(ObjectType{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: variantType, scale: 0, out: reflect.TypeOf(""), ctx: context.Background()},
{in: arrayType, scale: 0, out: reflect.TypeOf(""), ctx: context.Background()},
{in: binaryType, scale: 0, out: reflect.TypeOf([]byte{}), ctx: context.Background()},
{in: booleanType, scale: 0, out: reflect.TypeOf(true), ctx: context.Background()},
{in: sliceType, scale: 0, out: reflect.TypeOf(""), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 0}}, out: reflect.TypeOf([]int64{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 1}}, out: reflect.TypeOf([]float64{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 0}}, out: reflect.TypeOf([]*big.Int{}), ctx: WithHigherPrecision(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 1}}, out: reflect.TypeOf([]*big.Float{}), ctx: WithHigherPrecision(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "real", Scale: 1}}, out: reflect.TypeOf([]float64{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "text"}}, out: reflect.TypeOf([]string{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "date"}}, out: reflect.TypeOf([]time.Time{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "time"}}, out: reflect.TypeOf([]time.Time{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "timestamp_ntz"}}, out: reflect.TypeOf([]time.Time{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "timestamp_ltz"}}, out: reflect.TypeOf([]time.Time{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "timestamp_tz"}}, out: reflect.TypeOf([]time.Time{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "boolean"}}, out: reflect.TypeOf([]bool{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "binary"}}, out: reflect.TypeOf([][]byte{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "object"}}, out: reflect.TypeOf([]ObjectType{}), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 0}}, out: reflect.TypeOf(""), ctx: context.Background()},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 0}}, out: reflect.TypeOf([]int64{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 1}}, out: reflect.TypeOf([]float64{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 0}}, out: reflect.TypeOf([]*big.Int{}), ctx: WithStructuredTypesEnabled(WithHigherPrecision(context.Background()))},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "fixed", Scale: 1}}, out: reflect.TypeOf([]*big.Float{}), ctx: WithStructuredTypesEnabled(WithHigherPrecision(context.Background()))},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "real", Scale: 1}}, out: reflect.TypeOf([]float64{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "text"}}, out: reflect.TypeOf([]string{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "date"}}, out: reflect.TypeOf([]time.Time{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "time"}}, out: reflect.TypeOf([]time.Time{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "timestamp_ntz"}}, out: reflect.TypeOf([]time.Time{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "timestamp_ltz"}}, out: reflect.TypeOf([]time.Time{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "timestamp_tz"}}, out: reflect.TypeOf([]time.Time{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "boolean"}}, out: reflect.TypeOf([]bool{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "binary"}}, out: reflect.TypeOf([][]byte{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: arrayType, scale: 0, fields: []fieldMetadata{{Type: "object"}}, out: reflect.TypeOf([]ObjectType{}), ctx: WithStructuredTypesEnabled(context.Background())},
{in: mapType, fields: nil, out: reflect.TypeOf(""), ctx: context.Background()},
{in: mapType, fields: []fieldMetadata{}, out: reflect.TypeOf(""), ctx: context.Background()},
{in: mapType, fields: []fieldMetadata{{}, {}}, out: reflect.TypeOf(""), ctx: context.Background()},
{in: mapType, fields: []fieldMetadata{{Type: "text"}, {Type: "text"}}, out: reflect.TypeOf(map[string]string{}), ctx: WithStructuredTypesEnabled(context.Background())},
}
for _, test := range testcases {
t.Run(fmt.Sprintf("%v_%v", test.in, test.out), func(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ Additionally, `sf` tag can be added:
- first value is always a name of a field in an SQL object
- additionally `ignore` parameter can be passed to omit this field
2. Use it in regular scan:
2. Use WithStructuredTypesEnabled context while querying data.
3. Use it in regular scan:
var res simpleObject
err := rows.Scan(&res)
Expand Down
4 changes: 4 additions & 0 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func TestRowsWithoutChunkDownloader(t *testing.T) {
},
}
rows.sc = sc
rows.ctx = context.Background()
rows.ChunkDownloader = &snowflakeChunkDownloader{
sc: sc,
ctx: context.Background(),
Expand Down Expand Up @@ -185,6 +186,7 @@ func TestRowsWithChunkDownloader(t *testing.T) {
},
}
rows.sc = sc
rows.ctx = context.Background()
rows.ChunkDownloader = &snowflakeChunkDownloader{
sc: sc,
ctx: context.Background(),
Expand Down Expand Up @@ -269,6 +271,7 @@ func TestRowsWithChunkDownloaderError(t *testing.T) {
},
}
rows.sc = sc
rows.ctx = context.Background()
rows.ChunkDownloader = &snowflakeChunkDownloader{
sc: sc,
ctx: context.Background(),
Expand Down Expand Up @@ -352,6 +355,7 @@ func TestRowsWithChunkDownloaderErrorFail(t *testing.T) {
},
}
rows.sc = sc
rows.ctx = context.Background()
rows.ChunkDownloader = &snowflakeChunkDownloader{
sc: sc,
ctx: context.Background(),
Expand Down
9 changes: 9 additions & 0 deletions structured_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,15 @@ func (st *structuredType) fieldMetadataByFieldName(fieldName string) (fieldMetad
return fieldMetadata{}, errors.New("no metadata for field " + fieldName)
}

func structuredTypesEnabled(ctx context.Context) bool {
v := ctx.Value(enableStructuredTypes)
if v == nil {
return false
}
d, ok := v.(bool)
return ok && d
}

func mapValuesNullableEnabled(ctx context.Context) bool {
v := ctx.Value(mapValuesNullable)
if v == nil {
Expand Down
Loading

0 comments on commit 9730225

Please sign in to comment.