From 4d8b845da6930e83b94697ed1ae471479db5c2a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 8 Aug 2024 11:31:19 +0200 Subject: [PATCH 01/19] wip save (testing on function) save --- pkg/acceptance/check_destroy.go | 6 +- pkg/datasources/procedures.go | 4 +- pkg/resources/procedure.go | 1566 ++++++++--------- pkg/resources/procedure_acceptance_test.go | 2 +- pkg/sdk/external_functions_impl_gen.go | 58 +- pkg/sdk/functions_gen_test.go | 9 +- pkg/sdk/functions_impl_gen.go | 2 + pkg/sdk/identifier_helpers.go | 22 +- pkg/sdk/identifier_helpers_test.go | 40 + pkg/sdk/poc/main.go | 1 + pkg/sdk/procedures_def.go | 17 +- pkg/sdk/procedures_dto_builders_gen.go | 292 ++- pkg/sdk/procedures_dto_gen.go | 29 +- pkg/sdk/procedures_gen.go | 51 +- pkg/sdk/procedures_gen_test.go | 53 +- pkg/sdk/procedures_impl_gen.go | 42 +- .../testint/procedures_integration_test.go | 305 ++-- 17 files changed, 1260 insertions(+), 1239 deletions(-) diff --git a/pkg/acceptance/check_destroy.go b/pkg/acceptance/check_destroy.go index 617c90e5d2..6cd049eed0 100644 --- a/pkg/acceptance/check_destroy.go +++ b/pkg/acceptance/check_destroy.go @@ -118,9 +118,9 @@ var showByIdFunctions = map[resources.Resource]showByIdFunc{ resources.FileFormat: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { return runShowById(ctx, id, client.FileFormats.ShowByID) }, - resources.Function: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { - return runShowById(ctx, id, client.Functions.ShowByID) - }, + //resources.Function: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { + // return runShowById(ctx, id, client.Functions.ShowByID) + //}, resources.ManagedAccount: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { return runShowById(ctx, id, client.ManagedAccounts.ShowByID) }, diff --git a/pkg/datasources/procedures.go b/pkg/datasources/procedures.go index 4bbfe1df40..c4aaed7c69 100644 --- a/pkg/datasources/procedures.go +++ b/pkg/datasources/procedures.go @@ -77,10 +77,10 @@ func ReadContextProcedures(ctx context.Context, d *schema.ResourceData, meta int req := sdk.NewShowProcedureRequest() if databaseName != "" { - req.WithIn(&sdk.In{Database: sdk.NewAccountObjectIdentifier(databaseName)}) + req.WithIn(sdk.In{Database: sdk.NewAccountObjectIdentifier(databaseName)}) } if schemaName != "" { - req.WithIn(&sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)}) + req.WithIn(sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)}) } procedures, err := client.Procedures.Show(ctx, req) if err != nil { diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index 118bcde253..c9794742b6 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -1,795 +1,785 @@ package resources -import ( - "context" - "fmt" - "log" - "regexp" - "slices" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - "github.com/hashicorp/go-cty/cty" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" -) - -var procedureSchema = map[string]*schema.Schema{ - "name": { - Type: schema.TypeString, - Required: true, - Description: "Specifies the identifier for the procedure; does not have to be unique for the schema in which the procedure is created. Don't use the | character.", - }, - "database": { - Type: schema.TypeString, - Required: true, - Description: "The database in which to create the procedure. Don't use the | character.", - ForceNew: true, - }, - "schema": { - Type: schema.TypeString, - Required: true, - Description: "The schema in which to create the procedure. Don't use the | character.", - ForceNew: true, - }, - "secure": { - Type: schema.TypeBool, - Optional: true, - Description: "Specifies that the procedure is secure. For more information about secure procedures, see Protecting Sensitive Information with Secure UDFs and Stored Procedures.", - Default: false, - }, - "arguments": { - Type: schema.TypeList, - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "name": { - Type: schema.TypeString, - Required: true, - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(old, new) - }, - Description: "The argument name", - }, - "type": { - Type: schema.TypeString, - Required: true, - ValidateFunc: dataTypeValidateFunc, - DiffSuppressFunc: dataTypeDiffSuppressFunc, - Description: "The argument type", - }, - }, - }, - Optional: true, - Description: "List of the arguments for the procedure", - ForceNew: true, - }, - "return_type": { - Type: schema.TypeString, - Description: "The return type of the procedure", - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - if strings.EqualFold(old, new) { - return true - } - - varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} - if slices.Contains(varcharType, strings.ToUpper(old)) && slices.Contains(varcharType, strings.ToUpper(new)) { - return true - } - - // all these types are equivalent https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint - integerTypes := []string{"INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT", "NUMBER(38,0)"} - if slices.Contains(integerTypes, strings.ToUpper(old)) && slices.Contains(integerTypes, strings.ToUpper(new)) { - return true - } - return false - }, - Required: true, - ForceNew: true, - }, - "statement": { - Type: schema.TypeString, - Required: true, - Description: "Specifies the code used to create the procedure.", - ForceNew: true, - DiffSuppressFunc: DiffSuppressStatement, - }, - "language": { - Type: schema.TypeString, - Optional: true, - Default: "SQL", - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(old, new) - }, - ValidateFunc: validation.StringInSlice([]string{"javascript", "java", "scala", "SQL", "python"}, true), - Description: "Specifies the language of the stored procedure code.", - }, - "execute_as": { - Type: schema.TypeString, - Optional: true, - Default: "OWNER", - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(old, new) - }, - ValidateFunc: validation.StringInSlice([]string{"CALLER", "OWNER"}, true), - Description: "Sets execution context. Allowed values are CALLER and OWNER (consult a proper section in the [docs](https://docs.snowflake.com/en/sql-reference/sql/create-procedure#id1)). For more information see [caller's rights and owner's rights](https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-rights).", - }, - "null_input_behavior": { - Type: schema.TypeString, - Optional: true, - Default: "CALLED ON NULL INPUT", - ForceNew: true, - // We do not use STRICT, because Snowflake then in the Read phase returns RETURNS NULL ON NULL INPUT - ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT"}, false), - Description: "Specifies the behavior of the procedure when called with null inputs.", - }, - "return_behavior": { - Type: schema.TypeString, - Optional: true, - Default: "VOLATILE", - ForceNew: true, - ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), - Description: "Specifies the behavior of the function when returning results", - Deprecated: "These keywords are deprecated for stored procedures. These keywords are not intended to apply to stored procedures. In a future release, these keywords will be removed from the documentation.", - }, - "comment": { - Type: schema.TypeString, - Optional: true, - Default: "user-defined procedure", - Description: "Specifies a comment for the procedure.", - }, - "runtime_version": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Description: "Required for Python procedures. Specifies Python runtime version.", - }, - "packages": { - Type: schema.TypeList, - Elem: &schema.Schema{ - Type: schema.TypeString, - }, - Optional: true, - ForceNew: true, - Description: "List of package imports to use for Java / Python procedures. For Java, package imports should be of the form: package_name:version_number, where package_name is snowflake_domain:package. For Python use it should be: ('numpy','pandas','xgboost==1.5.0').", - }, - "imports": { - Type: schema.TypeList, - Elem: &schema.Schema{ - Type: schema.TypeString, - }, - Optional: true, - ForceNew: true, - Description: "Imports for Java / Python procedures. For Java this a list of jar files, for Python this is a list of Python files.", - }, - "handler": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Description: "The handler method for Java / Python procedures.", - }, -} - -// Procedure returns a pointer to the resource representing a stored procedure. +import "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + +// var procedureSchema = map[string]*schema.Schema{ +// "name": { +// Type: schema.TypeString, +// Required: true, +// Description: "Specifies the identifier for the procedure; does not have to be unique for the schema in which the procedure is created. Don't use the | character.", +// }, +// "database": { +// Type: schema.TypeString, +// Required: true, +// Description: "The database in which to create the procedure. Don't use the | character.", +// ForceNew: true, +// }, +// "schema": { +// Type: schema.TypeString, +// Required: true, +// Description: "The schema in which to create the procedure. Don't use the | character.", +// ForceNew: true, +// }, +// "secure": { +// Type: schema.TypeBool, +// Optional: true, +// Description: "Specifies that the procedure is secure. For more information about secure procedures, see Protecting Sensitive Information with Secure UDFs and Stored Procedures.", +// Default: false, +// }, +// "arguments": { +// Type: schema.TypeList, +// Elem: &schema.Resource{ +// Schema: map[string]*schema.Schema{ +// "name": { +// Type: schema.TypeString, +// Required: true, +// // Suppress the diff shown if the values are equal when both compared in lower case. +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(old, new) +// }, +// Description: "The argument name", +// }, +// "type": { +// Type: schema.TypeString, +// Required: true, +// ValidateFunc: dataTypeValidateFunc, +// DiffSuppressFunc: dataTypeDiffSuppressFunc, +// Description: "The argument type", +// }, +// }, +// }, +// Optional: true, +// Description: "List of the arguments for the procedure", +// ForceNew: true, +// }, +// "return_type": { +// Type: schema.TypeString, +// Description: "The return type of the procedure", +// // Suppress the diff shown if the values are equal when both compared in lower case. +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// if strings.EqualFold(old, new) { +// return true +// } +// +// varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} +// if slices.Contains(varcharType, strings.ToUpper(old)) && slices.Contains(varcharType, strings.ToUpper(new)) { +// return true +// } +// +// // all these types are equivalent https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint +// integerTypes := []string{"INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT", "NUMBER(38,0)"} +// if slices.Contains(integerTypes, strings.ToUpper(old)) && slices.Contains(integerTypes, strings.ToUpper(new)) { +// return true +// } +// return false +// }, +// Required: true, +// ForceNew: true, +// }, +// "statement": { +// Type: schema.TypeString, +// Required: true, +// Description: "Specifies the code used to create the procedure.", +// ForceNew: true, +// DiffSuppressFunc: DiffSuppressStatement, +// }, +// "language": { +// Type: schema.TypeString, +// Optional: true, +// Default: "SQL", +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(old, new) +// }, +// ValidateFunc: validation.StringInSlice([]string{"javascript", "java", "scala", "SQL", "python"}, true), +// Description: "Specifies the language of the stored procedure code.", +// }, +// "execute_as": { +// Type: schema.TypeString, +// Optional: true, +// Default: "OWNER", +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(old, new) +// }, +// ValidateFunc: validation.StringInSlice([]string{"CALLER", "OWNER"}, true), +// Description: "Sets execution context. Allowed values are CALLER and OWNER (consult a proper section in the [docs](https://docs.snowflake.com/en/sql-reference/sql/create-procedure#id1)). For more information see [caller's rights and owner's rights](https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-rights).", +// }, +// "null_input_behavior": { +// Type: schema.TypeString, +// Optional: true, +// Default: "CALLED ON NULL INPUT", +// ForceNew: true, +// // We do not use STRICT, because Snowflake then in the Read phase returns RETURNS NULL ON NULL INPUT +// ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT"}, false), +// Description: "Specifies the behavior of the procedure when called with null inputs.", +// }, +// "return_behavior": { +// Type: schema.TypeString, +// Optional: true, +// Default: "VOLATILE", +// ForceNew: true, +// ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), +// Description: "Specifies the behavior of the function when returning results", +// Deprecated: "These keywords are deprecated for stored procedures. These keywords are not intended to apply to stored procedures. In a future release, these keywords will be removed from the documentation.", +// }, +// "comment": { +// Type: schema.TypeString, +// Optional: true, +// Default: "user-defined procedure", +// Description: "Specifies a comment for the procedure.", +// }, +// "runtime_version": { +// Type: schema.TypeString, +// Optional: true, +// ForceNew: true, +// Description: "Required for Python procedures. Specifies Python runtime version.", +// }, +// "packages": { +// Type: schema.TypeList, +// Elem: &schema.Schema{ +// Type: schema.TypeString, +// }, +// Optional: true, +// ForceNew: true, +// Description: "List of package imports to use for Java / Python procedures. For Java, package imports should be of the form: package_name:version_number, where package_name is snowflake_domain:package. For Python use it should be: ('numpy','pandas','xgboost==1.5.0').", +// }, +// "imports": { +// Type: schema.TypeList, +// Elem: &schema.Schema{ +// Type: schema.TypeString, +// }, +// Optional: true, +// ForceNew: true, +// Description: "Imports for Java / Python procedures. For Java this a list of jar files, for Python this is a list of Python files.", +// }, +// "handler": { +// Type: schema.TypeString, +// Optional: true, +// ForceNew: true, +// Description: "The handler method for Java / Python procedures.", +// }, +// } +// +// // Procedure returns a pointer to the resource representing a stored procedure. func Procedure() *schema.Resource { return &schema.Resource{ - SchemaVersion: 1, - - CreateContext: CreateContextProcedure, - ReadContext: ReadContextProcedure, - UpdateContext: UpdateContextProcedure, - DeleteContext: DeleteContextProcedure, - - Schema: procedureSchema, - Importer: &schema.ResourceImporter{ - StateContext: schema.ImportStatePassthroughContext, - }, - - StateUpgraders: []schema.StateUpgrader{ - { - Version: 0, - // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject - Type: cty.EmptyObject, - Upgrade: v085ProcedureStateUpgrader, - }, - }, - } -} - -func CreateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - lang := strings.ToUpper(d.Get("language").(string)) - switch lang { - case "JAVA": - return createJavaProcedure(ctx, d, meta) - case "JAVASCRIPT": - return createJavaScriptProcedure(ctx, d, meta) - case "PYTHON": - return createPythonProcedure(ctx, d, meta) - case "SCALA": - return createScalaProcedure(ctx, d, meta) - case "SQL": - return createSQLProcedure(ctx, d, meta) - default: - return diag.Diagnostics{ - diag.Diagnostic{ - Severity: diag.Error, - Summary: "Invalid language", - Detail: fmt.Sprintf("Language %s is not supported", lang), - }, - } - } -} - -func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - name := d.Get("name").(string) - schema := d.Get("schema").(string) - database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) - - returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) - if diags != nil { - return diags - } - procedureDefinition := d.Get("statement").(string) - runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} - for _, item := range d.Get("packages").([]interface{}) { - packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) - } - handler := d.Get("handler").(string) - req := sdk.NewCreateForJavaProcedureRequest(id, *returns, runtimeVersion, packages, handler) - req.WithProcedureDefinition(sdk.String(procedureDefinition)) - args, diags := getProcedureArguments(d) - if diags != nil { - return diags - } - if len(args) > 0 { - req.WithArguments(args) - } - - // read optional params - if v, ok := d.GetOk("execute_as"); ok { - if strings.ToUpper(v.(string)) == "OWNER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsOwner)) - } else if strings.ToUpper(v.(string)) == "CALLER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsCaller)) - } - } - if v, ok := d.GetOk("comment"); ok { - req.WithComment(sdk.String(v.(string))) - } - if v, ok := d.GetOk("secure"); ok { - req.WithSecure(sdk.Bool(v.(bool))) - } - if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} - for _, item := range d.Get("imports").([]interface{}) { - imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) - } - req.WithImports(imports) - } - - if err := client.Procedures.CreateForJava(ctx, req); err != nil { - return diag.FromErr(err) - } - argTypes := make([]sdk.DataType, 0, len(args)) - for _, item := range args { - argTypes = append(argTypes, item.ArgDataType) - } - sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) - d.SetId(sid.FullyQualifiedName()) - return ReadContextProcedure(ctx, d, meta) -} - -func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - name := d.Get("name").(string) - schema := d.Get("schema").(string) - database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) - - returnType := d.Get("return_type").(string) - returnDataType, diags := convertProcedureDataType(returnType) - if diags != nil { - return diags - } - procedureDefinition := d.Get("statement").(string) - req := sdk.NewCreateForJavaScriptProcedureRequest(id, returnDataType, procedureDefinition) - args, diags := getProcedureArguments(d) - if diags != nil { - return diags - } - if len(args) > 0 { - req.WithArguments(args) - } - - // read optional params - if v, ok := d.GetOk("execute_as"); ok { - if strings.ToUpper(v.(string)) == "OWNER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsOwner)) - } else if strings.ToUpper(v.(string)) == "CALLER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsCaller)) - } - } - if v, ok := d.GetOk("null_input_behavior"); ok { - req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) - } - if v, ok := d.GetOk("comment"); ok { - req.WithComment(sdk.String(v.(string))) - } - if v, ok := d.GetOk("secure"); ok { - req.WithSecure(sdk.Bool(v.(bool))) - } - - if err := client.Procedures.CreateForJavaScript(ctx, req); err != nil { - return diag.FromErr(err) - } - argTypes := make([]sdk.DataType, 0, len(args)) - for _, item := range args { - argTypes = append(argTypes, item.ArgDataType) - } - sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) - d.SetId(sid.FullyQualifiedName()) - return ReadContextProcedure(ctx, d, meta) -} - -func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - name := d.Get("name").(string) - schema := d.Get("schema").(string) - database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) - - returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) - if diags != nil { - return diags - } - procedureDefinition := d.Get("statement").(string) - runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} - for _, item := range d.Get("packages").([]interface{}) { - packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) - } - handler := d.Get("handler").(string) - req := sdk.NewCreateForScalaProcedureRequest(id, *returns, runtimeVersion, packages, handler) - req.WithProcedureDefinition(sdk.String(procedureDefinition)) - args, diags := getProcedureArguments(d) - if diags != nil { - return diags - } - if len(args) > 0 { - req.WithArguments(args) - } - - // read optional params - if v, ok := d.GetOk("execute_as"); ok { - if strings.ToUpper(v.(string)) == "OWNER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsOwner)) - } else if strings.ToUpper(v.(string)) == "CALLER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsCaller)) - } - } - if v, ok := d.GetOk("comment"); ok { - req.WithComment(sdk.String(v.(string))) - } - if v, ok := d.GetOk("secure"); ok { - req.WithSecure(sdk.Bool(v.(bool))) - } - if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} - for _, item := range d.Get("imports").([]interface{}) { - imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) - } - req.WithImports(imports) - } - - if err := client.Procedures.CreateForScala(ctx, req); err != nil { - return diag.FromErr(err) - } - argTypes := make([]sdk.DataType, 0, len(args)) - for _, item := range args { - argTypes = append(argTypes, item.ArgDataType) - } - sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) - d.SetId(sid.FullyQualifiedName()) - return ReadContextProcedure(ctx, d, meta) -} - -func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - name := d.Get("name").(string) - schema := d.Get("schema").(string) - database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) - - returns, diags := parseProcedureSQLReturnsRequest(d.Get("return_type").(string)) - if diags != nil { - return diags - } - procedureDefinition := d.Get("statement").(string) - req := sdk.NewCreateForSQLProcedureRequest(id, *returns, procedureDefinition) - args, diags := getProcedureArguments(d) - if diags != nil { - return diags - } - if len(args) > 0 { - req.WithArguments(args) - } - - // read optional params - if v, ok := d.GetOk("execute_as"); ok { - if strings.ToUpper(v.(string)) == "OWNER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsOwner)) - } else if strings.ToUpper(v.(string)) == "CALLER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsCaller)) - } - } - if v, ok := d.GetOk("null_input_behavior"); ok { - req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) - } - if v, ok := d.GetOk("comment"); ok { - req.WithComment(sdk.String(v.(string))) - } - if v, ok := d.GetOk("secure"); ok { - req.WithSecure(sdk.Bool(v.(bool))) + //SchemaVersion: 1, + // + //CreateContext: CreateContextProcedure, + //ReadContext: ReadContextProcedure, + //UpdateContext: UpdateContextProcedure, + //DeleteContext: DeleteContextProcedure, + // + //Schema: procedureSchema, + //Importer: &schema.ResourceImporter{ + // StateContext: schema.ImportStatePassthroughContext, + //}, + + //StateUpgraders: []schema.StateUpgrader{ + // { + // Version: 0, + // // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + // Type: cty.EmptyObject, + // Upgrade: v085ProcedureStateUpgrader, + // }, + //}, } - - if err := client.Procedures.CreateForSQL(ctx, req); err != nil { - return diag.FromErr(err) - } - argTypes := make([]sdk.DataType, 0, len(args)) - for _, item := range args { - argTypes = append(argTypes, item.ArgDataType) - } - sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) - d.SetId(sid.FullyQualifiedName()) - return ReadContextProcedure(ctx, d, meta) -} - -func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - name := d.Get("name").(string) - schema := d.Get("schema").(string) - database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) - - returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) - if diags != nil { - return diags - } - procedureDefinition := d.Get("statement").(string) - runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} - for _, item := range d.Get("packages").([]interface{}) { - packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) - } - handler := d.Get("handler").(string) - req := sdk.NewCreateForPythonProcedureRequest(id, *returns, runtimeVersion, packages, handler) - req.WithProcedureDefinition(sdk.String(procedureDefinition)) - args, diags := getProcedureArguments(d) - if diags != nil { - return diags - } - if len(args) > 0 { - req.WithArguments(args) - } - - // read optional params - if v, ok := d.GetOk("execute_as"); ok { - if strings.ToUpper(v.(string)) == "OWNER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsOwner)) - } else if strings.ToUpper(v.(string)) == "CALLER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsCaller)) - } - } - - // [ { CALLED ON NULL INPUT | { RETURNS NULL ON NULL INPUT | STRICT } } ] does not work for java, scala or python - // posted in docs-discuss channel, either docs need to be updated to reflect reality or this feature needs to be added - // https://snowflake.slack.com/archives/C6380540P/p1707511734666249 - // if v, ok := d.GetOk("null_input_behavior"); ok { - // req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) - // } - - if v, ok := d.GetOk("comment"); ok { - req.WithComment(sdk.String(v.(string))) - } - if v, ok := d.GetOk("secure"); ok { - req.WithSecure(sdk.Bool(v.(bool))) - } - if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} - for _, item := range d.Get("imports").([]interface{}) { - imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) - } - req.WithImports(imports) - } - - if err := client.Procedures.CreateForPython(ctx, req); err != nil { - return diag.FromErr(err) - } - argTypes := make([]sdk.DataType, 0, len(args)) - for _, item := range args { - argTypes = append(argTypes, item.ArgDataType) - } - sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) - d.SetId(sid.FullyQualifiedName()) - return ReadContextProcedure(ctx, d, meta) } -func ReadContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - diags := diag.Diagnostics{} - client := meta.(*provider.Context).Client - - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - if err := d.Set("name", id.Name()); err != nil { - return diag.FromErr(err) - } - if err := d.Set("database", id.DatabaseName()); err != nil { - return diag.FromErr(err) - } - if err := d.Set("schema", id.SchemaName()); err != nil { - return diag.FromErr(err) - } - args := d.Get("arguments").([]interface{}) - argTypes := make([]string, len(args)) - for i, arg := range args { - argTypes[i] = arg.(map[string]interface{})["type"].(string) - } - procedureDetails, err := client.Procedures.Describe(ctx, sdk.NewDescribeProcedureRequest(id.WithoutArguments(), id.Arguments())) - if err != nil { - // if procedure is not found then mark resource to be removed from state file during apply or refresh - d.SetId("") - return diag.Diagnostics{ - diag.Diagnostic{ - Severity: diag.Warning, - Summary: "Describe procedure failed.", - Detail: fmt.Sprintf("Describe procedure failed: %v", err), - }, - } - } - for _, desc := range procedureDetails { - switch desc.Property { - case "signature": - // Format in Snowflake DB is: (argName argType, argName argType, ...) - args := strings.ReplaceAll(strings.ReplaceAll(desc.Value, "(", ""), ")", "") - - if args != "" { // Do nothing for functions without arguments - argPairs := strings.Split(args, ", ") - args := []interface{}{} - - for _, argPair := range argPairs { - argItem := strings.Split(argPair, " ") - - arg := map[string]interface{}{} - arg["name"] = argItem[0] - arg["type"] = argItem[1] - args = append(args, arg) - } - - if err := d.Set("arguments", args); err != nil { - return diag.FromErr(err) - } - } - case "null handling": - if err := d.Set("null_input_behavior", desc.Value); err != nil { - return diag.FromErr(err) - } - case "body": - if err := d.Set("statement", desc.Value); err != nil { - return diag.FromErr(err) - } - case "execute as": - if err := d.Set("execute_as", desc.Value); err != nil { - return diag.FromErr(err) - } - case "returns": - if err := d.Set("return_type", desc.Value); err != nil { - return diag.FromErr(err) - } - case "language": - if err := d.Set("language", desc.Value); err != nil { - return diag.FromErr(err) - } - case "runtime_version": - if err := d.Set("runtime_version", desc.Value); err != nil { - return diag.FromErr(err) - } - case "packages": - packagesString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", "") - if packagesString != "" { // Do nothing for Java / Python functions without packages - packages := strings.Split(packagesString, ",") - if err := d.Set("packages", packages); err != nil { - return diag.FromErr(err) - } - } - case "imports": - importsString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", ""), " ", "") - if importsString != "" { // Do nothing for Java functions without imports - imports := strings.Split(importsString, ",") - if err := d.Set("imports", imports); err != nil { - return diag.FromErr(err) - } - } - case "handler": - if err := d.Set("handler", desc.Value); err != nil { - return diag.FromErr(err) - } - case "volatility": - if err := d.Set("return_behavior", desc.Value); err != nil { - return diag.FromErr(err) - } - default: - log.Printf("[INFO] Unexpected procedure property %v returned from Snowflake with value %v", desc.Property, desc.Value) - } - } - - request := sdk.NewShowProcedureRequest().WithIn(&sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(id.DatabaseName(), id.SchemaName())}).WithLike(&sdk.Like{Pattern: sdk.String(id.Name())}) - - procedures, err := client.Procedures.Show(ctx, request) - if err != nil { - return diag.FromErr(err) - } - // procedure names can be overloaded with different argument types so we iterate over and find the correct one - // the ShowByID function should probably be updated to also require the list of arg types, like describe procedure - for _, procedure := range procedures { - argumentSignature := strings.Split(procedure.Arguments, " RETURN ")[0] - argumentSignature = strings.ReplaceAll(argumentSignature, " ", "") - if argumentSignature == id.ArgumentsSignature() { - if err := d.Set("secure", procedure.IsSecure); err != nil { - return diag.FromErr(err) - } - if err := d.Set("comment", procedure.Description); err != nil { - return diag.FromErr(err) - } - } - } - - return diags -} - -func UpdateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - if d.HasChange("name") { - newId := sdk.NewSchemaObjectIdentifierWithArgumentsOld(id.DatabaseName(), id.SchemaName(), d.Get("name").(string), id.Arguments()) - - err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithRenameTo(sdk.Pointer(newId.WithoutArguments()))) - if err != nil { - return diag.FromErr(err) - } - - d.SetId(newId.FullyQualifiedName()) - id = newId - } - - if d.HasChange("comment") { - comment := d.Get("comment") - if comment != "" { - if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithSetComment(sdk.String(comment.(string)))); err != nil { - return diag.FromErr(err) - } - } else { - if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithUnsetComment(sdk.Bool(true))); err != nil { - return diag.FromErr(err) - } - } - } - - if d.HasChange("execute_as") { - req := sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()) - executeAs := d.Get("execute_as").(string) - if strings.ToUpper(executeAs) == "OWNER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsOwner)) - } else if strings.ToUpper(executeAs) == "CALLER" { - req.WithExecuteAs(sdk.Pointer(sdk.ExecuteAsCaller)) - } - if err := client.Procedures.Alter(ctx, req); err != nil { - return diag.FromErr(err) - } - } - - return ReadContextProcedure(ctx, d, meta) -} - -func DeleteContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - if err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id.WithoutArguments(), id.Arguments())); err != nil { - return diag.FromErr(err) - } - d.SetId("") - return nil -} - -func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentRequest, diag.Diagnostics) { - args := make([]sdk.ProcedureArgumentRequest, 0) - if v, ok := d.GetOk("arguments"); ok { - for _, arg := range v.([]interface{}) { - argName := arg.(map[string]interface{})["name"].(string) - argType := arg.(map[string]interface{})["type"].(string) - argDataType, diags := convertProcedureDataType(argType) - if diags != nil { - return nil, diags - } - args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: argDataType}) - } - } - return args, nil -} - -func convertProcedureDataType(s string) (sdk.DataType, diag.Diagnostics) { - dataType, err := sdk.ToDataType(s) - if err != nil { - return dataType, diag.FromErr(err) - } - return dataType, nil -} - -func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) { - pattern := regexp.MustCompile(`(\w+)\s+(\w+)`) - matches := pattern.FindAllStringSubmatch(s, -1) - var columns []sdk.ProcedureColumn - for _, match := range matches { - if len(match) == 3 { - dataType, err := sdk.ToDataType(match[2]) - if err != nil { - return nil, diag.FromErr(err) - } - columns = append(columns, sdk.ProcedureColumn{ - ColumnName: match[1], - ColumnDataType: dataType, - }) - } - } - return columns, nil -} - -func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag.Diagnostics) { - returns := sdk.NewProcedureReturnsRequest() - if strings.HasPrefix(strings.ToLower(s), "table") { - columns, diags := convertProcedureColumns(s) - if diags != nil { - return nil, diags - } - var cr []sdk.ProcedureColumnRequest - for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) - } - returns.WithTable(sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) - } else { - returnDataType, diags := convertProcedureDataType(s) - if diags != nil { - return nil, diags - } - returns.WithResultDataType(sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) - } - return returns, nil -} - -func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, diag.Diagnostics) { - returns := sdk.NewProcedureSQLReturnsRequest() - if strings.HasPrefix(strings.ToLower(s), "table") { - columns, diags := convertProcedureColumns(s) - if diags != nil { - return nil, diags - } - var cr []sdk.ProcedureColumnRequest - for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) - } - returns.WithTable(sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) - } else { - returnDataType, diags := convertProcedureDataType(s) - if diags != nil { - return nil, diags - } - returns.WithResultDataType(sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) - } - return returns, nil -} +// +//func CreateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// lang := strings.ToUpper(d.Get("language").(string)) +// switch lang { +// case "JAVA": +// return createJavaProcedure(ctx, d, meta) +// case "JAVASCRIPT": +// return createJavaScriptProcedure(ctx, d, meta) +// case "PYTHON": +// return createPythonProcedure(ctx, d, meta) +// case "SCALA": +// return createScalaProcedure(ctx, d, meta) +// case "SQL": +// return createSQLProcedure(ctx, d, meta) +// default: +// return diag.Diagnostics{ +// diag.Diagnostic{ +// Severity: diag.Error, +// Summary: "Invalid language", +// Detail: fmt.Sprintf("Language %s is not supported", lang), +// }, +// } +// } +//} +// +//func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// name := d.Get("name").(string) +// schema := d.Get("schema").(string) +// database := d.Get("database").(string) +// id := sdk.NewSchemaObjectIdentifier(database, schema, name) +// +// returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) +// if diags != nil { +// return diags +// } +// procedureDefinition := d.Get("statement").(string) +// runtimeVersion := d.Get("runtime_version").(string) +// packages := []sdk.ProcedurePackageRequest{} +// for _, item := range d.Get("packages").([]interface{}) { +// packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) +// } +// handler := d.Get("handler").(string) +// req := sdk.NewCreateForJavaProcedureRequest(id, *returns, runtimeVersion, packages, handler) +// req.WithProcedureDefinition(procedureDefinition) +// args, diags := getProcedureArguments(d) +// if diags != nil { +// return diags +// } +// if len(args) > 0 { +// req.WithArguments(args) +// } +// +// // read optional params +// if v, ok := d.GetOk("execute_as"); ok { +// if strings.ToUpper(v.(string)) == "OWNER" { +// req.WithExecuteAs(sdk.ExecuteAsOwner) +// } else if strings.ToUpper(v.(string)) == "CALLER" { +// req.WithExecuteAs(sdk.ExecuteAsCaller) +// } +// } +// if v, ok := d.GetOk("comment"); ok { +// req.WithComment(v.(string)) +// } +// if v, ok := d.GetOk("secure"); ok { +// req.WithSecure(v.(bool)) +// } +// if _, ok := d.GetOk("imports"); ok { +// imports := []sdk.ProcedureImportRequest{} +// for _, item := range d.Get("imports").([]interface{}) { +// imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) +// } +// req.WithImports(imports) +// } +// +// if err := client.Procedures.CreateForJava(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// argTypes := make([]sdk.DataType, 0, len(args)) +// for _, item := range args { +// argTypes = append(argTypes, item.ArgDataType) +// } +// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) +// d.SetId(sid.FullyQualifiedName()) +// return ReadContextProcedure(ctx, d, meta) +//} +// +//func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// name := d.Get("name").(string) +// schema := d.Get("schema").(string) +// database := d.Get("database").(string) +// id := sdk.NewSchemaObjectIdentifier(database, schema, name) +// +// returnType := d.Get("return_type").(string) +// returnDataType, diags := convertProcedureDataType(returnType) +// if diags != nil { +// return diags +// } +// procedureDefinition := d.Get("statement").(string) +// req := sdk.NewCreateForJavaScriptProcedureRequest(id, returnDataType, procedureDefinition) +// args, diags := getProcedureArguments(d) +// if diags != nil { +// return diags +// } +// if len(args) > 0 { +// req.WithArguments(args) +// } +// +// // read optional params +// if v, ok := d.GetOk("execute_as"); ok { +// if strings.ToUpper(v.(string)) == "OWNER" { +// req.WithExecuteAs(sdk.ExecuteAsOwner) +// } else if strings.ToUpper(v.(string)) == "CALLER" { +// req.WithExecuteAs(sdk.ExecuteAsCaller) +// } +// } +// if v, ok := d.GetOk("null_input_behavior"); ok { +// req.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) +// } +// if v, ok := d.GetOk("comment"); ok { +// req.WithComment(v.(string)) +// } +// if v, ok := d.GetOk("secure"); ok { +// req.WithSecure(v.(bool)) +// } +// +// if err := client.Procedures.CreateForJavaScript(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// argTypes := make([]sdk.DataType, 0, len(args)) +// for _, item := range args { +// argTypes = append(argTypes, item.ArgDataType) +// } +// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) +// d.SetId(sid.FullyQualifiedName()) +// return ReadContextProcedure(ctx, d, meta) +//} +// +//func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// name := d.Get("name").(string) +// schema := d.Get("schema").(string) +// database := d.Get("database").(string) +// id := sdk.NewSchemaObjectIdentifier(database, schema, name) +// +// returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) +// if diags != nil { +// return diags +// } +// procedureDefinition := d.Get("statement").(string) +// runtimeVersion := d.Get("runtime_version").(string) +// packages := []sdk.ProcedurePackageRequest{} +// for _, item := range d.Get("packages").([]interface{}) { +// packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) +// } +// handler := d.Get("handler").(string) +// req := sdk.NewCreateForScalaProcedureRequest(id, *returns, runtimeVersion, packages, handler) +// req.WithProcedureDefinition(procedureDefinition) +// args, diags := getProcedureArguments(d) +// if diags != nil { +// return diags +// } +// if len(args) > 0 { +// req.WithArguments(args) +// } +// +// // read optional params +// if v, ok := d.GetOk("execute_as"); ok { +// if strings.ToUpper(v.(string)) == "OWNER" { +// req.WithExecuteAs(sdk.ExecuteAsOwner) +// } else if strings.ToUpper(v.(string)) == "CALLER" { +// req.WithExecuteAs(sdk.ExecuteAsCaller) +// } +// } +// if v, ok := d.GetOk("comment"); ok { +// req.WithComment(v.(string)) +// } +// if v, ok := d.GetOk("secure"); ok { +// req.WithSecure(v.(bool)) +// } +// if _, ok := d.GetOk("imports"); ok { +// imports := []sdk.ProcedureImportRequest{} +// for _, item := range d.Get("imports").([]interface{}) { +// imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) +// } +// req.WithImports(imports) +// } +// +// if err := client.Procedures.CreateForScala(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// argTypes := make([]sdk.DataType, 0, len(args)) +// for _, item := range args { +// argTypes = append(argTypes, item.ArgDataType) +// } +// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) +// d.SetId(sid.FullyQualifiedName()) +// return ReadContextProcedure(ctx, d, meta) +//} +// +//func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// name := d.Get("name").(string) +// schema := d.Get("schema").(string) +// database := d.Get("database").(string) +// id := sdk.NewSchemaObjectIdentifier(database, schema, name) +// +// returns, diags := parseProcedureSQLReturnsRequest(d.Get("return_type").(string)) +// if diags != nil { +// return diags +// } +// procedureDefinition := d.Get("statement").(string) +// req := sdk.NewCreateForSQLProcedureRequest(id, *returns, procedureDefinition) +// args, diags := getProcedureArguments(d) +// if diags != nil { +// return diags +// } +// if len(args) > 0 { +// req.WithArguments(args) +// } +// +// // read optional params +// if v, ok := d.GetOk("execute_as"); ok { +// if strings.ToUpper(v.(string)) == "OWNER" { +// req.WithExecuteAs(sdk.ExecuteAsOwner) +// } else if strings.ToUpper(v.(string)) == "CALLER" { +// req.WithExecuteAs(sdk.ExecuteAsCaller) +// } +// } +// if v, ok := d.GetOk("null_input_behavior"); ok { +// req.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) +// } +// if v, ok := d.GetOk("comment"); ok { +// req.WithComment(v.(string)) +// } +// if v, ok := d.GetOk("secure"); ok { +// req.WithSecure(v.(bool)) +// } +// +// if err := client.Procedures.CreateForSQL(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// argTypes := make([]sdk.DataType, 0, len(args)) +// for _, item := range args { +// argTypes = append(argTypes, item.ArgDataType) +// } +// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) +// d.SetId(sid.FullyQualifiedName()) +// return ReadContextProcedure(ctx, d, meta) +//} +// +//func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// name := d.Get("name").(string) +// schema := d.Get("schema").(string) +// database := d.Get("database").(string) +// id := sdk.NewSchemaObjectIdentifier(database, schema, name) +// +// returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) +// if diags != nil { +// return diags +// } +// procedureDefinition := d.Get("statement").(string) +// runtimeVersion := d.Get("runtime_version").(string) +// packages := []sdk.ProcedurePackageRequest{} +// for _, item := range d.Get("packages").([]interface{}) { +// packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) +// } +// handler := d.Get("handler").(string) +// req := sdk.NewCreateForPythonProcedureRequest(id, *returns, runtimeVersion, packages, handler) +// req.WithProcedureDefinition(procedureDefinition) +// args, diags := getProcedureArguments(d) +// if diags != nil { +// return diags +// } +// if len(args) > 0 { +// req.WithArguments(args) +// } +// +// // read optional params +// if v, ok := d.GetOk("execute_as"); ok { +// if strings.ToUpper(v.(string)) == "OWNER" { +// req.WithExecuteAs(sdk.ExecuteAsOwner) +// } else if strings.ToUpper(v.(string)) == "CALLER" { +// req.WithExecuteAs(sdk.ExecuteAsCaller) +// } +// } +// +// // [ { CALLED ON NULL INPUT | { RETURNS NULL ON NULL INPUT | STRICT } } ] does not work for java, scala or python +// // posted in docs-discuss channel, either docs need to be updated to reflect reality or this feature needs to be added +// // https://snowflake.slack.com/archives/C6380540P/p1707511734666249 +// // if v, ok := d.GetOk("null_input_behavior"); ok { +// // req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) +// // } +// +// if v, ok := d.GetOk("comment"); ok { +// req.WithComment(v.(string)) +// } +// if v, ok := d.GetOk("secure"); ok { +// req.WithSecure(v.(bool)) +// } +// if _, ok := d.GetOk("imports"); ok { +// imports := []sdk.ProcedureImportRequest{} +// for _, item := range d.Get("imports").([]interface{}) { +// imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) +// } +// req.WithImports(imports) +// } +// +// if err := client.Procedures.CreateForPython(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// argTypes := make([]sdk.DataType, 0, len(args)) +// for _, item := range args { +// argTypes = append(argTypes, item.ArgDataType) +// } +// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) +// d.SetId(sid.FullyQualifiedName()) +// return ReadContextProcedure(ctx, d, meta) +//} +// +//func ReadContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// diags := diag.Diagnostics{} +// client := meta.(*provider.Context).Client +// +// id, err := sdk.NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(d.Id()) +// if err != nil { +// return diag.FromErr(err) +// } +// if err := d.Set("name", id.Name()); err != nil { +// return diag.FromErr(err) +// } +// if err := d.Set("database", id.DatabaseName()); err != nil { +// return diag.FromErr(err) +// } +// if err := d.Set("schema", id.SchemaName()); err != nil { +// return diag.FromErr(err) +// } +// args := d.Get("arguments").([]interface{}) +// argTypes := make([]string, len(args)) +// for i, arg := range args { +// argTypes[i] = arg.(map[string]interface{})["type"].(string) +// } +// procedureDetails, err := client.Procedures.Describe(ctx, id) +// if err != nil { +// // if procedure is not found then mark resource to be removed from state file during apply or refresh +// d.SetId("") +// return diag.Diagnostics{ +// diag.Diagnostic{ +// Severity: diag.Warning, +// Summary: "Describe procedure failed.", +// Detail: fmt.Sprintf("Describe procedure failed: %v", err), +// }, +// } +// } +// for _, desc := range procedureDetails { +// switch desc.Property { +// case "signature": +// // Format in Snowflake DB is: (argName argType, argName argType, ...) +// args := strings.ReplaceAll(strings.ReplaceAll(desc.Value, "(", ""), ")", "") +// +// if args != "" { // Do nothing for functions without arguments +// argPairs := strings.Split(args, ", ") +// args := []interface{}{} +// +// for _, argPair := range argPairs { +// argItem := strings.Split(argPair, " ") +// +// arg := map[string]interface{}{} +// arg["name"] = argItem[0] +// arg["type"] = argItem[1] +// args = append(args, arg) +// } +// +// if err := d.Set("arguments", args); err != nil { +// return diag.FromErr(err) +// } +// } +// case "null handling": +// if err := d.Set("null_input_behavior", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "body": +// if err := d.Set("statement", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "execute as": +// if err := d.Set("execute_as", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "returns": +// if err := d.Set("return_type", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "language": +// if err := d.Set("language", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "runtime_version": +// if err := d.Set("runtime_version", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "packages": +// packagesString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", "") +// if packagesString != "" { // Do nothing for Java / Python functions without packages +// packages := strings.Split(packagesString, ",") +// if err := d.Set("packages", packages); err != nil { +// return diag.FromErr(err) +// } +// } +// case "imports": +// importsString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", ""), " ", "") +// if importsString != "" { // Do nothing for Java functions without imports +// imports := strings.Split(importsString, ",") +// if err := d.Set("imports", imports); err != nil { +// return diag.FromErr(err) +// } +// } +// case "handler": +// if err := d.Set("handler", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// case "volatility": +// if err := d.Set("return_behavior", desc.Value); err != nil { +// return diag.FromErr(err) +// } +// default: +// log.Printf("[INFO] Unexpected procedure property %v returned from Snowflake with value %v", desc.Property, desc.Value) +// } +// } +// +// request := sdk.NewShowProcedureRequest().WithIn(sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(id.DatabaseName(), id.SchemaName())}).WithLike(sdk.Like{Pattern: sdk.String(id.Name())}) +// +// procedures, err := client.Procedures.Show(ctx, request) +// if err != nil { +// return diag.FromErr(err) +// } +// // procedure names can be overloaded with different argument types so we iterate over and find the correct one +// // the ShowByID function should probably be updated to also require the list of arg types, like describe procedure +// for _, procedure := range procedures { +// argumentSignature := strings.Split(procedure.Arguments, " RETURN ")[0] +// argumentSignature = strings.ReplaceAll(argumentSignature, " ", "") +// if argumentSignature == id.ArgumentsSignature() { +// if err := d.Set("secure", procedure.IsSecure); err != nil { +// return diag.FromErr(err) +// } +// if err := d.Set("comment", procedure.Description); err != nil { +// return diag.FromErr(err) +// } +// } +// } +// +// return diags +//} +// +//func UpdateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// +// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) +// if d.HasChange("name") { +// newId := sdk.NewSchemaObjectIdentifierWithArgumentsOld(id.DatabaseName(), id.SchemaName(), d.Get("name").(string), id.Arguments()) +// +// err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithRenameTo(newId.WithoutArguments())) +// if err != nil { +// return diag.FromErr(err) +// } +// +// d.SetId(newId.FullyQualifiedName()) +// id = newId +// } +// +// if d.HasChange("comment") { +// comment := d.Get("comment") +// if comment != "" { +// if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithSetComment(comment.(string))); err != nil { +// return diag.FromErr(err) +// } +// } else { +// if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithUnsetComment(true)); err != nil { +// return diag.FromErr(err) +// } +// } +// } +// +// if d.HasChange("execute_as") { +// req := sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()) +// executeAs := d.Get("execute_as").(string) +// if strings.ToUpper(executeAs) == "OWNER" { +// req.WithExecuteAs(sdk.ExecuteAsOwner) +// } else if strings.ToUpper(executeAs) == "CALLER" { +// req.WithExecuteAs(sdk.ExecuteAsCaller) +// } +// if err := client.Procedures.Alter(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// } +// +// return ReadContextProcedure(ctx, d, meta) +//} +// +//func DeleteContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// +// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) +// if err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id.WithoutArguments(), id.Arguments())); err != nil { +// return diag.FromErr(err) +// } +// d.SetId("") +// return nil +//} +// +//func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentRequest, diag.Diagnostics) { +// args := make([]sdk.ProcedureArgumentRequest, 0) +// if v, ok := d.GetOk("arguments"); ok { +// for _, arg := range v.([]interface{}) { +// argName := arg.(map[string]interface{})["name"].(string) +// argType := arg.(map[string]interface{})["type"].(string) +// argDataType, diags := convertProcedureDataType(argType) +// if diags != nil { +// return nil, diags +// } +// args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: argDataType}) +// } +// } +// return args, nil +//} +// +//func convertProcedureDataType(s string) (sdk.DataType, diag.Diagnostics) { +// dataType, err := sdk.ToDataType(s) +// if err != nil { +// return dataType, diag.FromErr(err) +// } +// return dataType, nil +//} +// +//func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) { +// pattern := regexp.MustCompile(`(\w+)\s+(\w+)`) +// matches := pattern.FindAllStringSubmatch(s, -1) +// var columns []sdk.ProcedureColumn +// for _, match := range matches { +// if len(match) == 3 { +// dataType, err := sdk.ToDataType(match[2]) +// if err != nil { +// return nil, diag.FromErr(err) +// } +// columns = append(columns, sdk.ProcedureColumn{ +// ColumnName: match[1], +// ColumnDataType: dataType, +// }) +// } +// } +// return columns, nil +//} +// +//func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag.Diagnostics) { +// returns := sdk.NewProcedureReturnsRequest() +// if strings.HasPrefix(strings.ToLower(s), "table") { +// columns, diags := convertProcedureColumns(s) +// if diags != nil { +// return nil, diags +// } +// var cr []sdk.ProcedureColumnRequest +// for _, item := range columns { +// cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) +// } +// returns.WithTable(sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) +// } else { +// returnDataType, diags := convertProcedureDataType(s) +// if diags != nil { +// return nil, diags +// } +// returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) +// } +// return returns, nil +//} +// +//func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, diag.Diagnostics) { +// returns := sdk.NewProcedureSQLReturnsRequest() +// if strings.HasPrefix(strings.ToLower(s), "table") { +// columns, diags := convertProcedureColumns(s) +// if diags != nil { +// return nil, diags +// } +// var cr []sdk.ProcedureColumnRequest +// for _, item := range columns { +// cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) +// } +// returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) +// } else { +// returnDataType, diags := convertProcedureDataType(s) +// if diags != nil { +// return nil, diags +// } +// returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) +// } +// return returns, nil +//} diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index 29e89243c1..b601583a9a 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -257,7 +257,7 @@ resource "snowflake_procedure" "p" { } func TestAcc_Procedure_proveArgsPermanentDiff(t *testing.T) { - id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(sdk.DataTypeVARCHAR, sdk.DataTypeNumber) + id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR, sdk.DataTypeNumber) name := id.Name() resourceName := "snowflake_procedure.p" diff --git a/pkg/sdk/external_functions_impl_gen.go b/pkg/sdk/external_functions_impl_gen.go index fe81d14f56..2eac8e7c0f 100644 --- a/pkg/sdk/external_functions_impl_gen.go +++ b/pkg/sdk/external_functions_impl_gen.go @@ -2,9 +2,6 @@ package sdk import ( "context" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" ) var _ ExternalFunctions = (*externalFunctions)(nil) @@ -34,30 +31,32 @@ func (v *externalFunctions) Show(ctx context.Context, request *ShowExternalFunct } func (v *externalFunctions) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*ExternalFunction, error) { - arguments := id.Arguments() - externalFunctions, err := v.Show(ctx, NewShowExternalFunctionRequest(). - WithIn(&In{Schema: id.SchemaId()}). - WithLike(&Like{Pattern: String(id.Name())})) - if err != nil { - return nil, err - } - return collections.FindOne(externalFunctions, func(r ExternalFunction) bool { - database := strings.Trim(r.CatalogName, `"`) - schema := strings.Trim(r.SchemaName, `"`) - if r.Name != id.Name() || database != id.DatabaseName() || schema != id.SchemaName() { - return false - } - var sb strings.Builder - sb.WriteString("(") - for i, argument := range arguments { - sb.WriteString(string(argument)) - if i < len(arguments)-1 { - sb.WriteString(", ") - } - } - sb.WriteString(")") - return strings.Contains(r.Arguments, sb.String()) - }) + return nil, nil + // TODO + //arguments := id.Arguments() + //externalFunctions, err := v.Show(ctx, NewShowExternalFunctionRequest(). + // WithIn(&In{Schema: id.SchemaId()}). + // WithLike(&Like{Pattern: String(id.Name())})) + //if err != nil { + // return nil, err + //} + //return collections.FindOne(externalFunctions, func(r ExternalFunction) bool { + // database := strings.Trim(r.CatalogName, `"`) + // schema := strings.Trim(r.SchemaName, `"`) + // if r.Name != id.Name() || database != id.DatabaseName() || schema != id.SchemaName() { + // return false + // } + // var sb strings.Builder + // sb.WriteString("(") + // for i, argument := range arguments { + // sb.WriteString(string(argument)) + // if i < len(arguments)-1 { + // sb.WriteString(", ") + // } + // } + // sb.WriteString(")") + // return strings.Contains(r.Arguments, sb.String()) + //}) } func (v *externalFunctions) Describe(ctx context.Context, request *DescribeExternalFunctionRequest) ([]ExternalFunctionProperty, error) { @@ -112,9 +111,8 @@ func (r *CreateExternalFunctionRequest) toOpts() *CreateExternalFunctionOptions func (r *AlterExternalFunctionRequest) toOpts() *AlterExternalFunctionOptions { opts := &AlterExternalFunctionOptions{ - IfExists: r.IfExists, - name: r.name.WithoutArguments(), - ArgumentDataTypes: r.ArgumentDataTypes, + IfExists: r.IfExists, + name: r.name, } if r.Set != nil { opts.Set = &ExternalFunctionSet{ diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index 0bb4778832..5234a1fd7f 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -431,8 +431,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { } func TestFunctions_Drop(t *testing.T) { - noArgsId := randomSchemaObjectIdentifierWithArguments() - id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) + id := randomSchemaObjectIdentifier() defaultOpts := func() *DropFunctionOptions { return &DropFunctionOptions{ @@ -467,8 +466,7 @@ func TestFunctions_Drop(t *testing.T) { } func TestFunctions_Alter(t *testing.T) { - id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) - noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifier() defaultOpts := func() *AlterFunctionOptions { return &AlterFunctionOptions{ @@ -616,8 +614,7 @@ func TestFunctions_Show(t *testing.T) { } func TestFunctions_Describe(t *testing.T) { - id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) - noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifier() defaultOpts := func() *DescribeFunctionOptions { return &DescribeFunctionOptions{ diff --git a/pkg/sdk/functions_impl_gen.go b/pkg/sdk/functions_impl_gen.go index f2123639e0..7cbe401daf 100644 --- a/pkg/sdk/functions_impl_gen.go +++ b/pkg/sdk/functions_impl_gen.go @@ -6,6 +6,8 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "log" + "strings" ) var _ Functions = (*functions)(nil) diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 8bc4a33852..c8d7a127d1 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -232,7 +232,7 @@ type SchemaObjectIdentifier struct { databaseName string schemaName string name string - // TODO(next prs): left right now for backward compatibility for procedures and externalFunctions + // TODO(next prs ???): left right now for backward compatibility for procedures and externalFunctions arguments []DataType } @@ -343,15 +343,21 @@ type SchemaObjectIdentifierWithArguments struct { func NewSchemaObjectIdentifierWithArguments(databaseName, schemaName, name string, argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { return SchemaObjectIdentifierWithArguments{ - databaseName: strings.Trim(databaseName, `"`), - schemaName: strings.Trim(schemaName, `"`), - name: strings.Trim(name, `"`), - argumentDataTypes: argumentDataTypes, + databaseName: strings.Trim(databaseName, `"`), + schemaName: strings.Trim(schemaName, `"`), + name: strings.Trim(name, `"`), } } -func NewSchemaObjectIdentifierWithArgumentsInSchema(schemaId DatabaseObjectIdentifier, name string, argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { - return NewSchemaObjectIdentifierWithArguments(schemaId.DatabaseName(), schemaId.Name(), name, argumentDataTypes...) +func NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(fullyQualifiedName string) SchemaObjectIdentifierWithArguments { + parts := strings.Split(fullyQualifiedName, ".") + id := SchemaObjectIdentifierWithArguments{ + databaseName: strings.Trim(parts[0], `"`), + schemaName: strings.Trim(parts[1], `"`), + name: strings.Trim(parts[2], `"`), + // TODO: Arguments + } + return id } func (i SchemaObjectIdentifierWithArguments) DatabaseName() string { @@ -386,7 +392,7 @@ func (i SchemaObjectIdentifierWithArguments) FullyQualifiedName() string { if i.schemaName == "" && i.databaseName == "" && i.name == "" && len(i.argumentDataTypes) == 0 { return "" } - return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(AsStringList(i.argumentDataTypes), ",")) + return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(i.arguments, ",")) } type TableColumnIdentifier struct { diff --git a/pkg/sdk/identifier_helpers_test.go b/pkg/sdk/identifier_helpers_test.go index d3b0a98bd3..2f409cc84d 100644 --- a/pkg/sdk/identifier_helpers_test.go +++ b/pkg/sdk/identifier_helpers_test.go @@ -1,6 +1,7 @@ package sdk import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -82,3 +83,42 @@ func TestDatabaseObjectIdentifier(t *testing.T) { assert.Equal(t, `"aaa"."bbb"`, identifier.FullyQualifiedName()) }) } + +func TestNewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(t *testing.T) { + testCases := []struct { + RawInput string + Input SchemaObjectIdentifierWithArguments + Error string + }{ + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, DataTypeNumber, DataTypeTimestampTZ)}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, "VECTOR(INT, 20)", DataTypeFloat)}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)", "VECTOR(INT, 10)")}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat)}, + // TODO(): Won't work, because of the assumption that identifiers are not containing '(' and ')' parentheses + {Input: NewSchemaObjectIdentifierWithArguments(`ab()c`, `def()`, `()ghi`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat), Error: `unable to read identifier: "ab`}, + {Input: NewSchemaObjectIdentifierWithArguments(`ab(,)c`, `,def()`, `()ghi,`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat), Error: `unable to read identifier: "ab`}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`), RawInput: `abc.def.ghi()`}, + {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)"), RawInput: `abc.def.ghi(FLOAT, VECTOR(INT, 20))`}, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("processing %s", testCase.Input.FullyQualifiedName()), func(t *testing.T) { + var id SchemaObjectIdentifierWithArguments + var err error + if testCase.RawInput != "" { + id, err = NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(testCase.RawInput) + } else { + id, err = NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(testCase.Input.FullyQualifiedName()) + } + + if testCase.Error != "" { + assert.ErrorContains(t, err, testCase.Error) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.Input.FullyQualifiedName(), id.FullyQualifiedName()) + } + }) + } +} diff --git a/pkg/sdk/poc/main.go b/pkg/sdk/poc/main.go index f8f1014bdb..7f79ec3361 100644 --- a/pkg/sdk/poc/main.go +++ b/pkg/sdk/poc/main.go @@ -5,6 +5,7 @@ package main import ( "bytes" "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/genhelpers" "io" "log" "os" diff --git a/pkg/sdk/procedures_def.go b/pkg/sdk/procedures_def.go index 199eea62dd..1788c99a44 100644 --- a/pkg/sdk/procedures_def.go +++ b/pkg/sdk/procedures_def.go @@ -66,7 +66,7 @@ var procedureWithClause = g.NewQueryStruct("ProcedureWithClause"). var ProceduresDef = g.NewInterface( "Procedures", "Procedure", - g.KindOfT[SchemaObjectIdentifier](), + g.KindOfT[SchemaObjectIdentifierWithArguments](), ).CustomOperation( "CreateForJava", "https://docs.snowflake.com/en/sql-reference/sql/create-procedure#java-handler", @@ -75,7 +75,7 @@ var ProceduresDef = g.NewInterface( OrReplace(). OptionalSQL("SECURE"). SQL("PROCEDURE"). - Name(). + Identifier("name", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().Required()). ListQueryStructField( "Arguments", procedureArgument, @@ -119,7 +119,7 @@ var ProceduresDef = g.NewInterface( OrReplace(). OptionalSQL("SECURE"). SQL("PROCEDURE"). - Name(). + Identifier("name", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().Required()). ListQueryStructField( "Arguments", procedureArgument, @@ -143,7 +143,7 @@ var ProceduresDef = g.NewInterface( OrReplace(). OptionalSQL("SECURE"). SQL("PROCEDURE"). - Name(). + Identifier("name", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().Required()). ListQueryStructField( "Arguments", procedureArgument, @@ -186,7 +186,7 @@ var ProceduresDef = g.NewInterface( OrReplace(). OptionalSQL("SECURE"). SQL("PROCEDURE"). - Name(). + Identifier("name", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().Required()). ListQueryStructField( "Arguments", procedureArgument, @@ -228,7 +228,7 @@ var ProceduresDef = g.NewInterface( OrReplace(). OptionalSQL("SECURE"). SQL("PROCEDURE"). - Name(). + Identifier("name", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().Required()). ListQueryStructField( "Arguments", procedureArgument, @@ -254,7 +254,6 @@ var ProceduresDef = g.NewInterface( SQL("PROCEDURE"). IfExists(). Name(). - PredefinedQueryStructField("ArgumentDataTypes", "[]DataType", g.KeywordOptions().MustParentheses().Required()). OptionalIdentifier("RenameTo", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("RENAME TO")). OptionalTextAssignment("SET COMMENT", g.ParameterOptions().SingleQuotes()). OptionalTextAssignment("SET LOG_LEVEL", g.ParameterOptions().SingleQuotes()). @@ -273,7 +272,6 @@ var ProceduresDef = g.NewInterface( SQL("PROCEDURE"). IfExists(). Name(). - PredefinedQueryStructField("ArgumentDataTypes", "[]DataType", g.KeywordOptions().MustParentheses().Required()). WithValidation(g.ValidIdentifier, "name"), ).ShowOperation( "https://docs.snowflake.com/en/sql-reference/sql/show-procedures", @@ -325,14 +323,13 @@ var ProceduresDef = g.NewInterface( Describe(). SQL("PROCEDURE"). Name(). - PredefinedQueryStructField("ArgumentDataTypes", "[]DataType", g.KeywordOptions().MustParentheses().Required()). WithValidation(g.ValidIdentifier, "name"), ).CustomOperation( "Call", "https://docs.snowflake.com/en/sql-reference/sql/call", g.NewQueryStruct("Call"). SQL("CALL"). - Name(). + Identifier("name", g.KindOfT[AccountObjectIdentifier](), g.IdentifierOptions().Required()). PredefinedQueryStructField("CallArguments", "[]string", g.KeywordOptions().MustParentheses()). PredefinedQueryStructField("ScriptingVariable", "*string", g.ParameterOptions().NoEquals().NoQuotes().SQL("INTO")). WithValidation(g.ValidIdentifier, "name"), diff --git a/pkg/sdk/procedures_dto_builders_gen.go b/pkg/sdk/procedures_dto_builders_gen.go index 208aaceecd..75671ecf4e 100644 --- a/pkg/sdk/procedures_dto_builders_gen.go +++ b/pkg/sdk/procedures_dto_builders_gen.go @@ -20,13 +20,13 @@ func NewCreateForJavaProcedureRequest( return &s } -func (s *CreateForJavaProcedureRequest) WithOrReplace(OrReplace *bool) *CreateForJavaProcedureRequest { - s.OrReplace = OrReplace +func (s *CreateForJavaProcedureRequest) WithOrReplace(OrReplace bool) *CreateForJavaProcedureRequest { + s.OrReplace = &OrReplace return s } -func (s *CreateForJavaProcedureRequest) WithSecure(Secure *bool) *CreateForJavaProcedureRequest { - s.Secure = Secure +func (s *CreateForJavaProcedureRequest) WithSecure(Secure bool) *CreateForJavaProcedureRequest { + s.Secure = &Secure return s } @@ -35,8 +35,8 @@ func (s *CreateForJavaProcedureRequest) WithArguments(Arguments []ProcedureArgum return s } -func (s *CreateForJavaProcedureRequest) WithCopyGrants(CopyGrants *bool) *CreateForJavaProcedureRequest { - s.CopyGrants = CopyGrants +func (s *CreateForJavaProcedureRequest) WithCopyGrants(CopyGrants bool) *CreateForJavaProcedureRequest { + s.CopyGrants = &CopyGrants return s } @@ -55,28 +55,28 @@ func (s *CreateForJavaProcedureRequest) WithSecrets(Secrets []Secret) *CreateFor return s } -func (s *CreateForJavaProcedureRequest) WithTargetPath(TargetPath *string) *CreateForJavaProcedureRequest { - s.TargetPath = TargetPath +func (s *CreateForJavaProcedureRequest) WithTargetPath(TargetPath string) *CreateForJavaProcedureRequest { + s.TargetPath = &TargetPath return s } -func (s *CreateForJavaProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateForJavaProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateForJavaProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateForJavaProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateForJavaProcedureRequest) WithComment(Comment *string) *CreateForJavaProcedureRequest { - s.Comment = Comment +func (s *CreateForJavaProcedureRequest) WithComment(Comment string) *CreateForJavaProcedureRequest { + s.Comment = &Comment return s } -func (s *CreateForJavaProcedureRequest) WithExecuteAs(ExecuteAs *ExecuteAs) *CreateForJavaProcedureRequest { - s.ExecuteAs = ExecuteAs +func (s *CreateForJavaProcedureRequest) WithExecuteAs(ExecuteAs ExecuteAs) *CreateForJavaProcedureRequest { + s.ExecuteAs = &ExecuteAs return s } -func (s *CreateForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinition *string) *CreateForJavaProcedureRequest { - s.ProcedureDefinition = ProcedureDefinition +func (s *CreateForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinition string) *CreateForJavaProcedureRequest { + s.ProcedureDefinition = &ProcedureDefinition return s } @@ -90,8 +90,8 @@ func NewProcedureArgumentRequest( return &s } -func (s *ProcedureArgumentRequest) WithDefaultValue(DefaultValue *string) *ProcedureArgumentRequest { - s.DefaultValue = DefaultValue +func (s *ProcedureArgumentRequest) WithDefaultValue(DefaultValue string) *ProcedureArgumentRequest { + s.DefaultValue = &DefaultValue return s } @@ -99,13 +99,13 @@ func NewProcedureReturnsRequest() *ProcedureReturnsRequest { return &ProcedureReturnsRequest{} } -func (s *ProcedureReturnsRequest) WithResultDataType(ResultDataType *ProcedureReturnsResultDataTypeRequest) *ProcedureReturnsRequest { - s.ResultDataType = ResultDataType +func (s *ProcedureReturnsRequest) WithResultDataType(ResultDataType ProcedureReturnsResultDataTypeRequest) *ProcedureReturnsRequest { + s.ResultDataType = &ResultDataType return s } -func (s *ProcedureReturnsRequest) WithTable(Table *ProcedureReturnsTableRequest) *ProcedureReturnsRequest { - s.Table = Table +func (s *ProcedureReturnsRequest) WithTable(Table ProcedureReturnsTableRequest) *ProcedureReturnsRequest { + s.Table = &Table return s } @@ -117,13 +117,13 @@ func NewProcedureReturnsResultDataTypeRequest( return &s } -func (s *ProcedureReturnsResultDataTypeRequest) WithNull(Null *bool) *ProcedureReturnsResultDataTypeRequest { - s.Null = Null +func (s *ProcedureReturnsResultDataTypeRequest) WithNull(Null bool) *ProcedureReturnsResultDataTypeRequest { + s.Null = &Null return s } -func (s *ProcedureReturnsResultDataTypeRequest) WithNotNull(NotNull *bool) *ProcedureReturnsResultDataTypeRequest { - s.NotNull = NotNull +func (s *ProcedureReturnsResultDataTypeRequest) WithNotNull(NotNull bool) *ProcedureReturnsResultDataTypeRequest { + s.NotNull = &NotNull return s } @@ -174,13 +174,13 @@ func NewCreateForJavaScriptProcedureRequest( return &s } -func (s *CreateForJavaScriptProcedureRequest) WithOrReplace(OrReplace *bool) *CreateForJavaScriptProcedureRequest { - s.OrReplace = OrReplace +func (s *CreateForJavaScriptProcedureRequest) WithOrReplace(OrReplace bool) *CreateForJavaScriptProcedureRequest { + s.OrReplace = &OrReplace return s } -func (s *CreateForJavaScriptProcedureRequest) WithSecure(Secure *bool) *CreateForJavaScriptProcedureRequest { - s.Secure = Secure +func (s *CreateForJavaScriptProcedureRequest) WithSecure(Secure bool) *CreateForJavaScriptProcedureRequest { + s.Secure = &Secure return s } @@ -189,28 +189,28 @@ func (s *CreateForJavaScriptProcedureRequest) WithArguments(Arguments []Procedur return s } -func (s *CreateForJavaScriptProcedureRequest) WithCopyGrants(CopyGrants *bool) *CreateForJavaScriptProcedureRequest { - s.CopyGrants = CopyGrants +func (s *CreateForJavaScriptProcedureRequest) WithCopyGrants(CopyGrants bool) *CreateForJavaScriptProcedureRequest { + s.CopyGrants = &CopyGrants return s } -func (s *CreateForJavaScriptProcedureRequest) WithNotNull(NotNull *bool) *CreateForJavaScriptProcedureRequest { - s.NotNull = NotNull +func (s *CreateForJavaScriptProcedureRequest) WithNotNull(NotNull bool) *CreateForJavaScriptProcedureRequest { + s.NotNull = &NotNull return s } -func (s *CreateForJavaScriptProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateForJavaScriptProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateForJavaScriptProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateForJavaScriptProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateForJavaScriptProcedureRequest) WithComment(Comment *string) *CreateForJavaScriptProcedureRequest { - s.Comment = Comment +func (s *CreateForJavaScriptProcedureRequest) WithComment(Comment string) *CreateForJavaScriptProcedureRequest { + s.Comment = &Comment return s } -func (s *CreateForJavaScriptProcedureRequest) WithExecuteAs(ExecuteAs *ExecuteAs) *CreateForJavaScriptProcedureRequest { - s.ExecuteAs = ExecuteAs +func (s *CreateForJavaScriptProcedureRequest) WithExecuteAs(ExecuteAs ExecuteAs) *CreateForJavaScriptProcedureRequest { + s.ExecuteAs = &ExecuteAs return s } @@ -230,13 +230,13 @@ func NewCreateForPythonProcedureRequest( return &s } -func (s *CreateForPythonProcedureRequest) WithOrReplace(OrReplace *bool) *CreateForPythonProcedureRequest { - s.OrReplace = OrReplace +func (s *CreateForPythonProcedureRequest) WithOrReplace(OrReplace bool) *CreateForPythonProcedureRequest { + s.OrReplace = &OrReplace return s } -func (s *CreateForPythonProcedureRequest) WithSecure(Secure *bool) *CreateForPythonProcedureRequest { - s.Secure = Secure +func (s *CreateForPythonProcedureRequest) WithSecure(Secure bool) *CreateForPythonProcedureRequest { + s.Secure = &Secure return s } @@ -245,8 +245,8 @@ func (s *CreateForPythonProcedureRequest) WithArguments(Arguments []ProcedureArg return s } -func (s *CreateForPythonProcedureRequest) WithCopyGrants(CopyGrants *bool) *CreateForPythonProcedureRequest { - s.CopyGrants = CopyGrants +func (s *CreateForPythonProcedureRequest) WithCopyGrants(CopyGrants bool) *CreateForPythonProcedureRequest { + s.CopyGrants = &CopyGrants return s } @@ -265,23 +265,23 @@ func (s *CreateForPythonProcedureRequest) WithSecrets(Secrets []Secret) *CreateF return s } -func (s *CreateForPythonProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateForPythonProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateForPythonProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateForPythonProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateForPythonProcedureRequest) WithComment(Comment *string) *CreateForPythonProcedureRequest { - s.Comment = Comment +func (s *CreateForPythonProcedureRequest) WithComment(Comment string) *CreateForPythonProcedureRequest { + s.Comment = &Comment return s } -func (s *CreateForPythonProcedureRequest) WithExecuteAs(ExecuteAs *ExecuteAs) *CreateForPythonProcedureRequest { - s.ExecuteAs = ExecuteAs +func (s *CreateForPythonProcedureRequest) WithExecuteAs(ExecuteAs ExecuteAs) *CreateForPythonProcedureRequest { + s.ExecuteAs = &ExecuteAs return s } -func (s *CreateForPythonProcedureRequest) WithProcedureDefinition(ProcedureDefinition *string) *CreateForPythonProcedureRequest { - s.ProcedureDefinition = ProcedureDefinition +func (s *CreateForPythonProcedureRequest) WithProcedureDefinition(ProcedureDefinition string) *CreateForPythonProcedureRequest { + s.ProcedureDefinition = &ProcedureDefinition return s } @@ -301,13 +301,13 @@ func NewCreateForScalaProcedureRequest( return &s } -func (s *CreateForScalaProcedureRequest) WithOrReplace(OrReplace *bool) *CreateForScalaProcedureRequest { - s.OrReplace = OrReplace +func (s *CreateForScalaProcedureRequest) WithOrReplace(OrReplace bool) *CreateForScalaProcedureRequest { + s.OrReplace = &OrReplace return s } -func (s *CreateForScalaProcedureRequest) WithSecure(Secure *bool) *CreateForScalaProcedureRequest { - s.Secure = Secure +func (s *CreateForScalaProcedureRequest) WithSecure(Secure bool) *CreateForScalaProcedureRequest { + s.Secure = &Secure return s } @@ -316,8 +316,8 @@ func (s *CreateForScalaProcedureRequest) WithArguments(Arguments []ProcedureArgu return s } -func (s *CreateForScalaProcedureRequest) WithCopyGrants(CopyGrants *bool) *CreateForScalaProcedureRequest { - s.CopyGrants = CopyGrants +func (s *CreateForScalaProcedureRequest) WithCopyGrants(CopyGrants bool) *CreateForScalaProcedureRequest { + s.CopyGrants = &CopyGrants return s } @@ -326,28 +326,28 @@ func (s *CreateForScalaProcedureRequest) WithImports(Imports []ProcedureImportRe return s } -func (s *CreateForScalaProcedureRequest) WithTargetPath(TargetPath *string) *CreateForScalaProcedureRequest { - s.TargetPath = TargetPath +func (s *CreateForScalaProcedureRequest) WithTargetPath(TargetPath string) *CreateForScalaProcedureRequest { + s.TargetPath = &TargetPath return s } -func (s *CreateForScalaProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateForScalaProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateForScalaProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateForScalaProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateForScalaProcedureRequest) WithComment(Comment *string) *CreateForScalaProcedureRequest { - s.Comment = Comment +func (s *CreateForScalaProcedureRequest) WithComment(Comment string) *CreateForScalaProcedureRequest { + s.Comment = &Comment return s } -func (s *CreateForScalaProcedureRequest) WithExecuteAs(ExecuteAs *ExecuteAs) *CreateForScalaProcedureRequest { - s.ExecuteAs = ExecuteAs +func (s *CreateForScalaProcedureRequest) WithExecuteAs(ExecuteAs ExecuteAs) *CreateForScalaProcedureRequest { + s.ExecuteAs = &ExecuteAs return s } -func (s *CreateForScalaProcedureRequest) WithProcedureDefinition(ProcedureDefinition *string) *CreateForScalaProcedureRequest { - s.ProcedureDefinition = ProcedureDefinition +func (s *CreateForScalaProcedureRequest) WithProcedureDefinition(ProcedureDefinition string) *CreateForScalaProcedureRequest { + s.ProcedureDefinition = &ProcedureDefinition return s } @@ -363,13 +363,13 @@ func NewCreateForSQLProcedureRequest( return &s } -func (s *CreateForSQLProcedureRequest) WithOrReplace(OrReplace *bool) *CreateForSQLProcedureRequest { - s.OrReplace = OrReplace +func (s *CreateForSQLProcedureRequest) WithOrReplace(OrReplace bool) *CreateForSQLProcedureRequest { + s.OrReplace = &OrReplace return s } -func (s *CreateForSQLProcedureRequest) WithSecure(Secure *bool) *CreateForSQLProcedureRequest { - s.Secure = Secure +func (s *CreateForSQLProcedureRequest) WithSecure(Secure bool) *CreateForSQLProcedureRequest { + s.Secure = &Secure return s } @@ -378,23 +378,23 @@ func (s *CreateForSQLProcedureRequest) WithArguments(Arguments []ProcedureArgume return s } -func (s *CreateForSQLProcedureRequest) WithCopyGrants(CopyGrants *bool) *CreateForSQLProcedureRequest { - s.CopyGrants = CopyGrants +func (s *CreateForSQLProcedureRequest) WithCopyGrants(CopyGrants bool) *CreateForSQLProcedureRequest { + s.CopyGrants = &CopyGrants return s } -func (s *CreateForSQLProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateForSQLProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateForSQLProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateForSQLProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateForSQLProcedureRequest) WithComment(Comment *string) *CreateForSQLProcedureRequest { - s.Comment = Comment +func (s *CreateForSQLProcedureRequest) WithComment(Comment string) *CreateForSQLProcedureRequest { + s.Comment = &Comment return s } -func (s *CreateForSQLProcedureRequest) WithExecuteAs(ExecuteAs *ExecuteAs) *CreateForSQLProcedureRequest { - s.ExecuteAs = ExecuteAs +func (s *CreateForSQLProcedureRequest) WithExecuteAs(ExecuteAs ExecuteAs) *CreateForSQLProcedureRequest { + s.ExecuteAs = &ExecuteAs return s } @@ -402,58 +402,56 @@ func NewProcedureSQLReturnsRequest() *ProcedureSQLReturnsRequest { return &ProcedureSQLReturnsRequest{} } -func (s *ProcedureSQLReturnsRequest) WithResultDataType(ResultDataType *ProcedureReturnsResultDataTypeRequest) *ProcedureSQLReturnsRequest { - s.ResultDataType = ResultDataType +func (s *ProcedureSQLReturnsRequest) WithResultDataType(ResultDataType ProcedureReturnsResultDataTypeRequest) *ProcedureSQLReturnsRequest { + s.ResultDataType = &ResultDataType return s } -func (s *ProcedureSQLReturnsRequest) WithTable(Table *ProcedureReturnsTableRequest) *ProcedureSQLReturnsRequest { - s.Table = Table +func (s *ProcedureSQLReturnsRequest) WithTable(Table ProcedureReturnsTableRequest) *ProcedureSQLReturnsRequest { + s.Table = &Table return s } -func (s *ProcedureSQLReturnsRequest) WithNotNull(NotNull *bool) *ProcedureSQLReturnsRequest { - s.NotNull = NotNull +func (s *ProcedureSQLReturnsRequest) WithNotNull(NotNull bool) *ProcedureSQLReturnsRequest { + s.NotNull = &NotNull return s } func NewAlterProcedureRequest( - name SchemaObjectIdentifier, - ArgumentDataTypes []DataType, + name SchemaObjectIdentifierWithArguments, ) *AlterProcedureRequest { s := AlterProcedureRequest{} s.name = name - s.ArgumentDataTypes = ArgumentDataTypes return &s } -func (s *AlterProcedureRequest) WithIfExists(IfExists *bool) *AlterProcedureRequest { - s.IfExists = IfExists +func (s *AlterProcedureRequest) WithIfExists(IfExists bool) *AlterProcedureRequest { + s.IfExists = &IfExists return s } -func (s *AlterProcedureRequest) WithRenameTo(RenameTo *SchemaObjectIdentifier) *AlterProcedureRequest { - s.RenameTo = RenameTo +func (s *AlterProcedureRequest) WithRenameTo(RenameTo SchemaObjectIdentifier) *AlterProcedureRequest { + s.RenameTo = &RenameTo return s } -func (s *AlterProcedureRequest) WithSetComment(SetComment *string) *AlterProcedureRequest { - s.SetComment = SetComment +func (s *AlterProcedureRequest) WithSetComment(SetComment string) *AlterProcedureRequest { + s.SetComment = &SetComment return s } -func (s *AlterProcedureRequest) WithSetLogLevel(SetLogLevel *string) *AlterProcedureRequest { - s.SetLogLevel = SetLogLevel +func (s *AlterProcedureRequest) WithSetLogLevel(SetLogLevel string) *AlterProcedureRequest { + s.SetLogLevel = &SetLogLevel return s } -func (s *AlterProcedureRequest) WithSetTraceLevel(SetTraceLevel *string) *AlterProcedureRequest { - s.SetTraceLevel = SetTraceLevel +func (s *AlterProcedureRequest) WithSetTraceLevel(SetTraceLevel string) *AlterProcedureRequest { + s.SetTraceLevel = &SetTraceLevel return s } -func (s *AlterProcedureRequest) WithUnsetComment(UnsetComment *bool) *AlterProcedureRequest { - s.UnsetComment = UnsetComment +func (s *AlterProcedureRequest) WithUnsetComment(UnsetComment bool) *AlterProcedureRequest { + s.UnsetComment = &UnsetComment return s } @@ -467,23 +465,21 @@ func (s *AlterProcedureRequest) WithUnsetTags(UnsetTags []ObjectIdentifier) *Alt return s } -func (s *AlterProcedureRequest) WithExecuteAs(ExecuteAs *ExecuteAs) *AlterProcedureRequest { - s.ExecuteAs = ExecuteAs +func (s *AlterProcedureRequest) WithExecuteAs(ExecuteAs ExecuteAs) *AlterProcedureRequest { + s.ExecuteAs = &ExecuteAs return s } func NewDropProcedureRequest( - name SchemaObjectIdentifier, - ArgumentDataTypes []DataType, + name SchemaObjectIdentifierWithArguments, ) *DropProcedureRequest { s := DropProcedureRequest{} s.name = name - s.ArgumentDataTypes = ArgumentDataTypes return &s } -func (s *DropProcedureRequest) WithIfExists(IfExists *bool) *DropProcedureRequest { - s.IfExists = IfExists +func (s *DropProcedureRequest) WithIfExists(IfExists bool) *DropProcedureRequest { + s.IfExists = &IfExists return s } @@ -491,23 +487,21 @@ func NewShowProcedureRequest() *ShowProcedureRequest { return &ShowProcedureRequest{} } -func (s *ShowProcedureRequest) WithLike(Like *Like) *ShowProcedureRequest { - s.Like = Like +func (s *ShowProcedureRequest) WithLike(Like Like) *ShowProcedureRequest { + s.Like = &Like return s } -func (s *ShowProcedureRequest) WithIn(In *In) *ShowProcedureRequest { - s.In = In +func (s *ShowProcedureRequest) WithIn(In In) *ShowProcedureRequest { + s.In = &In return s } func NewDescribeProcedureRequest( - name SchemaObjectIdentifier, - ArgumentDataTypes []DataType, + name SchemaObjectIdentifierWithArguments, ) *DescribeProcedureRequest { s := DescribeProcedureRequest{} s.name = name - s.ArgumentDataTypes = ArgumentDataTypes return &s } @@ -524,8 +518,8 @@ func (s *CallProcedureRequest) WithCallArguments(CallArguments []string) *CallPr return s } -func (s *CallProcedureRequest) WithScriptingVariable(ScriptingVariable *string) *CallProcedureRequest { - s.ScriptingVariable = ScriptingVariable +func (s *CallProcedureRequest) WithScriptingVariable(ScriptingVariable string) *CallProcedureRequest { + s.ScriptingVariable = &ScriptingVariable return s } @@ -557,18 +551,18 @@ func (s *CreateAndCallForJavaProcedureRequest) WithImports(Imports []ProcedureIm return s } -func (s *CreateAndCallForJavaProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateAndCallForJavaProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateAndCallForJavaProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateAndCallForJavaProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateAndCallForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinition *string) *CreateAndCallForJavaProcedureRequest { - s.ProcedureDefinition = ProcedureDefinition +func (s *CreateAndCallForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinition string) *CreateAndCallForJavaProcedureRequest { + s.ProcedureDefinition = &ProcedureDefinition return s } -func (s *CreateAndCallForJavaProcedureRequest) WithWithClause(WithClause *ProcedureWithClauseRequest) *CreateAndCallForJavaProcedureRequest { - s.WithClause = WithClause +func (s *CreateAndCallForJavaProcedureRequest) WithWithClause(WithClause ProcedureWithClauseRequest) *CreateAndCallForJavaProcedureRequest { + s.WithClause = &WithClause return s } @@ -577,8 +571,8 @@ func (s *CreateAndCallForJavaProcedureRequest) WithCallArguments(CallArguments [ return s } -func (s *CreateAndCallForJavaProcedureRequest) WithScriptingVariable(ScriptingVariable *string) *CreateAndCallForJavaProcedureRequest { - s.ScriptingVariable = ScriptingVariable +func (s *CreateAndCallForJavaProcedureRequest) WithScriptingVariable(ScriptingVariable string) *CreateAndCallForJavaProcedureRequest { + s.ScriptingVariable = &ScriptingVariable return s } @@ -625,13 +619,13 @@ func (s *CreateAndCallForScalaProcedureRequest) WithImports(Imports []ProcedureI return s } -func (s *CreateAndCallForScalaProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateAndCallForScalaProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateAndCallForScalaProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateAndCallForScalaProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateAndCallForScalaProcedureRequest) WithProcedureDefinition(ProcedureDefinition *string) *CreateAndCallForScalaProcedureRequest { - s.ProcedureDefinition = ProcedureDefinition +func (s *CreateAndCallForScalaProcedureRequest) WithProcedureDefinition(ProcedureDefinition string) *CreateAndCallForScalaProcedureRequest { + s.ProcedureDefinition = &ProcedureDefinition return s } @@ -645,8 +639,8 @@ func (s *CreateAndCallForScalaProcedureRequest) WithCallArguments(CallArguments return s } -func (s *CreateAndCallForScalaProcedureRequest) WithScriptingVariable(ScriptingVariable *string) *CreateAndCallForScalaProcedureRequest { - s.ScriptingVariable = ScriptingVariable +func (s *CreateAndCallForScalaProcedureRequest) WithScriptingVariable(ScriptingVariable string) *CreateAndCallForScalaProcedureRequest { + s.ScriptingVariable = &ScriptingVariable return s } @@ -669,13 +663,13 @@ func (s *CreateAndCallForJavaScriptProcedureRequest) WithArguments(Arguments []P return s } -func (s *CreateAndCallForJavaScriptProcedureRequest) WithNotNull(NotNull *bool) *CreateAndCallForJavaScriptProcedureRequest { - s.NotNull = NotNull +func (s *CreateAndCallForJavaScriptProcedureRequest) WithNotNull(NotNull bool) *CreateAndCallForJavaScriptProcedureRequest { + s.NotNull = &NotNull return s } -func (s *CreateAndCallForJavaScriptProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateAndCallForJavaScriptProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateAndCallForJavaScriptProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateAndCallForJavaScriptProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } @@ -689,8 +683,8 @@ func (s *CreateAndCallForJavaScriptProcedureRequest) WithCallArguments(CallArgum return s } -func (s *CreateAndCallForJavaScriptProcedureRequest) WithScriptingVariable(ScriptingVariable *string) *CreateAndCallForJavaScriptProcedureRequest { - s.ScriptingVariable = ScriptingVariable +func (s *CreateAndCallForJavaScriptProcedureRequest) WithScriptingVariable(ScriptingVariable string) *CreateAndCallForJavaScriptProcedureRequest { + s.ScriptingVariable = &ScriptingVariable return s } @@ -722,13 +716,13 @@ func (s *CreateAndCallForPythonProcedureRequest) WithImports(Imports []Procedure return s } -func (s *CreateAndCallForPythonProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateAndCallForPythonProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateAndCallForPythonProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateAndCallForPythonProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateAndCallForPythonProcedureRequest) WithProcedureDefinition(ProcedureDefinition *string) *CreateAndCallForPythonProcedureRequest { - s.ProcedureDefinition = ProcedureDefinition +func (s *CreateAndCallForPythonProcedureRequest) WithProcedureDefinition(ProcedureDefinition string) *CreateAndCallForPythonProcedureRequest { + s.ProcedureDefinition = &ProcedureDefinition return s } @@ -742,8 +736,8 @@ func (s *CreateAndCallForPythonProcedureRequest) WithCallArguments(CallArguments return s } -func (s *CreateAndCallForPythonProcedureRequest) WithScriptingVariable(ScriptingVariable *string) *CreateAndCallForPythonProcedureRequest { - s.ScriptingVariable = ScriptingVariable +func (s *CreateAndCallForPythonProcedureRequest) WithScriptingVariable(ScriptingVariable string) *CreateAndCallForPythonProcedureRequest { + s.ScriptingVariable = &ScriptingVariable return s } @@ -766,8 +760,8 @@ func (s *CreateAndCallForSQLProcedureRequest) WithArguments(Arguments []Procedur return s } -func (s *CreateAndCallForSQLProcedureRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateAndCallForSQLProcedureRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateAndCallForSQLProcedureRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateAndCallForSQLProcedureRequest { + s.NullInputBehavior = &NullInputBehavior return s } @@ -781,7 +775,7 @@ func (s *CreateAndCallForSQLProcedureRequest) WithCallArguments(CallArguments [] return s } -func (s *CreateAndCallForSQLProcedureRequest) WithScriptingVariable(ScriptingVariable *string) *CreateAndCallForSQLProcedureRequest { - s.ScriptingVariable = ScriptingVariable +func (s *CreateAndCallForSQLProcedureRequest) WithScriptingVariable(ScriptingVariable string) *CreateAndCallForSQLProcedureRequest { + s.ScriptingVariable = &ScriptingVariable return s } diff --git a/pkg/sdk/procedures_dto_gen.go b/pkg/sdk/procedures_dto_gen.go index 02ded4ed1e..398169a682 100644 --- a/pkg/sdk/procedures_dto_gen.go +++ b/pkg/sdk/procedures_dto_gen.go @@ -145,23 +145,21 @@ type ProcedureSQLReturnsRequest struct { } type AlterProcedureRequest struct { - IfExists *bool - name SchemaObjectIdentifier // required - ArgumentDataTypes []DataType // required - RenameTo *SchemaObjectIdentifier - SetComment *string - SetLogLevel *string - SetTraceLevel *string - UnsetComment *bool - SetTags []TagAssociation - UnsetTags []ObjectIdentifier - ExecuteAs *ExecuteAs + IfExists *bool + name SchemaObjectIdentifierWithArguments // required + RenameTo *SchemaObjectIdentifier + SetComment *string + SetLogLevel *string + SetTraceLevel *string + UnsetComment *bool + SetTags []TagAssociation + UnsetTags []ObjectIdentifier + ExecuteAs *ExecuteAs } type DropProcedureRequest struct { - IfExists *bool - name SchemaObjectIdentifier // required - ArgumentDataTypes []DataType // required + IfExists *bool + name SchemaObjectIdentifierWithArguments // required } type ShowProcedureRequest struct { @@ -170,8 +168,7 @@ type ShowProcedureRequest struct { } type DescribeProcedureRequest struct { - name SchemaObjectIdentifier // required - ArgumentDataTypes []DataType // required + name SchemaObjectIdentifierWithArguments // required } type CallProcedureRequest struct { diff --git a/pkg/sdk/procedures_gen.go b/pkg/sdk/procedures_gen.go index 21f488a27b..9e6acbf39f 100644 --- a/pkg/sdk/procedures_gen.go +++ b/pkg/sdk/procedures_gen.go @@ -14,8 +14,8 @@ type Procedures interface { Alter(ctx context.Context, request *AlterProcedureRequest) error Drop(ctx context.Context, request *DropProcedureRequest) error Show(ctx context.Context, request *ShowProcedureRequest) ([]Procedure, error) - ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Procedure, error) - Describe(ctx context.Context, request *DescribeProcedureRequest) ([]ProcedureDetail, error) + ShowByID(ctx context.Context, id SchemaObjectIdentifierWithArguments) (*Procedure, error) + Describe(ctx context.Context, id SchemaObjectIdentifierWithArguments) ([]ProcedureDetail, error) Call(ctx context.Context, request *CallProcedureRequest) error CreateAndCallForJava(ctx context.Context, request *CreateAndCallForJavaProcedureRequest) error CreateAndCallForScala(ctx context.Context, request *CreateAndCallForScalaProcedureRequest) error @@ -170,28 +170,26 @@ type ProcedureSQLReturns struct { // AlterProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/alter-procedure. type AlterProcedureOptions struct { - alter bool `ddl:"static" sql:"ALTER"` - procedure bool `ddl:"static" sql:"PROCEDURE"` - IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` - name SchemaObjectIdentifier `ddl:"identifier"` - ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` - RenameTo *SchemaObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` - SetComment *string `ddl:"parameter,single_quotes" sql:"SET COMMENT"` - SetLogLevel *string `ddl:"parameter,single_quotes" sql:"SET LOG_LEVEL"` - SetTraceLevel *string `ddl:"parameter,single_quotes" sql:"SET TRACE_LEVEL"` - UnsetComment *bool `ddl:"keyword" sql:"UNSET COMMENT"` - SetTags []TagAssociation `ddl:"keyword" sql:"SET TAG"` - UnsetTags []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` - ExecuteAs *ExecuteAs `ddl:"keyword"` + alter bool `ddl:"static" sql:"ALTER"` + procedure bool `ddl:"static" sql:"PROCEDURE"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name SchemaObjectIdentifierWithArguments `ddl:"identifier"` + RenameTo *SchemaObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` + SetComment *string `ddl:"parameter,single_quotes" sql:"SET COMMENT"` + SetLogLevel *string `ddl:"parameter,single_quotes" sql:"SET LOG_LEVEL"` + SetTraceLevel *string `ddl:"parameter,single_quotes" sql:"SET TRACE_LEVEL"` + UnsetComment *bool `ddl:"keyword" sql:"UNSET COMMENT"` + SetTags []TagAssociation `ddl:"keyword" sql:"SET TAG"` + UnsetTags []ObjectIdentifier `ddl:"keyword" sql:"UNSET TAG"` + ExecuteAs *ExecuteAs `ddl:"keyword"` } // DropProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/drop-procedure. type DropProcedureOptions struct { - drop bool `ddl:"static" sql:"DROP"` - procedure bool `ddl:"static" sql:"PROCEDURE"` - IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` - name SchemaObjectIdentifier `ddl:"identifier"` - ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` + drop bool `ddl:"static" sql:"DROP"` + procedure bool `ddl:"static" sql:"PROCEDURE"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name SchemaObjectIdentifierWithArguments `ddl:"identifier"` } // ShowProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/show-procedures. @@ -236,16 +234,16 @@ type Procedure struct { IsSecure bool } -func (v *Procedure) ID() SchemaObjectIdentifier { - return NewSchemaObjectIdentifier(v.CatalogName, v.SchemaName, v.Name) +func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { + //return NewSchemaObjectIdentifier(v.CatalogName, v.SchemaName, v.Name) + return NewSchemaObjectIdentifierWithArguments("", "", "", "") } // DescribeProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-procedure. type DescribeProcedureOptions struct { - describe bool `ddl:"static" sql:"DESCRIBE"` - procedure bool `ddl:"static" sql:"PROCEDURE"` - name SchemaObjectIdentifier `ddl:"identifier"` - ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` + describe bool `ddl:"static" sql:"DESCRIBE"` + procedure bool `ddl:"static" sql:"PROCEDURE"` + name SchemaObjectIdentifierWithArguments `ddl:"identifier"` } type procedureDetailRow struct { @@ -286,7 +284,6 @@ type CreateAndCallForJavaProcedureOptions struct { CallArguments []string `ddl:"keyword,must_parentheses"` ScriptingVariable *string `ddl:"parameter,no_quotes,no_equals" sql:"INTO"` } - type ProcedureWithClause struct { prefix bool `ddl:"static" sql:","` CteName AccountObjectIdentifier `ddl:"identifier"` diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 015dc6075b..6345324d0b 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -403,7 +403,8 @@ func TestProcedures_CreateForSQL(t *testing.T) { } func TestProcedures_Drop(t *testing.T) { - id := randomSchemaObjectIdentifier() + noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) defaultOpts := func() *DropProcedureOptions { return &DropProcedureOptions{ @@ -417,31 +418,31 @@ func TestProcedures_Drop(t *testing.T) { t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() - opts.name = emptySchemaObjectIdentifier + opts.name = emptySchemaObjectIdentifierWithArguments assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() - assertOptsValidAndSQLEquals(t, opts, `DROP PROCEDURE %s ()`, id.FullyQualifiedName()) + opts.name = noArgsId + assertOptsValidAndSQLEquals(t, opts, `DROP PROCEDURE %s`, noArgsId.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.IfExists = Bool(true) - opts.ArgumentDataTypes = []DataType{DataTypeVARCHAR, DataTypeNumber} - assertOptsValidAndSQLEquals(t, opts, `DROP PROCEDURE IF EXISTS %s (VARCHAR, NUMBER)`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `DROP PROCEDURE IF EXISTS %s`, id.FullyQualifiedName()) }) } func TestProcedures_Alter(t *testing.T) { - id := randomSchemaObjectIdentifier() + noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) defaultOpts := func() *AlterProcedureOptions { return &AlterProcedureOptions{ - name: id, - IfExists: Bool(true), - ArgumentDataTypes: []DataType{DataTypeVARCHAR, DataTypeNumber}, + name: id, + IfExists: Bool(true), } } @@ -452,7 +453,7 @@ func TestProcedures_Alter(t *testing.T) { t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() - opts.name = emptySchemaObjectIdentifier + opts.name = emptySchemaObjectIdentifierWithArguments assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) @@ -467,45 +468,45 @@ func TestProcedures_Alter(t *testing.T) { opts := defaultOpts() target := randomSchemaObjectIdentifierInSchema(id.SchemaId()) opts.RenameTo = &target - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) RENAME TO %s`, id.FullyQualifiedName(), opts.RenameTo.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s RENAME TO %s`, id.FullyQualifiedName(), opts.RenameTo.FullyQualifiedName()) }) t.Run("alter: execute as", func(t *testing.T) { opts := defaultOpts() executeAs := ExecuteAsCaller opts.ExecuteAs = &executeAs - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) EXECUTE AS CALLER`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s EXECUTE AS CALLER`, id.FullyQualifiedName()) }) t.Run("alter: set log level", func(t *testing.T) { opts := defaultOpts() opts.SetLogLevel = String("DEBUG") - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) SET LOG_LEVEL = 'DEBUG'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s SET LOG_LEVEL = 'DEBUG'`, id.FullyQualifiedName()) }) t.Run("alter: set log level with no arguments", func(t *testing.T) { opts := defaultOpts() - opts.ArgumentDataTypes = nil + opts.name = noArgsId opts.SetLogLevel = String("DEBUG") - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s () SET LOG_LEVEL = 'DEBUG'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s SET LOG_LEVEL = 'DEBUG'`, noArgsId.FullyQualifiedName()) }) t.Run("alter: set trace level", func(t *testing.T) { opts := defaultOpts() opts.SetTraceLevel = String("DEBUG") - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) SET TRACE_LEVEL = 'DEBUG'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s SET TRACE_LEVEL = 'DEBUG'`, id.FullyQualifiedName()) }) t.Run("alter: set comment", func(t *testing.T) { opts := defaultOpts() opts.SetComment = String("comment") - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) SET COMMENT = 'comment'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s SET COMMENT = 'comment'`, id.FullyQualifiedName()) }) t.Run("alter: unset comment", func(t *testing.T) { opts := defaultOpts() opts.UnsetComment = Bool(true) - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) UNSET COMMENT`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s UNSET COMMENT`, id.FullyQualifiedName()) }) t.Run("alter: set tags", func(t *testing.T) { @@ -516,7 +517,7 @@ func TestProcedures_Alter(t *testing.T) { Value: "value1", }, } - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) SET TAG "tag1" = 'value1'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s SET TAG "tag1" = 'value1'`, id.FullyQualifiedName()) }) t.Run("alter: unset tags", func(t *testing.T) { @@ -525,7 +526,7 @@ func TestProcedures_Alter(t *testing.T) { NewAccountObjectIdentifier("tag1"), NewAccountObjectIdentifier("tag2"), } - assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s (VARCHAR, NUMBER) UNSET TAG "tag1", "tag2"`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER PROCEDURE IF EXISTS %s UNSET TAG "tag1", "tag2"`, id.FullyQualifiedName()) }) } @@ -562,7 +563,8 @@ func TestProcedures_Show(t *testing.T) { } func TestProcedures_Describe(t *testing.T) { - id := randomSchemaObjectIdentifier() + noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) defaultOpts := func() *DescribeProcedureOptions { return &DescribeProcedureOptions{ @@ -577,19 +579,19 @@ func TestProcedures_Describe(t *testing.T) { t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() - opts.name = emptySchemaObjectIdentifier + opts.name = emptySchemaObjectIdentifierWithArguments assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() - assertOptsValidAndSQLEquals(t, opts, `DESCRIBE PROCEDURE %s ()`, id.FullyQualifiedName()) + opts.name = noArgsId + assertOptsValidAndSQLEquals(t, opts, `DESCRIBE PROCEDURE %s`, noArgsId.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { opts := defaultOpts() - opts.ArgumentDataTypes = []DataType{DataTypeVARCHAR, DataTypeNumber} - assertOptsValidAndSQLEquals(t, opts, `DESCRIBE PROCEDURE %s (VARCHAR, NUMBER)`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `DESCRIBE PROCEDURE %s`, id.FullyQualifiedName()) }) } @@ -615,6 +617,7 @@ func TestProcedures_Call(t *testing.T) { t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() + opts.name = id assertOptsValidAndSQLEquals(t, opts, `CALL %s ()`, id.FullyQualifiedName()) }) diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index c665829e7e..24bae558d5 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -57,17 +57,19 @@ func (v *procedures) Show(ctx context.Context, request *ShowProcedureRequest) ([ return resultList, nil } -func (v *procedures) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Procedure, error) { - request := NewShowProcedureRequest().WithIn(&In{Schema: id.SchemaId()}).WithLike(&Like{String(id.Name())}) - procedures, err := v.Show(ctx, request) +func (v *procedures) ShowByID(ctx context.Context, id SchemaObjectIdentifierWithArguments) (*Procedure, error) { + // TODO: adjust request if e.g. LIKE is supported for the resource + procedures, err := v.Show(ctx, NewShowProcedureRequest()) if err != nil { return nil, err } return collections.FindOne(procedures, func(r Procedure) bool { return r.Name == id.Name() }) } -func (v *procedures) Describe(ctx context.Context, request *DescribeProcedureRequest) ([]ProcedureDetail, error) { - opts := request.toOpts() +func (v *procedures) Describe(ctx context.Context, id SchemaObjectIdentifierWithArguments) ([]ProcedureDetail, error) { + opts := &DescribeProcedureOptions{ + name: id, + } rows, err := validateAndQuery[procedureDetailRow](v.client, ctx, opts) if err != nil { return nil, err @@ -352,26 +354,24 @@ func (r *CreateForSQLProcedureRequest) toOpts() *CreateForSQLProcedureOptions { func (r *AlterProcedureRequest) toOpts() *AlterProcedureOptions { opts := &AlterProcedureOptions{ - IfExists: r.IfExists, - name: r.name, - ArgumentDataTypes: r.ArgumentDataTypes, - RenameTo: r.RenameTo, - SetComment: r.SetComment, - SetLogLevel: r.SetLogLevel, - SetTraceLevel: r.SetTraceLevel, - UnsetComment: r.UnsetComment, - SetTags: r.SetTags, - UnsetTags: r.UnsetTags, - ExecuteAs: r.ExecuteAs, + IfExists: r.IfExists, + name: r.name, + RenameTo: r.RenameTo, + SetComment: r.SetComment, + SetLogLevel: r.SetLogLevel, + SetTraceLevel: r.SetTraceLevel, + UnsetComment: r.UnsetComment, + SetTags: r.SetTags, + UnsetTags: r.UnsetTags, + ExecuteAs: r.ExecuteAs, } return opts } func (r *DropProcedureRequest) toOpts() *DropProcedureOptions { opts := &DropProcedureOptions{ - IfExists: r.IfExists, - name: r.name, - ArgumentDataTypes: r.ArgumentDataTypes, + IfExists: r.IfExists, + name: r.name, } return opts } @@ -408,8 +408,7 @@ func (r procedureRow) convert() *Procedure { func (r *DescribeProcedureRequest) toOpts() *DescribeProcedureOptions { opts := &DescribeProcedureOptions{ - name: r.name, - ArgumentDataTypes: r.ArgumentDataTypes, + name: r.name, } return opts } @@ -427,7 +426,6 @@ func (r procedureDetailRow) convert() *ProcedureDetail { func (r *CallProcedureRequest) toOpts() *CallProcedureOptions { opts := &CallProcedureOptions{ name: r.name, - CallArguments: r.CallArguments, ScriptingVariable: r.ScriptingVariable, } return opts diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index ea4df83a82..70b8a7b9b4 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -18,9 +18,9 @@ func TestInt_CreateProcedures(t *testing.T) { client := testClient(t) ctx := testContext(t) - cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifier, ats []sdk.DataType) func() { + cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id, ats)) + err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -31,7 +31,7 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Java: returns result data type", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#reading-a-dynamically-specified-file-with-inputstream name := "file_reader_java_proc_snowflakefile" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR) definition := ` import java.io.InputStream; @@ -47,16 +47,16 @@ func TestInt_CreateProcedures(t *testing.T) { }` dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureReturnsRequest().WithResultDataType(dt) + returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) argument := sdk.NewProcedureArgumentRequest("input", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} - request := sdk.NewCreateForJavaProcedureRequest(id, *returns, "11", packages, "FileReader.execute"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "FileReader.execute"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForJava(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -66,7 +66,7 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Java: returns table", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#specifying-return-column-names-and-types name := "filter_by_role" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) definition := ` import com.snowflake.snowpark_java.*; @@ -81,17 +81,17 @@ func TestInt_CreateProcedures(t *testing.T) { column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} - request := sdk.NewCreateForJavaProcedureRequest(id, *returns, "11", packages, "Filter.filterByRole"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "Filter.filterByRole"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForJava(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -101,7 +101,7 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Javascript", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-javascript#basic-examples name := "stproc1" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeFloat) definition := ` var sql_command = "INSERT INTO stproc_test_table1 (num_col1) VALUES (" + FLOAT_PARAM1 + ")"; @@ -115,13 +115,13 @@ func TestInt_CreateProcedures(t *testing.T) { return "Failed: " + err; // Return a success/error indicator. }` argument := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateForJavaScriptProcedureRequest(id, sdk.DataTypeString, definition). + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeString, definition). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). - WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) + WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). + WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -131,13 +131,13 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Javascript: no arguments", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-javascript#basic-examples name := "sp_pi" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name) definition := `return 3.1415926;` - request := sdk.NewCreateForJavaScriptProcedureRequest(id, sdk.DataTypeFloat, definition).WithNotNull(sdk.Bool(true)).WithOrReplace(sdk.Bool(true)) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeFloat, definition).WithNotNull(true).WithOrReplace(true) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, nil)) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -147,7 +147,7 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Scala: returns result data type", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-scala#reading-a-dynamically-specified-file-with-snowflakefile name := "file_reader_scala_proc_snowflakefile" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR) definition := ` import java.io.InputStream @@ -161,16 +161,16 @@ func TestInt_CreateProcedures(t *testing.T) { } }` dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureReturnsRequest().WithResultDataType(dt) + returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) argument := sdk.NewProcedureArgumentRequest("input", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} - request := sdk.NewCreateForScalaProcedureRequest(id, *returns, "2.12", packages, "FileReader.execute"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "FileReader.execute"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForScala(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -180,7 +180,7 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Scala: returns table", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-scala#specifying-return-column-names-and-types name := "filter_by_role" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) definition := ` import com.snowflake.snowpark.functions._ @@ -196,17 +196,17 @@ func TestInt_CreateProcedures(t *testing.T) { column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} - request := sdk.NewCreateForScalaProcedureRequest(id, *returns, "2.12", packages, "Filter.filterByRole"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "Filter.filterByRole"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForScala(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestInt_CreateProcedures(t *testing.T) { t.Run("create procedure for Python: returns result data type", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-python#running-concurrent-tasks-with-worker-processes name := "joblib_multiprocessing_proc" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, "INT") definition := ` import joblib @@ -226,19 +226,19 @@ def joblib_multiprocessing(session, i): return str(result)` dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeString) - returns := sdk.NewProcedureReturnsRequest().WithResultDataType(dt) + returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) argument := sdk.NewProcedureArgumentRequest("i", "INT") packages := []sdk.ProcedurePackageRequest{ *sdk.NewProcedurePackageRequest("snowflake-snowpark-python"), *sdk.NewProcedurePackageRequest("joblib"), } - request := sdk.NewCreateForPythonProcedureRequest(id, *returns, "3.8", packages, "joblib_multiprocessing"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "joblib_multiprocessing"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForPython(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{"INT"})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -248,7 +248,7 @@ def joblib_multiprocessing(session, i): t.Run("create procedure for Python: returns table", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-python#specifying-return-column-names-and-types name := "filterByRole" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) definition := ` from snowflake.snowpark.functions import col @@ -259,17 +259,17 @@ def filter_by_role(session, table_name, role): column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} - request := sdk.NewCreateForPythonProcedureRequest(id, *returns, "3.8", packages, "filter_by_role"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "filter_by_role"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForPython(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -279,7 +279,7 @@ def filter_by_role(session, table_name, role): t.Run("create procedure for SQL: returns result data type", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-snowflake-scripting name := "output_message" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR) definition := ` BEGIN @@ -287,10 +287,10 @@ def filter_by_role(session, table_name, role): END;` dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(dt).WithNotNull(sdk.Bool(true)) + returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForSQLProcedureRequest(id, *returns, definition). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). + WithOrReplace(true). // Suddenly this is erroring out, when it used to not have an problem. Must be an error with the Snowflake API. // Created issue in docs-discuss channel. https://snowflake.slack.com/archives/C6380540P/p1707511734666249 // Error: Received unexpected error: @@ -301,7 +301,7 @@ def filter_by_role(session, table_name, role): WithArguments([]sdk.ProcedureArgumentRequest{*argument}) err := client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -310,7 +310,7 @@ def filter_by_role(session, table_name, role): t.Run("create procedure for SQL: returns table", func(t *testing.T) { name := "find_invoice_by_id" - id := testClientHelper().Ids.NewSchemaObjectIdentifier(name) + id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR) definition := ` DECLARE @@ -321,16 +321,16 @@ def filter_by_role(session, table_name, role): column1 := sdk.NewProcedureColumnRequest("id", "INTEGER") column2 := sdk.NewProcedureColumnRequest("price", "NUMBER(12,2)") returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2}) - returns := sdk.NewProcedureSQLReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable) argument := sdk.NewProcedureArgumentRequest("id", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForSQLProcedureRequest(id, *returns, definition). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). + WithOrReplace(true). // SNOW-1051627 todo: uncomment once null input behavior working again // WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorReturnNullInput)). WithArguments([]sdk.ProcedureArgumentRequest{*argument}) err := client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) require.NoError(t, err) @@ -345,7 +345,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { tagTest, tagCleanup := testClientHelper().Tag.CreateTag(t) t.Cleanup(tagCleanup) - assertProcedure := func(t *testing.T, id sdk.SchemaObjectIdentifier, secure bool) { + assertProcedure := func(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments, secure bool) { t.Helper() procedure, err := client.Procedures.ShowByID(ctx, id) @@ -366,9 +366,9 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { assert.Equal(t, secure, procedure.IsSecure) } - cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifier, ats []sdk.DataType) func() { + cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id, ats)) + err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -383,27 +383,27 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { BEGIN RETURN message; END;` - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(dt).WithNotNull(sdk.Bool(true)) + returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForSQLProcedureRequest(id, *returns, definition). - WithSecure(sdk.Bool(true)). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). + WithSecure(true). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) + WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) err := client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) if cleanup { - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) } procedure, err := client.Procedures.ShowByID(ctx, id) require.NoError(t, err) return procedure } - defaultAlterRequest := func(id sdk.SchemaObjectIdentifier) *sdk.AlterProcedureRequest { - return sdk.NewAlterProcedureRequest(id, []sdk.DataType{sdk.DataTypeVARCHAR}) + defaultAlterRequest := func(id sdk.SchemaObjectIdentifierWithArguments) *sdk.AlterProcedureRequest { + return sdk.NewAlterProcedureRequest(id) } t.Run("alter procedure: rename", func(t *testing.T) { @@ -411,18 +411,20 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { id := f.ID() nid := testClientHelper().Ids.RandomSchemaObjectIdentifier() - err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithRenameTo(&nid)) + nidWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(nid.DatabaseName(), nid.SchemaName(), nid.Name(), id.ArgumentDataTypes()...) + + err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithRenameTo(nid)) if err != nil { - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) } else { - t.Cleanup(cleanupProcedureHandle(nid, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(nidWithArguments)) } require.NoError(t, err) _, err = client.Procedures.ShowByID(ctx, id) assert.ErrorIs(t, err, collections.ErrObjectNotFound) - e, err := client.Procedures.ShowByID(ctx, nid) + e, err := client.Procedures.ShowByID(ctx, nidWithArguments) require.NoError(t, err) require.Equal(t, nid.Name(), e.Name) }) @@ -431,7 +433,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() - err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithSetLogLevel(sdk.String("DEBUG"))) + err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithSetLogLevel("DEBUG")) require.NoError(t, err) assertProcedure(t, id, true) }) @@ -440,7 +442,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() - err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithSetTraceLevel(sdk.String("ALWAYS"))) + err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithSetTraceLevel("ALWAYS")) require.NoError(t, err) assertProcedure(t, id, true) }) @@ -449,7 +451,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() - err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithSetComment(sdk.String("comment"))) + err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithSetComment("comment")) require.NoError(t, err) assertProcedure(t, id, true) }) @@ -458,7 +460,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() - err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithUnsetComment(sdk.Bool(true))) + err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithUnsetComment(true)) require.NoError(t, err) assertProcedure(t, id, true) }) @@ -467,7 +469,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() - err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsOwner))) + err := client.Procedures.Alter(ctx, defaultAlterRequest(id).WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsOwner))) require.NoError(t, err) assertProcedure(t, id, true) }) @@ -510,7 +512,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f1 := createProcedureForSQLHandle(t, true) f2 := createProcedureForSQLHandle(t, true) - procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest().WithLike(&sdk.Like{Pattern: &f1.Name})) + procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest().WithLike(sdk.Like{Pattern: &f1.Name})) require.NoError(t, err) require.Equal(t, 1, len(procedures)) @@ -519,7 +521,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { }) t.Run("show procedure for SQL: no matches", func(t *testing.T) { - procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest().WithLike(&sdk.Like{Pattern: sdk.String("non-existing-id-pattern")})) + procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest().WithLike(sdk.Like{Pattern: sdk.String("non-existing-id-pattern")})) require.NoError(t, err) require.Equal(t, 0, len(procedures)) }) @@ -528,8 +530,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() - request := sdk.NewDescribeProcedureRequest(id, []sdk.DataType{sdk.DataTypeString}) - details, err := client.Procedures.Describe(ctx, request) + details, err := client.Procedures.Describe(ctx, id) require.NoError(t, err) pairs := make(map[string]string) for _, detail := range details { @@ -546,18 +547,18 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { BEGIN RETURN message; END;` - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(dt).WithNotNull(sdk.Bool(true)) + returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForSQLProcedureRequest(id, *returns, definition). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) + WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) err := client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) - err = client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + err = client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)) require.NoError(t, err) }) } @@ -567,9 +568,9 @@ func TestInt_CallProcedure(t *testing.T) { ctx := testContext(t) databaseTest, schemaTest := testDb(t), testSchema(t) - cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifier, ats []sdk.DataType) func() { + cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id, ats)) + err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -601,19 +602,19 @@ func TestInt_CallProcedure(t *testing.T) { BEGIN RETURN message; END;` - id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(dt).WithNotNull(sdk.Bool(true)) + returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForSQLProcedureRequest(id, *returns, definition). - WithSecure(sdk.Bool(true)). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). + WithSecure(true). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) + WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) err := client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) if cleanup { - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) } procedure, err := client.Procedures.ShowByID(ctx, id) require.NoError(t, err) @@ -622,20 +623,20 @@ func TestInt_CallProcedure(t *testing.T) { t.Run("call procedure for SQL: argument positions", func(t *testing.T) { f := createProcedureForSQLHandle(t, true) - err := client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(f.ID()).WithCallArguments([]string{"'hi'"})) + err := client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(f.ID().SchemaObjectId()).WithCallArguments([]string{"'hi'"})) require.NoError(t, err) }) t.Run("call procedure for SQL: argument names", func(t *testing.T) { f := createProcedureForSQLHandle(t, true) - err := client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(f.ID()).WithCallArguments([]string{"message => 'hi'"})) + err := client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(f.ID().SchemaObjectId()).WithCallArguments([]string{"message => 'hi'"})) require.NoError(t, err) }) t.Run("call procedure for Java: returns table", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#omitting-return-column-names-and-types name := "filter_by_role" - id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, name, sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) definition := ` import com.snowflake.snowpark_java.*; @@ -650,27 +651,27 @@ func TestInt_CallProcedure(t *testing.T) { column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} - request := sdk.NewCreateForJavaProcedureRequest(id, *returns, "11", packages, "Filter.filterByRole"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "Filter.filterByRole"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForJava(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} - err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id).WithCallArguments(ca)) + err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id.SchemaObjectId()).WithCallArguments(ca)) require.NoError(t, err) }) t.Run("call procedure for Scala: returns table", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-scala#omitting-return-column-names-and-types name := "filter_by_role" - id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, name, sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) definition := ` import com.snowflake.snowpark.functions._ @@ -684,27 +685,27 @@ func TestInt_CallProcedure(t *testing.T) { } }` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} - request := sdk.NewCreateForScalaProcedureRequest(id, *returns, "2.12", packages, "Filter.filterByRole"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "Filter.filterByRole"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForScala(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} - err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id).WithCallArguments(ca)) + err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id.SchemaObjectId()).WithCallArguments(ca)) require.NoError(t, err) }) t.Run("call procedure for Javascript", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-javascript#basic-examples name := "stproc1" - id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, name, sdk.DataTypeFloat) definition := ` var sql_command = "INSERT INTO stproc_test_table1 (num_col1) VALUES (" + FLOAT_PARAM1 + ")"; @@ -718,37 +719,37 @@ func TestInt_CallProcedure(t *testing.T) { return "Failed: " + err; // Return a success/error indicator. }` arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateForJavaScriptProcedureRequest(id, sdk.DataTypeString, definition). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeString, definition). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg}). - WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). - WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsOwner)) + WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). + WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsOwner)) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeFloat})) + t.Cleanup(cleanupProcedureHandle(id)) - err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id).WithCallArguments([]string{"5.14::FLOAT"})) + err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id.SchemaObjectId()).WithCallArguments([]string{"5.14::FLOAT"})) require.NoError(t, err) }) t.Run("call procedure for Javascript: no arguments", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-javascript#basic-examples name := "sp_pi" - id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, name) definition := `return 3.1415926;` - request := sdk.NewCreateForJavaScriptProcedureRequest(id, sdk.DataTypeFloat, definition).WithNotNull(sdk.Bool(true)).WithOrReplace(sdk.Bool(true)) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeFloat, definition).WithNotNull(true).WithOrReplace(true) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, nil)) + t.Cleanup(cleanupProcedureHandle(id)) - err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id)) + err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id.SchemaObjectId())) require.NoError(t, err) }) t.Run("call procedure for Python: returns table", func(t *testing.T) { // https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-python#omitting-return-column-names-and-types - id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, "filterByRole") + id := sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, "filterByRole", sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR) definition := ` from snowflake.snowpark.functions import col @@ -756,21 +757,21 @@ def filter_by_role(session, name, role): df = session.table(name) return df.filter(col("role") == role)` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} - request := sdk.NewCreateForPythonProcedureRequest(id, *returns, "3.8", packages, "filter_by_role"). - WithOrReplace(sdk.Bool(true)). + request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "filter_by_role"). + WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)) + WithProcedureDefinition(definition) err := client.Procedures.CreateForPython(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR, sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) - id = sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, "filterByRole") + id = sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, "filterByRole", sdk.DataTypeVARCHAR) ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} - err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id).WithCallArguments(ca)) + err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id.SchemaObjectId()).WithCallArguments(ca)) require.NoError(t, err) }) } @@ -815,14 +816,14 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForJavaProcedureRequest(name, *returns, "11", packages, "Filter.filterByRole", name). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)). + WithProcedureDefinition(definition). WithCallArguments(ca) err := client.Procedures.CreateAndCallForJava(ctx, request) require.NoError(t, err) @@ -848,14 +849,14 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForScalaProcedureRequest(name, *returns, "2.12", packages, "Filter.filterByRole", name). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)). + WithProcedureDefinition(definition). WithCallArguments(ca) err := client.Procedures.CreateAndCallForScala(ctx, request) require.NoError(t, err) @@ -880,7 +881,7 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeString, definition, name). WithArguments([]sdk.ProcedureArgumentRequest{*arg}). - WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). + WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). WithCallArguments([]string{"5.14::FLOAT"}) err := client.Procedures.CreateAndCallForJavaScript(ctx, request) require.NoError(t, err) @@ -892,7 +893,7 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { name := sdk.NewAccountObjectIdentifier("sp_pi") definition := `return 3.1415926;` - request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeFloat, definition, name).WithNotNull(sdk.Bool(true)) + request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeFloat, definition, name).WithNotNull(true) err := client.Procedures.CreateAndCallForJavaScript(ctx, request) require.NoError(t, err) }) @@ -905,7 +906,7 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { name := testClientHelper().Ids.RandomAccountObjectIdentifier() dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureReturnsRequest().WithResultDataType(dt) + returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) request := sdk.NewCreateAndCallForSQLProcedureRequest(name, *returns, definition, name). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). @@ -924,14 +925,14 @@ def filter_by_role(session, name, role): df = session.table(name) return df.filter(col("role") == role)` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForPythonProcedureRequest(name, *returns, "3.8", packages, "filter_by_role", name). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)). + WithProcedureDefinition(definition). WithCallArguments(ca) err := client.Procedures.CreateAndCallForPython(ctx, request) require.NoError(t, err) @@ -954,7 +955,7 @@ def filter_by_role(session, name, role): column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) - returns := sdk.NewProcedureReturnsRequest().WithTable(returnsTable) + returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} @@ -966,8 +967,8 @@ def filter_by_role(session, name, role): clause := sdk.NewProcedureWithClauseRequest(cte, statement).WithCteColumns([]string{"name", "role"}) request := sdk.NewCreateAndCallForJavaProcedureRequest(name, *returns, "11", packages, "Filter.filterByRole", name). WithArguments([]sdk.ProcedureArgumentRequest{*arg1, *arg2}). - WithProcedureDefinition(sdk.String(definition)). - WithWithClause(clause). + WithProcedureDefinition(definition). + WithWithClause(*clause). WithCallArguments(ca) err := client.Procedures.CreateAndCallForJava(ctx, request) require.NoError(t, err) @@ -978,9 +979,9 @@ func TestInt_ProceduresShowByID(t *testing.T) { client := testClient(t) ctx := testContext(t) - cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { + cleanupProcedureHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id, dts)) + err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)) if errors.Is(err, sdk.ErrObjectNotExistOrAuthorized) { return } @@ -988,7 +989,7 @@ func TestInt_ProceduresShowByID(t *testing.T) { } } - createProcedureForSQLHandle := func(t *testing.T, id sdk.SchemaObjectIdentifier) { + createProcedureForSQLHandle := func(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) { t.Helper() definition := ` @@ -996,22 +997,22 @@ func TestInt_ProceduresShowByID(t *testing.T) { RETURN message; END;` dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) - returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(dt).WithNotNull(sdk.Bool(true)) + returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForSQLProcedureRequest(id, *returns, definition). + request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). - WithExecuteAs(sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) + WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) err := client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupProcedureHandle(id, []sdk.DataType{sdk.DataTypeVARCHAR})) + t.Cleanup(cleanupProcedureHandle(id)) } t.Run("show by id - same name in different schemas", func(t *testing.T) { schema, schemaCleanup := testClientHelper().Schema.CreateSchema(t) t.Cleanup(schemaCleanup) - id1 := testClientHelper().Ids.RandomSchemaObjectIdentifier() - id2 := testClientHelper().Ids.NewSchemaObjectIdentifierInSchema(id1.Name(), schema.ID()) + id1 := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() + id2 := testClientHelper().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(id1.Name(), schema.ID()) createProcedureForSQLHandle(t, id1) createProcedureForSQLHandle(t, id2) From 14067c57f4df27cca3462f5d39b6e29062bbc575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 8 Aug 2024 13:30:05 +0200 Subject: [PATCH 02/19] wip --- pkg/acceptance/check_destroy.go | 6 +- pkg/datasources/procedures.go | 2 +- pkg/resources/function_acceptance_test.go | 170 +- pkg/resources/function_state_upgraders.go | 2 + pkg/resources/procedure.go | 1563 +++++++++-------- pkg/resources/procedure_acceptance_test.go | 11 +- pkg/resources/procedure_state_upgraders.go | 2 + pkg/sdk/functions_impl_gen.go | 5 +- pkg/sdk/identifier_helpers.go | 22 +- pkg/sdk/identifier_helpers_test.go | 40 - pkg/sdk/poc/main.go | 1 - pkg/sdk/procedures_gen.go | 6 +- pkg/sdk/procedures_impl_gen.go | 19 +- .../testint/procedures_integration_test.go | 6 +- 14 files changed, 942 insertions(+), 913 deletions(-) diff --git a/pkg/acceptance/check_destroy.go b/pkg/acceptance/check_destroy.go index 6cd049eed0..617c90e5d2 100644 --- a/pkg/acceptance/check_destroy.go +++ b/pkg/acceptance/check_destroy.go @@ -118,9 +118,9 @@ var showByIdFunctions = map[resources.Resource]showByIdFunc{ resources.FileFormat: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { return runShowById(ctx, id, client.FileFormats.ShowByID) }, - //resources.Function: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { - // return runShowById(ctx, id, client.Functions.ShowByID) - //}, + resources.Function: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { + return runShowById(ctx, id, client.Functions.ShowByID) + }, resources.ManagedAccount: func(ctx context.Context, client *sdk.Client, id sdk.ObjectIdentifier) error { return runShowById(ctx, id, client.ManagedAccounts.ShowByID) }, diff --git a/pkg/datasources/procedures.go b/pkg/datasources/procedures.go index c4aaed7c69..28a1ddc563 100644 --- a/pkg/datasources/procedures.go +++ b/pkg/datasources/procedures.go @@ -103,7 +103,7 @@ func ReadContextProcedures(ctx context.Context, d *schema.ResourceData, meta int procedureMap["database"] = procedure.CatalogName procedureMap["schema"] = procedure.SchemaName procedureMap["comment"] = procedure.Description - procedureSignatureMap, err := parseArguments(procedure.Arguments) + procedureSignatureMap, err := parseArguments(procedure.ArgumentsRaw) if err != nil { return diag.FromErr(err) } diff --git a/pkg/resources/function_acceptance_test.go b/pkg/resources/function_acceptance_test.go index b4e99be301..104cc98e67 100644 --- a/pkg/resources/function_acceptance_test.go +++ b/pkg/resources/function_acceptance_test.go @@ -236,40 +236,6 @@ func TestAcc_Function_migrateFromVersion085(t *testing.T) { }) } -func TestAcc_Function_EnsureSmoothResourceIdMigrationToV0950(t *testing.T) { - name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() - resourceName := "snowflake_function.f" - - resource.Test(t, resource.TestCase{ - PreCheck: func() { acc.TestAccPreCheck(t) }, - TerraformVersionChecks: []tfversion.TerraformVersionCheck{ - tfversion.RequireAbove(tfversion.Version1_5_0), - }, - CheckDestroy: acc.CheckDestroy(t, resources.Function), - Steps: []resource.TestStep{ - { - ExternalProviders: map[string]resource.ExternalProvider{ - "snowflake": { - VersionConstraint: "=0.94.1", - Source: "Snowflake-Labs/snowflake", - }, - }, - Config: functionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), - ), - }, - { - ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, - Config: functionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), - Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), - ), - }, - }, - }) -} - func TestAcc_Function_Rename(t *testing.T) { name := acc.TestClient().Ids.Alpha() newName := acc.TestClient().Ids.Alpha() @@ -311,32 +277,6 @@ func TestAcc_Function_Rename(t *testing.T) { }) } -func functionConfigWithMoreArguments(database string, schema string, name string) string { - return fmt.Sprintf(` -resource "snowflake_function" "f" { - database = "%[1]s" - schema = "%[2]s" - name = "%[3]s" - return_type = "VARCHAR" - return_behavior = "IMMUTABLE" - statement = "SELECT A" - - arguments { - name = "A" - type = "VARCHAR" - } - arguments { - name = "B" - type = "FLOAT" - } - arguments { - name = "C" - type = "NUMBER" - } -} -`, database, schema, name) -} - func functionConfig(database string, schema string, name string, comment string) string { return fmt.Sprintf(` resource "snowflake_function" "f" { @@ -395,3 +335,113 @@ resource "snowflake_function" "f" { } `, database, schema, name) } + +// TODO: test new state upgrader + +func TestAcc_Function_EnsureSmoothResourceIdMigrationToV0950(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_function.f" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.Function), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: functionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: functionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + +func functionConfigWithMoreArguments(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_function" "f" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + return_type = "VARCHAR" + return_behavior = "IMMUTABLE" + statement = "SELECT A" + + arguments { + name = "A" + type = "VARCHAR" + } + arguments { + name = "B" + type = "FLOAT" + } + arguments { + name = "C" + type = "NUMBER" + } +} +`, database, schema, name) +} + +func TestAcc_Function_EnsureSmoothResourceIdMigrationToV0950_WithoutArguments(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_function.f" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.Function), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: functionConfigWithoutArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + // TODO: Fails + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: functionConfigWithoutArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"()`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + +func functionConfigWithoutArguments(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_function" "f" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + return_type = "VARCHAR" + return_behavior = "IMMUTABLE" + statement = "SELECT 'abc'" +} +`, database, schema, name) +} diff --git a/pkg/resources/function_state_upgraders.go b/pkg/resources/function_state_upgraders.go index 501e44f1dc..bae81663a1 100644 --- a/pkg/resources/function_state_upgraders.go +++ b/pkg/resources/function_state_upgraders.go @@ -60,3 +60,5 @@ func v085FunctionIdStateUpgrader(ctx context.Context, rawState map[string]interf return rawState, nil } + +// TODO: state upgrader for empty args (without '()') diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index c9794742b6..69d2cd9b8f 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -1,785 +1,792 @@ package resources -import "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - -// var procedureSchema = map[string]*schema.Schema{ -// "name": { -// Type: schema.TypeString, -// Required: true, -// Description: "Specifies the identifier for the procedure; does not have to be unique for the schema in which the procedure is created. Don't use the | character.", -// }, -// "database": { -// Type: schema.TypeString, -// Required: true, -// Description: "The database in which to create the procedure. Don't use the | character.", -// ForceNew: true, -// }, -// "schema": { -// Type: schema.TypeString, -// Required: true, -// Description: "The schema in which to create the procedure. Don't use the | character.", -// ForceNew: true, -// }, -// "secure": { -// Type: schema.TypeBool, -// Optional: true, -// Description: "Specifies that the procedure is secure. For more information about secure procedures, see Protecting Sensitive Information with Secure UDFs and Stored Procedures.", -// Default: false, -// }, -// "arguments": { -// Type: schema.TypeList, -// Elem: &schema.Resource{ -// Schema: map[string]*schema.Schema{ -// "name": { -// Type: schema.TypeString, -// Required: true, -// // Suppress the diff shown if the values are equal when both compared in lower case. -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(old, new) -// }, -// Description: "The argument name", -// }, -// "type": { -// Type: schema.TypeString, -// Required: true, -// ValidateFunc: dataTypeValidateFunc, -// DiffSuppressFunc: dataTypeDiffSuppressFunc, -// Description: "The argument type", -// }, -// }, -// }, -// Optional: true, -// Description: "List of the arguments for the procedure", -// ForceNew: true, -// }, -// "return_type": { -// Type: schema.TypeString, -// Description: "The return type of the procedure", -// // Suppress the diff shown if the values are equal when both compared in lower case. -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// if strings.EqualFold(old, new) { -// return true -// } -// -// varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} -// if slices.Contains(varcharType, strings.ToUpper(old)) && slices.Contains(varcharType, strings.ToUpper(new)) { -// return true -// } -// -// // all these types are equivalent https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint -// integerTypes := []string{"INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT", "NUMBER(38,0)"} -// if slices.Contains(integerTypes, strings.ToUpper(old)) && slices.Contains(integerTypes, strings.ToUpper(new)) { -// return true -// } -// return false -// }, -// Required: true, -// ForceNew: true, -// }, -// "statement": { -// Type: schema.TypeString, -// Required: true, -// Description: "Specifies the code used to create the procedure.", -// ForceNew: true, -// DiffSuppressFunc: DiffSuppressStatement, -// }, -// "language": { -// Type: schema.TypeString, -// Optional: true, -// Default: "SQL", -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(old, new) -// }, -// ValidateFunc: validation.StringInSlice([]string{"javascript", "java", "scala", "SQL", "python"}, true), -// Description: "Specifies the language of the stored procedure code.", -// }, -// "execute_as": { -// Type: schema.TypeString, -// Optional: true, -// Default: "OWNER", -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(old, new) -// }, -// ValidateFunc: validation.StringInSlice([]string{"CALLER", "OWNER"}, true), -// Description: "Sets execution context. Allowed values are CALLER and OWNER (consult a proper section in the [docs](https://docs.snowflake.com/en/sql-reference/sql/create-procedure#id1)). For more information see [caller's rights and owner's rights](https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-rights).", -// }, -// "null_input_behavior": { -// Type: schema.TypeString, -// Optional: true, -// Default: "CALLED ON NULL INPUT", -// ForceNew: true, -// // We do not use STRICT, because Snowflake then in the Read phase returns RETURNS NULL ON NULL INPUT -// ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT"}, false), -// Description: "Specifies the behavior of the procedure when called with null inputs.", -// }, -// "return_behavior": { -// Type: schema.TypeString, -// Optional: true, -// Default: "VOLATILE", -// ForceNew: true, -// ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), -// Description: "Specifies the behavior of the function when returning results", -// Deprecated: "These keywords are deprecated for stored procedures. These keywords are not intended to apply to stored procedures. In a future release, these keywords will be removed from the documentation.", -// }, -// "comment": { -// Type: schema.TypeString, -// Optional: true, -// Default: "user-defined procedure", -// Description: "Specifies a comment for the procedure.", -// }, -// "runtime_version": { -// Type: schema.TypeString, -// Optional: true, -// ForceNew: true, -// Description: "Required for Python procedures. Specifies Python runtime version.", -// }, -// "packages": { -// Type: schema.TypeList, -// Elem: &schema.Schema{ -// Type: schema.TypeString, -// }, -// Optional: true, -// ForceNew: true, -// Description: "List of package imports to use for Java / Python procedures. For Java, package imports should be of the form: package_name:version_number, where package_name is snowflake_domain:package. For Python use it should be: ('numpy','pandas','xgboost==1.5.0').", -// }, -// "imports": { -// Type: schema.TypeList, -// Elem: &schema.Schema{ -// Type: schema.TypeString, -// }, -// Optional: true, -// ForceNew: true, -// Description: "Imports for Java / Python procedures. For Java this a list of jar files, for Python this is a list of Python files.", -// }, -// "handler": { -// Type: schema.TypeString, -// Optional: true, -// ForceNew: true, -// Description: "The handler method for Java / Python procedures.", -// }, -// } -// -// // Procedure returns a pointer to the resource representing a stored procedure. +import ( + "context" + "fmt" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/go-cty/cty" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" + "log" + "regexp" + "slices" + "strings" +) + +var procedureSchema = map[string]*schema.Schema{ + "name": { + Type: schema.TypeString, + Required: true, + Description: "Specifies the identifier for the procedure; does not have to be unique for the schema in which the procedure is created. Don't use the | character.", + }, + "database": { + Type: schema.TypeString, + Required: true, + Description: "The database in which to create the procedure. Don't use the | character.", + ForceNew: true, + }, + "schema": { + Type: schema.TypeString, + Required: true, + Description: "The schema in which to create the procedure. Don't use the | character.", + ForceNew: true, + }, + "secure": { + Type: schema.TypeBool, + Optional: true, + Description: "Specifies that the procedure is secure. For more information about secure procedures, see Protecting Sensitive Information with Secure UDFs and Stored Procedures.", + Default: false, + }, + "arguments": { + Type: schema.TypeList, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "name": { + Type: schema.TypeString, + Required: true, + // Suppress the diff shown if the values are equal when both compared in lower case. + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(old, new) + }, + Description: "The argument name", + }, + "type": { + Type: schema.TypeString, + Required: true, + ValidateFunc: dataTypeValidateFunc, + DiffSuppressFunc: dataTypeDiffSuppressFunc, + Description: "The argument type", + }, + }, + }, + Optional: true, + Description: "List of the arguments for the procedure", + ForceNew: true, + }, + "return_type": { + Type: schema.TypeString, + Description: "The return type of the procedure", + // Suppress the diff shown if the values are equal when both compared in lower case. + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + if strings.EqualFold(old, new) { + return true + } + + varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} + if slices.Contains(varcharType, strings.ToUpper(old)) && slices.Contains(varcharType, strings.ToUpper(new)) { + return true + } + + // all these types are equivalent https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint + integerTypes := []string{"INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT", "NUMBER(38,0)"} + if slices.Contains(integerTypes, strings.ToUpper(old)) && slices.Contains(integerTypes, strings.ToUpper(new)) { + return true + } + return false + }, + Required: true, + ForceNew: true, + }, + "statement": { + Type: schema.TypeString, + Required: true, + Description: "Specifies the code used to create the procedure.", + ForceNew: true, + DiffSuppressFunc: DiffSuppressStatement, + }, + "language": { + Type: schema.TypeString, + Optional: true, + Default: "SQL", + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(old, new) + }, + ValidateFunc: validation.StringInSlice([]string{"javascript", "java", "scala", "SQL", "python"}, true), + Description: "Specifies the language of the stored procedure code.", + }, + "execute_as": { + Type: schema.TypeString, + Optional: true, + Default: "OWNER", + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(old, new) + }, + ValidateFunc: validation.StringInSlice([]string{"CALLER", "OWNER"}, true), + Description: "Sets execution context. Allowed values are CALLER and OWNER (consult a proper section in the [docs](https://docs.snowflake.com/en/sql-reference/sql/create-procedure#id1)). For more information see [caller's rights and owner's rights](https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-rights).", + }, + "null_input_behavior": { + Type: schema.TypeString, + Optional: true, + Default: "CALLED ON NULL INPUT", + ForceNew: true, + // We do not use STRICT, because Snowflake then in the Read phase returns RETURNS NULL ON NULL INPUT + ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT"}, false), + Description: "Specifies the behavior of the procedure when called with null inputs.", + }, + "return_behavior": { + Type: schema.TypeString, + Optional: true, + Default: "VOLATILE", + ForceNew: true, + ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), + Description: "Specifies the behavior of the function when returning results", + Deprecated: "These keywords are deprecated for stored procedures. These keywords are not intended to apply to stored procedures. In a future release, these keywords will be removed from the documentation.", + }, + "comment": { + Type: schema.TypeString, + Optional: true, + Default: "user-defined procedure", + Description: "Specifies a comment for the procedure.", + }, + "runtime_version": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Description: "Required for Python procedures. Specifies Python runtime version.", + }, + "packages": { + Type: schema.TypeList, + Elem: &schema.Schema{ + Type: schema.TypeString, + }, + Optional: true, + ForceNew: true, + Description: "List of package imports to use for Java / Python procedures. For Java, package imports should be of the form: package_name:version_number, where package_name is snowflake_domain:package. For Python use it should be: ('numpy','pandas','xgboost==1.5.0').", + }, + "imports": { + Type: schema.TypeList, + Elem: &schema.Schema{ + Type: schema.TypeString, + }, + Optional: true, + ForceNew: true, + Description: "Imports for Java / Python procedures. For Java this a list of jar files, for Python this is a list of Python files.", + }, + "handler": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Description: "The handler method for Java / Python procedures.", + }, +} + +// Procedure returns a pointer to the resource representing a stored procedure. func Procedure() *schema.Resource { return &schema.Resource{ - //SchemaVersion: 1, - // - //CreateContext: CreateContextProcedure, - //ReadContext: ReadContextProcedure, - //UpdateContext: UpdateContextProcedure, - //DeleteContext: DeleteContextProcedure, - // - //Schema: procedureSchema, - //Importer: &schema.ResourceImporter{ - // StateContext: schema.ImportStatePassthroughContext, - //}, - - //StateUpgraders: []schema.StateUpgrader{ - // { - // Version: 0, - // // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject - // Type: cty.EmptyObject, - // Upgrade: v085ProcedureStateUpgrader, - // }, - //}, + SchemaVersion: 1, + + CreateContext: CreateContextProcedure, + ReadContext: ReadContextProcedure, + UpdateContext: UpdateContextProcedure, + DeleteContext: DeleteContextProcedure, + + Schema: procedureSchema, + Importer: &schema.ResourceImporter{ + StateContext: schema.ImportStatePassthroughContext, + }, + + StateUpgraders: []schema.StateUpgrader{ + { + Version: 0, + // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + Type: cty.EmptyObject, + Upgrade: v085ProcedureStateUpgrader, + }, + }, + } +} + +func CreateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + lang := strings.ToUpper(d.Get("language").(string)) + switch lang { + case "JAVA": + return createJavaProcedure(ctx, d, meta) + case "JAVASCRIPT": + return createJavaScriptProcedure(ctx, d, meta) + case "PYTHON": + return createPythonProcedure(ctx, d, meta) + case "SCALA": + return createScalaProcedure(ctx, d, meta) + case "SQL": + return createSQLProcedure(ctx, d, meta) + default: + return diag.Diagnostics{ + diag.Diagnostic{ + Severity: diag.Error, + Summary: "Invalid language", + Detail: fmt.Sprintf("Language %s is not supported", lang), + }, + } + } +} + +func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + name := d.Get("name").(string) + schema := d.Get("schema").(string) + database := d.Get("database").(string) + args, diags := getProcedureArguments(d) + if diags != nil { + return diags + } + argDataTypes := make([]sdk.DataType, len(args)) + for i, arg := range args { + argDataTypes[i] = arg.ArgDataType + } + id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + + returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) + if diags != nil { + return diags + } + procedureDefinition := d.Get("statement").(string) + runtimeVersion := d.Get("runtime_version").(string) + packages := []sdk.ProcedurePackageRequest{} + for _, item := range d.Get("packages").([]interface{}) { + packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) + } + handler := d.Get("handler").(string) + req := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, runtimeVersion, packages, handler) + req.WithProcedureDefinition(procedureDefinition) + if len(args) > 0 { + req.WithArguments(args) + } + + // read optional params + if v, ok := d.GetOk("execute_as"); ok { + if strings.ToUpper(v.(string)) == "OWNER" { + req.WithExecuteAs(sdk.ExecuteAsOwner) + } else if strings.ToUpper(v.(string)) == "CALLER" { + req.WithExecuteAs(sdk.ExecuteAsCaller) + } + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(v.(string)) + } + if v, ok := d.GetOk("secure"); ok { + req.WithSecure(v.(bool)) + } + if _, ok := d.GetOk("imports"); ok { + imports := []sdk.ProcedureImportRequest{} + for _, item := range d.Get("imports").([]interface{}) { + imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) + } + req.WithImports(imports) + } + + if err := client.Procedures.CreateForJava(ctx, req); err != nil { + return diag.FromErr(err) + } + d.SetId(id.FullyQualifiedName()) + return ReadContextProcedure(ctx, d, meta) +} + +func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + name := d.Get("name").(string) + schema := d.Get("schema").(string) + database := d.Get("database").(string) + args, diags := getProcedureArguments(d) + if diags != nil { + return diags + } + argDataTypes := make([]sdk.DataType, len(args)) + for i, arg := range args { + argDataTypes[i] = arg.ArgDataType + } + id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + + returnType := d.Get("return_type").(string) + returnDataType, diags := convertProcedureDataType(returnType) + if diags != nil { + return diags + } + procedureDefinition := d.Get("statement").(string) + req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), returnDataType, procedureDefinition) + if len(args) > 0 { + req.WithArguments(args) + } + + // read optional params + if v, ok := d.GetOk("execute_as"); ok { + if strings.ToUpper(v.(string)) == "OWNER" { + req.WithExecuteAs(sdk.ExecuteAsOwner) + } else if strings.ToUpper(v.(string)) == "CALLER" { + req.WithExecuteAs(sdk.ExecuteAsCaller) + } } + if v, ok := d.GetOk("null_input_behavior"); ok { + req.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(v.(string)) + } + if v, ok := d.GetOk("secure"); ok { + req.WithSecure(v.(bool)) + } + + if err := client.Procedures.CreateForJavaScript(ctx, req); err != nil { + return diag.FromErr(err) + } + d.SetId(id.FullyQualifiedName()) + return ReadContextProcedure(ctx, d, meta) } -// -//func CreateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// lang := strings.ToUpper(d.Get("language").(string)) -// switch lang { -// case "JAVA": -// return createJavaProcedure(ctx, d, meta) -// case "JAVASCRIPT": -// return createJavaScriptProcedure(ctx, d, meta) -// case "PYTHON": -// return createPythonProcedure(ctx, d, meta) -// case "SCALA": -// return createScalaProcedure(ctx, d, meta) -// case "SQL": -// return createSQLProcedure(ctx, d, meta) -// default: -// return diag.Diagnostics{ -// diag.Diagnostic{ -// Severity: diag.Error, -// Summary: "Invalid language", -// Detail: fmt.Sprintf("Language %s is not supported", lang), -// }, -// } -// } -//} -// -//func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// name := d.Get("name").(string) -// schema := d.Get("schema").(string) -// database := d.Get("database").(string) -// id := sdk.NewSchemaObjectIdentifier(database, schema, name) -// -// returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) -// if diags != nil { -// return diags -// } -// procedureDefinition := d.Get("statement").(string) -// runtimeVersion := d.Get("runtime_version").(string) -// packages := []sdk.ProcedurePackageRequest{} -// for _, item := range d.Get("packages").([]interface{}) { -// packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) -// } -// handler := d.Get("handler").(string) -// req := sdk.NewCreateForJavaProcedureRequest(id, *returns, runtimeVersion, packages, handler) -// req.WithProcedureDefinition(procedureDefinition) -// args, diags := getProcedureArguments(d) -// if diags != nil { -// return diags -// } -// if len(args) > 0 { -// req.WithArguments(args) -// } -// -// // read optional params -// if v, ok := d.GetOk("execute_as"); ok { -// if strings.ToUpper(v.(string)) == "OWNER" { -// req.WithExecuteAs(sdk.ExecuteAsOwner) -// } else if strings.ToUpper(v.(string)) == "CALLER" { -// req.WithExecuteAs(sdk.ExecuteAsCaller) -// } -// } -// if v, ok := d.GetOk("comment"); ok { -// req.WithComment(v.(string)) -// } -// if v, ok := d.GetOk("secure"); ok { -// req.WithSecure(v.(bool)) -// } -// if _, ok := d.GetOk("imports"); ok { -// imports := []sdk.ProcedureImportRequest{} -// for _, item := range d.Get("imports").([]interface{}) { -// imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) -// } -// req.WithImports(imports) -// } -// -// if err := client.Procedures.CreateForJava(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// argTypes := make([]sdk.DataType, 0, len(args)) -// for _, item := range args { -// argTypes = append(argTypes, item.ArgDataType) -// } -// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) -// d.SetId(sid.FullyQualifiedName()) -// return ReadContextProcedure(ctx, d, meta) -//} -// -//func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// name := d.Get("name").(string) -// schema := d.Get("schema").(string) -// database := d.Get("database").(string) -// id := sdk.NewSchemaObjectIdentifier(database, schema, name) -// -// returnType := d.Get("return_type").(string) -// returnDataType, diags := convertProcedureDataType(returnType) -// if diags != nil { -// return diags -// } -// procedureDefinition := d.Get("statement").(string) -// req := sdk.NewCreateForJavaScriptProcedureRequest(id, returnDataType, procedureDefinition) -// args, diags := getProcedureArguments(d) -// if diags != nil { -// return diags -// } -// if len(args) > 0 { -// req.WithArguments(args) -// } -// -// // read optional params -// if v, ok := d.GetOk("execute_as"); ok { -// if strings.ToUpper(v.(string)) == "OWNER" { -// req.WithExecuteAs(sdk.ExecuteAsOwner) -// } else if strings.ToUpper(v.(string)) == "CALLER" { -// req.WithExecuteAs(sdk.ExecuteAsCaller) -// } -// } -// if v, ok := d.GetOk("null_input_behavior"); ok { -// req.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) -// } -// if v, ok := d.GetOk("comment"); ok { -// req.WithComment(v.(string)) -// } -// if v, ok := d.GetOk("secure"); ok { -// req.WithSecure(v.(bool)) -// } -// -// if err := client.Procedures.CreateForJavaScript(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// argTypes := make([]sdk.DataType, 0, len(args)) -// for _, item := range args { -// argTypes = append(argTypes, item.ArgDataType) -// } -// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) -// d.SetId(sid.FullyQualifiedName()) -// return ReadContextProcedure(ctx, d, meta) -//} -// -//func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// name := d.Get("name").(string) -// schema := d.Get("schema").(string) -// database := d.Get("database").(string) -// id := sdk.NewSchemaObjectIdentifier(database, schema, name) -// -// returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) -// if diags != nil { -// return diags -// } -// procedureDefinition := d.Get("statement").(string) -// runtimeVersion := d.Get("runtime_version").(string) -// packages := []sdk.ProcedurePackageRequest{} -// for _, item := range d.Get("packages").([]interface{}) { -// packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) -// } -// handler := d.Get("handler").(string) -// req := sdk.NewCreateForScalaProcedureRequest(id, *returns, runtimeVersion, packages, handler) -// req.WithProcedureDefinition(procedureDefinition) -// args, diags := getProcedureArguments(d) -// if diags != nil { -// return diags -// } -// if len(args) > 0 { -// req.WithArguments(args) -// } -// -// // read optional params -// if v, ok := d.GetOk("execute_as"); ok { -// if strings.ToUpper(v.(string)) == "OWNER" { -// req.WithExecuteAs(sdk.ExecuteAsOwner) -// } else if strings.ToUpper(v.(string)) == "CALLER" { -// req.WithExecuteAs(sdk.ExecuteAsCaller) -// } -// } -// if v, ok := d.GetOk("comment"); ok { -// req.WithComment(v.(string)) -// } -// if v, ok := d.GetOk("secure"); ok { -// req.WithSecure(v.(bool)) -// } -// if _, ok := d.GetOk("imports"); ok { -// imports := []sdk.ProcedureImportRequest{} -// for _, item := range d.Get("imports").([]interface{}) { -// imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) -// } -// req.WithImports(imports) -// } -// -// if err := client.Procedures.CreateForScala(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// argTypes := make([]sdk.DataType, 0, len(args)) -// for _, item := range args { -// argTypes = append(argTypes, item.ArgDataType) -// } -// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) -// d.SetId(sid.FullyQualifiedName()) -// return ReadContextProcedure(ctx, d, meta) -//} -// -//func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// name := d.Get("name").(string) -// schema := d.Get("schema").(string) -// database := d.Get("database").(string) -// id := sdk.NewSchemaObjectIdentifier(database, schema, name) -// -// returns, diags := parseProcedureSQLReturnsRequest(d.Get("return_type").(string)) -// if diags != nil { -// return diags -// } -// procedureDefinition := d.Get("statement").(string) -// req := sdk.NewCreateForSQLProcedureRequest(id, *returns, procedureDefinition) -// args, diags := getProcedureArguments(d) -// if diags != nil { -// return diags -// } -// if len(args) > 0 { -// req.WithArguments(args) -// } -// -// // read optional params -// if v, ok := d.GetOk("execute_as"); ok { -// if strings.ToUpper(v.(string)) == "OWNER" { -// req.WithExecuteAs(sdk.ExecuteAsOwner) -// } else if strings.ToUpper(v.(string)) == "CALLER" { -// req.WithExecuteAs(sdk.ExecuteAsCaller) -// } -// } -// if v, ok := d.GetOk("null_input_behavior"); ok { -// req.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) -// } -// if v, ok := d.GetOk("comment"); ok { -// req.WithComment(v.(string)) -// } -// if v, ok := d.GetOk("secure"); ok { -// req.WithSecure(v.(bool)) -// } -// -// if err := client.Procedures.CreateForSQL(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// argTypes := make([]sdk.DataType, 0, len(args)) -// for _, item := range args { -// argTypes = append(argTypes, item.ArgDataType) -// } -// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) -// d.SetId(sid.FullyQualifiedName()) -// return ReadContextProcedure(ctx, d, meta) -//} -// -//func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// name := d.Get("name").(string) -// schema := d.Get("schema").(string) -// database := d.Get("database").(string) -// id := sdk.NewSchemaObjectIdentifier(database, schema, name) -// -// returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) -// if diags != nil { -// return diags -// } -// procedureDefinition := d.Get("statement").(string) -// runtimeVersion := d.Get("runtime_version").(string) -// packages := []sdk.ProcedurePackageRequest{} -// for _, item := range d.Get("packages").([]interface{}) { -// packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) -// } -// handler := d.Get("handler").(string) -// req := sdk.NewCreateForPythonProcedureRequest(id, *returns, runtimeVersion, packages, handler) -// req.WithProcedureDefinition(procedureDefinition) -// args, diags := getProcedureArguments(d) -// if diags != nil { -// return diags -// } -// if len(args) > 0 { -// req.WithArguments(args) -// } -// -// // read optional params -// if v, ok := d.GetOk("execute_as"); ok { -// if strings.ToUpper(v.(string)) == "OWNER" { -// req.WithExecuteAs(sdk.ExecuteAsOwner) -// } else if strings.ToUpper(v.(string)) == "CALLER" { -// req.WithExecuteAs(sdk.ExecuteAsCaller) -// } -// } -// -// // [ { CALLED ON NULL INPUT | { RETURNS NULL ON NULL INPUT | STRICT } } ] does not work for java, scala or python -// // posted in docs-discuss channel, either docs need to be updated to reflect reality or this feature needs to be added -// // https://snowflake.slack.com/archives/C6380540P/p1707511734666249 -// // if v, ok := d.GetOk("null_input_behavior"); ok { -// // req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) -// // } -// -// if v, ok := d.GetOk("comment"); ok { -// req.WithComment(v.(string)) -// } -// if v, ok := d.GetOk("secure"); ok { -// req.WithSecure(v.(bool)) -// } -// if _, ok := d.GetOk("imports"); ok { -// imports := []sdk.ProcedureImportRequest{} -// for _, item := range d.Get("imports").([]interface{}) { -// imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) -// } -// req.WithImports(imports) -// } -// -// if err := client.Procedures.CreateForPython(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// argTypes := make([]sdk.DataType, 0, len(args)) -// for _, item := range args { -// argTypes = append(argTypes, item.ArgDataType) -// } -// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schema, name, argTypes) -// d.SetId(sid.FullyQualifiedName()) -// return ReadContextProcedure(ctx, d, meta) -//} -// -//func ReadContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// diags := diag.Diagnostics{} -// client := meta.(*provider.Context).Client -// -// id, err := sdk.NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(d.Id()) -// if err != nil { -// return diag.FromErr(err) -// } -// if err := d.Set("name", id.Name()); err != nil { -// return diag.FromErr(err) -// } -// if err := d.Set("database", id.DatabaseName()); err != nil { -// return diag.FromErr(err) -// } -// if err := d.Set("schema", id.SchemaName()); err != nil { -// return diag.FromErr(err) -// } -// args := d.Get("arguments").([]interface{}) -// argTypes := make([]string, len(args)) -// for i, arg := range args { -// argTypes[i] = arg.(map[string]interface{})["type"].(string) -// } -// procedureDetails, err := client.Procedures.Describe(ctx, id) -// if err != nil { -// // if procedure is not found then mark resource to be removed from state file during apply or refresh -// d.SetId("") -// return diag.Diagnostics{ -// diag.Diagnostic{ -// Severity: diag.Warning, -// Summary: "Describe procedure failed.", -// Detail: fmt.Sprintf("Describe procedure failed: %v", err), -// }, -// } -// } -// for _, desc := range procedureDetails { -// switch desc.Property { -// case "signature": -// // Format in Snowflake DB is: (argName argType, argName argType, ...) -// args := strings.ReplaceAll(strings.ReplaceAll(desc.Value, "(", ""), ")", "") -// -// if args != "" { // Do nothing for functions without arguments -// argPairs := strings.Split(args, ", ") -// args := []interface{}{} -// -// for _, argPair := range argPairs { -// argItem := strings.Split(argPair, " ") -// -// arg := map[string]interface{}{} -// arg["name"] = argItem[0] -// arg["type"] = argItem[1] -// args = append(args, arg) -// } -// -// if err := d.Set("arguments", args); err != nil { -// return diag.FromErr(err) -// } -// } -// case "null handling": -// if err := d.Set("null_input_behavior", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "body": -// if err := d.Set("statement", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "execute as": -// if err := d.Set("execute_as", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "returns": -// if err := d.Set("return_type", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "language": -// if err := d.Set("language", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "runtime_version": -// if err := d.Set("runtime_version", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "packages": -// packagesString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", "") -// if packagesString != "" { // Do nothing for Java / Python functions without packages -// packages := strings.Split(packagesString, ",") -// if err := d.Set("packages", packages); err != nil { -// return diag.FromErr(err) -// } -// } -// case "imports": -// importsString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", ""), " ", "") -// if importsString != "" { // Do nothing for Java functions without imports -// imports := strings.Split(importsString, ",") -// if err := d.Set("imports", imports); err != nil { -// return diag.FromErr(err) -// } -// } -// case "handler": -// if err := d.Set("handler", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// case "volatility": -// if err := d.Set("return_behavior", desc.Value); err != nil { -// return diag.FromErr(err) -// } -// default: -// log.Printf("[INFO] Unexpected procedure property %v returned from Snowflake with value %v", desc.Property, desc.Value) -// } -// } -// -// request := sdk.NewShowProcedureRequest().WithIn(sdk.In{Schema: sdk.NewDatabaseObjectIdentifier(id.DatabaseName(), id.SchemaName())}).WithLike(sdk.Like{Pattern: sdk.String(id.Name())}) -// -// procedures, err := client.Procedures.Show(ctx, request) -// if err != nil { -// return diag.FromErr(err) -// } -// // procedure names can be overloaded with different argument types so we iterate over and find the correct one -// // the ShowByID function should probably be updated to also require the list of arg types, like describe procedure -// for _, procedure := range procedures { -// argumentSignature := strings.Split(procedure.Arguments, " RETURN ")[0] -// argumentSignature = strings.ReplaceAll(argumentSignature, " ", "") -// if argumentSignature == id.ArgumentsSignature() { -// if err := d.Set("secure", procedure.IsSecure); err != nil { -// return diag.FromErr(err) -// } -// if err := d.Set("comment", procedure.Description); err != nil { -// return diag.FromErr(err) -// } -// } -// } -// -// return diags -//} -// -//func UpdateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// -// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) -// if d.HasChange("name") { -// newId := sdk.NewSchemaObjectIdentifierWithArgumentsOld(id.DatabaseName(), id.SchemaName(), d.Get("name").(string), id.Arguments()) -// -// err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithRenameTo(newId.WithoutArguments())) -// if err != nil { -// return diag.FromErr(err) -// } -// -// d.SetId(newId.FullyQualifiedName()) -// id = newId -// } -// -// if d.HasChange("comment") { -// comment := d.Get("comment") -// if comment != "" { -// if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithSetComment(comment.(string))); err != nil { -// return diag.FromErr(err) -// } -// } else { -// if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()).WithUnsetComment(true)); err != nil { -// return diag.FromErr(err) -// } -// } -// } -// -// if d.HasChange("execute_as") { -// req := sdk.NewAlterProcedureRequest(id.WithoutArguments(), id.Arguments()) -// executeAs := d.Get("execute_as").(string) -// if strings.ToUpper(executeAs) == "OWNER" { -// req.WithExecuteAs(sdk.ExecuteAsOwner) -// } else if strings.ToUpper(executeAs) == "CALLER" { -// req.WithExecuteAs(sdk.ExecuteAsCaller) -// } -// if err := client.Procedures.Alter(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// } -// -// return ReadContextProcedure(ctx, d, meta) -//} -// -//func DeleteContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// -// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) -// if err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id.WithoutArguments(), id.Arguments())); err != nil { -// return diag.FromErr(err) -// } -// d.SetId("") -// return nil -//} -// -//func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentRequest, diag.Diagnostics) { -// args := make([]sdk.ProcedureArgumentRequest, 0) -// if v, ok := d.GetOk("arguments"); ok { -// for _, arg := range v.([]interface{}) { -// argName := arg.(map[string]interface{})["name"].(string) -// argType := arg.(map[string]interface{})["type"].(string) -// argDataType, diags := convertProcedureDataType(argType) -// if diags != nil { -// return nil, diags -// } -// args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: argDataType}) -// } -// } -// return args, nil -//} -// -//func convertProcedureDataType(s string) (sdk.DataType, diag.Diagnostics) { -// dataType, err := sdk.ToDataType(s) -// if err != nil { -// return dataType, diag.FromErr(err) -// } -// return dataType, nil -//} -// -//func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) { -// pattern := regexp.MustCompile(`(\w+)\s+(\w+)`) -// matches := pattern.FindAllStringSubmatch(s, -1) -// var columns []sdk.ProcedureColumn -// for _, match := range matches { -// if len(match) == 3 { -// dataType, err := sdk.ToDataType(match[2]) -// if err != nil { -// return nil, diag.FromErr(err) -// } -// columns = append(columns, sdk.ProcedureColumn{ -// ColumnName: match[1], -// ColumnDataType: dataType, -// }) -// } -// } -// return columns, nil -//} -// -//func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag.Diagnostics) { -// returns := sdk.NewProcedureReturnsRequest() -// if strings.HasPrefix(strings.ToLower(s), "table") { -// columns, diags := convertProcedureColumns(s) -// if diags != nil { -// return nil, diags -// } -// var cr []sdk.ProcedureColumnRequest -// for _, item := range columns { -// cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) -// } -// returns.WithTable(sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) -// } else { -// returnDataType, diags := convertProcedureDataType(s) -// if diags != nil { -// return nil, diags -// } -// returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) -// } -// return returns, nil -//} -// -//func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, diag.Diagnostics) { -// returns := sdk.NewProcedureSQLReturnsRequest() -// if strings.HasPrefix(strings.ToLower(s), "table") { -// columns, diags := convertProcedureColumns(s) -// if diags != nil { -// return nil, diags -// } -// var cr []sdk.ProcedureColumnRequest -// for _, item := range columns { -// cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) -// } -// returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) -// } else { -// returnDataType, diags := convertProcedureDataType(s) -// if diags != nil { -// return nil, diags -// } -// returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) -// } -// return returns, nil -//} +func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + name := d.Get("name").(string) + schema := d.Get("schema").(string) + database := d.Get("database").(string) + args, diags := getProcedureArguments(d) + if diags != nil { + return diags + } + argDataTypes := make([]sdk.DataType, len(args)) + for i, arg := range args { + argDataTypes[i] = arg.ArgDataType + } + id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + + returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) + if diags != nil { + return diags + } + procedureDefinition := d.Get("statement").(string) + runtimeVersion := d.Get("runtime_version").(string) + packages := []sdk.ProcedurePackageRequest{} + for _, item := range d.Get("packages").([]interface{}) { + packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) + } + handler := d.Get("handler").(string) + req := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, runtimeVersion, packages, handler) + req.WithProcedureDefinition(procedureDefinition) + if len(args) > 0 { + req.WithArguments(args) + } + + // read optional params + if v, ok := d.GetOk("execute_as"); ok { + if strings.ToUpper(v.(string)) == "OWNER" { + req.WithExecuteAs(sdk.ExecuteAsOwner) + } else if strings.ToUpper(v.(string)) == "CALLER" { + req.WithExecuteAs(sdk.ExecuteAsCaller) + } + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(v.(string)) + } + if v, ok := d.GetOk("secure"); ok { + req.WithSecure(v.(bool)) + } + if _, ok := d.GetOk("imports"); ok { + imports := []sdk.ProcedureImportRequest{} + for _, item := range d.Get("imports").([]interface{}) { + imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) + } + req.WithImports(imports) + } + + if err := client.Procedures.CreateForScala(ctx, req); err != nil { + return diag.FromErr(err) + } + d.SetId(id.FullyQualifiedName()) + return ReadContextProcedure(ctx, d, meta) +} + +func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + name := d.Get("name").(string) + schema := d.Get("schema").(string) + database := d.Get("database").(string) + args, diags := getProcedureArguments(d) + if diags != nil { + return diags + } + argDataTypes := make([]sdk.DataType, len(args)) + for i, arg := range args { + argDataTypes[i] = arg.ArgDataType + } + id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + + returns, diags := parseProcedureSQLReturnsRequest(d.Get("return_type").(string)) + if diags != nil { + return diags + } + procedureDefinition := d.Get("statement").(string) + req := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, procedureDefinition) + if len(args) > 0 { + req.WithArguments(args) + } + + // read optional params + if v, ok := d.GetOk("execute_as"); ok { + if strings.ToUpper(v.(string)) == "OWNER" { + req.WithExecuteAs(sdk.ExecuteAsOwner) + } else if strings.ToUpper(v.(string)) == "CALLER" { + req.WithExecuteAs(sdk.ExecuteAsCaller) + } + } + if v, ok := d.GetOk("null_input_behavior"); ok { + req.WithNullInputBehavior(sdk.NullInputBehavior(v.(string))) + } + if v, ok := d.GetOk("comment"); ok { + req.WithComment(v.(string)) + } + if v, ok := d.GetOk("secure"); ok { + req.WithSecure(v.(bool)) + } + + if err := client.Procedures.CreateForSQL(ctx, req); err != nil { + return diag.FromErr(err) + } + d.SetId(id.FullyQualifiedName()) + return ReadContextProcedure(ctx, d, meta) +} + +func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + name := d.Get("name").(string) + schema := d.Get("schema").(string) + database := d.Get("database").(string) + args, diags := getProcedureArguments(d) + if diags != nil { + return diags + } + argDataTypes := make([]sdk.DataType, len(args)) + for i, arg := range args { + argDataTypes[i] = arg.ArgDataType + } + id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + + returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) + if diags != nil { + return diags + } + procedureDefinition := d.Get("statement").(string) + runtimeVersion := d.Get("runtime_version").(string) + packages := []sdk.ProcedurePackageRequest{} + for _, item := range d.Get("packages").([]interface{}) { + packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) + } + handler := d.Get("handler").(string) + req := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, runtimeVersion, packages, handler) + req.WithProcedureDefinition(procedureDefinition) + if len(args) > 0 { + req.WithArguments(args) + } + + // read optional params + if v, ok := d.GetOk("execute_as"); ok { + if strings.ToUpper(v.(string)) == "OWNER" { + req.WithExecuteAs(sdk.ExecuteAsOwner) + } else if strings.ToUpper(v.(string)) == "CALLER" { + req.WithExecuteAs(sdk.ExecuteAsCaller) + } + } + + // [ { CALLED ON NULL INPUT | { RETURNS NULL ON NULL INPUT | STRICT } } ] does not work for java, scala or python + // posted in docs-discuss channel, either docs need to be updated to reflect reality or this feature needs to be added + // https://snowflake.slack.com/archives/C6380540P/p1707511734666249 + // if v, ok := d.GetOk("null_input_behavior"); ok { + // req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehavior(v.(string)))) + // } + + if v, ok := d.GetOk("comment"); ok { + req.WithComment(v.(string)) + } + if v, ok := d.GetOk("secure"); ok { + req.WithSecure(v.(bool)) + } + if _, ok := d.GetOk("imports"); ok { + imports := []sdk.ProcedureImportRequest{} + for _, item := range d.Get("imports").([]interface{}) { + imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) + } + req.WithImports(imports) + } + + if err := client.Procedures.CreateForPython(ctx, req); err != nil { + return diag.FromErr(err) + } + d.SetId(id.FullyQualifiedName()) + return ReadContextProcedure(ctx, d, meta) +} + +func ReadContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + diags := diag.Diagnostics{} + client := meta.(*provider.Context).Client + + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } + if err := d.Set("name", id.Name()); err != nil { + return diag.FromErr(err) + } + if err := d.Set("database", id.DatabaseName()); err != nil { + return diag.FromErr(err) + } + if err := d.Set("schema", id.SchemaName()); err != nil { + return diag.FromErr(err) + } + args := d.Get("arguments").([]interface{}) + argTypes := make([]string, len(args)) + for i, arg := range args { + argTypes[i] = arg.(map[string]interface{})["type"].(string) + } + procedureDetails, err := client.Procedures.Describe(ctx, id) + if err != nil { + // if procedure is not found then mark resource to be removed from state file during apply or refresh + d.SetId("") + return diag.Diagnostics{ + diag.Diagnostic{ + Severity: diag.Warning, + Summary: "Describe procedure failed.", + Detail: fmt.Sprintf("Describe procedure failed: %v", err), + }, + } + } + for _, desc := range procedureDetails { + switch desc.Property { + case "signature": + // Format in Snowflake DB is: (argName argType, argName argType, ...) + args := strings.ReplaceAll(strings.ReplaceAll(desc.Value, "(", ""), ")", "") + + if args != "" { // Do nothing for functions without arguments + argPairs := strings.Split(args, ", ") + args := []interface{}{} + + for _, argPair := range argPairs { + argItem := strings.Split(argPair, " ") + + arg := map[string]interface{}{} + arg["name"] = argItem[0] + arg["type"] = argItem[1] + args = append(args, arg) + } + + if err := d.Set("arguments", args); err != nil { + return diag.FromErr(err) + } + } + case "null handling": + if err := d.Set("null_input_behavior", desc.Value); err != nil { + return diag.FromErr(err) + } + case "body": + if err := d.Set("statement", desc.Value); err != nil { + return diag.FromErr(err) + } + case "execute as": + if err := d.Set("execute_as", desc.Value); err != nil { + return diag.FromErr(err) + } + case "returns": + if err := d.Set("return_type", desc.Value); err != nil { + return diag.FromErr(err) + } + case "language": + if err := d.Set("language", desc.Value); err != nil { + return diag.FromErr(err) + } + case "runtime_version": + if err := d.Set("runtime_version", desc.Value); err != nil { + return diag.FromErr(err) + } + case "packages": + packagesString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", "") + if packagesString != "" { // Do nothing for Java / Python functions without packages + packages := strings.Split(packagesString, ",") + if err := d.Set("packages", packages); err != nil { + return diag.FromErr(err) + } + } + case "imports": + importsString := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(desc.Value, "[", ""), "]", ""), "'", ""), " ", "") + if importsString != "" { // Do nothing for Java functions without imports + imports := strings.Split(importsString, ",") + if err := d.Set("imports", imports); err != nil { + return diag.FromErr(err) + } + } + case "handler": + if err := d.Set("handler", desc.Value); err != nil { + return diag.FromErr(err) + } + case "volatility": + if err := d.Set("return_behavior", desc.Value); err != nil { + return diag.FromErr(err) + } + default: + log.Printf("[INFO] Unexpected procedure property %v returned from Snowflake with value %v", desc.Property, desc.Value) + } + } + + procedure, err := client.Procedures.ShowByID(ctx, id) + if err != nil { + return diag.FromErr(err) + } + + if err := d.Set("secure", procedure.IsSecure); err != nil { + return diag.FromErr(err) + } + + if err := d.Set("comment", procedure.Description); err != nil { + return diag.FromErr(err) + } + + return diags +} + +func UpdateContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } + + if d.HasChange("name") { + newId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), id.SchemaName(), d.Get("name").(string)) + newIdWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), d.Get("name").(string), id.ArgumentDataTypes()...) + + err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithRenameTo(newId.WithoutArguments())) + if err != nil { + return diag.FromErr(err) + } + + d.SetId(newIdWithArguments.FullyQualifiedName()) + id = newIdWithArguments + } + + if d.HasChange("comment") { + comment := d.Get("comment") + if comment != "" { + if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithSetComment(comment.(string))); err != nil { + return diag.FromErr(err) + } + } else { + if err := client.Procedures.Alter(ctx, sdk.NewAlterProcedureRequest(id).WithUnsetComment(true)); err != nil { + return diag.FromErr(err) + } + } + } + + if d.HasChange("execute_as") { + req := sdk.NewAlterProcedureRequest(id) + executeAs := d.Get("execute_as").(string) + if strings.ToUpper(executeAs) == "OWNER" { + req.WithExecuteAs(sdk.ExecuteAsOwner) + } else if strings.ToUpper(executeAs) == "CALLER" { + req.WithExecuteAs(sdk.ExecuteAsCaller) + } + if err := client.Procedures.Alter(ctx, req); err != nil { + return diag.FromErr(err) + } + } + + return ReadContextProcedure(ctx, d, meta) +} + +func DeleteContextProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } + if err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)); err != nil { + return diag.FromErr(err) + } + d.SetId("") + return nil +} + +func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentRequest, diag.Diagnostics) { + args := make([]sdk.ProcedureArgumentRequest, 0) + if v, ok := d.GetOk("arguments"); ok { + for _, arg := range v.([]interface{}) { + argName := arg.(map[string]interface{})["name"].(string) + argType := arg.(map[string]interface{})["type"].(string) + argDataType, diags := convertProcedureDataType(argType) + if diags != nil { + return nil, diags + } + args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + } + } + return args, nil +} + +func convertProcedureDataType(s string) (sdk.DataType, diag.Diagnostics) { + dataType, err := sdk.ToDataType(s) + if err != nil { + return dataType, diag.FromErr(err) + } + return dataType, nil +} + +func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) { + pattern := regexp.MustCompile(`(\w+)\s+(\w+)`) + matches := pattern.FindAllStringSubmatch(s, -1) + var columns []sdk.ProcedureColumn + for _, match := range matches { + if len(match) == 3 { + dataType, err := sdk.ToDataType(match[2]) + if err != nil { + return nil, diag.FromErr(err) + } + columns = append(columns, sdk.ProcedureColumn{ + ColumnName: match[1], + ColumnDataType: dataType, + }) + } + } + return columns, nil +} + +func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag.Diagnostics) { + returns := sdk.NewProcedureReturnsRequest() + if strings.HasPrefix(strings.ToLower(s), "table") { + columns, diags := convertProcedureColumns(s) + if diags != nil { + return nil, diags + } + var cr []sdk.ProcedureColumnRequest + for _, item := range columns { + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) + } + returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) + } else { + returnDataType, diags := convertProcedureDataType(s) + if diags != nil { + return nil, diags + } + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) + } + return returns, nil +} + +func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, diag.Diagnostics) { + returns := sdk.NewProcedureSQLReturnsRequest() + if strings.HasPrefix(strings.ToLower(s), "table") { + columns, diags := convertProcedureColumns(s) + if diags != nil { + return nil, diags + } + var cr []sdk.ProcedureColumnRequest + for _, item := range columns { + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) + } + returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) + } else { + returnDataType, diags := convertProcedureDataType(s) + if diags != nil { + return nil, diags + } + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) + } + return returns, nil +} diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index b601583a9a..d10316ebdd 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -225,8 +225,13 @@ func TestAcc_Procedure_migrateFromVersion085(t *testing.T) { ), }, { - ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, - Config: procedureConfig(acc.TestDatabaseName, acc.TestSchemaName, name), + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: procedureConfig(acc.TestDatabaseName, acc.TestSchemaName, name), ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{plancheck.ExpectEmptyPlan()}, }, @@ -241,6 +246,8 @@ func TestAcc_Procedure_migrateFromVersion085(t *testing.T) { }) } +// TODO: test new state upgrader + func procedureConfig(database string, schema string, name string) string { return fmt.Sprintf(` resource "snowflake_procedure" "p" { diff --git a/pkg/resources/procedure_state_upgraders.go b/pkg/resources/procedure_state_upgraders.go index 24e47d7d9f..27866b7387 100644 --- a/pkg/resources/procedure_state_upgraders.go +++ b/pkg/resources/procedure_state_upgraders.go @@ -60,3 +60,5 @@ func v085ProcedureStateUpgrader(ctx context.Context, rawState map[string]interfa return rawState, nil } + +// TODO: state upgrader for empty args (without '()') diff --git a/pkg/sdk/functions_impl_gen.go b/pkg/sdk/functions_impl_gen.go index 7cbe401daf..e2122633d5 100644 --- a/pkg/sdk/functions_impl_gen.go +++ b/pkg/sdk/functions_impl_gen.go @@ -6,8 +6,6 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" - "log" - "strings" ) var _ Functions = (*functions)(nil) @@ -62,8 +60,7 @@ func (v *functions) Show(ctx context.Context, request *ShowFunctionRequest) ([]F } func (v *functions) ShowByID(ctx context.Context, id SchemaObjectIdentifierWithArguments) (*Function, error) { - request := NewShowFunctionRequest().WithIn(In{Schema: id.SchemaId()}).WithLike(Like{String(id.Name())}) - functions, err := v.Show(ctx, request) + functions, err := v.Show(ctx, NewShowFunctionRequest().WithIn(In{Schema: id.SchemaId()}).WithLike(Like{String(id.Name())})) if err != nil { return nil, err } diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index c8d7a127d1..8bc4a33852 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -232,7 +232,7 @@ type SchemaObjectIdentifier struct { databaseName string schemaName string name string - // TODO(next prs ???): left right now for backward compatibility for procedures and externalFunctions + // TODO(next prs): left right now for backward compatibility for procedures and externalFunctions arguments []DataType } @@ -343,21 +343,15 @@ type SchemaObjectIdentifierWithArguments struct { func NewSchemaObjectIdentifierWithArguments(databaseName, schemaName, name string, argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { return SchemaObjectIdentifierWithArguments{ - databaseName: strings.Trim(databaseName, `"`), - schemaName: strings.Trim(schemaName, `"`), - name: strings.Trim(name, `"`), + databaseName: strings.Trim(databaseName, `"`), + schemaName: strings.Trim(schemaName, `"`), + name: strings.Trim(name, `"`), + argumentDataTypes: argumentDataTypes, } } -func NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(fullyQualifiedName string) SchemaObjectIdentifierWithArguments { - parts := strings.Split(fullyQualifiedName, ".") - id := SchemaObjectIdentifierWithArguments{ - databaseName: strings.Trim(parts[0], `"`), - schemaName: strings.Trim(parts[1], `"`), - name: strings.Trim(parts[2], `"`), - // TODO: Arguments - } - return id +func NewSchemaObjectIdentifierWithArgumentsInSchema(schemaId DatabaseObjectIdentifier, name string, argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(schemaId.DatabaseName(), schemaId.Name(), name, argumentDataTypes...) } func (i SchemaObjectIdentifierWithArguments) DatabaseName() string { @@ -392,7 +386,7 @@ func (i SchemaObjectIdentifierWithArguments) FullyQualifiedName() string { if i.schemaName == "" && i.databaseName == "" && i.name == "" && len(i.argumentDataTypes) == 0 { return "" } - return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(i.arguments, ",")) + return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(AsStringList(i.argumentDataTypes), ",")) } type TableColumnIdentifier struct { diff --git a/pkg/sdk/identifier_helpers_test.go b/pkg/sdk/identifier_helpers_test.go index 2f409cc84d..d3b0a98bd3 100644 --- a/pkg/sdk/identifier_helpers_test.go +++ b/pkg/sdk/identifier_helpers_test.go @@ -1,7 +1,6 @@ package sdk import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -83,42 +82,3 @@ func TestDatabaseObjectIdentifier(t *testing.T) { assert.Equal(t, `"aaa"."bbb"`, identifier.FullyQualifiedName()) }) } - -func TestNewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(t *testing.T) { - testCases := []struct { - RawInput string - Input SchemaObjectIdentifierWithArguments - Error string - }{ - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, DataTypeNumber, DataTypeTimestampTZ)}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, "VECTOR(INT, 20)", DataTypeFloat)}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)", "VECTOR(INT, 10)")}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat)}, - // TODO(): Won't work, because of the assumption that identifiers are not containing '(' and ')' parentheses - {Input: NewSchemaObjectIdentifierWithArguments(`ab()c`, `def()`, `()ghi`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat), Error: `unable to read identifier: "ab`}, - {Input: NewSchemaObjectIdentifierWithArguments(`ab(,)c`, `,def()`, `()ghi,`, DataTypeTime, "VECTOR(INT, 20)", "VECTOR(FLOAT, 10)", DataTypeFloat), Error: `unable to read identifier: "ab`}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`), RawInput: `abc.def.ghi()`}, - {Input: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)"), RawInput: `abc.def.ghi(FLOAT, VECTOR(INT, 20))`}, - } - - for _, testCase := range testCases { - t.Run(fmt.Sprintf("processing %s", testCase.Input.FullyQualifiedName()), func(t *testing.T) { - var id SchemaObjectIdentifierWithArguments - var err error - if testCase.RawInput != "" { - id, err = NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(testCase.RawInput) - } else { - id, err = NewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName(testCase.Input.FullyQualifiedName()) - } - - if testCase.Error != "" { - assert.ErrorContains(t, err, testCase.Error) - } else { - assert.NoError(t, err) - assert.Equal(t, testCase.Input.FullyQualifiedName(), id.FullyQualifiedName()) - } - }) - } -} diff --git a/pkg/sdk/poc/main.go b/pkg/sdk/poc/main.go index 7f79ec3361..f8f1014bdb 100644 --- a/pkg/sdk/poc/main.go +++ b/pkg/sdk/poc/main.go @@ -5,7 +5,6 @@ package main import ( "bytes" "fmt" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/genhelpers" "io" "log" "os" diff --git a/pkg/sdk/procedures_gen.go b/pkg/sdk/procedures_gen.go index 9e6acbf39f..b12fd6beeb 100644 --- a/pkg/sdk/procedures_gen.go +++ b/pkg/sdk/procedures_gen.go @@ -226,7 +226,8 @@ type Procedure struct { IsAnsi bool MinNumArguments int MaxNumArguments int - Arguments string + Arguments []DataType + ArgumentsRaw string Description string CatalogName string IsTableFunction bool @@ -235,8 +236,7 @@ type Procedure struct { } func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { - //return NewSchemaObjectIdentifier(v.CatalogName, v.SchemaName, v.Name) - return NewSchemaObjectIdentifierWithArguments("", "", "", "") + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) } // DescribeProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-procedure. diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index 24bae558d5..21ff800c8f 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -2,6 +2,8 @@ package sdk import ( "context" + "log" + "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" ) @@ -58,12 +60,11 @@ func (v *procedures) Show(ctx context.Context, request *ShowProcedureRequest) ([ } func (v *procedures) ShowByID(ctx context.Context, id SchemaObjectIdentifierWithArguments) (*Procedure, error) { - // TODO: adjust request if e.g. LIKE is supported for the resource - procedures, err := v.Show(ctx, NewShowProcedureRequest()) + procedures, err := v.Show(ctx, NewShowProcedureRequest().WithIn(In{Schema: id.SchemaId()}).WithLike(Like{String(id.Name())})) if err != nil { return nil, err } - return collections.FindOne(procedures, func(r Procedure) bool { return r.Name == id.Name() }) + return collections.FindOne(procedures, func(r Procedure) bool { return r.ID().FullyQualifiedName() == id.FullyQualifiedName() }) } func (v *procedures) Describe(ctx context.Context, id SchemaObjectIdentifierWithArguments) ([]ProcedureDetail, error) { @@ -394,12 +395,20 @@ func (r procedureRow) convert() *Procedure { IsAnsi: r.IsAnsi == "Y", MinNumArguments: r.MinNumArguments, MaxNumArguments: r.MaxNumArguments, - Arguments: r.Arguments, + ArgumentsRaw: r.Arguments, Description: r.Description, CatalogName: r.CatalogName, IsTableFunction: r.IsTableFunction == "Y", ValidForClustering: r.ValidForClustering == "Y", } + arguments := strings.TrimLeft(r.Arguments, r.Name) + returnIndex := strings.Index(arguments, ") RETURN ") + dataTypes, err := ParseFunctionArgumentsFromString(arguments[:returnIndex+1]) + if err != nil { + log.Printf("[DEBUG] failed to parse function arguments, err = %s", err) + } else { + e.Arguments = dataTypes + } if r.IsSecure.Valid { e.IsSecure = r.IsSecure.String == "Y" } @@ -425,7 +434,9 @@ func (r procedureDetailRow) convert() *ProcedureDetail { func (r *CallProcedureRequest) toOpts() *CallProcedureOptions { opts := &CallProcedureOptions{ + call: false, name: r.name, + CallArguments: r.CallArguments, ScriptingVariable: r.ScriptingVariable, } return opts diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index 70b8a7b9b4..811c8357ab 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -359,6 +359,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { assert.Equal(t, 1, procedure.MinNumArguments) assert.Equal(t, 1, procedure.MaxNumArguments) assert.NotEmpty(t, procedure.Arguments) + assert.NotEmpty(t, procedure.ArgumentsRaw) assert.NotEmpty(t, procedure.Description) assert.NotEmpty(t, procedure.CatalogName) assert.Equal(t, false, procedure.IsTableFunction) @@ -769,7 +770,6 @@ def filter_by_role(session, name, role): require.NoError(t, err) t.Cleanup(cleanupProcedureHandle(id)) - id = sdk.NewSchemaObjectIdentifierWithArguments(databaseTest.Name, schemaTest.Name, "filterByRole", sdk.DataTypeVARCHAR) ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} err = client.Procedures.Call(ctx, sdk.NewCallProcedureRequest(id.SchemaObjectId()).WithCallArguments(ca)) require.NoError(t, err) @@ -1011,8 +1011,8 @@ func TestInt_ProceduresShowByID(t *testing.T) { schema, schemaCleanup := testClientHelper().Schema.CreateSchema(t) t.Cleanup(schemaCleanup) - id1 := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() - id2 := testClientHelper().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(id1.Name(), schema.ID()) + id1 := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) + id2 := testClientHelper().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(id1.Name(), schema.ID(), sdk.DataTypeVARCHAR) createProcedureForSQLHandle(t, id1) createProcedureForSQLHandle(t, id2) From 31521e7f26b6bf2a18ce0a61c5d9a8a32c8fc027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 8 Aug 2024 16:16:11 +0200 Subject: [PATCH 03/19] wip --- pkg/resources/external_function.go | 1001 +++++++++-------- pkg/sdk/external_functions_def.go | 9 +- .../external_functions_dto_builders_gen.go | 123 +- pkg/sdk/external_functions_dto_gen.go | 13 +- pkg/sdk/external_functions_gen.go | 31 +- ...external_functions_gen_integration_test.go | 27 + pkg/sdk/external_functions_gen_test.go | 42 +- pkg/sdk/external_functions_impl_gen.go | 56 +- pkg/sdk/functions_gen_test.go | 9 +- pkg/sdk/procedures_impl_gen.go | 2 +- .../external_functions_integration_test.go | 157 ++- 11 files changed, 740 insertions(+), 730 deletions(-) create mode 100644 pkg/sdk/external_functions_gen_integration_test.go diff --git a/pkg/resources/external_function.go b/pkg/resources/external_function.go index e57e43ce51..51c8dcf8e3 100644 --- a/pkg/resources/external_function.go +++ b/pkg/resources/external_function.go @@ -1,505 +1,510 @@ package resources -import ( - "context" - "encoding/json" - "log" - "regexp" - "strconv" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - "github.com/hashicorp/go-cty/cty" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" -) - -var externalFunctionSchema = map[string]*schema.Schema{ - "name": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "Specifies the identifier for the external function. The identifier can contain the schema name and database name, as well as the function name. The function's signature (name and argument data types) must be unique within the schema.", - }, - "schema": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "The schema in which to create the external function.", - }, - "database": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "The database in which to create the external function.", - }, - "arg": { - Type: schema.TypeList, - Optional: true, - ForceNew: true, - Description: "Specifies the arguments/inputs for the external function. These should correspond to the arguments that the remote service expects.", - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "name": { - Type: schema.TypeString, - Required: true, - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) - }, - Description: "Argument name", - }, - "type": { - Type: schema.TypeString, - Required: true, - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) - }, - Description: "Argument type, e.g. VARCHAR", - }, - }, - }, - }, - "null_input_behavior": { - Type: schema.TypeString, - Optional: true, - Default: "CALLED ON NULL INPUT", - ForceNew: true, - ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT", "STRICT"}, false), - Description: "Specifies the behavior of the external function when called with null inputs.", - }, - "return_type": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) - }, - Description: "Specifies the data type returned by the external function.", - }, - "return_null_allowed": { - Type: schema.TypeBool, - Optional: true, - ForceNew: true, - Description: "Indicates whether the function can return NULL values (true) or must return only NON-NULL values (false).", - Default: true, - }, - "return_behavior": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), - Description: "Specifies the behavior of the function when returning results", - }, - "api_integration": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "The name of the API integration object that should be used to authenticate the call to the proxy service.", - }, - "header": { - Type: schema.TypeSet, - Optional: true, - ForceNew: true, - Description: "Allows users to specify key-value metadata that is sent with every request as HTTP headers.", - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "name": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "Header name", - }, - "value": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "Header value", - }, - }, - }, - }, - "context_headers": { - Type: schema.TypeList, - Elem: &schema.Schema{Type: schema.TypeString}, - Optional: true, - ForceNew: true, - // Suppress the diff shown if the values are equal when both compared in lower case. - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) - }, - Description: "Binds Snowflake context function results to HTTP headers.", - }, - "max_batch_rows": { - Type: schema.TypeInt, - Optional: true, - ForceNew: true, - Description: "This specifies the maximum number of rows in each batch sent to the proxy service.", - }, - "compression": { - Type: schema.TypeString, - Optional: true, - Default: "AUTO", - ForceNew: true, - ValidateFunc: validation.StringInSlice([]string{"NONE", "AUTO", "GZIP", "DEFLATE"}, false), - Description: "If specified, the JSON payload is compressed when sent from Snowflake to the proxy service, and when sent back from the proxy service to Snowflake.", - }, - "request_translator": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Description: "This specifies the name of the request translator function", - }, - "response_translator": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - Description: "This specifies the name of the response translator function.", - }, - "url_of_proxy_and_resource": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - Description: "This is the invocation URL of the proxy service and resource through which Snowflake calls the remote service.", - }, - "comment": { - Type: schema.TypeString, - Optional: true, - Default: "user-defined function", - Description: "A description of the external function.", - }, - "created_on": { - Type: schema.TypeString, - Computed: true, - Description: "Date and time when the external function was created.", - }, -} - -// ExternalFunction returns a pointer to the resource representing an external function. +import "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + +// import ( +// +// "context" +// "encoding/json" +// "log" +// "regexp" +// "strconv" +// "strings" +// +// "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" +// "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" +// "github.com/hashicorp/go-cty/cty" +// "github.com/hashicorp/terraform-plugin-sdk/v2/diag" +// "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" +// "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" +// +// ) +// +// var externalFunctionSchema = map[string]*schema.Schema{ +// "name": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "Specifies the identifier for the external function. The identifier can contain the schema name and database name, as well as the function name. The function's signature (name and argument data types) must be unique within the schema.", +// }, +// "schema": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "The schema in which to create the external function.", +// }, +// "database": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "The database in which to create the external function.", +// }, +// "arg": { +// Type: schema.TypeList, +// Optional: true, +// ForceNew: true, +// Description: "Specifies the arguments/inputs for the external function. These should correspond to the arguments that the remote service expects.", +// Elem: &schema.Resource{ +// Schema: map[string]*schema.Schema{ +// "name": { +// Type: schema.TypeString, +// Required: true, +// // Suppress the diff shown if the values are equal when both compared in lower case. +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) +// }, +// Description: "Argument name", +// }, +// "type": { +// Type: schema.TypeString, +// Required: true, +// // Suppress the diff shown if the values are equal when both compared in lower case. +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) +// }, +// Description: "Argument type, e.g. VARCHAR", +// }, +// }, +// }, +// }, +// "null_input_behavior": { +// Type: schema.TypeString, +// Optional: true, +// Default: "CALLED ON NULL INPUT", +// ForceNew: true, +// ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT", "STRICT"}, false), +// Description: "Specifies the behavior of the external function when called with null inputs.", +// }, +// "return_type": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// // Suppress the diff shown if the values are equal when both compared in lower case. +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) +// }, +// Description: "Specifies the data type returned by the external function.", +// }, +// "return_null_allowed": { +// Type: schema.TypeBool, +// Optional: true, +// ForceNew: true, +// Description: "Indicates whether the function can return NULL values (true) or must return only NON-NULL values (false).", +// Default: true, +// }, +// "return_behavior": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), +// Description: "Specifies the behavior of the function when returning results", +// }, +// "api_integration": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "The name of the API integration object that should be used to authenticate the call to the proxy service.", +// }, +// "header": { +// Type: schema.TypeSet, +// Optional: true, +// ForceNew: true, +// Description: "Allows users to specify key-value metadata that is sent with every request as HTTP headers.", +// Elem: &schema.Resource{ +// Schema: map[string]*schema.Schema{ +// "name": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "Header name", +// }, +// "value": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "Header value", +// }, +// }, +// }, +// }, +// "context_headers": { +// Type: schema.TypeList, +// Elem: &schema.Schema{Type: schema.TypeString}, +// Optional: true, +// ForceNew: true, +// // Suppress the diff shown if the values are equal when both compared in lower case. +// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { +// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) +// }, +// Description: "Binds Snowflake context function results to HTTP headers.", +// }, +// "max_batch_rows": { +// Type: schema.TypeInt, +// Optional: true, +// ForceNew: true, +// Description: "This specifies the maximum number of rows in each batch sent to the proxy service.", +// }, +// "compression": { +// Type: schema.TypeString, +// Optional: true, +// Default: "AUTO", +// ForceNew: true, +// ValidateFunc: validation.StringInSlice([]string{"NONE", "AUTO", "GZIP", "DEFLATE"}, false), +// Description: "If specified, the JSON payload is compressed when sent from Snowflake to the proxy service, and when sent back from the proxy service to Snowflake.", +// }, +// "request_translator": { +// Type: schema.TypeString, +// Optional: true, +// ForceNew: true, +// Description: "This specifies the name of the request translator function", +// }, +// "response_translator": { +// Type: schema.TypeString, +// Optional: true, +// ForceNew: true, +// Description: "This specifies the name of the response translator function.", +// }, +// "url_of_proxy_and_resource": { +// Type: schema.TypeString, +// Required: true, +// ForceNew: true, +// Description: "This is the invocation URL of the proxy service and resource through which Snowflake calls the remote service.", +// }, +// "comment": { +// Type: schema.TypeString, +// Optional: true, +// Default: "user-defined function", +// Description: "A description of the external function.", +// }, +// "created_on": { +// Type: schema.TypeString, +// Computed: true, +// Description: "Date and time when the external function was created.", +// }, +// } +// +// // ExternalFunction returns a pointer to the resource representing an external function. func ExternalFunction() *schema.Resource { return &schema.Resource{ - SchemaVersion: 1, - - CreateContext: CreateContextExternalFunction, - ReadContext: ReadContextExternalFunction, - UpdateContext: UpdateContextExternalFunction, - DeleteContext: DeleteContextExternalFunction, - - Schema: externalFunctionSchema, - Importer: &schema.ResourceImporter{ - StateContext: schema.ImportStatePassthroughContext, - }, - - StateUpgraders: []schema.StateUpgrader{ - { - Version: 0, - // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject - Type: cty.EmptyObject, - Upgrade: v085ExternalFunctionStateUpgrader, - }, - }, - } -} - -func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - database := d.Get("database").(string) - schemaName := d.Get("schema").(string) - name := d.Get("name").(string) - id := sdk.NewSchemaObjectIdentifier(database, schemaName, name) - - returnType := d.Get("return_type").(string) - resultDataType, err := sdk.ToDataType(returnType) - if err != nil { - return diag.FromErr(err) - } - apiIntegration := sdk.NewAccountObjectIdentifier(d.Get("api_integration").(string)) - urlOfProxyAndResource := d.Get("url_of_proxy_and_resource").(string) - req := sdk.NewCreateExternalFunctionRequest(id, resultDataType, &apiIntegration, urlOfProxyAndResource) - - // Set optionals - args := make([]sdk.ExternalFunctionArgumentRequest, 0) - if v, ok := d.GetOk("arg"); ok { - for _, arg := range v.([]interface{}) { - argName := arg.(map[string]interface{})["name"].(string) - argType := arg.(map[string]interface{})["type"].(string) - argDataType, err := sdk.ToDataType(argType) - if err != nil { - return diag.FromErr(err) - } - args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) - } - } - if len(args) > 0 { - req.WithArguments(args) - } - - if v, ok := d.GetOk("return_null_allowed"); ok { - if v.(bool) { - req.WithReturnNullValues(&sdk.ReturnNullValuesNull) - } else { - req.WithReturnNullValues(&sdk.ReturnNullValuesNotNull) - } - } - - if v, ok := d.GetOk("return_behavior"); ok { - if v.(string) == "VOLATILE" { - req.WithReturnResultsBehavior(&sdk.ReturnResultsBehaviorVolatile) - } else { - req.WithReturnResultsBehavior(&sdk.ReturnResultsBehaviorImmutable) - } - } - - if v, ok := d.GetOk("null_input_behavior"); ok { - switch { - case v.(string) == "CALLED ON NULL INPUT": - req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorCalledOnNullInput)) - case v.(string) == "RETURNS NULL ON NULL INPUT": - req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorReturnNullInput)) - default: - req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorStrict)) - } - } - - if v, ok := d.GetOk("comment"); ok { - req.WithComment(sdk.String(v.(string))) - } - - if _, ok := d.GetOk("header"); ok { - headers := make([]sdk.ExternalFunctionHeaderRequest, 0) - for _, header := range d.Get("header").(*schema.Set).List() { - m := header.(map[string]interface{}) - headerName := m["name"].(string) - headerValue := m["value"].(string) - headers = append(headers, sdk.ExternalFunctionHeaderRequest{ - Name: headerName, - Value: headerValue, - }) - } - req.WithHeaders(headers) - } - - if v, ok := d.GetOk("context_headers"); ok { - contextHeadersList := expandStringList(v.([]interface{})) - contextHeaders := make([]sdk.ExternalFunctionContextHeaderRequest, 0) - for _, header := range contextHeadersList { - contextHeaders = append(contextHeaders, sdk.ExternalFunctionContextHeaderRequest{ - ContextFunction: header, - }) - } - req.WithContextHeaders(contextHeaders) - } - - if v, ok := d.GetOk("max_batch_rows"); ok { - req.WithMaxBatchRows(sdk.Int(v.(int))) - } - - if v, ok := d.GetOk("compression"); ok { - req.WithCompression(sdk.String(v.(string))) - } - - if v, ok := d.GetOk("request_translator"); ok { - req.WithRequestTranslator(sdk.Pointer(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string)))) - } - - if v, ok := d.GetOk("response_translator"); ok { - req.WithResponseTranslator(sdk.Pointer(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string)))) - } - - if err := client.ExternalFunctions.Create(ctx, req); err != nil { - return diag.FromErr(err) + //SchemaVersion: 1, + // + //CreateContext: CreateContextExternalFunction, + //ReadContext: ReadContextExternalFunction, + //UpdateContext: UpdateContextExternalFunction, + //DeleteContext: DeleteContextExternalFunction, + // + //Schema: externalFunctionSchema, + //Importer: &schema.ResourceImporter{ + // StateContext: schema.ImportStatePassthroughContext, + //}, + // + //StateUpgraders: []schema.StateUpgrader{ + // { + // Version: 0, + // // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + // Type: cty.EmptyObject, + // Upgrade: v085ExternalFunctionStateUpgrader, + // }, + //}, } - argTypes := make([]sdk.DataType, 0, len(args)) - for _, item := range args { - argTypes = append(argTypes, item.ArgDataType) - } - sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schemaName, name, argTypes) - d.SetId(sid.FullyQualifiedName()) - return ReadContextExternalFunction(ctx, d, meta) -} - -func ReadContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - externalFunction, err := client.ExternalFunctions.ShowByID(ctx, id) - if err != nil { - d.SetId("") - return nil - } - - // Some properties can come from the SHOW EXTERNAL FUNCTION call - if err := d.Set("name", externalFunction.Name); err != nil { - return diag.FromErr(err) - } - - if err := d.Set("schema", strings.Trim(externalFunction.SchemaName, "\"")); err != nil { - return diag.FromErr(err) - } - - if err := d.Set("database", strings.Trim(externalFunction.CatalogName, "\"")); err != nil { - return diag.FromErr(err) - } - - if err := d.Set("comment", externalFunction.Description); err != nil { - return diag.FromErr(err) - } - - if err := d.Set("created_on", externalFunction.CreatedOn); err != nil { - return diag.FromErr(err) - } - - // Some properties come from the DESCRIBE FUNCTION call - externalFunctionPropertyRows, err := client.ExternalFunctions.Describe(ctx, sdk.NewDescribeExternalFunctionRequest(id.WithoutArguments(), id.Arguments())) - if err != nil { - d.SetId("") - return nil - } - - for _, row := range externalFunctionPropertyRows { - switch row.Property { - case "signature": - // Format in Snowflake DB is: (argName argType, argName argType, ...) - args := strings.ReplaceAll(strings.ReplaceAll(row.Value, "(", ""), ")", "") - - if args != "" { // Do nothing for functions without arguments - argPairs := strings.Split(args, ", ") - args := []interface{}{} - - for _, argPair := range argPairs { - argItem := strings.Split(argPair, " ") - - arg := map[string]interface{}{} - arg["name"] = argItem[0] - arg["type"] = argItem[1] - args = append(args, arg) - } - - if err := d.Set("arg", args); err != nil { - return diag.Errorf("error setting arg: %v", err) - } - } - case "returns": - returnType := row.Value - // We first check for VARIANT or OBJECT - if returnType == "VARIANT" || returnType == "OBJECT" { - if err := d.Set("return_type", returnType); err != nil { - return diag.Errorf("error setting return_type: %v", err) - } - break - } - - // otherwise, format in Snowflake DB is returnType() - re := regexp.MustCompile(`^(\w+)\([0-9]*\)$`) - match := re.FindStringSubmatch(row.Value) - if len(match) < 2 { - return diag.Errorf("return_type %s not recognized", returnType) - } - if err := d.Set("return_type", match[1]); err != nil { - return diag.Errorf("error setting return_type: %v", err) - } - - case "null handling": - if err := d.Set("null_input_behavior", row.Value); err != nil { - return diag.Errorf("error setting null_input_behavior: %v", err) - } - case "volatility": - if err := d.Set("return_behavior", row.Value); err != nil { - return diag.Errorf("error setting return_behavior: %v", err) - } - case "headers": - if row.Value != "" && row.Value != "null" { - // Format in Snowflake DB is: {"head1":"val1","head2":"val2"} - var jsonHeaders map[string]string - err := json.Unmarshal([]byte(row.Value), &jsonHeaders) - if err != nil { - return diag.Errorf("error unmarshalling headers: %v", err) - } - - headers := make([]any, 0, len(jsonHeaders)) - for key, value := range jsonHeaders { - headers = append(headers, map[string]any{ - "name": key, - "value": value, - }) - } - - if err := d.Set("header", headers); err != nil { - return diag.Errorf("error setting return_behavior: %v", err) - } - } - case "context_headers": - if row.Value != "" && row.Value != "null" { - // Format in Snowflake DB is: ["CONTEXT_FUNCTION_1","CONTEXT_FUNCTION_2"] - contextHeaders := strings.Split(strings.Trim(row.Value, "[]"), ",") - for i, v := range contextHeaders { - contextHeaders[i] = strings.Trim(v, "\"") - } - if err := d.Set("context_headers", contextHeaders); err != nil { - return diag.Errorf("error setting context_headers: %v", err) - } - } - case "max_batch_rows": - if row.Value != "not set" { - maxBatchRows, err := strconv.ParseInt(row.Value, 10, 64) - if err != nil { - return diag.Errorf("error parsing max_batch_rows: %v", err) - } - if err := d.Set("max_batch_rows", maxBatchRows); err != nil { - return diag.Errorf("error setting max_batch_rows: %v", err) - } - } - case "compression": - if err := d.Set("compression", row.Value); err != nil { - return diag.Errorf("error setting compression: %v", err) - } - case "body": - if err := d.Set("url_of_proxy_and_resource", row.Value); err != nil { - return diag.Errorf("error setting url_of_proxy_and_resource: %v", err) - } - case "language": - // To ignore - default: - log.Printf("[WARN] unexpected external function property %v returned from Snowflake", row.Property) - } - } - - return nil -} - -func UpdateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - req := sdk.NewAlterFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) - if d.HasChange("comment") { - _, new := d.GetChange("comment") - if new == "" { - req.UnsetComment = sdk.Bool(true) - } else { - req.SetComment = sdk.String(new.(string)) - } - err := client.Functions.Alter(ctx, req) - if err != nil { - return diag.FromErr(err) - } - } - return ReadContextExternalFunction(ctx, d, meta) } -func DeleteContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*provider.Context).Client - - id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) - req := sdk.NewDropFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) - if err := client.Functions.Drop(ctx, req); err != nil { - return diag.FromErr(err) - } - - d.SetId("") - return nil -} +// +//func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// database := d.Get("database").(string) +// schemaName := d.Get("schema").(string) +// name := d.Get("name").(string) +// id := sdk.NewSchemaObjectIdentifier(database, schemaName, name) +// +// returnType := d.Get("return_type").(string) +// resultDataType, err := sdk.ToDataType(returnType) +// if err != nil { +// return diag.FromErr(err) +// } +// apiIntegration := sdk.NewAccountObjectIdentifier(d.Get("api_integration").(string)) +// urlOfProxyAndResource := d.Get("url_of_proxy_and_resource").(string) +// req := sdk.NewCreateExternalFunctionRequest(id, resultDataType, &apiIntegration, urlOfProxyAndResource) +// +// // Set optionals +// args := make([]sdk.ExternalFunctionArgumentRequest, 0) +// if v, ok := d.GetOk("arg"); ok { +// for _, arg := range v.([]interface{}) { +// argName := arg.(map[string]interface{})["name"].(string) +// argType := arg.(map[string]interface{})["type"].(string) +// argDataType, err := sdk.ToDataType(argType) +// if err != nil { +// return diag.FromErr(err) +// } +// args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) +// } +// } +// if len(args) > 0 { +// req.WithArguments(args) +// } +// +// if v, ok := d.GetOk("return_null_allowed"); ok { +// if v.(bool) { +// req.WithReturnNullValues(&sdk.ReturnNullValuesNull) +// } else { +// req.WithReturnNullValues(&sdk.ReturnNullValuesNotNull) +// } +// } +// +// if v, ok := d.GetOk("return_behavior"); ok { +// if v.(string) == "VOLATILE" { +// req.WithReturnResultsBehavior(&sdk.ReturnResultsBehaviorVolatile) +// } else { +// req.WithReturnResultsBehavior(&sdk.ReturnResultsBehaviorImmutable) +// } +// } +// +// if v, ok := d.GetOk("null_input_behavior"); ok { +// switch { +// case v.(string) == "CALLED ON NULL INPUT": +// req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorCalledOnNullInput)) +// case v.(string) == "RETURNS NULL ON NULL INPUT": +// req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorReturnNullInput)) +// default: +// req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorStrict)) +// } +// } +// +// if v, ok := d.GetOk("comment"); ok { +// req.WithComment(sdk.String(v.(string))) +// } +// +// if _, ok := d.GetOk("header"); ok { +// headers := make([]sdk.ExternalFunctionHeaderRequest, 0) +// for _, header := range d.Get("header").(*schema.Set).List() { +// m := header.(map[string]interface{}) +// headerName := m["name"].(string) +// headerValue := m["value"].(string) +// headers = append(headers, sdk.ExternalFunctionHeaderRequest{ +// Name: headerName, +// Value: headerValue, +// }) +// } +// req.WithHeaders(headers) +// } +// +// if v, ok := d.GetOk("context_headers"); ok { +// contextHeadersList := expandStringList(v.([]interface{})) +// contextHeaders := make([]sdk.ExternalFunctionContextHeaderRequest, 0) +// for _, header := range contextHeadersList { +// contextHeaders = append(contextHeaders, sdk.ExternalFunctionContextHeaderRequest{ +// ContextFunction: header, +// }) +// } +// req.WithContextHeaders(contextHeaders) +// } +// +// if v, ok := d.GetOk("max_batch_rows"); ok { +// req.WithMaxBatchRows(sdk.Int(v.(int))) +// } +// +// if v, ok := d.GetOk("compression"); ok { +// req.WithCompression(sdk.String(v.(string))) +// } +// +// if v, ok := d.GetOk("request_translator"); ok { +// req.WithRequestTranslator(sdk.Pointer(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string)))) +// } +// +// if v, ok := d.GetOk("response_translator"); ok { +// req.WithResponseTranslator(sdk.Pointer(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string)))) +// } +// +// if err := client.ExternalFunctions.Create(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// argTypes := make([]sdk.DataType, 0, len(args)) +// for _, item := range args { +// argTypes = append(argTypes, item.ArgDataType) +// } +// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schemaName, name, argTypes) +// d.SetId(sid.FullyQualifiedName()) +// return ReadContextExternalFunction(ctx, d, meta) +//} +// +//func ReadContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// +// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) +// externalFunction, err := client.ExternalFunctions.ShowByID(ctx, id) +// if err != nil { +// d.SetId("") +// return nil +// } +// +// // Some properties can come from the SHOW EXTERNAL FUNCTION call +// if err := d.Set("name", externalFunction.Name); err != nil { +// return diag.FromErr(err) +// } +// +// if err := d.Set("schema", strings.Trim(externalFunction.SchemaName, "\"")); err != nil { +// return diag.FromErr(err) +// } +// +// if err := d.Set("database", strings.Trim(externalFunction.CatalogName, "\"")); err != nil { +// return diag.FromErr(err) +// } +// +// if err := d.Set("comment", externalFunction.Description); err != nil { +// return diag.FromErr(err) +// } +// +// if err := d.Set("created_on", externalFunction.CreatedOn); err != nil { +// return diag.FromErr(err) +// } +// +// // Some properties come from the DESCRIBE FUNCTION call +// externalFunctionPropertyRows, err := client.ExternalFunctions.Describe(ctx, sdk.NewDescribeExternalFunctionRequest(id.WithoutArguments(), id.Arguments())) +// if err != nil { +// d.SetId("") +// return nil +// } +// +// for _, row := range externalFunctionPropertyRows { +// switch row.Property { +// case "signature": +// // Format in Snowflake DB is: (argName argType, argName argType, ...) +// args := strings.ReplaceAll(strings.ReplaceAll(row.Value, "(", ""), ")", "") +// +// if args != "" { // Do nothing for functions without arguments +// argPairs := strings.Split(args, ", ") +// args := []interface{}{} +// +// for _, argPair := range argPairs { +// argItem := strings.Split(argPair, " ") +// +// arg := map[string]interface{}{} +// arg["name"] = argItem[0] +// arg["type"] = argItem[1] +// args = append(args, arg) +// } +// +// if err := d.Set("arg", args); err != nil { +// return diag.Errorf("error setting arg: %v", err) +// } +// } +// case "returns": +// returnType := row.Value +// // We first check for VARIANT or OBJECT +// if returnType == "VARIANT" || returnType == "OBJECT" { +// if err := d.Set("return_type", returnType); err != nil { +// return diag.Errorf("error setting return_type: %v", err) +// } +// break +// } +// +// // otherwise, format in Snowflake DB is returnType() +// re := regexp.MustCompile(`^(\w+)\([0-9]*\)$`) +// match := re.FindStringSubmatch(row.Value) +// if len(match) < 2 { +// return diag.Errorf("return_type %s not recognized", returnType) +// } +// if err := d.Set("return_type", match[1]); err != nil { +// return diag.Errorf("error setting return_type: %v", err) +// } +// +// case "null handling": +// if err := d.Set("null_input_behavior", row.Value); err != nil { +// return diag.Errorf("error setting null_input_behavior: %v", err) +// } +// case "volatility": +// if err := d.Set("return_behavior", row.Value); err != nil { +// return diag.Errorf("error setting return_behavior: %v", err) +// } +// case "headers": +// if row.Value != "" && row.Value != "null" { +// // Format in Snowflake DB is: {"head1":"val1","head2":"val2"} +// var jsonHeaders map[string]string +// err := json.Unmarshal([]byte(row.Value), &jsonHeaders) +// if err != nil { +// return diag.Errorf("error unmarshalling headers: %v", err) +// } +// +// headers := make([]any, 0, len(jsonHeaders)) +// for key, value := range jsonHeaders { +// headers = append(headers, map[string]any{ +// "name": key, +// "value": value, +// }) +// } +// +// if err := d.Set("header", headers); err != nil { +// return diag.Errorf("error setting return_behavior: %v", err) +// } +// } +// case "context_headers": +// if row.Value != "" && row.Value != "null" { +// // Format in Snowflake DB is: ["CONTEXT_FUNCTION_1","CONTEXT_FUNCTION_2"] +// contextHeaders := strings.Split(strings.Trim(row.Value, "[]"), ",") +// for i, v := range contextHeaders { +// contextHeaders[i] = strings.Trim(v, "\"") +// } +// if err := d.Set("context_headers", contextHeaders); err != nil { +// return diag.Errorf("error setting context_headers: %v", err) +// } +// } +// case "max_batch_rows": +// if row.Value != "not set" { +// maxBatchRows, err := strconv.ParseInt(row.Value, 10, 64) +// if err != nil { +// return diag.Errorf("error parsing max_batch_rows: %v", err) +// } +// if err := d.Set("max_batch_rows", maxBatchRows); err != nil { +// return diag.Errorf("error setting max_batch_rows: %v", err) +// } +// } +// case "compression": +// if err := d.Set("compression", row.Value); err != nil { +// return diag.Errorf("error setting compression: %v", err) +// } +// case "body": +// if err := d.Set("url_of_proxy_and_resource", row.Value); err != nil { +// return diag.Errorf("error setting url_of_proxy_and_resource: %v", err) +// } +// case "language": +// // To ignore +// default: +// log.Printf("[WARN] unexpected external function property %v returned from Snowflake", row.Property) +// } +// } +// +// return nil +//} +// +//func UpdateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// +// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) +// req := sdk.NewAlterFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) +// if d.HasChange("comment") { +// _, new := d.GetChange("comment") +// if new == "" { +// req.UnsetComment = sdk.Bool(true) +// } else { +// req.SetComment = sdk.String(new.(string)) +// } +// err := client.Functions.Alter(ctx, req) +// if err != nil { +// return diag.FromErr(err) +// } +// } +// return ReadContextExternalFunction(ctx, d, meta) +//} +// +//func DeleteContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { +// client := meta.(*provider.Context).Client +// +// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) +// req := sdk.NewDropFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) +// if err := client.Functions.Drop(ctx, req); err != nil { +// return diag.FromErr(err) +// } +// +// d.SetId("") +// return nil +//} diff --git a/pkg/sdk/external_functions_def.go b/pkg/sdk/external_functions_def.go index 093800ec8a..af7aeee9de 100644 --- a/pkg/sdk/external_functions_def.go +++ b/pkg/sdk/external_functions_def.go @@ -46,7 +46,7 @@ var externalFunctionUnset = g.NewQueryStruct("ExternalFunctionUnset"). var ExternalFunctionsDef = g.NewInterface( "ExternalFunctions", "ExternalFunction", - g.KindOfT[SchemaObjectIdentifier](), + g.KindOfT[SchemaObjectIdentifierWithArguments](), ).CreateOperation( "https://docs.snowflake.com/en/sql-reference/sql/create-external-function", g.NewQueryStruct("CreateExternalFunction"). @@ -54,7 +54,7 @@ var ExternalFunctionsDef = g.NewInterface( OrReplace(). OptionalSQL("SECURE"). SQL("EXTERNAL FUNCTION"). - Name(). + Identifier("name", g.KindOfT[SchemaObjectIdentifier](), g.IdentifierOptions().Required()). ListQueryStructField( "Arguments", externalFunctionArgument, @@ -92,7 +92,6 @@ var ExternalFunctionsDef = g.NewInterface( SQL("FUNCTION"). IfExists(). Name(). - PredefinedQueryStructField("ArgumentDataTypes", g.KindOfTSlice[DataType](), g.KeywordOptions().MustParentheses().Required()). OptionalQueryStructField( "Set", externalFunctionSet, @@ -148,7 +147,8 @@ var ExternalFunctionsDef = g.NewInterface( g.NewQueryStruct("ShowFunctions"). Show(). SQL("EXTERNAL FUNCTIONS"). - OptionalLike(), + OptionalLike(). + OptionalIn(), ).ShowByIdOperation().DescribeOperation( g.DescriptionMappingKindSlice, "https://docs.snowflake.com/en/sql-reference/sql/desc-function", @@ -162,6 +162,5 @@ var ExternalFunctionsDef = g.NewInterface( Describe(). SQL("FUNCTION"). Name(). - PredefinedQueryStructField("ArgumentDataTypes", g.KindOfTSlice[DataType](), g.KeywordOptions().MustParentheses().Required()). WithValidation(g.ValidIdentifier, "name"), ) diff --git a/pkg/sdk/external_functions_dto_builders_gen.go b/pkg/sdk/external_functions_dto_builders_gen.go index 1522073db3..e43eecb785 100644 --- a/pkg/sdk/external_functions_dto_builders_gen.go +++ b/pkg/sdk/external_functions_dto_builders_gen.go @@ -2,6 +2,8 @@ package sdk +import () + func NewCreateExternalFunctionRequest( name SchemaObjectIdentifier, ResultDataType DataType, @@ -16,13 +18,13 @@ func NewCreateExternalFunctionRequest( return &s } -func (s *CreateExternalFunctionRequest) WithOrReplace(OrReplace *bool) *CreateExternalFunctionRequest { - s.OrReplace = OrReplace +func (s *CreateExternalFunctionRequest) WithOrReplace(OrReplace bool) *CreateExternalFunctionRequest { + s.OrReplace = &OrReplace return s } -func (s *CreateExternalFunctionRequest) WithSecure(Secure *bool) *CreateExternalFunctionRequest { - s.Secure = Secure +func (s *CreateExternalFunctionRequest) WithSecure(Secure bool) *CreateExternalFunctionRequest { + s.Secure = &Secure return s } @@ -31,23 +33,23 @@ func (s *CreateExternalFunctionRequest) WithArguments(Arguments []ExternalFuncti return s } -func (s *CreateExternalFunctionRequest) WithReturnNullValues(ReturnNullValues *ReturnNullValues) *CreateExternalFunctionRequest { - s.ReturnNullValues = ReturnNullValues +func (s *CreateExternalFunctionRequest) WithReturnNullValues(ReturnNullValues ReturnNullValues) *CreateExternalFunctionRequest { + s.ReturnNullValues = &ReturnNullValues return s } -func (s *CreateExternalFunctionRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateExternalFunctionRequest { - s.NullInputBehavior = NullInputBehavior +func (s *CreateExternalFunctionRequest) WithNullInputBehavior(NullInputBehavior NullInputBehavior) *CreateExternalFunctionRequest { + s.NullInputBehavior = &NullInputBehavior return s } -func (s *CreateExternalFunctionRequest) WithReturnResultsBehavior(ReturnResultsBehavior *ReturnResultsBehavior) *CreateExternalFunctionRequest { - s.ReturnResultsBehavior = ReturnResultsBehavior +func (s *CreateExternalFunctionRequest) WithReturnResultsBehavior(ReturnResultsBehavior ReturnResultsBehavior) *CreateExternalFunctionRequest { + s.ReturnResultsBehavior = &ReturnResultsBehavior return s } -func (s *CreateExternalFunctionRequest) WithComment(Comment *string) *CreateExternalFunctionRequest { - s.Comment = Comment +func (s *CreateExternalFunctionRequest) WithComment(Comment string) *CreateExternalFunctionRequest { + s.Comment = &Comment return s } @@ -61,23 +63,23 @@ func (s *CreateExternalFunctionRequest) WithContextHeaders(ContextHeaders []Exte return s } -func (s *CreateExternalFunctionRequest) WithMaxBatchRows(MaxBatchRows *int) *CreateExternalFunctionRequest { - s.MaxBatchRows = MaxBatchRows +func (s *CreateExternalFunctionRequest) WithMaxBatchRows(MaxBatchRows int) *CreateExternalFunctionRequest { + s.MaxBatchRows = &MaxBatchRows return s } -func (s *CreateExternalFunctionRequest) WithCompression(Compression *string) *CreateExternalFunctionRequest { - s.Compression = Compression +func (s *CreateExternalFunctionRequest) WithCompression(Compression string) *CreateExternalFunctionRequest { + s.Compression = &Compression return s } -func (s *CreateExternalFunctionRequest) WithRequestTranslator(RequestTranslator *SchemaObjectIdentifier) *CreateExternalFunctionRequest { - s.RequestTranslator = RequestTranslator +func (s *CreateExternalFunctionRequest) WithRequestTranslator(RequestTranslator SchemaObjectIdentifier) *CreateExternalFunctionRequest { + s.RequestTranslator = &RequestTranslator return s } -func (s *CreateExternalFunctionRequest) WithResponseTranslator(ResponseTranslator *SchemaObjectIdentifier) *CreateExternalFunctionRequest { - s.ResponseTranslator = ResponseTranslator +func (s *CreateExternalFunctionRequest) WithResponseTranslator(ResponseTranslator SchemaObjectIdentifier) *CreateExternalFunctionRequest { + s.ResponseTranslator = &ResponseTranslator return s } @@ -110,27 +112,25 @@ func NewExternalFunctionContextHeaderRequest( } func NewAlterExternalFunctionRequest( - name SchemaObjectIdentifier, - ArgumentDataTypes []DataType, + name SchemaObjectIdentifierWithArguments, ) *AlterExternalFunctionRequest { s := AlterExternalFunctionRequest{} s.name = name - s.ArgumentDataTypes = ArgumentDataTypes return &s } -func (s *AlterExternalFunctionRequest) WithIfExists(IfExists *bool) *AlterExternalFunctionRequest { - s.IfExists = IfExists +func (s *AlterExternalFunctionRequest) WithIfExists(IfExists bool) *AlterExternalFunctionRequest { + s.IfExists = &IfExists return s } -func (s *AlterExternalFunctionRequest) WithSet(Set *ExternalFunctionSetRequest) *AlterExternalFunctionRequest { - s.Set = Set +func (s *AlterExternalFunctionRequest) WithSet(Set ExternalFunctionSetRequest) *AlterExternalFunctionRequest { + s.Set = &Set return s } -func (s *AlterExternalFunctionRequest) WithUnset(Unset *ExternalFunctionUnsetRequest) *AlterExternalFunctionRequest { - s.Unset = Unset +func (s *AlterExternalFunctionRequest) WithUnset(Unset ExternalFunctionUnsetRequest) *AlterExternalFunctionRequest { + s.Unset = &Unset return s } @@ -138,8 +138,8 @@ func NewExternalFunctionSetRequest() *ExternalFunctionSetRequest { return &ExternalFunctionSetRequest{} } -func (s *ExternalFunctionSetRequest) WithApiIntegration(ApiIntegration *AccountObjectIdentifier) *ExternalFunctionSetRequest { - s.ApiIntegration = ApiIntegration +func (s *ExternalFunctionSetRequest) WithApiIntegration(ApiIntegration AccountObjectIdentifier) *ExternalFunctionSetRequest { + s.ApiIntegration = &ApiIntegration return s } @@ -153,23 +153,23 @@ func (s *ExternalFunctionSetRequest) WithContextHeaders(ContextHeaders []Externa return s } -func (s *ExternalFunctionSetRequest) WithMaxBatchRows(MaxBatchRows *int) *ExternalFunctionSetRequest { - s.MaxBatchRows = MaxBatchRows +func (s *ExternalFunctionSetRequest) WithMaxBatchRows(MaxBatchRows int) *ExternalFunctionSetRequest { + s.MaxBatchRows = &MaxBatchRows return s } -func (s *ExternalFunctionSetRequest) WithCompression(Compression *string) *ExternalFunctionSetRequest { - s.Compression = Compression +func (s *ExternalFunctionSetRequest) WithCompression(Compression string) *ExternalFunctionSetRequest { + s.Compression = &Compression return s } -func (s *ExternalFunctionSetRequest) WithRequestTranslator(RequestTranslator *SchemaObjectIdentifier) *ExternalFunctionSetRequest { - s.RequestTranslator = RequestTranslator +func (s *ExternalFunctionSetRequest) WithRequestTranslator(RequestTranslator SchemaObjectIdentifier) *ExternalFunctionSetRequest { + s.RequestTranslator = &RequestTranslator return s } -func (s *ExternalFunctionSetRequest) WithResponseTranslator(ResponseTranslator *SchemaObjectIdentifier) *ExternalFunctionSetRequest { - s.ResponseTranslator = ResponseTranslator +func (s *ExternalFunctionSetRequest) WithResponseTranslator(ResponseTranslator SchemaObjectIdentifier) *ExternalFunctionSetRequest { + s.ResponseTranslator = &ResponseTranslator return s } @@ -177,43 +177,43 @@ func NewExternalFunctionUnsetRequest() *ExternalFunctionUnsetRequest { return &ExternalFunctionUnsetRequest{} } -func (s *ExternalFunctionUnsetRequest) WithComment(Comment *bool) *ExternalFunctionUnsetRequest { - s.Comment = Comment +func (s *ExternalFunctionUnsetRequest) WithComment(Comment bool) *ExternalFunctionUnsetRequest { + s.Comment = &Comment return s } -func (s *ExternalFunctionUnsetRequest) WithHeaders(Headers *bool) *ExternalFunctionUnsetRequest { - s.Headers = Headers +func (s *ExternalFunctionUnsetRequest) WithHeaders(Headers bool) *ExternalFunctionUnsetRequest { + s.Headers = &Headers return s } -func (s *ExternalFunctionUnsetRequest) WithContextHeaders(ContextHeaders *bool) *ExternalFunctionUnsetRequest { - s.ContextHeaders = ContextHeaders +func (s *ExternalFunctionUnsetRequest) WithContextHeaders(ContextHeaders bool) *ExternalFunctionUnsetRequest { + s.ContextHeaders = &ContextHeaders return s } -func (s *ExternalFunctionUnsetRequest) WithMaxBatchRows(MaxBatchRows *bool) *ExternalFunctionUnsetRequest { - s.MaxBatchRows = MaxBatchRows +func (s *ExternalFunctionUnsetRequest) WithMaxBatchRows(MaxBatchRows bool) *ExternalFunctionUnsetRequest { + s.MaxBatchRows = &MaxBatchRows return s } -func (s *ExternalFunctionUnsetRequest) WithCompression(Compression *bool) *ExternalFunctionUnsetRequest { - s.Compression = Compression +func (s *ExternalFunctionUnsetRequest) WithCompression(Compression bool) *ExternalFunctionUnsetRequest { + s.Compression = &Compression return s } -func (s *ExternalFunctionUnsetRequest) WithSecure(Secure *bool) *ExternalFunctionUnsetRequest { - s.Secure = Secure +func (s *ExternalFunctionUnsetRequest) WithSecure(Secure bool) *ExternalFunctionUnsetRequest { + s.Secure = &Secure return s } -func (s *ExternalFunctionUnsetRequest) WithRequestTranslator(RequestTranslator *bool) *ExternalFunctionUnsetRequest { - s.RequestTranslator = RequestTranslator +func (s *ExternalFunctionUnsetRequest) WithRequestTranslator(RequestTranslator bool) *ExternalFunctionUnsetRequest { + s.RequestTranslator = &RequestTranslator return s } -func (s *ExternalFunctionUnsetRequest) WithResponseTranslator(ResponseTranslator *bool) *ExternalFunctionUnsetRequest { - s.ResponseTranslator = ResponseTranslator +func (s *ExternalFunctionUnsetRequest) WithResponseTranslator(ResponseTranslator bool) *ExternalFunctionUnsetRequest { + s.ResponseTranslator = &ResponseTranslator return s } @@ -221,22 +221,15 @@ func NewShowExternalFunctionRequest() *ShowExternalFunctionRequest { return &ShowExternalFunctionRequest{} } -func (s *ShowExternalFunctionRequest) WithLike(Like *Like) *ShowExternalFunctionRequest { - s.Like = Like - return s -} - -func (s *ShowExternalFunctionRequest) WithIn(In *In) *ShowExternalFunctionRequest { - s.In = In +func (s *ShowExternalFunctionRequest) WithLike(Like Like) *ShowExternalFunctionRequest { + s.Like = &Like return s } func NewDescribeExternalFunctionRequest( - name SchemaObjectIdentifier, - ArgumentDataTypes []DataType, + name SchemaObjectIdentifierWithArguments, ) *DescribeExternalFunctionRequest { s := DescribeExternalFunctionRequest{} s.name = name - s.ArgumentDataTypes = ArgumentDataTypes return &s } diff --git a/pkg/sdk/external_functions_dto_gen.go b/pkg/sdk/external_functions_dto_gen.go index 975a7fd2f9..50b16edaac 100644 --- a/pkg/sdk/external_functions_dto_gen.go +++ b/pkg/sdk/external_functions_dto_gen.go @@ -44,11 +44,10 @@ type ExternalFunctionContextHeaderRequest struct { } type AlterExternalFunctionRequest struct { - IfExists *bool - name SchemaObjectIdentifier // required - ArgumentDataTypes []DataType // required - Set *ExternalFunctionSetRequest - Unset *ExternalFunctionUnsetRequest + IfExists *bool + name SchemaObjectIdentifierWithArguments // required + Set *ExternalFunctionSetRequest + Unset *ExternalFunctionUnsetRequest } type ExternalFunctionSetRequest struct { @@ -74,10 +73,8 @@ type ExternalFunctionUnsetRequest struct { type ShowExternalFunctionRequest struct { Like *Like - In *In } type DescribeExternalFunctionRequest struct { - name SchemaObjectIdentifier // required - ArgumentDataTypes []DataType // required + name SchemaObjectIdentifierWithArguments // required } diff --git a/pkg/sdk/external_functions_gen.go b/pkg/sdk/external_functions_gen.go index d201f65dfe..7a380b4fa1 100644 --- a/pkg/sdk/external_functions_gen.go +++ b/pkg/sdk/external_functions_gen.go @@ -9,8 +9,8 @@ type ExternalFunctions interface { Create(ctx context.Context, request *CreateExternalFunctionRequest) error Alter(ctx context.Context, request *AlterExternalFunctionRequest) error Show(ctx context.Context, request *ShowExternalFunctionRequest) ([]ExternalFunction, error) - ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*ExternalFunction, error) - Describe(ctx context.Context, request *DescribeExternalFunctionRequest) ([]ExternalFunctionProperty, error) + ShowByID(ctx context.Context, id SchemaObjectIdentifierWithArguments) (*ExternalFunction, error) + Describe(ctx context.Context, id SchemaObjectIdentifierWithArguments) ([]ExternalFunctionProperty, error) } // CreateExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-external-function. @@ -52,13 +52,12 @@ type ExternalFunctionContextHeader struct { // AlterExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/alter-function. type AlterExternalFunctionOptions struct { - alter bool `ddl:"static" sql:"ALTER"` - function bool `ddl:"static" sql:"FUNCTION"` - IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` - name SchemaObjectIdentifier `ddl:"identifier"` - ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` - Set *ExternalFunctionSet `ddl:"keyword" sql:"SET"` - Unset *ExternalFunctionUnset `ddl:"list,no_parentheses" sql:"UNSET"` + alter bool `ddl:"static" sql:"ALTER"` + function bool `ddl:"static" sql:"FUNCTION"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name SchemaObjectIdentifierWithArguments `ddl:"identifier"` + Set *ExternalFunctionSet `ddl:"keyword" sql:"SET"` + Unset *ExternalFunctionUnset `ddl:"list,no_parentheses" sql:"UNSET"` } type ExternalFunctionSet struct { @@ -120,7 +119,8 @@ type ExternalFunction struct { IsAnsi bool MinNumArguments int MaxNumArguments int - Arguments string + Arguments []DataType + ArgumentsRaw string Description string CatalogName string IsTableFunction bool @@ -132,16 +132,15 @@ type ExternalFunction struct { IsDataMetric bool } -func (v *ExternalFunction) ID() SchemaObjectIdentifier { - return NewSchemaObjectIdentifier(v.CatalogName, v.SchemaName, v.Name) +func (v *ExternalFunction) ID() SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) } // DescribeExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-function. type DescribeExternalFunctionOptions struct { - describe bool `ddl:"static" sql:"DESCRIBE"` - function bool `ddl:"static" sql:"FUNCTION"` - name SchemaObjectIdentifier `ddl:"identifier"` - ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` + describe bool `ddl:"static" sql:"DESCRIBE"` + function bool `ddl:"static" sql:"FUNCTION"` + name SchemaObjectIdentifierWithArguments `ddl:"identifier"` } type externalFunctionPropertyRow struct { diff --git a/pkg/sdk/external_functions_gen_integration_test.go b/pkg/sdk/external_functions_gen_integration_test.go new file mode 100644 index 0000000000..d6b43eb45e --- /dev/null +++ b/pkg/sdk/external_functions_gen_integration_test.go @@ -0,0 +1,27 @@ +package sdk + +import "testing" + +func TestInt_ExternalFunctions(t *testing.T) { + // TODO: prepare common resources + + t.Run("Create", func(t *testing.T) { + // TODO: fill me + }) + + t.Run("Alter", func(t *testing.T) { + // TODO: fill me + }) + + t.Run("Show", func(t *testing.T) { + // TODO: fill me + }) + + t.Run("ShowByID", func(t *testing.T) { + // TODO: fill me + }) + + t.Run("Describe", func(t *testing.T) { + // TODO: fill me + }) +} diff --git a/pkg/sdk/external_functions_gen_test.go b/pkg/sdk/external_functions_gen_test.go index 073d9b3caf..939fe0488d 100644 --- a/pkg/sdk/external_functions_gen_test.go +++ b/pkg/sdk/external_functions_gen_test.go @@ -94,13 +94,13 @@ func TestExternalFunctions_Create(t *testing.T) { } func TestExternalFunctions_Alter(t *testing.T) { - id := randomSchemaObjectIdentifier() + noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) defaultOpts := func() *AlterExternalFunctionOptions { return &AlterExternalFunctionOptions{ - name: id, - IfExists: Bool(true), - ArgumentDataTypes: []DataType{DataTypeVARCHAR, DataTypeNumber}, + name: id, + IfExists: Bool(true), } } @@ -111,7 +111,7 @@ func TestExternalFunctions_Alter(t *testing.T) { t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() - opts.name = emptySchemaObjectIdentifier + opts.name = emptySchemaObjectIdentifierWithArguments assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) @@ -156,7 +156,7 @@ func TestExternalFunctions_Alter(t *testing.T) { opts.Set = &ExternalFunctionSet{ ApiIntegration: &integration, } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET API_INTEGRATION = "api_integration"`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET API_INTEGRATION = "api_integration"`, id.FullyQualifiedName()) }) t.Run("alter: set headers", func(t *testing.T) { @@ -173,7 +173,7 @@ func TestExternalFunctions_Alter(t *testing.T) { }, }, } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET HEADERS = ('header1' = 'value1', 'header2' = 'value2')`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET HEADERS = ('header1' = 'value1', 'header2' = 'value2')`, id.FullyQualifiedName()) }) t.Run("alter: set max batch rows", func(t *testing.T) { @@ -181,7 +181,7 @@ func TestExternalFunctions_Alter(t *testing.T) { opts.Set = &ExternalFunctionSet{ MaxBatchRows: Int(100), } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET MAX_BATCH_ROWS = 100`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET MAX_BATCH_ROWS = 100`, id.FullyQualifiedName()) }) t.Run("alter: set compression", func(t *testing.T) { @@ -189,7 +189,7 @@ func TestExternalFunctions_Alter(t *testing.T) { opts.Set = &ExternalFunctionSet{ Compression: String("GZIP"), } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET COMPRESSION = GZIP`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET COMPRESSION = GZIP`, id.FullyQualifiedName()) }) t.Run("alter: set context headers", func(t *testing.T) { @@ -204,7 +204,7 @@ func TestExternalFunctions_Alter(t *testing.T) { }, }, } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET CONTEXT_HEADERS = (CURRENT_ACCOUNT, CURRENT_USER)`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET CONTEXT_HEADERS = (CURRENT_ACCOUNT, CURRENT_USER)`, id.FullyQualifiedName()) }) t.Run("alter: set request translator", func(t *testing.T) { @@ -213,7 +213,7 @@ func TestExternalFunctions_Alter(t *testing.T) { opts.Set = &ExternalFunctionSet{ RequestTranslator: &rt, } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET REQUEST_TRANSLATOR = %s`, id.FullyQualifiedName(), rt.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET REQUEST_TRANSLATOR = %s`, id.FullyQualifiedName(), rt.FullyQualifiedName()) }) t.Run("alter: set response translator", func(t *testing.T) { @@ -222,12 +222,11 @@ func TestExternalFunctions_Alter(t *testing.T) { opts.Set = &ExternalFunctionSet{ ResponseTranslator: &st, } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET RESPONSE_TRANSLATOR = %s`, id.FullyQualifiedName(), st.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s SET RESPONSE_TRANSLATOR = %s`, id.FullyQualifiedName(), st.FullyQualifiedName()) }) t.Run("alter: unset", func(t *testing.T) { opts := defaultOpts() - opts.ArgumentDataTypes = []DataType{DataTypeVARCHAR, DataTypeNumber} opts.Unset = &ExternalFunctionUnset{ Comment: Bool(true), Headers: Bool(true), @@ -238,12 +237,12 @@ func TestExternalFunctions_Alter(t *testing.T) { RequestTranslator: Bool(true), ResponseTranslator: Bool(true), } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) UNSET COMMENT, HEADERS, CONTEXT_HEADERS, MAX_BATCH_ROWS, COMPRESSION, SECURE, REQUEST_TRANSLATOR, RESPONSE_TRANSLATOR`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s UNSET COMMENT, HEADERS, CONTEXT_HEADERS, MAX_BATCH_ROWS, COMPRESSION, SECURE, REQUEST_TRANSLATOR, RESPONSE_TRANSLATOR`, id.FullyQualifiedName()) }) t.Run("alter: unset with no arguments", func(t *testing.T) { opts := defaultOpts() - opts.ArgumentDataTypes = nil + opts.name = noArgsId opts.Unset = &ExternalFunctionUnset{ Comment: Bool(true), Headers: Bool(true), @@ -254,7 +253,7 @@ func TestExternalFunctions_Alter(t *testing.T) { RequestTranslator: Bool(true), ResponseTranslator: Bool(true), } - assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s () UNSET COMMENT, HEADERS, CONTEXT_HEADERS, MAX_BATCH_ROWS, COMPRESSION, SECURE, REQUEST_TRANSLATOR, RESPONSE_TRANSLATOR`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s UNSET COMMENT, HEADERS, CONTEXT_HEADERS, MAX_BATCH_ROWS, COMPRESSION, SECURE, REQUEST_TRANSLATOR, RESPONSE_TRANSLATOR`, noArgsId.FullyQualifiedName()) }) } @@ -292,7 +291,8 @@ func TestExternalFunctions_Show(t *testing.T) { } func TestExternalFunctions_Describe(t *testing.T) { - id := randomSchemaObjectIdentifier() + noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) defaultOpts := func() *DescribeExternalFunctionOptions { return &DescribeExternalFunctionOptions{ @@ -307,18 +307,18 @@ func TestExternalFunctions_Describe(t *testing.T) { t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() - opts.name = emptySchemaObjectIdentifier + opts.name = emptySchemaObjectIdentifierWithArguments assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() - assertOptsValidAndSQLEquals(t, opts, `DESCRIBE FUNCTION %s ()`, id.FullyQualifiedName()) + opts.name = noArgsId + assertOptsValidAndSQLEquals(t, opts, `DESCRIBE FUNCTION %s`, noArgsId.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { opts := defaultOpts() - opts.ArgumentDataTypes = []DataType{DataTypeVARCHAR, DataTypeNumber} - assertOptsValidAndSQLEquals(t, opts, `DESCRIBE FUNCTION %s (VARCHAR, NUMBER)`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `DESCRIBE FUNCTION %s`, id.FullyQualifiedName()) }) } diff --git a/pkg/sdk/external_functions_impl_gen.go b/pkg/sdk/external_functions_impl_gen.go index 2eac8e7c0f..ca27121431 100644 --- a/pkg/sdk/external_functions_impl_gen.go +++ b/pkg/sdk/external_functions_impl_gen.go @@ -2,6 +2,9 @@ package sdk import ( "context" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "log" + "strings" ) var _ ExternalFunctions = (*externalFunctions)(nil) @@ -30,37 +33,18 @@ func (v *externalFunctions) Show(ctx context.Context, request *ShowExternalFunct return resultList, nil } -func (v *externalFunctions) ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*ExternalFunction, error) { - return nil, nil - // TODO - //arguments := id.Arguments() - //externalFunctions, err := v.Show(ctx, NewShowExternalFunctionRequest(). - // WithIn(&In{Schema: id.SchemaId()}). - // WithLike(&Like{Pattern: String(id.Name())})) - //if err != nil { - // return nil, err - //} - //return collections.FindOne(externalFunctions, func(r ExternalFunction) bool { - // database := strings.Trim(r.CatalogName, `"`) - // schema := strings.Trim(r.SchemaName, `"`) - // if r.Name != id.Name() || database != id.DatabaseName() || schema != id.SchemaName() { - // return false - // } - // var sb strings.Builder - // sb.WriteString("(") - // for i, argument := range arguments { - // sb.WriteString(string(argument)) - // if i < len(arguments)-1 { - // sb.WriteString(", ") - // } - // } - // sb.WriteString(")") - // return strings.Contains(r.Arguments, sb.String()) - //}) +func (v *externalFunctions) ShowByID(ctx context.Context, id SchemaObjectIdentifierWithArguments) (*ExternalFunction, error) { + externalFunctions, err := v.Show(ctx, NewShowExternalFunctionRequest().WithLike(Like{String(id.Name())})) + if err != nil { + return nil, err + } + return collections.FindOne(externalFunctions, func(r ExternalFunction) bool { return r.ID().FullyQualifiedName() == id.FullyQualifiedName() }) } -func (v *externalFunctions) Describe(ctx context.Context, request *DescribeExternalFunctionRequest) ([]ExternalFunctionProperty, error) { - opts := request.toOpts() +func (v *externalFunctions) Describe(ctx context.Context, id SchemaObjectIdentifierWithArguments) ([]ExternalFunctionProperty, error) { + opts := &DescribeFunctionOptions{ + name: id, + } rows, err := validateAndQuery[externalFunctionPropertyRow](v.client, ctx, opts) if err != nil { return nil, err @@ -156,7 +140,6 @@ func (r *AlterExternalFunctionRequest) toOpts() *AlterExternalFunctionOptions { func (r *ShowExternalFunctionRequest) toOpts() *ShowExternalFunctionOptions { opts := &ShowExternalFunctionOptions{ Like: r.Like, - In: r.In, } return opts } @@ -170,13 +153,21 @@ func (r externalFunctionRow) convert() *ExternalFunction { IsAnsi: r.IsAnsi == "Y", MinNumArguments: r.MinNumArguments, MaxNumArguments: r.MaxNumArguments, - Arguments: r.Arguments, + ArgumentsRaw: r.Arguments, Description: r.Description, IsTableFunction: r.IsTableFunction == "Y", ValidForClustering: r.ValidForClustering == "Y", IsExternalFunction: r.IsExternalFunction == "Y", Language: r.Language, } + arguments := strings.TrimLeft(r.Arguments, r.Name) + returnIndex := strings.Index(arguments, ") RETURN ") + dataTypes, err := ParseFunctionArgumentsFromString(arguments[:returnIndex+1]) + if err != nil { + log.Printf("[DEBUG] failed to parse external function arguments, err = %s", err) + } else { + e.Arguments = dataTypes + } if r.SchemaName.Valid { e.SchemaName = r.SchemaName.String } @@ -197,8 +188,7 @@ func (r externalFunctionRow) convert() *ExternalFunction { func (r *DescribeExternalFunctionRequest) toOpts() *DescribeExternalFunctionOptions { opts := &DescribeExternalFunctionOptions{ - name: r.name, - ArgumentDataTypes: r.ArgumentDataTypes, + name: r.name, } return opts } diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index 5234a1fd7f..0bb4778832 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -431,7 +431,8 @@ func TestFunctions_CreateForSQL(t *testing.T) { } func TestFunctions_Drop(t *testing.T) { - id := randomSchemaObjectIdentifier() + noArgsId := randomSchemaObjectIdentifierWithArguments() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) defaultOpts := func() *DropFunctionOptions { return &DropFunctionOptions{ @@ -466,7 +467,8 @@ func TestFunctions_Drop(t *testing.T) { } func TestFunctions_Alter(t *testing.T) { - id := randomSchemaObjectIdentifier() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) + noArgsId := randomSchemaObjectIdentifierWithArguments() defaultOpts := func() *AlterFunctionOptions { return &AlterFunctionOptions{ @@ -614,7 +616,8 @@ func TestFunctions_Show(t *testing.T) { } func TestFunctions_Describe(t *testing.T) { - id := randomSchemaObjectIdentifier() + id := randomSchemaObjectIdentifierWithArguments(DataTypeVARCHAR, DataTypeNumber) + noArgsId := randomSchemaObjectIdentifierWithArguments() defaultOpts := func() *DescribeFunctionOptions { return &DescribeFunctionOptions{ diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index 21ff800c8f..0f58bc946b 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -405,7 +405,7 @@ func (r procedureRow) convert() *Procedure { returnIndex := strings.Index(arguments, ") RETURN ") dataTypes, err := ParseFunctionArgumentsFromString(arguments[:returnIndex+1]) if err != nil { - log.Printf("[DEBUG] failed to parse function arguments, err = %s", err) + log.Printf("[DEBUG] failed to parse procedure arguments, err = %s", err) } else { e.Arguments = dataTypes } diff --git a/pkg/sdk/testint/external_functions_integration_test.go b/pkg/sdk/testint/external_functions_integration_test.go index 2901d5e3f7..c84555a1cf 100644 --- a/pkg/sdk/testint/external_functions_integration_test.go +++ b/pkg/sdk/testint/external_functions_integration_test.go @@ -18,35 +18,34 @@ func TestInt_ExternalFunctions(t *testing.T) { integration, integrationCleanup := testClientHelper().ApiIntegration.CreateApiIntegration(t) t.Cleanup(integrationCleanup) - cleanupExternalFunctionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { + cleanupExternalFunctionHandle := func(id sdk.SchemaObjectIdentifierWithArguments) func() { return func() { - err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), dts...)).WithIfExists(true)) + err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id).WithIfExists(true)) require.NoError(t, err) } } // TODO [SNOW-999049]: id returned on purpose; address during identifiers rework - createExternalFunction := func(t *testing.T) (*sdk.ExternalFunction, sdk.SchemaObjectIdentifier) { + createExternalFunction := func(t *testing.T) *sdk.ExternalFunction { t.Helper() - id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(defaultDataTypes...) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) argument := sdk.NewExternalFunctionArgumentRequest("x", defaultDataTypes[0]) as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). - WithOrReplace(sdk.Bool(true)). - WithSecure(sdk.Bool(true)). + request := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). + WithOrReplace(true). + WithSecure(true). WithArguments([]sdk.ExternalFunctionArgumentRequest{*argument}) err := client.ExternalFunctions.Create(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupExternalFunctionHandle(id.WithoutArguments(), []sdk.DataType{sdk.DataTypeVariant})) + t.Cleanup(cleanupExternalFunctionHandle(id)) e, err := client.ExternalFunctions.ShowByID(ctx, id) require.NoError(t, err) - return e, id + return e } - assertExternalFunction := func(t *testing.T, id sdk.SchemaObjectIdentifier, secure bool) { + assertExternalFunction := func(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments, secure bool) { t.Helper() - dts := id.Arguments() e, err := client.ExternalFunctions.ShowByID(ctx, id) require.NoError(t, err) @@ -57,14 +56,15 @@ func TestInt_ExternalFunctions(t *testing.T) { require.Equal(t, false, e.IsBuiltin) require.Equal(t, false, e.IsAggregate) require.Equal(t, false, e.IsAnsi) - if len(dts) > 0 { + if len(id.ArgumentDataTypes()) > 0 { + require.NotEmpty(t, e.Arguments) require.Equal(t, 1, e.MinNumArguments) require.Equal(t, 1, e.MaxNumArguments) } else { + require.Empty(t, e.Arguments) require.Equal(t, 0, e.MinNumArguments) require.Equal(t, 0, e.MaxNumArguments) } - require.NotEmpty(t, e.Arguments) require.NotEmpty(t, e.Description) require.NotEmpty(t, e.CatalogName) require.Equal(t, false, e.IsTableFunction) @@ -77,7 +77,7 @@ func TestInt_ExternalFunctions(t *testing.T) { } t.Run("create external function", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(defaultDataTypes...) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(defaultDataTypes...) argument := sdk.NewExternalFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) headers := []sdk.ExternalFunctionHeaderRequest{ { @@ -94,46 +94,46 @@ func TestInt_ExternalFunctions(t *testing.T) { }, } as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). - WithOrReplace(sdk.Bool(true)). - WithSecure(sdk.Bool(true)). + request := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as). + WithOrReplace(true). + WithSecure(true). WithArguments([]sdk.ExternalFunctionArgumentRequest{*argument}). - WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)). + WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)). WithHeaders(headers). WithContextHeaders(ch). - WithMaxBatchRows(sdk.Int(10)). - WithCompression(sdk.String("GZIP")) + WithMaxBatchRows(10). + WithCompression("GZIP") err := client.ExternalFunctions.Create(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupExternalFunctionHandle(id.WithoutArguments(), []sdk.DataType{sdk.DataTypeVariant})) + t.Cleanup(cleanupExternalFunctionHandle(id)) assertExternalFunction(t, id, true) }) t.Run("create external function without arguments", func(t *testing.T) { - id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArgumentsOld() + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" - request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as) + request := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), sdk.DataTypeVariant, sdk.Pointer(integration.ID()), as) err := client.ExternalFunctions.Create(ctx, request) require.NoError(t, err) - t.Cleanup(cleanupExternalFunctionHandle(id, nil)) + t.Cleanup(cleanupExternalFunctionHandle(id)) assertExternalFunction(t, id, false) }) t.Run("alter external function: set api integration", func(t *testing.T) { - _, id := createExternalFunction(t) + externalFunction := createExternalFunction(t) set := sdk.NewExternalFunctionSetRequest(). - WithApiIntegration(sdk.Pointer(integration.ID())) - request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + WithApiIntegration(integration.ID()) + request := sdk.NewAlterExternalFunctionRequest(externalFunction.ID()).WithSet(*set) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) - assertExternalFunction(t, id, true) + assertExternalFunction(t, externalFunction.ID(), true) }) t.Run("alter external function: set headers", func(t *testing.T) { - _, id := createExternalFunction(t) + externalFunction := createExternalFunction(t) headers := []sdk.ExternalFunctionHeaderRequest{ { @@ -142,14 +142,14 @@ func TestInt_ExternalFunctions(t *testing.T) { }, } set := sdk.NewExternalFunctionSetRequest().WithHeaders(headers) - request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + request := sdk.NewAlterExternalFunctionRequest(externalFunction.ID()).WithSet(*set) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) - assertExternalFunction(t, id, true) + assertExternalFunction(t, externalFunction.ID(), true) }) t.Run("alter external function: set context headers", func(t *testing.T) { - _, id := createExternalFunction(t) + externalFunction := createExternalFunction(t) ch := []sdk.ExternalFunctionContextHeaderRequest{ { @@ -160,50 +160,50 @@ func TestInt_ExternalFunctions(t *testing.T) { }, } set := sdk.NewExternalFunctionSetRequest().WithContextHeaders(ch) - request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + request := sdk.NewAlterExternalFunctionRequest(externalFunction.ID()).WithSet(*set) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) - assertExternalFunction(t, id, true) + assertExternalFunction(t, externalFunction.ID(), true) }) t.Run("alter external function: set compression", func(t *testing.T) { - _, id := createExternalFunction(t) + externalFunction := createExternalFunction(t) - set := sdk.NewExternalFunctionSetRequest().WithCompression(sdk.String("AUTO")) - request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + set := sdk.NewExternalFunctionSetRequest().WithCompression("AUTO") + request := sdk.NewAlterExternalFunctionRequest(externalFunction.ID()).WithSet(*set) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) - assertExternalFunction(t, id, true) + assertExternalFunction(t, externalFunction.ID(), true) }) t.Run("alter external function: set max batch rows", func(t *testing.T) { - _, id := createExternalFunction(t) + externalFunction := createExternalFunction(t) - set := sdk.NewExternalFunctionSetRequest().WithMaxBatchRows(sdk.Int(20)) - request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + set := sdk.NewExternalFunctionSetRequest().WithMaxBatchRows(20) + request := sdk.NewAlterExternalFunctionRequest(externalFunction.ID()).WithSet(*set) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) - assertExternalFunction(t, id, true) + assertExternalFunction(t, externalFunction.ID(), true) }) t.Run("alter external function: unset", func(t *testing.T) { - _, id := createExternalFunction(t) + externalFunction := createExternalFunction(t) unset := sdk.NewExternalFunctionUnsetRequest(). - WithComment(sdk.Bool(true)). - WithHeaders(sdk.Bool(true)) - request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithUnset(unset) + WithComment(true). + WithHeaders(true) + request := sdk.NewAlterExternalFunctionRequest(externalFunction.ID()).WithUnset(*unset) err := client.ExternalFunctions.Alter(ctx, request) require.NoError(t, err) - assertExternalFunction(t, id, true) + assertExternalFunction(t, externalFunction.ID(), true) }) t.Run("show external function: with like", func(t *testing.T) { - e1, _ := createExternalFunction(t) - e2, _ := createExternalFunction(t) + e1 := createExternalFunction(t) + e2 := createExternalFunction(t) - es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithLike(&sdk.Like{Pattern: sdk.String(e1.Name)})) + es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithLike(sdk.Like{Pattern: sdk.String(e1.Name)})) require.NoError(t, err) require.Equal(t, 1, len(es)) @@ -211,50 +211,47 @@ func TestInt_ExternalFunctions(t *testing.T) { require.NotContains(t, es, *e2) }) - t.Run("show external function: with in", func(t *testing.T) { - otherDb, otherDbCleanup := testClientHelper().Database.CreateDatabase(t) - t.Cleanup(otherDbCleanup) - - e1, _ := createExternalFunction(t) - - es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithIn(&sdk.In{Schema: e1.ID().SchemaId()})) - require.NoError(t, err) - - require.Contains(t, es, *e1) - - es, err = client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithIn(&sdk.In{Database: testClientHelper().Ids.DatabaseId()})) - require.NoError(t, err) - - require.Contains(t, es, *e1) - - es, err = client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithIn(&sdk.In{Database: otherDb.ID()})) - require.NoError(t, err) - - require.Empty(t, es) - }) + // TODO: Uncomment + //t.Run("show external function: with in", func(t *testing.T) { + // otherDb, otherDbCleanup := testClientHelper().Database.CreateDatabase(t) + // t.Cleanup(otherDbCleanup) + // + // e1 := createExternalFunction(t) + // + // es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithIn(&sdk.In{Schema: e1.ID().SchemaId()})) + // require.NoError(t, err) + // + // require.Contains(t, es, *e1) + // + // es, err = client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithIn(&sdk.In{Database: testClientHelper().Ids.DatabaseId()})) + // require.NoError(t, err) + // + // require.Contains(t, es, *e1) + // + // es, err = client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithIn(&sdk.In{Database: otherDb.ID()})) + // require.NoError(t, err) + // + // require.Empty(t, es) + //}) t.Run("show external function: no matches", func(t *testing.T) { - es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithLike(&sdk.Like{Pattern: sdk.String("non-existing-id-pattern")})) + es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithLike(sdk.Like{Pattern: sdk.String("non-existing-id-pattern")})) require.NoError(t, err) require.Equal(t, 0, len(es)) }) t.Run("show external function by id", func(t *testing.T) { - e, id := createExternalFunction(t) + e := createExternalFunction(t) - es, err := client.ExternalFunctions.ShowByID(ctx, id) + es, err := client.ExternalFunctions.ShowByID(ctx, e.ID()) require.NoError(t, err) require.Equal(t, *e, *es) - - _, err = client.ExternalFunctions.ShowByID(ctx, id.WithoutArguments()) - require.Error(t, err, sdk.ErrObjectNotExistOrAuthorized) }) t.Run("describe external function", func(t *testing.T) { - e, _ := createExternalFunction(t) + e := createExternalFunction(t) - request := sdk.NewDescribeExternalFunctionRequest(e.ID(), []sdk.DataType{sdk.DataTypeVARCHAR}) - details, err := client.ExternalFunctions.Describe(ctx, request) + details, err := client.ExternalFunctions.Describe(ctx, e.ID()) require.NoError(t, err) pairs := make(map[string]string) for _, detail := range details { From f0b3fe276f081012ce38ff359b8ffc5cf6d9f89d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Fri, 9 Aug 2024 10:44:37 +0200 Subject: [PATCH 04/19] wip --- pkg/resources/external_function.go | 1019 +++++++++-------- .../external_function_acceptance_test.go | 130 +++ pkg/resources/function.go | 10 +- pkg/resources/function_acceptance_test.go | 3 - pkg/resources/identifiers_state_upgraders.go | 21 + pkg/resources/procedure.go | 19 +- pkg/resources/procedure_acceptance_test.go | 117 +- .../external_functions_dto_builders_gen.go | 5 + pkg/sdk/external_functions_dto_gen.go | 1 + pkg/sdk/external_functions_impl_gen.go | 6 +- .../external_functions_integration_test.go | 46 +- .../testint/procedures_integration_test.go | 20 +- 12 files changed, 843 insertions(+), 554 deletions(-) create mode 100644 pkg/resources/identifiers_state_upgraders.go diff --git a/pkg/resources/external_function.go b/pkg/resources/external_function.go index 51c8dcf8e3..cce334978f 100644 --- a/pkg/resources/external_function.go +++ b/pkg/resources/external_function.go @@ -1,510 +1,523 @@ package resources -import "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - -// import ( -// -// "context" -// "encoding/json" -// "log" -// "regexp" -// "strconv" -// "strings" -// -// "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" -// "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" -// "github.com/hashicorp/go-cty/cty" -// "github.com/hashicorp/terraform-plugin-sdk/v2/diag" -// "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" -// "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" -// -// ) -// -// var externalFunctionSchema = map[string]*schema.Schema{ -// "name": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "Specifies the identifier for the external function. The identifier can contain the schema name and database name, as well as the function name. The function's signature (name and argument data types) must be unique within the schema.", -// }, -// "schema": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "The schema in which to create the external function.", -// }, -// "database": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "The database in which to create the external function.", -// }, -// "arg": { -// Type: schema.TypeList, -// Optional: true, -// ForceNew: true, -// Description: "Specifies the arguments/inputs for the external function. These should correspond to the arguments that the remote service expects.", -// Elem: &schema.Resource{ -// Schema: map[string]*schema.Schema{ -// "name": { -// Type: schema.TypeString, -// Required: true, -// // Suppress the diff shown if the values are equal when both compared in lower case. -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) -// }, -// Description: "Argument name", -// }, -// "type": { -// Type: schema.TypeString, -// Required: true, -// // Suppress the diff shown if the values are equal when both compared in lower case. -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) -// }, -// Description: "Argument type, e.g. VARCHAR", -// }, -// }, -// }, -// }, -// "null_input_behavior": { -// Type: schema.TypeString, -// Optional: true, -// Default: "CALLED ON NULL INPUT", -// ForceNew: true, -// ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT", "STRICT"}, false), -// Description: "Specifies the behavior of the external function when called with null inputs.", -// }, -// "return_type": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// // Suppress the diff shown if the values are equal when both compared in lower case. -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) -// }, -// Description: "Specifies the data type returned by the external function.", -// }, -// "return_null_allowed": { -// Type: schema.TypeBool, -// Optional: true, -// ForceNew: true, -// Description: "Indicates whether the function can return NULL values (true) or must return only NON-NULL values (false).", -// Default: true, -// }, -// "return_behavior": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), -// Description: "Specifies the behavior of the function when returning results", -// }, -// "api_integration": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "The name of the API integration object that should be used to authenticate the call to the proxy service.", -// }, -// "header": { -// Type: schema.TypeSet, -// Optional: true, -// ForceNew: true, -// Description: "Allows users to specify key-value metadata that is sent with every request as HTTP headers.", -// Elem: &schema.Resource{ -// Schema: map[string]*schema.Schema{ -// "name": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "Header name", -// }, -// "value": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "Header value", -// }, -// }, -// }, -// }, -// "context_headers": { -// Type: schema.TypeList, -// Elem: &schema.Schema{Type: schema.TypeString}, -// Optional: true, -// ForceNew: true, -// // Suppress the diff shown if the values are equal when both compared in lower case. -// DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { -// return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) -// }, -// Description: "Binds Snowflake context function results to HTTP headers.", -// }, -// "max_batch_rows": { -// Type: schema.TypeInt, -// Optional: true, -// ForceNew: true, -// Description: "This specifies the maximum number of rows in each batch sent to the proxy service.", -// }, -// "compression": { -// Type: schema.TypeString, -// Optional: true, -// Default: "AUTO", -// ForceNew: true, -// ValidateFunc: validation.StringInSlice([]string{"NONE", "AUTO", "GZIP", "DEFLATE"}, false), -// Description: "If specified, the JSON payload is compressed when sent from Snowflake to the proxy service, and when sent back from the proxy service to Snowflake.", -// }, -// "request_translator": { -// Type: schema.TypeString, -// Optional: true, -// ForceNew: true, -// Description: "This specifies the name of the request translator function", -// }, -// "response_translator": { -// Type: schema.TypeString, -// Optional: true, -// ForceNew: true, -// Description: "This specifies the name of the response translator function.", -// }, -// "url_of_proxy_and_resource": { -// Type: schema.TypeString, -// Required: true, -// ForceNew: true, -// Description: "This is the invocation URL of the proxy service and resource through which Snowflake calls the remote service.", -// }, -// "comment": { -// Type: schema.TypeString, -// Optional: true, -// Default: "user-defined function", -// Description: "A description of the external function.", -// }, -// "created_on": { -// Type: schema.TypeString, -// Computed: true, -// Description: "Date and time when the external function was created.", -// }, -// } -// -// // ExternalFunction returns a pointer to the resource representing an external function. +import ( + "context" + "encoding/json" + "log" + "regexp" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/go-cty/cty" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" +) + +var externalFunctionSchema = map[string]*schema.Schema{ + "name": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "Specifies the identifier for the external function. The identifier can contain the schema name and database name, as well as the function name. The function's signature (name and argument data types) must be unique within the schema.", + }, + "schema": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "The schema in which to create the external function.", + }, + "database": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "The database in which to create the external function.", + }, + "arg": { + Type: schema.TypeList, + Optional: true, + ForceNew: true, + Description: "Specifies the arguments/inputs for the external function. These should correspond to the arguments that the remote service expects.", + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "name": { + Type: schema.TypeString, + Required: true, + // Suppress the diff shown if the values are equal when both compared in lower case. + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) + }, + Description: "Argument name", + }, + "type": { + Type: schema.TypeString, + Required: true, + // Suppress the diff shown if the values are equal when both compared in lower case. + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) + }, + Description: "Argument type, e.g. VARCHAR", + }, + }, + }, + }, + "null_input_behavior": { + Type: schema.TypeString, + Optional: true, + Default: "CALLED ON NULL INPUT", + ForceNew: true, + ValidateFunc: validation.StringInSlice([]string{"CALLED ON NULL INPUT", "RETURNS NULL ON NULL INPUT", "STRICT"}, false), + Description: "Specifies the behavior of the external function when called with null inputs.", + }, + "return_type": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + // Suppress the diff shown if the values are equal when both compared in lower case. + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) + }, + Description: "Specifies the data type returned by the external function.", + }, + "return_null_allowed": { + Type: schema.TypeBool, + Optional: true, + ForceNew: true, + Description: "Indicates whether the function can return NULL values (true) or must return only NON-NULL values (false).", + Default: true, + }, + "return_behavior": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + ValidateFunc: validation.StringInSlice([]string{"VOLATILE", "IMMUTABLE"}, false), + Description: "Specifies the behavior of the function when returning results", + }, + "api_integration": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "The name of the API integration object that should be used to authenticate the call to the proxy service.", + }, + "header": { + Type: schema.TypeSet, + Optional: true, + ForceNew: true, + Description: "Allows users to specify key-value metadata that is sent with every request as HTTP headers.", + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "name": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "Header name", + }, + "value": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "Header value", + }, + }, + }, + }, + "context_headers": { + Type: schema.TypeList, + Elem: &schema.Schema{Type: schema.TypeString}, + Optional: true, + ForceNew: true, + // Suppress the diff shown if the values are equal when both compared in lower case. + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return strings.EqualFold(strings.ToLower(old), strings.ToLower(new)) + }, + Description: "Binds Snowflake context function results to HTTP headers.", + }, + "max_batch_rows": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + Description: "This specifies the maximum number of rows in each batch sent to the proxy service.", + }, + "compression": { + Type: schema.TypeString, + Optional: true, + Default: "AUTO", + ForceNew: true, + ValidateFunc: validation.StringInSlice([]string{"NONE", "AUTO", "GZIP", "DEFLATE"}, false), + Description: "If specified, the JSON payload is compressed when sent from Snowflake to the proxy service, and when sent back from the proxy service to Snowflake.", + }, + "request_translator": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Description: "This specifies the name of the request translator function", + }, + "response_translator": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Description: "This specifies the name of the response translator function.", + }, + "url_of_proxy_and_resource": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + Description: "This is the invocation URL of the proxy service and resource through which Snowflake calls the remote service.", + }, + "comment": { + Type: schema.TypeString, + Optional: true, + Default: "user-defined function", + Description: "A description of the external function.", + }, + "created_on": { + Type: schema.TypeString, + Computed: true, + Description: "Date and time when the external function was created.", + }, +} + func ExternalFunction() *schema.Resource { return &schema.Resource{ - //SchemaVersion: 1, - // - //CreateContext: CreateContextExternalFunction, - //ReadContext: ReadContextExternalFunction, - //UpdateContext: UpdateContextExternalFunction, - //DeleteContext: DeleteContextExternalFunction, - // - //Schema: externalFunctionSchema, - //Importer: &schema.ResourceImporter{ - // StateContext: schema.ImportStatePassthroughContext, - //}, - // - //StateUpgraders: []schema.StateUpgrader{ - // { - // Version: 0, - // // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject - // Type: cty.EmptyObject, - // Upgrade: v085ExternalFunctionStateUpgrader, - // }, - //}, + SchemaVersion: 2, + + CreateContext: CreateContextExternalFunction, + ReadContext: ReadContextExternalFunction, + UpdateContext: UpdateContextExternalFunction, + DeleteContext: DeleteContextExternalFunction, + + Schema: externalFunctionSchema, + Importer: &schema.ResourceImporter{ + StateContext: schema.ImportStatePassthroughContext, + }, + + StateUpgraders: []schema.StateUpgrader{ + { + Version: 0, + // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + Type: cty.EmptyObject, + Upgrade: v085ExternalFunctionStateUpgrader, + }, + { + Version: 1, + // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + Type: cty.EmptyObject, + Upgrade: v0941ResourceIdentifierWithArguments, + }, + }, + } +} + +func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + database := d.Get("database").(string) + schemaName := d.Get("schema").(string) + name := d.Get("name").(string) + args := make([]sdk.ExternalFunctionArgumentRequest, 0) + if v, ok := d.GetOk("arg"); ok { + for _, arg := range v.([]interface{}) { + argName := arg.(map[string]interface{})["name"].(string) + argType := arg.(map[string]interface{})["type"].(string) + argDataType, err := sdk.ToDataType(argType) + if err != nil { + return diag.FromErr(err) + } + args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + } + } + argTypes := make([]sdk.DataType, 0, len(args)) + for _, item := range args { + argTypes = append(argTypes, item.ArgDataType) + } + id := sdk.NewSchemaObjectIdentifierWithArguments(database, schemaName, name, argTypes...) + + returnType := d.Get("return_type").(string) + resultDataType, err := sdk.ToDataType(returnType) + if err != nil { + return diag.FromErr(err) + } + apiIntegration := sdk.NewAccountObjectIdentifier(d.Get("api_integration").(string)) + urlOfProxyAndResource := d.Get("url_of_proxy_and_resource").(string) + req := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), resultDataType, &apiIntegration, urlOfProxyAndResource) + + // Set optionals + if len(args) > 0 { + req.WithArguments(args) + } + + if v, ok := d.GetOk("return_null_allowed"); ok { + if v.(bool) { + req.WithReturnNullValues(sdk.ReturnNullValuesNull) + } else { + req.WithReturnNullValues(sdk.ReturnNullValuesNotNull) + } + } + + if v, ok := d.GetOk("return_behavior"); ok { + if v.(string) == "VOLATILE" { + req.WithReturnResultsBehavior(sdk.ReturnResultsBehaviorVolatile) + } else { + req.WithReturnResultsBehavior(sdk.ReturnResultsBehaviorImmutable) + } + } + + if v, ok := d.GetOk("null_input_behavior"); ok { + switch { + case v.(string) == "CALLED ON NULL INPUT": + req.WithNullInputBehavior(sdk.NullInputBehaviorCalledOnNullInput) + case v.(string) == "RETURNS NULL ON NULL INPUT": + req.WithNullInputBehavior(sdk.NullInputBehaviorReturnNullInput) + default: + req.WithNullInputBehavior(sdk.NullInputBehaviorStrict) + } + } + + if v, ok := d.GetOk("comment"); ok { + req.WithComment(v.(string)) + } + + if _, ok := d.GetOk("header"); ok { + headers := make([]sdk.ExternalFunctionHeaderRequest, 0) + for _, header := range d.Get("header").(*schema.Set).List() { + m := header.(map[string]interface{}) + headerName := m["name"].(string) + headerValue := m["value"].(string) + headers = append(headers, sdk.ExternalFunctionHeaderRequest{ + Name: headerName, + Value: headerValue, + }) + } + req.WithHeaders(headers) + } + + if v, ok := d.GetOk("context_headers"); ok { + contextHeadersList := expandStringList(v.([]interface{})) + contextHeaders := make([]sdk.ExternalFunctionContextHeaderRequest, 0) + for _, header := range contextHeadersList { + contextHeaders = append(contextHeaders, sdk.ExternalFunctionContextHeaderRequest{ + ContextFunction: header, + }) + } + req.WithContextHeaders(contextHeaders) + } + + if v, ok := d.GetOk("max_batch_rows"); ok { + req.WithMaxBatchRows(v.(int)) + } + + if v, ok := d.GetOk("compression"); ok { + req.WithCompression(v.(string)) + } + + if v, ok := d.GetOk("request_translator"); ok { + req.WithRequestTranslator(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string))) + } + + if v, ok := d.GetOk("response_translator"); ok { + req.WithResponseTranslator(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string))) + } + + if err := client.ExternalFunctions.Create(ctx, req); err != nil { + return diag.FromErr(err) + } + + d.SetId(id.FullyQualifiedName()) + + return ReadContextExternalFunction(ctx, d, meta) +} + +func ReadContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } + + externalFunction, err := client.ExternalFunctions.ShowByID(ctx, id) + if err != nil { + d.SetId("") + return nil + } + + // Some properties can come from the SHOW EXTERNAL FUNCTION call + if err := d.Set("name", externalFunction.Name); err != nil { + return diag.FromErr(err) + } + + if err := d.Set("schema", strings.Trim(externalFunction.SchemaName, "\"")); err != nil { + return diag.FromErr(err) + } + + if err := d.Set("database", strings.Trim(externalFunction.CatalogName, "\"")); err != nil { + return diag.FromErr(err) + } + + if err := d.Set("comment", externalFunction.Description); err != nil { + return diag.FromErr(err) + } + + if err := d.Set("created_on", externalFunction.CreatedOn); err != nil { + return diag.FromErr(err) + } + + // Some properties come from the DESCRIBE FUNCTION call + externalFunctionPropertyRows, err := client.ExternalFunctions.Describe(ctx, id) + if err != nil { + d.SetId("") + return nil + } + + for _, row := range externalFunctionPropertyRows { + switch row.Property { + case "signature": + // Format in Snowflake DB is: (argName argType, argName argType, ...) + args := strings.ReplaceAll(strings.ReplaceAll(row.Value, "(", ""), ")", "") + + if args != "" { // Do nothing for functions without arguments + argPairs := strings.Split(args, ", ") + args := []interface{}{} + + for _, argPair := range argPairs { + argItem := strings.Split(argPair, " ") + + arg := map[string]interface{}{} + arg["name"] = argItem[0] + arg["type"] = argItem[1] + args = append(args, arg) + } + + if err := d.Set("arg", args); err != nil { + return diag.Errorf("error setting arg: %v", err) + } + } + case "returns": + returnType := row.Value + // We first check for VARIANT or OBJECT + if returnType == "VARIANT" || returnType == "OBJECT" { + if err := d.Set("return_type", returnType); err != nil { + return diag.Errorf("error setting return_type: %v", err) + } + break + } + + // otherwise, format in Snowflake DB is returnType() + re := regexp.MustCompile(`^(\w+)\([0-9]*\)$`) + match := re.FindStringSubmatch(row.Value) + if len(match) < 2 { + return diag.Errorf("return_type %s not recognized", returnType) + } + if err := d.Set("return_type", match[1]); err != nil { + return diag.Errorf("error setting return_type: %v", err) + } + + case "null handling": + if err := d.Set("null_input_behavior", row.Value); err != nil { + return diag.Errorf("error setting null_input_behavior: %v", err) + } + case "volatility": + if err := d.Set("return_behavior", row.Value); err != nil { + return diag.Errorf("error setting return_behavior: %v", err) + } + case "headers": + if row.Value != "" && row.Value != "null" { + // Format in Snowflake DB is: {"head1":"val1","head2":"val2"} + var jsonHeaders map[string]string + err := json.Unmarshal([]byte(row.Value), &jsonHeaders) + if err != nil { + return diag.Errorf("error unmarshalling headers: %v", err) + } + + headers := make([]any, 0, len(jsonHeaders)) + for key, value := range jsonHeaders { + headers = append(headers, map[string]any{ + "name": key, + "value": value, + }) + } + + if err := d.Set("header", headers); err != nil { + return diag.Errorf("error setting return_behavior: %v", err) + } + } + case "context_headers": + if row.Value != "" && row.Value != "null" { + // Format in Snowflake DB is: ["CONTEXT_FUNCTION_1","CONTEXT_FUNCTION_2"] + contextHeaders := strings.Split(strings.Trim(row.Value, "[]"), ",") + for i, v := range contextHeaders { + contextHeaders[i] = strings.Trim(v, "\"") + } + if err := d.Set("context_headers", contextHeaders); err != nil { + return diag.Errorf("error setting context_headers: %v", err) + } + } + case "max_batch_rows": + if row.Value != "not set" { + maxBatchRows, err := strconv.ParseInt(row.Value, 10, 64) + if err != nil { + return diag.Errorf("error parsing max_batch_rows: %v", err) + } + if err := d.Set("max_batch_rows", maxBatchRows); err != nil { + return diag.Errorf("error setting max_batch_rows: %v", err) + } + } + case "compression": + if err := d.Set("compression", row.Value); err != nil { + return diag.Errorf("error setting compression: %v", err) + } + case "body": + if err := d.Set("url_of_proxy_and_resource", row.Value); err != nil { + return diag.Errorf("error setting url_of_proxy_and_resource: %v", err) + } + case "language": + // To ignore + default: + log.Printf("[WARN] unexpected external function property %v returned from Snowflake", row.Property) + } + } + + return nil +} + +func UpdateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) } + + req := sdk.NewAlterFunctionRequest(id) + if d.HasChange("comment") { + _, new := d.GetChange("comment") + if new == "" { + req.UnsetComment = sdk.Bool(true) + } else { + req.SetComment = sdk.String(new.(string)) + } + err := client.Functions.Alter(ctx, req) + if err != nil { + return diag.FromErr(err) + } + } + return ReadContextExternalFunction(ctx, d, meta) } -// -//func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// database := d.Get("database").(string) -// schemaName := d.Get("schema").(string) -// name := d.Get("name").(string) -// id := sdk.NewSchemaObjectIdentifier(database, schemaName, name) -// -// returnType := d.Get("return_type").(string) -// resultDataType, err := sdk.ToDataType(returnType) -// if err != nil { -// return diag.FromErr(err) -// } -// apiIntegration := sdk.NewAccountObjectIdentifier(d.Get("api_integration").(string)) -// urlOfProxyAndResource := d.Get("url_of_proxy_and_resource").(string) -// req := sdk.NewCreateExternalFunctionRequest(id, resultDataType, &apiIntegration, urlOfProxyAndResource) -// -// // Set optionals -// args := make([]sdk.ExternalFunctionArgumentRequest, 0) -// if v, ok := d.GetOk("arg"); ok { -// for _, arg := range v.([]interface{}) { -// argName := arg.(map[string]interface{})["name"].(string) -// argType := arg.(map[string]interface{})["type"].(string) -// argDataType, err := sdk.ToDataType(argType) -// if err != nil { -// return diag.FromErr(err) -// } -// args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) -// } -// } -// if len(args) > 0 { -// req.WithArguments(args) -// } -// -// if v, ok := d.GetOk("return_null_allowed"); ok { -// if v.(bool) { -// req.WithReturnNullValues(&sdk.ReturnNullValuesNull) -// } else { -// req.WithReturnNullValues(&sdk.ReturnNullValuesNotNull) -// } -// } -// -// if v, ok := d.GetOk("return_behavior"); ok { -// if v.(string) == "VOLATILE" { -// req.WithReturnResultsBehavior(&sdk.ReturnResultsBehaviorVolatile) -// } else { -// req.WithReturnResultsBehavior(&sdk.ReturnResultsBehaviorImmutable) -// } -// } -// -// if v, ok := d.GetOk("null_input_behavior"); ok { -// switch { -// case v.(string) == "CALLED ON NULL INPUT": -// req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorCalledOnNullInput)) -// case v.(string) == "RETURNS NULL ON NULL INPUT": -// req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorReturnNullInput)) -// default: -// req.WithNullInputBehavior(sdk.Pointer(sdk.NullInputBehaviorStrict)) -// } -// } -// -// if v, ok := d.GetOk("comment"); ok { -// req.WithComment(sdk.String(v.(string))) -// } -// -// if _, ok := d.GetOk("header"); ok { -// headers := make([]sdk.ExternalFunctionHeaderRequest, 0) -// for _, header := range d.Get("header").(*schema.Set).List() { -// m := header.(map[string]interface{}) -// headerName := m["name"].(string) -// headerValue := m["value"].(string) -// headers = append(headers, sdk.ExternalFunctionHeaderRequest{ -// Name: headerName, -// Value: headerValue, -// }) -// } -// req.WithHeaders(headers) -// } -// -// if v, ok := d.GetOk("context_headers"); ok { -// contextHeadersList := expandStringList(v.([]interface{})) -// contextHeaders := make([]sdk.ExternalFunctionContextHeaderRequest, 0) -// for _, header := range contextHeadersList { -// contextHeaders = append(contextHeaders, sdk.ExternalFunctionContextHeaderRequest{ -// ContextFunction: header, -// }) -// } -// req.WithContextHeaders(contextHeaders) -// } -// -// if v, ok := d.GetOk("max_batch_rows"); ok { -// req.WithMaxBatchRows(sdk.Int(v.(int))) -// } -// -// if v, ok := d.GetOk("compression"); ok { -// req.WithCompression(sdk.String(v.(string))) -// } -// -// if v, ok := d.GetOk("request_translator"); ok { -// req.WithRequestTranslator(sdk.Pointer(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string)))) -// } -// -// if v, ok := d.GetOk("response_translator"); ok { -// req.WithResponseTranslator(sdk.Pointer(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(v.(string)))) -// } -// -// if err := client.ExternalFunctions.Create(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// argTypes := make([]sdk.DataType, 0, len(args)) -// for _, item := range args { -// argTypes = append(argTypes, item.ArgDataType) -// } -// sid := sdk.NewSchemaObjectIdentifierWithArgumentsOld(database, schemaName, name, argTypes) -// d.SetId(sid.FullyQualifiedName()) -// return ReadContextExternalFunction(ctx, d, meta) -//} -// -//func ReadContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// -// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) -// externalFunction, err := client.ExternalFunctions.ShowByID(ctx, id) -// if err != nil { -// d.SetId("") -// return nil -// } -// -// // Some properties can come from the SHOW EXTERNAL FUNCTION call -// if err := d.Set("name", externalFunction.Name); err != nil { -// return diag.FromErr(err) -// } -// -// if err := d.Set("schema", strings.Trim(externalFunction.SchemaName, "\"")); err != nil { -// return diag.FromErr(err) -// } -// -// if err := d.Set("database", strings.Trim(externalFunction.CatalogName, "\"")); err != nil { -// return diag.FromErr(err) -// } -// -// if err := d.Set("comment", externalFunction.Description); err != nil { -// return diag.FromErr(err) -// } -// -// if err := d.Set("created_on", externalFunction.CreatedOn); err != nil { -// return diag.FromErr(err) -// } -// -// // Some properties come from the DESCRIBE FUNCTION call -// externalFunctionPropertyRows, err := client.ExternalFunctions.Describe(ctx, sdk.NewDescribeExternalFunctionRequest(id.WithoutArguments(), id.Arguments())) -// if err != nil { -// d.SetId("") -// return nil -// } -// -// for _, row := range externalFunctionPropertyRows { -// switch row.Property { -// case "signature": -// // Format in Snowflake DB is: (argName argType, argName argType, ...) -// args := strings.ReplaceAll(strings.ReplaceAll(row.Value, "(", ""), ")", "") -// -// if args != "" { // Do nothing for functions without arguments -// argPairs := strings.Split(args, ", ") -// args := []interface{}{} -// -// for _, argPair := range argPairs { -// argItem := strings.Split(argPair, " ") -// -// arg := map[string]interface{}{} -// arg["name"] = argItem[0] -// arg["type"] = argItem[1] -// args = append(args, arg) -// } -// -// if err := d.Set("arg", args); err != nil { -// return diag.Errorf("error setting arg: %v", err) -// } -// } -// case "returns": -// returnType := row.Value -// // We first check for VARIANT or OBJECT -// if returnType == "VARIANT" || returnType == "OBJECT" { -// if err := d.Set("return_type", returnType); err != nil { -// return diag.Errorf("error setting return_type: %v", err) -// } -// break -// } -// -// // otherwise, format in Snowflake DB is returnType() -// re := regexp.MustCompile(`^(\w+)\([0-9]*\)$`) -// match := re.FindStringSubmatch(row.Value) -// if len(match) < 2 { -// return diag.Errorf("return_type %s not recognized", returnType) -// } -// if err := d.Set("return_type", match[1]); err != nil { -// return diag.Errorf("error setting return_type: %v", err) -// } -// -// case "null handling": -// if err := d.Set("null_input_behavior", row.Value); err != nil { -// return diag.Errorf("error setting null_input_behavior: %v", err) -// } -// case "volatility": -// if err := d.Set("return_behavior", row.Value); err != nil { -// return diag.Errorf("error setting return_behavior: %v", err) -// } -// case "headers": -// if row.Value != "" && row.Value != "null" { -// // Format in Snowflake DB is: {"head1":"val1","head2":"val2"} -// var jsonHeaders map[string]string -// err := json.Unmarshal([]byte(row.Value), &jsonHeaders) -// if err != nil { -// return diag.Errorf("error unmarshalling headers: %v", err) -// } -// -// headers := make([]any, 0, len(jsonHeaders)) -// for key, value := range jsonHeaders { -// headers = append(headers, map[string]any{ -// "name": key, -// "value": value, -// }) -// } -// -// if err := d.Set("header", headers); err != nil { -// return diag.Errorf("error setting return_behavior: %v", err) -// } -// } -// case "context_headers": -// if row.Value != "" && row.Value != "null" { -// // Format in Snowflake DB is: ["CONTEXT_FUNCTION_1","CONTEXT_FUNCTION_2"] -// contextHeaders := strings.Split(strings.Trim(row.Value, "[]"), ",") -// for i, v := range contextHeaders { -// contextHeaders[i] = strings.Trim(v, "\"") -// } -// if err := d.Set("context_headers", contextHeaders); err != nil { -// return diag.Errorf("error setting context_headers: %v", err) -// } -// } -// case "max_batch_rows": -// if row.Value != "not set" { -// maxBatchRows, err := strconv.ParseInt(row.Value, 10, 64) -// if err != nil { -// return diag.Errorf("error parsing max_batch_rows: %v", err) -// } -// if err := d.Set("max_batch_rows", maxBatchRows); err != nil { -// return diag.Errorf("error setting max_batch_rows: %v", err) -// } -// } -// case "compression": -// if err := d.Set("compression", row.Value); err != nil { -// return diag.Errorf("error setting compression: %v", err) -// } -// case "body": -// if err := d.Set("url_of_proxy_and_resource", row.Value); err != nil { -// return diag.Errorf("error setting url_of_proxy_and_resource: %v", err) -// } -// case "language": -// // To ignore -// default: -// log.Printf("[WARN] unexpected external function property %v returned from Snowflake", row.Property) -// } -// } -// -// return nil -//} -// -//func UpdateContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// -// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) -// req := sdk.NewAlterFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) -// if d.HasChange("comment") { -// _, new := d.GetChange("comment") -// if new == "" { -// req.UnsetComment = sdk.Bool(true) -// } else { -// req.SetComment = sdk.String(new.(string)) -// } -// err := client.Functions.Alter(ctx, req) -// if err != nil { -// return diag.FromErr(err) -// } -// } -// return ReadContextExternalFunction(ctx, d, meta) -//} -// -//func DeleteContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { -// client := meta.(*provider.Context).Client -// -// id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(d.Id()) -// req := sdk.NewDropFunctionRequest(sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...)) -// if err := client.Functions.Drop(ctx, req); err != nil { -// return diag.FromErr(err) -// } -// -// d.SetId("") -// return nil -//} +func DeleteContextExternalFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*provider.Context).Client + + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(d.Id()) + if err != nil { + return diag.FromErr(err) + } + + req := sdk.NewDropFunctionRequest(id).WithIfExists(true) + if err := client.Functions.Drop(ctx, req); err != nil { + return diag.FromErr(err) + } + + d.SetId("") + return nil +} diff --git a/pkg/resources/external_function_acceptance_test.go b/pkg/resources/external_function_acceptance_test.go index d147071fd2..df4ddd66d2 100644 --- a/pkg/resources/external_function_acceptance_test.go +++ b/pkg/resources/external_function_acceptance_test.go @@ -559,3 +559,133 @@ resource "snowflake_external_function" "f" { `, id.DatabaseName(), id.SchemaName(), id.Name()) } + +func TestAcc_ExternalFunction_EnsureSmoothResourceIdMigrationToV0950(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_external_function.f" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.ExternalFunction), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: externalFunctionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: externalFunctionConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + +func externalFunctionConfigWithMoreArguments(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_api_integration" "test_api_int" { + name = "%[3]s" + api_provider = "aws_api_gateway" + api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" + api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] + enabled = true +} + +resource "snowflake_external_function" "f" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + + arg { + name = "ARG1" + type = "VARCHAR" + } + + arg { + name = "ARG2" + type = "FLOAT" + } + + arg { + name = "ARG3" + type = "NUMBER" + } + + return_type = "VARIANT" + return_behavior = "IMMUTABLE" + api_integration = snowflake_api_integration.test_api_int.name + url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" +} +`, database, schema, name) +} + +func TestAcc_ExternalFunction_EnsureSmoothResourceIdMigrationToV0950_WithoutArguments(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_external_function.f" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.ExternalFunction), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: externalFunctionConfigWithoutArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: externalFunctionConfigWithoutArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"()`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + +func externalFunctionConfigWithoutArguments(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_api_integration" "test_api_int" { + name = "%[3]s" + api_provider = "aws_api_gateway" + api_aws_role_arn = "arn:aws:iam::000000000001:/role/test" + api_allowed_prefixes = ["https://123456.execute-api.us-west-2.amazonaws.com/prod/"] + enabled = true +} + +resource "snowflake_external_function" "f" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + + return_type = "VARIANT" + return_behavior = "IMMUTABLE" + api_integration = snowflake_api_integration.test_api_int.name + url_of_proxy_and_resource = "https://123456.execute-api.us-west-2.amazonaws.com/prod/test_func" +} + +`, database, schema, name) +} diff --git a/pkg/resources/function.go b/pkg/resources/function.go index f222c9b0a7..dd914a2c05 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -161,7 +161,7 @@ var functionSchema = map[string]*schema.Schema{ func Function() *schema.Resource { return &schema.Resource{ - SchemaVersion: 1, + SchemaVersion: 2, CreateContext: CreateContextFunction, ReadContext: ReadContextFunction, @@ -180,6 +180,12 @@ func Function() *schema.Resource { Type: cty.EmptyObject, Upgrade: v085FunctionIdStateUpgrader, }, + { + Version: 1, + // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + Type: cty.EmptyObject, + Upgrade: v0941ResourceIdentifierWithArguments, + }, }, } } @@ -703,7 +709,7 @@ func DeleteContextFunction(ctx context.Context, d *schema.ResourceData, meta int if err != nil { return diag.FromErr(err) } - if err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id)); err != nil { + if err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id).WithIfExists(true)); err != nil { return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/function_acceptance_test.go b/pkg/resources/function_acceptance_test.go index 104cc98e67..ded26d4c73 100644 --- a/pkg/resources/function_acceptance_test.go +++ b/pkg/resources/function_acceptance_test.go @@ -336,8 +336,6 @@ resource "snowflake_function" "f" { `, database, schema, name) } -// TODO: test new state upgrader - func TestAcc_Function_EnsureSmoothResourceIdMigrationToV0950(t *testing.T) { name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() resourceName := "snowflake_function.f" @@ -422,7 +420,6 @@ func TestAcc_Function_EnsureSmoothResourceIdMigrationToV0950_WithoutArguments(t ), }, { - // TODO: Fails ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, Config: functionConfigWithoutArguments(acc.TestDatabaseName, acc.TestSchemaName, name), Check: resource.ComposeTestCheckFunc( diff --git a/pkg/resources/identifiers_state_upgraders.go b/pkg/resources/identifiers_state_upgraders.go new file mode 100644 index 0000000000..c16d516b79 --- /dev/null +++ b/pkg/resources/identifiers_state_upgraders.go @@ -0,0 +1,21 @@ +package resources + +import ( + "context" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" +) + +// v0941ResourceIdentifierWithArguments migrates functions, procedures, and external functions to use the new identifier type. +// They're already using old identifier with arguments, but the only case where parentheses weren't specified +// (which are essential in the new identifier) is for empty argument list. +func v0941ResourceIdentifierWithArguments(ctx context.Context, rawState map[string]any, meta any) (map[string]any, error) { + if rawState == nil { + return rawState, nil + } + + id := sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(rawState["id"].(string)) + rawState["id"] = sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), id.Arguments()...).FullyQualifiedName() + + return rawState, nil +} diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index 69d2cd9b8f..2510c39749 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -3,16 +3,17 @@ package resources import ( "context" "fmt" + "log" + "regexp" + "slices" + "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" - "log" - "regexp" - "slices" - "strings" ) var procedureSchema = map[string]*schema.Schema{ @@ -175,7 +176,7 @@ var procedureSchema = map[string]*schema.Schema{ // Procedure returns a pointer to the resource representing a stored procedure. func Procedure() *schema.Resource { return &schema.Resource{ - SchemaVersion: 1, + SchemaVersion: 2, CreateContext: CreateContextProcedure, ReadContext: ReadContextProcedure, @@ -194,6 +195,12 @@ func Procedure() *schema.Resource { Type: cty.EmptyObject, Upgrade: v085ProcedureStateUpgrader, }, + { + Version: 1, + // setting type to cty.EmptyObject is a bit hacky here but following https://developer.hashicorp.com/terraform/plugin/framework/migrating/resources/state-upgrade#sdkv2-1 would require lots of repetitive code; this should work with cty.EmptyObject + Type: cty.EmptyObject, + Upgrade: v0941ResourceIdentifierWithArguments, + }, }, } } @@ -697,7 +704,7 @@ func DeleteContextProcedure(ctx context.Context, d *schema.ResourceData, meta in if err != nil { return diag.FromErr(err) } - if err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id)); err != nil { + if err := client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id).WithIfExists(true)); err != nil { return diag.FromErr(err) } d.SetId("") diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index d10316ebdd..782a14813e 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -246,8 +246,6 @@ func TestAcc_Procedure_migrateFromVersion085(t *testing.T) { }) } -// TODO: test new state upgrader - func procedureConfig(database string, schema string, name string) string { return fmt.Sprintf(` resource "snowflake_procedure" "p" { @@ -395,3 +393,118 @@ END; } `, database, schema, name) } + +func TestAcc_Procedure_EnsureSmoothResourceIdMigrationToV0950(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_procedure.p" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.Procedure), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: procedureConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: procedureConfigWithMoreArguments(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(VARCHAR, FLOAT, NUMBER)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + +func procedureConfigWithMoreArguments(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_procedure" "p" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + language = "SQL" + return_type = "NUMBER(38,0)" + statement = < Date: Mon, 12 Aug 2024 10:45:17 +0200 Subject: [PATCH 05/19] wip --- .../external_function_acceptance_test.go | 18 ++++++++++++++---- pkg/sdk/identifier_helpers.go | 2 +- pkg/sdk/identifier_parsers.go | 18 +++++++++++++----- pkg/sdk/identifier_parsers_test.go | 2 ++ 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/pkg/resources/external_function_acceptance_test.go b/pkg/resources/external_function_acceptance_test.go index df4ddd66d2..0a00ea95d8 100644 --- a/pkg/resources/external_function_acceptance_test.go +++ b/pkg/resources/external_function_acceptance_test.go @@ -255,8 +255,13 @@ func TestAcc_ExternalFunction_migrateFromVersion085(t *testing.T) { ExpectNonEmptyPlan: true, }, { - ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, - Config: externalFunctionConfig(acc.TestDatabaseName, acc.TestSchemaName, name), + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: externalFunctionConfig(acc.TestDatabaseName, acc.TestSchemaName, name), ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{plancheck.ExpectEmptyPlan()}, }, @@ -298,8 +303,13 @@ func TestAcc_ExternalFunction_migrateFromVersion085_issue2694_previousValuePrese ExpectNonEmptyPlan: true, }, { - ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, - Config: externalFunctionConfig(acc.TestDatabaseName, acc.TestSchemaName, name), + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: externalFunctionConfig(acc.TestDatabaseName, acc.TestSchemaName, name), ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{plancheck.ExpectEmptyPlan()}, }, diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 8bc4a33852..dc7229b12a 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -386,7 +386,7 @@ func (i SchemaObjectIdentifierWithArguments) FullyQualifiedName() string { if i.schemaName == "" && i.databaseName == "" && i.name == "" && len(i.argumentDataTypes) == 0 { return "" } - return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(AsStringList(i.argumentDataTypes), ",")) + return fmt.Sprintf(`"%v"."%v"."%v"(%v)`, i.databaseName, i.schemaName, i.name, strings.Join(AsStringList(i.argumentDataTypes), ", ")) } type TableColumnIdentifier struct { diff --git a/pkg/sdk/identifier_parsers.go b/pkg/sdk/identifier_parsers.go index b92a215f5b..358dd586f2 100644 --- a/pkg/sdk/identifier_parsers.go +++ b/pkg/sdk/identifier_parsers.go @@ -167,10 +167,10 @@ func ParseSchemaObjectIdentifierWithArguments(fullyQualifiedName string) (Schema } // ParseFunctionArgumentsFromString parses function argument from arguments string with optional argument names. -// Varying types are not supported (e.g. VARCHAR(200)), because Snowflake outputs them in shortened version -// (VARCHAR in this case). The only exception is newly added type VECTOR which has the following structure +// Varying types are not supported (e.g. VARCHAR(200)), because Snowflake outputs them in a shortened version +// (VARCHAR in this case). The only exception is newly added type VECTOR that has the following structure // VECTOR(, n) where right now can be either INT or FLOAT and n is the number of elements in the VECTOR. -// Snowflake returns vectors with their exact type and this function supports it. +// Snowflake returns vectors with their exact type, and this function supports it. func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { dataTypes := make([]DataType, 0) @@ -180,9 +180,18 @@ func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { stringBuffer := bytes.NewBufferString(arguments) for stringBuffer.Len() > 0 { + stringBuffer = bytes.NewBufferString(strings.TrimSpace(stringBuffer.String())) + + // When a function is created with a default value for an argument, in the SHOW output ("arguments" column) + // the argument's data type is prefixed with "DEFAULT ", e.g. "(DEFAULT INT, DEFAULT VARCHAR)". + if strings.HasPrefix(stringBuffer.String(), "DEFAULT") { + if _, err := stringBuffer.ReadString(' '); err != nil { + return nil, fmt.Errorf("failed to skip default keyword, err = %w", err) + } + } + // We use another buffer to peek into next data type (needed for vector parsing) peekDataType, _ := bytes.NewBufferString(stringBuffer.String()).ReadString(',') - peekDataType = strings.TrimSpace(peekDataType) switch { // For now, only vectors need special parsing behavior @@ -234,7 +243,6 @@ func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { if err == nil { dataType = dataType[:len(dataType)-1] } - dataType = strings.TrimSpace(dataType) dataTypes = append(dataTypes, DataType(dataType)) } } diff --git a/pkg/sdk/identifier_parsers_test.go b/pkg/sdk/identifier_parsers_test.go index 078cbbc30d..cf67c499ec 100644 --- a/pkg/sdk/identifier_parsers_test.go +++ b/pkg/sdk/identifier_parsers_test.go @@ -281,6 +281,8 @@ func Test_ParseFunctionArgumentsFromString(t *testing.T) { {Arguments: `()`, Expected: []DataType{}}, {Arguments: `(FLOAT, NUMBER, TIME)`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataTypeTime}}, {Arguments: `FLOAT, NUMBER, TIME`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataTypeTime}}, + {Arguments: `(DEFAULT FLOAT, DEFAULT NUMBER, DEFAULT TIME)`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataTypeTime}}, + {Arguments: `DEFAULT FLOAT, DEFAULT NUMBER, DEFAULT TIME`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataTypeTime}}, {Arguments: `(FLOAT, NUMBER, VECTOR(FLOAT, 20))`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataType("VECTOR(FLOAT, 20)")}}, {Arguments: `FLOAT, NUMBER, VECTOR(FLOAT, 20)`, Expected: []DataType{DataTypeFloat, DataTypeNumber, DataType("VECTOR(FLOAT, 20)")}}, {Arguments: `(VECTOR(FLOAT, 10), NUMBER, VECTOR(FLOAT, 20))`, Expected: []DataType{DataType("VECTOR(FLOAT, 10)"), DataTypeNumber, DataType("VECTOR(FLOAT, 20)")}}, From a88eeae59bf5f8790a1103539aa832920f3115d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Mon, 12 Aug 2024 13:31:11 +0200 Subject: [PATCH 06/19] wip --- pkg/sdk/testint/functions_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index e3a945934d..3d9a68bf74 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -266,7 +266,7 @@ func TestInt_OtherFunctions(t *testing.T) { f := createFunctionForSQLHandle(t, false, true) id := f.ID() - nid := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments() + nid := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) err := client.Functions.Alter(ctx, sdk.NewAlterFunctionRequest(id).WithRenameTo(nid.SchemaObjectId())) if err != nil { t.Cleanup(cleanupFunctionHandle(id)) From 116bb2cb7a4d98e9f05f9a58e502424bc78ad7bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Mon, 12 Aug 2024 13:39:51 +0200 Subject: [PATCH 07/19] wip --- pkg/resources/function_state_upgraders.go | 2 -- pkg/resources/procedure_state_upgraders.go | 2 -- ...external_functions_gen_integration_test.go | 27 ------------------- 3 files changed, 31 deletions(-) delete mode 100644 pkg/sdk/external_functions_gen_integration_test.go diff --git a/pkg/resources/function_state_upgraders.go b/pkg/resources/function_state_upgraders.go index bae81663a1..501e44f1dc 100644 --- a/pkg/resources/function_state_upgraders.go +++ b/pkg/resources/function_state_upgraders.go @@ -60,5 +60,3 @@ func v085FunctionIdStateUpgrader(ctx context.Context, rawState map[string]interf return rawState, nil } - -// TODO: state upgrader for empty args (without '()') diff --git a/pkg/resources/procedure_state_upgraders.go b/pkg/resources/procedure_state_upgraders.go index 27866b7387..24e47d7d9f 100644 --- a/pkg/resources/procedure_state_upgraders.go +++ b/pkg/resources/procedure_state_upgraders.go @@ -60,5 +60,3 @@ func v085ProcedureStateUpgrader(ctx context.Context, rawState map[string]interfa return rawState, nil } - -// TODO: state upgrader for empty args (without '()') diff --git a/pkg/sdk/external_functions_gen_integration_test.go b/pkg/sdk/external_functions_gen_integration_test.go deleted file mode 100644 index d6b43eb45e..0000000000 --- a/pkg/sdk/external_functions_gen_integration_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package sdk - -import "testing" - -func TestInt_ExternalFunctions(t *testing.T) { - // TODO: prepare common resources - - t.Run("Create", func(t *testing.T) { - // TODO: fill me - }) - - t.Run("Alter", func(t *testing.T) { - // TODO: fill me - }) - - t.Run("Show", func(t *testing.T) { - // TODO: fill me - }) - - t.Run("ShowByID", func(t *testing.T) { - // TODO: fill me - }) - - t.Run("Describe", func(t *testing.T) { - // TODO: fill me - }) -} From d4b2bb298c153866a8572e67064eba519f85008e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Tue, 13 Aug 2024 14:07:06 +0200 Subject: [PATCH 08/19] changes after review --- .../helpers/external_function_client.go | 56 +++++++++++++++++ pkg/acceptance/helpers/function_client.go | 55 +++++++++++++++++ pkg/acceptance/helpers/procedure_client.go | 55 +++++++++++++++++ pkg/acceptance/helpers/test_client.go | 6 ++ pkg/resources/procedure_acceptance_test.go | 60 +++++++++++++++++++ pkg/sdk/external_functions_impl_gen.go | 4 +- pkg/sdk/identifier_helpers.go | 12 +++- .../external_functions_integration_test.go | 23 +++++++ pkg/sdk/testint/functions_integration_test.go | 23 +++++++ .../testint/procedures_integration_test.go | 27 ++++++++- 10 files changed, 316 insertions(+), 5 deletions(-) create mode 100644 pkg/acceptance/helpers/external_function_client.go create mode 100644 pkg/acceptance/helpers/function_client.go create mode 100644 pkg/acceptance/helpers/procedure_client.go diff --git a/pkg/acceptance/helpers/external_function_client.go b/pkg/acceptance/helpers/external_function_client.go new file mode 100644 index 0000000000..8e66ba076b --- /dev/null +++ b/pkg/acceptance/helpers/external_function_client.go @@ -0,0 +1,56 @@ +package helpers + +import ( + "context" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" + "testing" +) + +type ExternalFunctionClient struct { + context *TestClientContext + ids *IdsGenerator +} + +func NewExternalFunctionClient(context *TestClientContext, idsGenerator *IdsGenerator) *ExternalFunctionClient { + return &ExternalFunctionClient{ + context: context, + ids: idsGenerator, + } +} + +func (c *ExternalFunctionClient) client() sdk.ExternalFunctions { + return c.context.client.ExternalFunctions +} + +func (c *ExternalFunctionClient) Create(t *testing.T, apiIntegrationId sdk.AccountObjectIdentifier, arguments ...sdk.DataType) *sdk.ExternalFunction { + t.Helper() + return c.CreateWithIdentifier(t, apiIntegrationId, c.ids.RandomSchemaObjectIdentifierWithArguments(arguments...)) +} + +func (c *ExternalFunctionClient) CreateWithIdentifier(t *testing.T, apiIntegrationId sdk.AccountObjectIdentifier, id sdk.SchemaObjectIdentifierWithArguments) *sdk.ExternalFunction { + t.Helper() + ctx := context.Background() + argumentRequests := make([]sdk.ExternalFunctionArgumentRequest, len(id.ArgumentDataTypes())) + for i, argumentDataType := range id.ArgumentDataTypes() { + argumentRequests[i] = *sdk.NewExternalFunctionArgumentRequest(c.ids.Alpha(), argumentDataType) + } + err := c.client().Create(ctx, + sdk.NewCreateExternalFunctionRequest( + id.SchemaObjectId(), + sdk.DataTypeVariant, + &apiIntegrationId, + "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo", + ).WithArguments(argumentRequests), + ) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, c.context.client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id).WithIfExists(true))) + }) + + externalFunction, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return externalFunction +} diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go new file mode 100644 index 0000000000..f2f058d6ab --- /dev/null +++ b/pkg/acceptance/helpers/function_client.go @@ -0,0 +1,55 @@ +package helpers + +import ( + "context" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" + "testing" +) + +type FunctionClient struct { + context *TestClientContext + ids *IdsGenerator +} + +func NewFunctionClient(context *TestClientContext, idsGenerator *IdsGenerator) *FunctionClient { + return &FunctionClient{ + context: context, + ids: idsGenerator, + } +} + +func (c *FunctionClient) client() sdk.Functions { + return c.context.client.Functions +} + +func (c *FunctionClient) Create(t *testing.T, arguments ...sdk.DataType) *sdk.Function { + t.Helper() + return c.CreateWithIdentifier(t, c.ids.RandomSchemaObjectIdentifierWithArguments(arguments...)) +} + +func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) *sdk.Function { + t.Helper() + ctx := context.Background() + argumentRequests := make([]sdk.FunctionArgumentRequest, len(id.ArgumentDataTypes())) + for i, argumentDataType := range id.ArgumentDataTypes() { + argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), argumentDataType) + } + err := c.client().CreateForSQL(ctx, + sdk.NewCreateForSQLFunctionRequest( + id.SchemaObjectId(), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), + "SELECT 1", + ).WithArguments(argumentRequests), + ) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, c.context.client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id).WithIfExists(true))) + }) + + function, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return function +} diff --git a/pkg/acceptance/helpers/procedure_client.go b/pkg/acceptance/helpers/procedure_client.go new file mode 100644 index 0000000000..d4df476e44 --- /dev/null +++ b/pkg/acceptance/helpers/procedure_client.go @@ -0,0 +1,55 @@ +package helpers + +import ( + "context" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" + "testing" +) + +type ProcedureClient struct { + context *TestClientContext + ids *IdsGenerator +} + +func NewProcedureClient(context *TestClientContext, idsGenerator *IdsGenerator) *ProcedureClient { + return &ProcedureClient{ + context: context, + ids: idsGenerator, + } +} + +func (c *ProcedureClient) client() sdk.Procedures { + return c.context.client.Procedures +} + +func (c *ProcedureClient) Create(t *testing.T, arguments ...sdk.DataType) *sdk.Procedure { + t.Helper() + return c.CreateWithIdentifier(t, c.ids.RandomSchemaObjectIdentifierWithArguments(arguments...)) +} + +func (c *ProcedureClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) *sdk.Procedure { + t.Helper() + ctx := context.Background() + argumentRequests := make([]sdk.ProcedureArgumentRequest, len(id.ArgumentDataTypes())) + for i, argumentDataType := range id.ArgumentDataTypes() { + argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), argumentDataType) + } + err := c.client().CreateForSQL(ctx, + sdk.NewCreateForSQLProcedureRequest( + id.SchemaObjectId(), + *sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeInt)), + "SELECT 1", + ).WithArguments(argumentRequests), + ) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, c.context.client.Procedures.Drop(ctx, sdk.NewDropProcedureRequest(id).WithIfExists(true))) + }) + + procedure, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return procedure +} diff --git a/pkg/acceptance/helpers/test_client.go b/pkg/acceptance/helpers/test_client.go index e443b4fac1..2f68673466 100644 --- a/pkg/acceptance/helpers/test_client.go +++ b/pkg/acceptance/helpers/test_client.go @@ -24,9 +24,11 @@ type TestClient struct { DataMetricFunctionReferences *DataMetricFunctionReferencesClient DynamicTable *DynamicTableClient ExternalAccessIntegration *ExternalAccessIntegrationClient + ExternalFunction *ExternalFunctionClient ExternalVolume *ExternalVolumeClient FailoverGroup *FailoverGroupClient FileFormat *FileFormatClient + Function *FunctionClient Grant *GrantClient MaskingPolicy *MaskingPolicyClient MaterializedView *MaterializedViewClient @@ -35,6 +37,7 @@ type TestClient struct { Parameter *ParameterClient PasswordPolicy *PasswordPolicyClient Pipe *PipeClient + Procedure *ProcedureClient ProjectionPolicy *ProjectionPolicyClient PolicyReferences *PolicyReferencesClient ResourceMonitor *ResourceMonitorClient @@ -83,9 +86,11 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri DataMetricFunctionReferences: NewDataMetricFunctionReferencesClient(context), DynamicTable: NewDynamicTableClient(context, idsGenerator), ExternalAccessIntegration: NewExternalAccessIntegrationClient(context, idsGenerator), + ExternalFunction: NewExternalFunctionClient(context, idsGenerator), ExternalVolume: NewExternalVolumeClient(context, idsGenerator), FailoverGroup: NewFailoverGroupClient(context, idsGenerator), FileFormat: NewFileFormatClient(context, idsGenerator), + Function: NewFunctionClient(context, idsGenerator), Grant: NewGrantClient(context, idsGenerator), MaskingPolicy: NewMaskingPolicyClient(context, idsGenerator), MaterializedView: NewMaterializedViewClient(context, idsGenerator), @@ -94,6 +99,7 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri Parameter: NewParameterClient(context), PasswordPolicy: NewPasswordPolicyClient(context, idsGenerator), Pipe: NewPipeClient(context, idsGenerator), + Procedure: NewProcedureClient(context, idsGenerator), ProjectionPolicy: NewProjectionPolicyClient(context, idsGenerator), PolicyReferences: NewPolicyReferencesClient(context), ResourceMonitor: NewResourceMonitorClient(context, idsGenerator), diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index 782a14813e..16f1a677fe 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -508,3 +508,63 @@ resource "snowflake_procedure" "p" { } `, database, schema, name) } + +func TestAcc_Procedure_EnsureSmoothResourceIdMigrationToV0950_ArgumentSynonyms(t *testing.T) { + name := acc.TestClient().Ids.RandomAccountObjectIdentifier().Name() + resourceName := "snowflake_procedure.p" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckDestroy(t, resources.Procedure), + Steps: []resource.TestStep{ + { + ExternalProviders: map[string]resource.ExternalProvider{ + "snowflake": { + VersionConstraint: "=0.94.1", + Source: "Snowflake-Labs/snowflake", + }, + }, + Config: procedureConfigWithArgumentSynonyms(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(NUMBER, VARCHAR)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + { + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + Config: procedureConfigWithArgumentSynonyms(acc.TestDatabaseName, acc.TestSchemaName, name), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf(`"%s"."%s"."%s"(NUMBER, VARCHAR)`, acc.TestDatabaseName, acc.TestSchemaName, name)), + ), + }, + }, + }) +} + +func procedureConfigWithArgumentSynonyms(database string, schema string, name string) string { + return fmt.Sprintf(` +resource "snowflake_procedure" "p" { + database = "%[1]s" + schema = "%[2]s" + name = "%[3]s" + return_type = "VARCHAR" + return_behavior = "IMMUTABLE" + statement = < Date: Wed, 14 Aug 2024 13:30:01 +0200 Subject: [PATCH 09/19] fix tests --- pkg/acceptance/helpers/external_function_client.go | 3 ++- pkg/acceptance/helpers/function_client.go | 3 ++- pkg/acceptance/helpers/procedure_client.go | 6 +++--- pkg/sdk/testint/external_functions_integration_test.go | 3 +-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pkg/acceptance/helpers/external_function_client.go b/pkg/acceptance/helpers/external_function_client.go index 8e66ba076b..5ddec73441 100644 --- a/pkg/acceptance/helpers/external_function_client.go +++ b/pkg/acceptance/helpers/external_function_client.go @@ -2,9 +2,10 @@ package helpers import ( "context" + "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/require" - "testing" ) type ExternalFunctionClient struct { diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index f2f058d6ab..cbfb4da00b 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -2,9 +2,10 @@ package helpers import ( "context" + "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/require" - "testing" ) type FunctionClient struct { diff --git a/pkg/acceptance/helpers/procedure_client.go b/pkg/acceptance/helpers/procedure_client.go index d4df476e44..e9a4375f2d 100644 --- a/pkg/acceptance/helpers/procedure_client.go +++ b/pkg/acceptance/helpers/procedure_client.go @@ -2,9 +2,10 @@ package helpers import ( "context" + "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/stretchr/testify/require" - "testing" ) type ProcedureClient struct { @@ -39,8 +40,7 @@ func (c *ProcedureClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObject sdk.NewCreateForSQLProcedureRequest( id.SchemaObjectId(), *sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeInt)), - "SELECT 1", - ).WithArguments(argumentRequests), + `BEGIN RETURN 1; END`).WithArguments(argumentRequests), ) require.NoError(t, err) diff --git a/pkg/sdk/testint/external_functions_integration_test.go b/pkg/sdk/testint/external_functions_integration_test.go index 0fe311ee8f..1a14c84185 100644 --- a/pkg/sdk/testint/external_functions_integration_test.go +++ b/pkg/sdk/testint/external_functions_integration_test.go @@ -2,7 +2,6 @@ package testint import ( "context" - "fmt" "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -51,7 +50,7 @@ func TestInt_ExternalFunctions(t *testing.T) { require.NotEmpty(t, e.CreatedOn) require.Equal(t, id.Name(), e.Name) - require.Equal(t, fmt.Sprintf(`"%v"`, id.SchemaName()), e.SchemaName) + require.Equal(t, id.SchemaName(), e.SchemaName) require.Equal(t, false, e.IsBuiltin) require.Equal(t, false, e.IsAggregate) require.Equal(t, false, e.IsAnsi) From b26d5fce58069e4a70b82d659243efdf673d4749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Fri, 16 Aug 2024 12:07:11 +0200 Subject: [PATCH 10/19] Changes after review --- pkg/sdk/testint/external_functions_integration_test.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pkg/sdk/testint/external_functions_integration_test.go b/pkg/sdk/testint/external_functions_integration_test.go index 1a14c84185..3d23fd6d6d 100644 --- a/pkg/sdk/testint/external_functions_integration_test.go +++ b/pkg/sdk/testint/external_functions_integration_test.go @@ -49,17 +49,14 @@ func TestInt_ExternalFunctions(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, e.CreatedOn) - require.Equal(t, id.Name(), e.Name) - require.Equal(t, id.SchemaName(), e.SchemaName) + require.Equal(t, id, e.ID()) require.Equal(t, false, e.IsBuiltin) require.Equal(t, false, e.IsAggregate) require.Equal(t, false, e.IsAnsi) if len(id.ArgumentDataTypes()) > 0 { - require.NotEmpty(t, e.Arguments) require.Equal(t, 1, e.MinNumArguments) require.Equal(t, 1, e.MaxNumArguments) } else { - require.Empty(t, e.Arguments) require.Equal(t, 0, e.MinNumArguments) require.Equal(t, 0, e.MaxNumArguments) } From ceae043065b3f1155c2fcb98c5dc73b494b2d6b5 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Fri, 16 Aug 2024 14:48:33 +0200 Subject: [PATCH 11/19] Use new function identifiers in grants --- pkg/acceptance/helpers/function_client.go | 49 ++++++++ pkg/acceptance/helpers/share_client.go | 1 - pkg/resources/grant_ownership.go | 25 ++-- .../grant_ownership_acceptance_test.go | 102 +++++++++++++++ .../grant_ownership_identifier_test.go | 28 +++++ .../grant_privileges_to_account_role.go | 76 +++++++++--- ...vileges_to_account_role_acceptance_test.go | 116 +++++++++++++++++- ...t_privileges_to_account_role_identifier.go | 16 ++- ...vileges_to_account_role_identifier_test.go | 34 +++++ .../grant_privileges_to_database_role.go | 67 ++++++---- ...ileges_to_database_role_acceptance_test.go | 110 +++++++++++++++++ ..._privileges_to_database_role_identifier.go | 16 ++- ...ileges_to_database_role_identifier_test.go | 34 +++++ pkg/resources/grant_privileges_to_share.go | 80 +++++++----- ...ant_privileges_to_share_acceptance_test.go | 99 ++++++++++++++- .../grant_privileges_to_share_identifier.go | 56 ++++----- ...ant_privileges_to_share_identifier_test.go | 50 ++++---- .../OnObject_Procedure_ToAccountRole/test.tf | 39 ++++++ .../variables.tf | 15 +++ .../OnObject_Procedure_ToDatabaseRole/test.tf | 36 ++++++ .../variables.tf | 15 +++ .../OnSchemaObject_OnFunction/test.tf | 10 ++ .../OnSchemaObject_OnFunction/variables.tf | 27 ++++ .../OnSchemaObject_OnFunction/test.tf | 10 ++ .../OnSchemaObject_OnFunction/variables.tf | 27 ++++ .../OnFunction/test.tf | 13 ++ .../OnFunction/variables.tf | 23 ++++ pkg/sdk/grants.go | 36 +++--- pkg/sdk/identifier_parsers.go | 35 +++++- pkg/sdk/identifier_parsers_test.go | 28 +++++ pkg/sdk/testint/grants_integration_test.go | 37 +++++- v1-preparations/ESSENTIAL_GA_OBJECTS.MD | 6 +- 32 files changed, 1134 insertions(+), 182 deletions(-) create mode 100644 pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/variables.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole/test.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole/variables.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantPrivilegesToAccountRole/OnSchemaObject_OnFunction/test.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantPrivilegesToAccountRole/OnSchemaObject_OnFunction/variables.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantPrivilegesToDatabaseRole/OnSchemaObject_OnFunction/test.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantPrivilegesToDatabaseRole/OnSchemaObject_OnFunction/variables.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantPrivilegesToShare/OnFunction/test.tf create mode 100644 pkg/resources/testdata/TestAcc_GrantPrivilegesToShare/OnFunction/variables.tf diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index cbfb4da00b..742640d1d5 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -54,3 +54,52 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI return function } + +func (c *FunctionClient) CreateFunction(t *testing.T) (*sdk.Function, func()) { + t.Helper() + definition := "3.141592654::FLOAT" + id := c.ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + return c.CreateFunctionWithRequest(t, id, + sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). + WithSecure(true). + WithArguments([]sdk.FunctionArgumentRequest{*argument}), + ) +} + +func (c *FunctionClient) CreateFunctionWithoutArguments(t *testing.T) (*sdk.Function, func()) { + t.Helper() + definition := "3.141592654::FLOAT" + id := c.ids.RandomSchemaObjectIdentifierWithArguments() + dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + return c.CreateFunctionWithRequest(t, id, + sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). + WithSecure(true), + ) +} + +func (c *FunctionClient) CreateFunctionWithRequest(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments, request *sdk.CreateForSQLFunctionRequest) (*sdk.Function, func()) { + t.Helper() + ctx := context.Background() + + err := c.client().CreateForSQL(ctx, request) + require.NoError(t, err) + + Function, err := c.client().ShowByID(ctx, id) + require.NoError(t, err) + + return Function, c.DropFunctionFunc(t, id) +} + +func (c *FunctionClient) DropFunctionFunc(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) func() { + t.Helper() + ctx := context.Background() + + return func() { + err := c.client().Drop(ctx, sdk.NewDropFunctionRequest(id).WithIfExists(true)) + require.NoError(t, err) + } +} diff --git a/pkg/acceptance/helpers/share_client.go b/pkg/acceptance/helpers/share_client.go index e61a6f53ea..0b67ac6c1b 100644 --- a/pkg/acceptance/helpers/share_client.go +++ b/pkg/acceptance/helpers/share_client.go @@ -26,7 +26,6 @@ func (c *ShareClient) client() sdk.Shares { func (c *ShareClient) CreateShare(t *testing.T) (*sdk.Share, func()) { t.Helper() - // TODO(SNOW-1058419): Try with identifier containing dot during identifiers rework return c.CreateShareWithIdentifier(t, c.ids.RandomAccountObjectIdentifier()) } diff --git a/pkg/resources/grant_ownership.go b/pkg/resources/grant_ownership.go index ff7d9ecb29..37aea7768b 100644 --- a/pkg/resources/grant_ownership.go +++ b/pkg/resources/grant_ownership.go @@ -6,8 +6,6 @@ import ( "log" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" @@ -400,11 +398,6 @@ func ReadGrantOwnership(ctx context.Context, d *schema.ResourceData, meta any) d // TODO(SNOW-1229218): Make sdk.ObjectType + string objectName to sdk.ObjectIdentifier mapping available in the sdk (for all object types). func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.ObjectIdentifier, error) { - identifier, err := helpers.DecodeSnowflakeParameterID(objectName) - if err != nil { - return nil, err - } - switch objectType { case sdk.ObjectTypeComputePool, sdk.ObjectTypeDatabase, @@ -416,12 +409,10 @@ func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.Ob sdk.ObjectTypeRole, sdk.ObjectTypeUser, sdk.ObjectTypeWarehouse: - return sdk.NewAccountObjectIdentifier(objectName), nil + return sdk.ParseAccountObjectIdentifier(objectName) case sdk.ObjectTypeDatabaseRole, sdk.ObjectTypeSchema: - if _, ok := identifier.(sdk.DatabaseObjectIdentifier); !ok { - return nil, sdk.NewError(fmt.Sprintf("invalid object_name %s, expected database object identifier", objectName)) - } + return sdk.ParseDatabaseObjectIdentifier(objectName) case sdk.ObjectTypeAggregationPolicy, sdk.ObjectTypeAlert, sdk.ObjectTypeAuthenticationPolicy, @@ -430,7 +421,6 @@ func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.Ob sdk.ObjectTypeEventTable, sdk.ObjectTypeExternalTable, sdk.ObjectTypeFileFormat, - sdk.ObjectTypeFunction, sdk.ObjectTypeGitRepository, sdk.ObjectTypeHybridTable, sdk.ObjectTypeIcebergTable, @@ -439,7 +429,6 @@ func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.Ob sdk.ObjectTypeNetworkRule, sdk.ObjectTypePackagesPolicy, sdk.ObjectTypePipe, - sdk.ObjectTypeProcedure, sdk.ObjectTypeMaskingPolicy, sdk.ObjectTypePasswordPolicy, sdk.ObjectTypeProjectionPolicy, @@ -453,14 +442,14 @@ func getOnObjectIdentifier(objectType sdk.ObjectType, objectName string) (sdk.Ob sdk.ObjectTypeTag, sdk.ObjectTypeTask, sdk.ObjectTypeView: - if _, ok := identifier.(sdk.SchemaObjectIdentifier); !ok { - return nil, sdk.NewError(fmt.Sprintf("invalid object_name %s, expected schema object identifier", objectName)) - } + return sdk.ParseSchemaObjectIdentifier(objectName) + case sdk.ObjectTypeFunction, + sdk.ObjectTypeProcedure, + sdk.ObjectTypeExternalFunction: + return sdk.ParseSchemaObjectIdentifierWithArguments(objectName) default: return nil, sdk.NewError(fmt.Sprintf("object_type %s is not supported, please create a feature request for the provider if given object_type should be supported", objectType)) } - - return identifier, nil } func getOwnershipGrantOn(d *schema.ResourceData) (*sdk.OwnershipGrantOn, error) { diff --git a/pkg/resources/grant_ownership_acceptance_test.go b/pkg/resources/grant_ownership_acceptance_test.go index 2417174ddb..e3c3dac8a4 100644 --- a/pkg/resources/grant_ownership_acceptance_test.go +++ b/pkg/resources/grant_ownership_acceptance_test.go @@ -323,6 +323,108 @@ func TestAcc_GrantOwnership_OnObject_Table_ToDatabaseRole(t *testing.T) { }) } +func TestAcc_GrantOwnership_OnObject_Procedure_ToAccountRole(t *testing.T) { + databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() + databaseName := databaseId.Name() + schemaId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) + schemaName := schemaId.Name() + procedureId := acc.TestClient().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(acc.TestClient().Ids.Alpha(), schemaId, sdk.DataTypeFloat) + accountRoleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() + accountRoleName := accountRoleId.Name() + + configVariables := config.Variables{ + "account_role_name": config.StringVariable(accountRoleName), + "database_name": config.StringVariable(databaseName), + "schema_name": config.StringVariable(schemaName), + "procedure_name": config.StringVariable(procedureId.Name()), + } + resourceName := "snowflake_grant_ownership.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + Steps: []resource.TestStep{ + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole"), + ConfigVariables: configVariables, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "account_role_name", accountRoleName), + resource.TestCheckResourceAttr(resourceName, "on.0.object_type", "PROCEDURE"), + resource.TestCheckResourceAttr(resourceName, "on.0.object_name", procedureId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("ToAccountRole|%s||OnObject|PROCEDURE|%s", accountRoleId.FullyQualifiedName(), procedureId.FullyQualifiedName())), + checkResourceOwnershipIsGranted(&sdk.ShowGrantOptions{ + To: &sdk.ShowGrantsTo{ + Role: accountRoleId, + }, + }, sdk.ObjectTypeProcedure, accountRoleName, procedureId.FullyQualifiedName()), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAcc_GrantOwnership_OnObject_ProcedureWithoutArguments_ToDatabaseRole(t *testing.T) { + databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() + databaseName := databaseId.Name() + schemaId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) + schemaName := schemaId.Name() + procedureId := acc.TestClient().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(acc.TestClient().Ids.Alpha(), schemaId) + + databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) + databaseRoleName := databaseRoleId.Name() + databaseRoleFullyQualifiedName := databaseRoleId.FullyQualifiedName() + + configVariables := config.Variables{ + "database_role_name": config.StringVariable(databaseRoleName), + "database_name": config.StringVariable(databaseName), + "schema_name": config.StringVariable(schemaName), + "procedure_name": config.StringVariable(procedureId.Name()), + } + resourceName := "snowflake_grant_ownership.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + Steps: []resource.TestStep{ + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole"), + ConfigVariables: configVariables, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "database_role_name", databaseRoleFullyQualifiedName), + resource.TestCheckResourceAttr(resourceName, "on.0.object_type", "PROCEDURE"), + resource.TestCheckResourceAttr(resourceName, "on.0.object_name", procedureId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("ToDatabaseRole|%s||OnObject|PROCEDURE|%s", databaseRoleFullyQualifiedName, procedureId.FullyQualifiedName())), + checkResourceOwnershipIsGranted(&sdk.ShowGrantOptions{ + To: &sdk.ShowGrantsTo{ + DatabaseRole: databaseRoleId, + }, + }, sdk.ObjectTypeProcedure, databaseRoleName, procedureId.FullyQualifiedName()), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func TestAcc_GrantOwnership_OnAll_InDatabase_ToAccountRole(t *testing.T) { databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() accountRoleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() diff --git a/pkg/resources/grant_ownership_identifier_test.go b/pkg/resources/grant_ownership_identifier_test.go index e0345209ef..54ea170adc 100644 --- a/pkg/resources/grant_ownership_identifier_test.go +++ b/pkg/resources/grant_ownership_identifier_test.go @@ -130,6 +130,34 @@ func TestParseGrantOwnershipId(t *testing.T) { }, }, }, + { + Name: "grant ownership on function to account role", + Identifier: `ToAccountRole|"account-role"|COPY|OnObject|FUNCTION|"database-name"."schema-name"."function-name"(FLOAT)`, + Expected: GrantOwnershipId{ + GrantOwnershipTargetRoleKind: ToAccountGrantOwnershipTargetRoleKind, + AccountRoleName: sdk.NewAccountObjectIdentifier("account-role"), + OutboundPrivilegesBehavior: sdk.Pointer(CopyOutboundPrivilegesBehavior), + Kind: OnObjectGrantOwnershipKind, + Data: &OnObjectGrantOwnershipData{ + ObjectType: sdk.ObjectTypeFunction, + ObjectName: sdk.NewSchemaObjectIdentifierWithArguments("database-name", "schema-name", "function-name", sdk.DataTypeFloat), + }, + }, + }, + { + Name: "grant ownership on function without arguments to database role", + Identifier: `ToDatabaseRole|"database-name"."database-role"|REVOKE|OnObject|FUNCTION|"database-name"."schema-name"."function-name"()`, + Expected: GrantOwnershipId{ + GrantOwnershipTargetRoleKind: ToDatabaseGrantOwnershipTargetRoleKind, + DatabaseRoleName: sdk.NewDatabaseObjectIdentifier("database-name", "database-role"), + OutboundPrivilegesBehavior: sdk.Pointer(RevokeOutboundPrivilegesBehavior), + Kind: OnObjectGrantOwnershipKind, + Data: &OnObjectGrantOwnershipData{ + ObjectType: sdk.ObjectTypeFunction, + ObjectName: sdk.NewSchemaObjectIdentifierWithArguments("database-name", "schema-name", "function-name", []sdk.DataType{}...), + }, + }, + }, { Name: "validation: not enough parts", Identifier: `ToDatabaseRole|"database-name"."role-name"|`, diff --git a/pkg/resources/grant_privileges_to_account_role.go b/pkg/resources/grant_privileges_to_account_role.go index 56fab09757..f014746ccb 100644 --- a/pkg/resources/grant_privileges_to_account_role.go +++ b/pkg/resources/grant_privileges_to_account_role.go @@ -220,7 +220,6 @@ var grantPrivilegesToAccountRoleSchema = map[string]*schema.Schema{ "on_schema_object.0.all", "on_schema_object.0.future", }, - ValidateDiagFunc: IsValidIdentifier[sdk.SchemaObjectIdentifier](), }, "all": { Type: schema.TypeList, @@ -405,13 +404,20 @@ func CreateGrantPrivilegesToAccountRole(ctx context.Context, d *schema.ResourceD logging.DebugLogger.Printf("[DEBUG] Entering create grant privileges to account role") client := meta.(*provider.Context).Client - id := createGrantPrivilegesToAccountRoleIdFromSchema(d) + id, err := createGrantPrivilegesToAccountRoleIdFromSchema(d) + if err != nil { + return diag.FromErr(err) + } logging.DebugLogger.Printf("[DEBUG] created identifier from schema: %s", id.String()) - err := client.Grants.GrantPrivilegesToAccountRole( + grantOn, err := getAccountRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } + err = client.Grants.GrantPrivilegesToAccountRole( ctx, getAccountRolePrivilegesFromSchema(d), - getAccountRoleGrantOn(d), + grantOn, sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("account_role_name").(string)), &sdk.GrantPrivilegesToAccountRoleOptions{ WithGrantOption: sdk.Bool(d.Get("with_grant_option").(bool)), @@ -456,17 +462,20 @@ func UpdateGrantPrivilegesToAccountRole(ctx context.Context, d *schema.ResourceD // handle all_privileges -> privileges change (revoke all privileges) if d.HasChange("all_privileges") { _, allPrivileges := d.GetChange("all_privileges") + grantOn, err := getAccountRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } if !allPrivileges.(bool) { logging.DebugLogger.Printf("[DEBUG] Revoking all privileges") err = client.Grants.RevokePrivilegesFromAccountRole(ctx, &sdk.AccountRoleGrantPrivileges{ AllPrivileges: sdk.Bool(true), }, - getAccountRoleGrantOn(d), + grantOn, id.RoleName, new(sdk.RevokePrivilegesFromAccountRoleOptions), ) - if err != nil { return diag.Diagnostics{ diag.Diagnostic{ @@ -513,7 +522,10 @@ func UpdateGrantPrivilegesToAccountRole(ctx context.Context, d *schema.ResourceD } } - grantOn := getAccountRoleGrantOn(d) + grantOn, err := getAccountRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } if len(privilegesToAdd) > 0 { logging.DebugLogger.Printf("[DEBUG] Granting privileges: %v", privilegesToAdd) @@ -589,14 +601,17 @@ func UpdateGrantPrivilegesToAccountRole(ctx context.Context, d *schema.ResourceD if allPrivileges.(bool) { logging.DebugLogger.Printf("[DEBUG] Granting all privileges") + grantOn, err := getAccountRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } err = client.Grants.GrantPrivilegesToAccountRole(ctx, &sdk.AccountRoleGrantPrivileges{ AllPrivileges: sdk.Bool(true), }, - getAccountRoleGrantOn(d), + grantOn, id.RoleName, new(sdk.GrantPrivilegesToAccountRoleOptions), ) - if err != nil { return diag.Diagnostics{ diag.Diagnostic{ @@ -617,10 +632,14 @@ func UpdateGrantPrivilegesToAccountRole(ctx context.Context, d *schema.ResourceD if id.AlwaysApply { logging.DebugLogger.Printf("[DEBUG] Performing always_apply re-grant") - err := client.Grants.GrantPrivilegesToAccountRole( + grantOn, err := getAccountRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } + err = client.Grants.GrantPrivilegesToAccountRole( ctx, getAccountRolePrivilegesFromSchema(d), - getAccountRoleGrantOn(d), + grantOn, id.RoleName, &sdk.GrantPrivilegesToAccountRoleOptions{ WithGrantOption: &id.WithGrantOption, @@ -659,10 +678,15 @@ func DeleteGrantPrivilegesToAccountRole(ctx context.Context, d *schema.ResourceD } logging.DebugLogger.Printf("[DEBUG] Parsed identifier: %s", id.String()) + grantOn, err := getAccountRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } + err = client.Grants.RevokePrivilegesFromAccountRole( ctx, getAccountRolePrivilegesFromSchema(d), - getAccountRoleGrantOn(d), + grantOn, id.RoleName, &sdk.RevokePrivilegesFromAccountRoleOptions{}, ) @@ -966,7 +990,7 @@ func getAccountRolePrivileges(allPrivileges bool, privileges []string, onAccount return accountRoleGrantPrivileges } -func getAccountRoleGrantOn(d *schema.ResourceData) *sdk.AccountRoleGrantOn { +func getAccountRoleGrantOn(d *schema.ResourceData) (*sdk.AccountRoleGrantOn, error) { _, onAccountOk := d.GetOk("on_account") onAccountObjectBlock, onAccountObjectOk := d.GetOk("on_account_object") onSchemaBlock, onSchemaOk := d.GetOk("on_schema") @@ -1050,9 +1074,20 @@ func getAccountRoleGrantOn(d *schema.ResourceData) *sdk.AccountRoleGrantOn { switch { case objectTypeOk && objectNameOk: + objectType := sdk.ObjectType(objectType) + var id sdk.ObjectIdentifier + if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + var err error + id, err = sdk.ParseSchemaObjectIdentifierWithArguments(objectName) + if err != nil { + return nil, err + } + } else { + id = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(objectName) + } grantOnSchemaObject.SchemaObject = &sdk.Object{ - ObjectType: sdk.ObjectType(objectType), - Name: sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(objectName), + ObjectType: objectType, + Name: id, } case allOk: grantOnSchemaObject.All = getGrantOnSchemaObjectIn(all[0].(map[string]any)) @@ -1063,10 +1098,10 @@ func getAccountRoleGrantOn(d *schema.ResourceData) *sdk.AccountRoleGrantOn { on.SchemaObject = grantOnSchemaObject } - return on + return on, nil } -func createGrantPrivilegesToAccountRoleIdFromSchema(d *schema.ResourceData) *GrantPrivilegesToAccountRoleId { +func createGrantPrivilegesToAccountRoleIdFromSchema(d *schema.ResourceData) (*GrantPrivilegesToAccountRoleId, error) { id := new(GrantPrivilegesToAccountRoleId) id.RoleName = sdk.NewAccountObjectIdentifierFromFullyQualifiedName(d.Get("account_role_name").(string)) id.AllPrivileges = d.Get("all_privileges").(bool) @@ -1076,7 +1111,10 @@ func createGrantPrivilegesToAccountRoleIdFromSchema(d *schema.ResourceData) *Gra id.WithGrantOption = d.Get("with_grant_option").(bool) id.AlwaysApply = d.Get("always_apply").(bool) - on := getAccountRoleGrantOn(d) + on, err := getAccountRoleGrantOn(d) + if err != nil { + return nil, err + } switch { case on.Account != nil: id.Kind = OnAccountAccountRoleGrantKind @@ -1149,5 +1187,5 @@ func createGrantPrivilegesToAccountRoleIdFromSchema(d *schema.ResourceData) *Gra id.Data = onSchemaObjectGrantData } - return id + return id, nil } diff --git a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go index 2699b0744d..5f8a21e782 100644 --- a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go @@ -445,6 +445,118 @@ func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnObject(t *testing.T) }) } +func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunction(t *testing.T) { + acc.TestAccPreCheck(t) + + roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() + roleFullyQualifiedName := roleId.FullyQualifiedName() + function, functionCleanup := acc.TestClient().Function.CreateFunction(t) + t.Cleanup(functionCleanup) + configVariables := config.Variables{ + "name": config.StringVariable(roleFullyQualifiedName), + "function_name": config.StringVariable(function.ID().Name()), + "privileges": config.ListVariable( + config.StringVariable(string(sdk.SchemaObjectPrivilegeUsage)), + ), + "database": config.StringVariable(acc.TestDatabaseName), + "schema": config.StringVariable(acc.TestSchemaName), + "with_grant_option": config.BoolVariable(false), + "argument_type": config.StringVariable(string(sdk.DataTypeFloat)), + } + resourceName := "snowflake_grant_privileges_to_account_role.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), + Steps: []resource.TestStep{ + { + PreConfig: func() { + _, roleCleanup := acc.TestClient().Role.CreateRoleWithIdentifier(t, roleId) + t.Cleanup(roleCleanup) + }, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToAccountRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "account_role_name", roleFullyQualifiedName), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.SchemaObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.#", "1"), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_type", string(sdk.ObjectTypeFunction)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_name", function.ID().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "with_grant_option", "false"), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|%s", roleFullyQualifiedName, function.ID().FullyQualifiedName())), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToAccountRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunctionWithoutArguments(t *testing.T) { + acc.TestAccPreCheck(t) + + roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() + roleFullyQualifiedName := roleId.FullyQualifiedName() + function, functionCleanup := acc.TestClient().Function.CreateFunctionWithoutArguments(t) + t.Cleanup(functionCleanup) + configVariables := config.Variables{ + "name": config.StringVariable(roleFullyQualifiedName), + "function_name": config.StringVariable(function.ID().Name()), + "privileges": config.ListVariable( + config.StringVariable(string(sdk.SchemaObjectPrivilegeUsage)), + ), + "database": config.StringVariable(acc.TestDatabaseName), + "schema": config.StringVariable(acc.TestSchemaName), + "with_grant_option": config.BoolVariable(false), + "argument_type": config.StringVariable(""), + } + resourceName := "snowflake_grant_privileges_to_account_role.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), + Steps: []resource.TestStep{ + { + PreConfig: func() { + _, roleCleanup := acc.TestClient().Role.CreateRoleWithIdentifier(t, roleId) + t.Cleanup(roleCleanup) + }, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToAccountRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "account_role_name", roleFullyQualifiedName), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.SchemaObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.#", "1"), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_type", string(sdk.ObjectTypeFunction)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_name", function.ID().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "with_grant_option", "false"), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|%s", roleFullyQualifiedName, function.ID().FullyQualifiedName())), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToAccountRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnObject_OwnershipPrivilege(t *testing.T) { roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() name := roleId.Name() @@ -1588,8 +1700,8 @@ func createExternalVolume(t *testing.T, externalVolumeName string) func() { ctx := context.Background() _, err := client.ExecForTests(ctx, fmt.Sprintf(`create external volume "%s" storage_locations = ( ( - name = 'test' - storage_provider = 's3' + name = 'test' + storage_provider = 's3' storage_base_url = 's3://my_example_bucket/' storage_aws_role_arn = 'arn:aws:iam::123456789012:role/myrole' encryption=(type='aws_sse_kms' kms_key_id='1234abcd-12ab-34cd-56ef-1234567890ab') diff --git a/pkg/resources/grant_privileges_to_account_role_identifier.go b/pkg/resources/grant_privileges_to_account_role_identifier.go index 024e18bc11..6612f5b3c1 100644 --- a/pkg/resources/grant_privileges_to_account_role_identifier.go +++ b/pkg/resources/grant_privileges_to_account_role_identifier.go @@ -2,6 +2,7 @@ package resources import ( "fmt" + "slices" "strconv" "strings" @@ -137,9 +138,20 @@ func ParseGrantPrivilegesToAccountRoleId(id string) (GrantPrivilegesToAccountRol if len(parts) != 8 { return accountRoleId, sdk.NewError(`account role identifier should hold 8 parts "||||OnSchemaObject|OnObject||"`) } + objectType := sdk.ObjectType(parts[6]) + var id sdk.ObjectIdentifier + if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + var err error + id, err = sdk.ParseSchemaObjectIdentifierWithArguments(parts[7]) + if err != nil { + return accountRoleId, err + } + } else { + id = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(parts[7]) + } onSchemaObjectGrantData.Object = &sdk.Object{ - ObjectType: sdk.ObjectType(parts[6]), - Name: sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(parts[7]), + ObjectType: objectType, + Name: id, } case OnAllSchemaObjectGrantKind, OnFutureSchemaObjectGrantKind: bulkOperationGrantData := &BulkOperationGrantData{ diff --git a/pkg/resources/grant_privileges_to_account_role_identifier_test.go b/pkg/resources/grant_privileges_to_account_role_identifier_test.go index 41905d4705..8a047021d6 100644 --- a/pkg/resources/grant_privileges_to_account_role_identifier_test.go +++ b/pkg/resources/grant_privileges_to_account_role_identifier_test.go @@ -122,6 +122,40 @@ func TestParseGrantPrivilegesToAccountRoleId(t *testing.T) { }, }, }, + { + Name: "grant account role on function", + Identifier: `"account-role"|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|"database-name"."schema-name"."function-name"(FLOAT)`, + Expected: GrantPrivilegesToAccountRoleId{ + RoleName: sdk.NewAccountObjectIdentifier("account-role"), + WithGrantOption: false, + Privileges: []string{"USAGE"}, + Kind: OnSchemaObjectAccountRoleGrantKind, + Data: &OnSchemaObjectGrantData{ + Kind: OnObjectSchemaObjectGrantKind, + Object: &sdk.Object{ + ObjectType: sdk.ObjectTypeFunction, + Name: sdk.NewSchemaObjectIdentifierWithArguments("database-name", "schema-name", "function-name", sdk.DataTypeFloat), + }, + }, + }, + }, + { + Name: "grant account role on function without arguments", + Identifier: `"account-role"|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|"database-name"."schema-name"."function-name"()`, + Expected: GrantPrivilegesToAccountRoleId{ + RoleName: sdk.NewAccountObjectIdentifier("account-role"), + WithGrantOption: false, + Privileges: []string{"USAGE"}, + Kind: OnSchemaObjectAccountRoleGrantKind, + Data: &OnSchemaObjectGrantData{ + Kind: OnObjectSchemaObjectGrantKind, + Object: &sdk.Object{ + ObjectType: sdk.ObjectTypeFunction, + Name: sdk.NewSchemaObjectIdentifierWithArguments("database-name", "schema-name", "function-name", []sdk.DataType{}...), + }, + }, + }, + }, { Name: "grant account role on schema object with on all option", Identifier: `"account-role"|false|false|CREATE SCHEMA,USAGE,MONITOR|OnSchemaObject|OnAll|TABLES`, diff --git a/pkg/resources/grant_privileges_to_database_role.go b/pkg/resources/grant_privileges_to_database_role.go index 75b151f11c..07f1dd8788 100644 --- a/pkg/resources/grant_privileges_to_database_role.go +++ b/pkg/resources/grant_privileges_to_database_role.go @@ -172,7 +172,6 @@ var grantPrivilegesToDatabaseRoleSchema = map[string]*schema.Schema{ "on_schema_object.0.all", "on_schema_object.0.future", }, - ValidateDiagFunc: IsValidIdentifier[sdk.SchemaObjectIdentifier](), }, "all": { Type: schema.TypeList, @@ -342,11 +341,19 @@ func ImportGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource func CreateGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - id := createGrantPrivilegesToDatabaseRoleIdFromSchema(d) - err := client.Grants.GrantPrivilegesToDatabaseRole( + id, err := createGrantPrivilegesToDatabaseRoleIdFromSchema(d) + if err != nil { + return diag.FromErr(err) + } + grantOn, err := getDatabaseRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } + + err = client.Grants.GrantPrivilegesToDatabaseRole( ctx, getDatabaseRolePrivilegesFromSchema(d), - getDatabaseRoleGrantOn(d), + grantOn, sdk.NewDatabaseObjectIdentifierFromFullyQualifiedName(d.Get("database_role_name").(string)), &sdk.GrantPrivilegesToDatabaseRoleOptions{ WithGrantOption: sdk.Bool(d.Get("with_grant_option").(bool)), @@ -384,6 +391,11 @@ func UpdateGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource id.WithGrantOption = d.Get("with_grant_option").(bool) } + grantOn, err := getDatabaseRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } + // handle all_privileges -> privileges change (revoke all privileges) if d.HasChange("all_privileges") { _, allPrivileges := d.GetChange("all_privileges") @@ -392,11 +404,10 @@ func UpdateGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource err = client.Grants.RevokePrivilegesFromDatabaseRole(ctx, &sdk.DatabaseRoleGrantPrivileges{ AllPrivileges: sdk.Bool(true), }, - getDatabaseRoleGrantOn(d), + grantOn, id.DatabaseRoleName, new(sdk.RevokePrivilegesFromDatabaseRoleOptions), ) - if err != nil { return diag.Diagnostics{ diag.Diagnostic{ @@ -441,8 +452,6 @@ func UpdateGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource } } - grantOn := getDatabaseRoleGrantOn(d) - if len(privilegesToAdd) > 0 { privilegesToGrant := getDatabaseRolePrivileges( false, @@ -515,11 +524,10 @@ func UpdateGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource err = client.Grants.GrantPrivilegesToDatabaseRole(ctx, &sdk.DatabaseRoleGrantPrivileges{ AllPrivileges: sdk.Bool(true), }, - getDatabaseRoleGrantOn(d), + grantOn, id.DatabaseRoleName, new(sdk.GrantPrivilegesToDatabaseRoleOptions), ) - if err != nil { return diag.Diagnostics{ diag.Diagnostic{ @@ -539,10 +547,10 @@ func UpdateGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource } if id.AlwaysApply { - err := client.Grants.GrantPrivilegesToDatabaseRole( + err = client.Grants.GrantPrivilegesToDatabaseRole( ctx, getDatabaseRolePrivilegesFromSchema(d), - getDatabaseRoleGrantOn(d), + grantOn, id.DatabaseRoleName, &sdk.GrantPrivilegesToDatabaseRoleOptions{ WithGrantOption: &id.WithGrantOption, @@ -576,11 +584,14 @@ func DeleteGrantPrivilegesToDatabaseRole(ctx context.Context, d *schema.Resource }, } } - + grantOn, err := getDatabaseRoleGrantOn(d) + if err != nil { + return diag.FromErr(err) + } err = client.Grants.RevokePrivilegesFromDatabaseRole( ctx, getDatabaseRolePrivilegesFromSchema(d), - getDatabaseRoleGrantOn(d), + grantOn, id.DatabaseRoleName, &sdk.RevokePrivilegesFromDatabaseRoleOptions{}, ) @@ -840,7 +851,7 @@ func getDatabaseRolePrivileges(allPrivileges bool, privileges []string, onDataba return databaseRoleGrantPrivileges } -func getDatabaseRoleGrantOn(d *schema.ResourceData) *sdk.DatabaseRoleGrantOn { +func getDatabaseRoleGrantOn(d *schema.ResourceData) (*sdk.DatabaseRoleGrantOn, error) { onDatabase, onDatabaseOk := d.GetOk("on_database") onSchemaBlock, onSchemaOk := d.GetOk("on_schema") onSchemaObjectBlock, onSchemaObjectOk := d.GetOk("on_schema_object") @@ -892,9 +903,20 @@ func getDatabaseRoleGrantOn(d *schema.ResourceData) *sdk.DatabaseRoleGrantOn { switch { case objectTypeOk && objectNameOk: + objectType := sdk.ObjectType(objectType) + var id sdk.ObjectIdentifier + if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + var err error + id, err = sdk.ParseSchemaObjectIdentifierWithArguments(objectName) + if err != nil { + return nil, err + } + } else { + id = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(objectName) + } grantOnSchemaObject.SchemaObject = &sdk.Object{ - ObjectType: sdk.ObjectType(objectType), - Name: sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(objectName), + ObjectType: objectType, + Name: id, } case allOk: grantOnSchemaObject.All = getGrantOnSchemaObjectIn(all[0].(map[string]any)) @@ -905,10 +927,10 @@ func getDatabaseRoleGrantOn(d *schema.ResourceData) *sdk.DatabaseRoleGrantOn { on.SchemaObject = grantOnSchemaObject } - return on + return on, nil } -func createGrantPrivilegesToDatabaseRoleIdFromSchema(d *schema.ResourceData) *GrantPrivilegesToDatabaseRoleId { +func createGrantPrivilegesToDatabaseRoleIdFromSchema(d *schema.ResourceData) (*GrantPrivilegesToDatabaseRoleId, error) { id := new(GrantPrivilegesToDatabaseRoleId) id.DatabaseRoleName = sdk.NewDatabaseObjectIdentifierFromFullyQualifiedName(d.Get("database_role_name").(string)) id.AllPrivileges = d.Get("all_privileges").(bool) @@ -918,7 +940,10 @@ func createGrantPrivilegesToDatabaseRoleIdFromSchema(d *schema.ResourceData) *Gr id.WithGrantOption = d.Get("with_grant_option").(bool) id.AlwaysApply = d.Get("always_apply").(bool) - on := getDatabaseRoleGrantOn(d) + on, err := getDatabaseRoleGrantOn(d) + if err != nil { + return nil, err + } switch { case on.Database != nil: id.Kind = OnDatabaseDatabaseRoleGrantKind @@ -961,5 +986,5 @@ func createGrantPrivilegesToDatabaseRoleIdFromSchema(d *schema.ResourceData) *Gr id.Data = onSchemaObjectGrantData } - return id + return id, nil } diff --git a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go index e00b5c842c..e2c49c0cf6 100644 --- a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go @@ -625,6 +625,116 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnAll_Streamlits_InDat }) } +func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunction(t *testing.T) { + acc.TestAccPreCheck(t) + + databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() + function, functionCleanup := acc.TestClient().Function.CreateFunction(t) + t.Cleanup(functionCleanup) + configVariables := config.Variables{ + "name": config.StringVariable(databaseRoleId.FullyQualifiedName()), + "function_name": config.StringVariable(function.ID().Name()), + "privileges": config.ListVariable( + config.StringVariable(string(sdk.SchemaObjectPrivilegeUsage)), + ), + "database": config.StringVariable(acc.TestDatabaseName), + "schema": config.StringVariable(acc.TestSchemaName), + "with_grant_option": config.BoolVariable(false), + "argument_type": config.StringVariable(string(sdk.DataTypeFloat)), + } + resourceName := "snowflake_grant_privileges_to_database_role.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), + Steps: []resource.TestStep{ + { + PreConfig: func() { + _, roleCleanup := acc.TestClient().DatabaseRole.CreateDatabaseRoleWithName(t, databaseRoleId.Name()) + t.Cleanup(roleCleanup) + }, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToDatabaseRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "database_role_name", databaseRoleId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.SchemaObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.#", "1"), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_type", string(sdk.ObjectTypeFunction)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_name", function.ID().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "with_grant_option", "false"), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|%s", databaseRoleId.FullyQualifiedName(), function.ID().FullyQualifiedName())), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToDatabaseRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunctionWithoutArguments(t *testing.T) { + acc.TestAccPreCheck(t) + + databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() + function, functionCleanup := acc.TestClient().Function.CreateFunctionWithoutArguments(t) + t.Cleanup(functionCleanup) + configVariables := config.Variables{ + "name": config.StringVariable(databaseRoleId.FullyQualifiedName()), + "function_name": config.StringVariable(function.ID().Name()), + "privileges": config.ListVariable( + config.StringVariable(string(sdk.SchemaObjectPrivilegeUsage)), + ), + "database": config.StringVariable(acc.TestDatabaseName), + "schema": config.StringVariable(acc.TestSchemaName), + "with_grant_option": config.BoolVariable(false), + "argument_type": config.StringVariable(""), + } + resourceName := "snowflake_grant_privileges_to_database_role.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), + Steps: []resource.TestStep{ + { + PreConfig: func() { + _, roleCleanup := acc.TestClient().DatabaseRole.CreateDatabaseRoleWithName(t, databaseRoleId.Name()) + t.Cleanup(roleCleanup) + }, + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToDatabaseRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "database_role_name", databaseRoleId.FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.SchemaObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.#", "1"), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_type", string(sdk.ObjectTypeFunction)), + resource.TestCheckResourceAttr(resourceName, "on_schema_object.0.object_name", function.ID().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "with_grant_option", "false"), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|%s", databaseRoleId.FullyQualifiedName(), function.ID().FullyQualifiedName())), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToDatabaseRole/OnSchemaObject_OnFunction"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func TestAcc_GrantPrivilegesToDatabaseRole_UpdatePrivileges(t *testing.T) { databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() diff --git a/pkg/resources/grant_privileges_to_database_role_identifier.go b/pkg/resources/grant_privileges_to_database_role_identifier.go index 973b7ab024..4c45cfa5b9 100644 --- a/pkg/resources/grant_privileges_to_database_role_identifier.go +++ b/pkg/resources/grant_privileges_to_database_role_identifier.go @@ -2,6 +2,7 @@ package resources import ( "fmt" + "slices" "strconv" "strings" @@ -127,9 +128,20 @@ func ParseGrantPrivilegesToDatabaseRoleId(id string) (GrantPrivilegesToDatabaseR if len(parts) != 8 { return databaseRoleId, sdk.NewError(`database role identifier should hold 8 parts "||||OnSchemaObject|OnObject||"`) } + objectType := sdk.ObjectType(parts[6]) + var id sdk.ObjectIdentifier + if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + var err error + id, err = sdk.ParseSchemaObjectIdentifierWithArguments(parts[7]) + if err != nil { + return databaseRoleId, err + } + } else { + id = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(parts[7]) + } onSchemaObjectGrantData.Object = &sdk.Object{ - ObjectType: sdk.ObjectType(parts[6]), - Name: sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(parts[7]), + ObjectType: objectType, + Name: id, } case OnAllSchemaObjectGrantKind, OnFutureSchemaObjectGrantKind: bulkOperationGrantData := &BulkOperationGrantData{ diff --git a/pkg/resources/grant_privileges_to_database_role_identifier_test.go b/pkg/resources/grant_privileges_to_database_role_identifier_test.go index ea8cd4404b..a9ed631605 100644 --- a/pkg/resources/grant_privileges_to_database_role_identifier_test.go +++ b/pkg/resources/grant_privileges_to_database_role_identifier_test.go @@ -114,6 +114,40 @@ func TestParseGrantPrivilegesToDatabaseRoleId(t *testing.T) { }, }, }, + { + Name: "grant database role on function", + Identifier: `"database-name"."database-role"|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|"database-name"."schema-name"."function-name"(FLOAT)`, + Expected: GrantPrivilegesToDatabaseRoleId{ + DatabaseRoleName: sdk.NewDatabaseObjectIdentifier("database-name", "database-role"), + WithGrantOption: false, + Privileges: []string{"USAGE"}, + Kind: OnSchemaObjectDatabaseRoleGrantKind, + Data: &OnSchemaObjectGrantData{ + Kind: OnObjectSchemaObjectGrantKind, + Object: &sdk.Object{ + ObjectType: sdk.ObjectTypeFunction, + Name: sdk.NewSchemaObjectIdentifierWithArguments("database-name", "schema-name", "function-name", sdk.DataTypeFloat), + }, + }, + }, + }, + { + Name: "grant database role on function without arguments", + Identifier: `"database-name"."database-role"|false|false|USAGE|OnSchemaObject|OnObject|FUNCTION|"database-name"."schema-name"."function-name"()`, + Expected: GrantPrivilegesToDatabaseRoleId{ + DatabaseRoleName: sdk.NewDatabaseObjectIdentifier("database-name", "database-role"), + WithGrantOption: false, + Privileges: []string{"USAGE"}, + Kind: OnSchemaObjectDatabaseRoleGrantKind, + Data: &OnSchemaObjectGrantData{ + Kind: OnObjectSchemaObjectGrantKind, + Object: &sdk.Object{ + ObjectType: sdk.ObjectTypeFunction, + Name: sdk.NewSchemaObjectIdentifierWithArguments("database-name", "schema-name", "function-name", []sdk.DataType{}...), + }, + }, + }, + }, { Name: "grant database role on schema object with on all option", Identifier: `"database-name"."database-role"|false|false|CREATE SCHEMA,USAGE,MONITOR|OnSchemaObject|OnAll|TABLES`, diff --git a/pkg/resources/grant_privileges_to_share.go b/pkg/resources/grant_privileges_to_share.go index bc7cd6573b..6d5e232d47 100644 --- a/pkg/resources/grant_privileges_to_share.go +++ b/pkg/resources/grant_privileges_to_share.go @@ -17,7 +17,7 @@ import ( var grantPrivilegesToShareGrantExactlyOneOfValidation = []string{ "on_database", "on_schema", - // TODO(SNOW-990811): "function_name", + "on_function", "on_table", "on_all_tables_in_schema", "on_tag", @@ -54,15 +54,6 @@ var grantPrivilegesToShareSchema = map[string]*schema.Schema{ ValidateDiagFunc: IsValidIdentifier[sdk.DatabaseObjectIdentifier](), ExactlyOneOf: grantPrivilegesToShareGrantExactlyOneOfValidation, }, - // TODO(SNOW-1021686): Because function identifier contains arguments which are not supported right now - // "function_name": { - // Type: schema.TypeString, - // Optional: true, - // ForceNew: true, - // Description: "The fully qualified name of the function on which privileges will be granted.", - // ValidateDiagFunc: IsValidIdentifier[sdk.FunctionIdentifier](), - // ExactlyOneOf: grantPrivilegesToShareGrantExactlyOneOfValidation, - // }, "on_table": { Type: schema.TypeString, Optional: true, @@ -95,6 +86,13 @@ var grantPrivilegesToShareSchema = map[string]*schema.Schema{ ValidateDiagFunc: IsValidIdentifier[sdk.SchemaObjectIdentifier](), ExactlyOneOf: grantPrivilegesToShareGrantExactlyOneOfValidation, }, + "on_function": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + Description: "The fully qualified name of the function on which privileges will be granted.", + ExactlyOneOf: grantPrivilegesToShareGrantExactlyOneOfValidation, + }, } func GrantPrivilegesToShare() *schema.Resource { @@ -133,10 +131,10 @@ func ImportGrantPrivilegesToShare() func(ctx context.Context, d *schema.Resource if err := d.Set("on_schema", id.Identifier.FullyQualifiedName()); err != nil { return nil, err } - // TODO(SNOW-990811) case OnFunctionShareGrantKind: - // if err := d.Set("function_name", id.Identifier.FullyQualifiedName()); err != nil { - // return nil, err - // } + case OnFunctionShareGrantKind: + if err := d.Set("on_function", id.Identifier.FullyQualifiedName()); err != nil { + return nil, err + } case OnTableShareGrantKind: if err := d.Set("on_table", id.Identifier.FullyQualifiedName()); err != nil { return nil, err @@ -161,10 +159,17 @@ func ImportGrantPrivilegesToShare() func(ctx context.Context, d *schema.Resource func CreateGrantPrivilegesToShare(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { client := meta.(*provider.Context).Client - id := createGrantPrivilegesToShareIdFromSchema(d) + id, err := createGrantPrivilegesToShareIdFromSchema(d) + if err != nil { + return diag.FromErr(err) + } log.Printf("[DEBUG] created identifier from schema: %s", id.String()) + grantOn, err := getShareGrantOn(d) + if err != nil { + return diag.FromErr(err) + } - err := client.Grants.GrantPrivilegeToShare(ctx, getObjectPrivilegesFromSchema(d), getShareGrantOn(d), id.ShareName) + err = client.Grants.GrantPrivilegeToShare(ctx, getObjectPrivilegesFromSchema(d), grantOn, id.ShareName) if err != nil { return diag.Diagnostics{ diag.Diagnostic{ @@ -213,7 +218,10 @@ func UpdateGrantPrivilegesToShare(ctx context.Context, d *schema.ResourceData, m } } - grantOn := getShareGrantOn(d) + grantOn, err := getShareGrantOn(d) + if err != nil { + return diag.FromErr(err) + } if len(privilegesToAdd) > 0 { err = client.Grants.GrantPrivilegeToShare( @@ -271,8 +279,12 @@ func DeleteGrantPrivilegesToShare(ctx context.Context, d *schema.ResourceData, m }, } } + grantOn, err := getShareGrantOn(d) + if err != nil { + return diag.FromErr(err) + } - err = client.Grants.RevokePrivilegeFromShare(ctx, getObjectPrivilegesFromSchema(d), getShareGrantOn(d), id.ShareName) + err = client.Grants.RevokePrivilegeFromShare(ctx, getObjectPrivilegesFromSchema(d), grantOn, id.ShareName) if err != nil { return diag.Diagnostics{ diag.Diagnostic{ @@ -375,14 +387,14 @@ func ReadGrantPrivilegesToShare(ctx context.Context, d *schema.ResourceData, met return nil } -func createGrantPrivilegesToShareIdFromSchema(d *schema.ResourceData) *GrantPrivilegesToShareId { +func createGrantPrivilegesToShareIdFromSchema(d *schema.ResourceData) (*GrantPrivilegesToShareId, error) { id := new(GrantPrivilegesToShareId) id.ShareName = sdk.NewAccountObjectIdentifier(d.Get("to_share").(string)) id.Privileges = expandStringList(d.Get("privileges").(*schema.Set).List()) databaseName, databaseNameOk := d.GetOk("on_database") schemaName, schemaNameOk := d.GetOk("on_schema") - // TODO(SNOW-990811) functionName, functionNameOk := d.GetOk("function_name") + functionName, functionNameOk := d.GetOk("on_function") tableName, tableNameOk := d.GetOk("on_table") allTablesInSchema, allTablesInSchemaOk := d.GetOk("on_all_tables_in_schema") tagName, tagNameOk := d.GetOk("on_tag") @@ -395,9 +407,13 @@ func createGrantPrivilegesToShareIdFromSchema(d *schema.ResourceData) *GrantPriv case schemaNameOk: id.Kind = OnSchemaShareGrantKind id.Identifier = sdk.NewDatabaseObjectIdentifierFromFullyQualifiedName(schemaName.(string)) - // TODO(SNOW-990811) case functionNameOk: - // id.Kind = OnFunctionShareGrantKind - // id.Identifier = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(functionName.(string)) + case functionNameOk: + id.Kind = OnFunctionShareGrantKind + parsed, err := sdk.ParseSchemaObjectIdentifierWithArguments(functionName.(string)) + if err != nil { + return nil, err + } + id.Identifier = parsed case tableNameOk: id.Kind = OnTableShareGrantKind id.Identifier = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(tableName.(string)) @@ -412,7 +428,7 @@ func createGrantPrivilegesToShareIdFromSchema(d *schema.ResourceData) *GrantPriv id.Identifier = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(viewName.(string)) } - return id + return id, nil } func getObjectPrivilegesFromSchema(d *schema.ResourceData) []sdk.ObjectPrivilege { @@ -424,12 +440,12 @@ func getObjectPrivilegesFromSchema(d *schema.ResourceData) []sdk.ObjectPrivilege return objectPrivileges } -func getShareGrantOn(d *schema.ResourceData) *sdk.ShareGrantOn { +func getShareGrantOn(d *schema.ResourceData) (*sdk.ShareGrantOn, error) { grantOn := new(sdk.ShareGrantOn) databaseName, databaseNameOk := d.GetOk("on_database") schemaName, schemaNameOk := d.GetOk("on_schema") - // TODO(SNOW-990811) functionName, functionNameOk := d.GetOk("on_function") + functionName, functionNameOk := d.GetOk("on_function") tableName, tableNameOk := d.GetOk("on_table") allTablesInSchema, allTablesInSchemaOk := d.GetOk("on_all_tables_in_schema") tagName, tagNameOk := d.GetOk("on_tag") @@ -440,8 +456,12 @@ func getShareGrantOn(d *schema.ResourceData) *sdk.ShareGrantOn { grantOn.Database = sdk.NewAccountObjectIdentifierFromFullyQualifiedName(databaseName.(string)) case len(schemaName.(string)) > 0 && schemaNameOk: grantOn.Schema = sdk.NewDatabaseObjectIdentifierFromFullyQualifiedName(schemaName.(string)) - // TODO(SNOW-990811) case len(functionName.(string)) > 0 && functionNameOk: - // grantOn.Function = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(functionName.(string)) + case len(functionName.(string)) > 0 && functionNameOk: + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(functionName.(string)) + if err != nil { + return nil, err + } + grantOn.Function = id case len(tableName.(string)) > 0 && tableNameOk: grantOn.Table = &sdk.OnTable{ Name: sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(tableName.(string)), @@ -456,7 +476,7 @@ func getShareGrantOn(d *schema.ResourceData) *sdk.ShareGrantOn { grantOn.View = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(viewName.(string)) } - return grantOn + return grantOn, nil } func prepareShowGrantsRequestForShare(id GrantPrivilegesToShareId) (*sdk.ShowGrantOptions, sdk.ObjectType) { @@ -477,6 +497,8 @@ func prepareShowGrantsRequestForShare(id GrantPrivilegesToShareId) (*sdk.ShowGra objectType = sdk.ObjectTypeTag case OnViewShareGrantKind: objectType = sdk.ObjectTypeView + case OnFunctionShareGrantKind: + objectType = sdk.ObjectTypeFunction } opts.On = &sdk.ShowGrantsOn{ diff --git a/pkg/resources/grant_privileges_to_share_acceptance_test.go b/pkg/resources/grant_privileges_to_share_acceptance_test.go index cefa39629d..4833733c36 100644 --- a/pkg/resources/grant_privileges_to_share_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_share_acceptance_test.go @@ -1,6 +1,7 @@ package resources_test import ( + "fmt" "regexp" "testing" @@ -116,8 +117,6 @@ func TestAcc_GrantPrivilegesToShare_OnSchema(t *testing.T) { }) } -// TODO(SNOW-1021686): Add on_function test - func TestAcc_GrantPrivilegesToShare_OnTable(t *testing.T) { databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() schemaId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) @@ -338,6 +337,102 @@ func TestAcc_GrantPrivilegesToShare_OnTag(t *testing.T) { }) } +func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunction(t *testing.T) { + acc.TestAccPreCheck(t) + + share, shareCleanup := acc.TestClient().Share.CreateShare(t) + t.Cleanup(shareCleanup) + function, functionCleanup := acc.TestClient().Function.CreateFunction(t) + t.Cleanup(functionCleanup) + configVariables := config.Variables{ + "name": config.StringVariable(share.ID().Name()), + "function_name": config.StringVariable(function.ID().Name()), + "privileges": config.ListVariable( + config.StringVariable(string(sdk.SchemaObjectPrivilegeUsage)), + ), + "database": config.StringVariable(acc.TestDatabaseName), + "schema": config.StringVariable(acc.TestSchemaName), + "argument_type": config.StringVariable(string(sdk.DataTypeFloat)), + } + resourceName := "snowflake_grant_privileges_to_share.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), + Steps: []resource.TestStep{ + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnFunction"), + ConfigVariables: configVariables, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "to_share", share.ID().Name()), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.SchemaObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "on_function", function.ID().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|USAGE|OnFunction|%s", share.ID().FullyQualifiedName(), function.ID().FullyQualifiedName())), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnFunction"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunctionWithoutArguments(t *testing.T) { + acc.TestAccPreCheck(t) + + share, shareCleanup := acc.TestClient().Share.CreateShare(t) + t.Cleanup(shareCleanup) + function, functionCleanup := acc.TestClient().Function.CreateFunctionWithoutArguments(t) + t.Cleanup(functionCleanup) + configVariables := config.Variables{ + "name": config.StringVariable(share.ID().Name()), + "function_name": config.StringVariable(function.ID().Name()), + "privileges": config.ListVariable( + config.StringVariable(string(sdk.SchemaObjectPrivilegeUsage)), + ), + "database": config.StringVariable(acc.TestDatabaseName), + "schema": config.StringVariable(acc.TestSchemaName), + "argument_type": config.StringVariable(""), + } + resourceName := "snowflake_grant_privileges_to_share.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: acc.CheckAccountRolePrivilegesRevoked(t), + Steps: []resource.TestStep{ + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnFunction"), + ConfigVariables: configVariables, + Check: resource.ComposeAggregateTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "to_share", share.ID().Name()), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", string(sdk.SchemaObjectPrivilegeUsage)), + resource.TestCheckResourceAttr(resourceName, "on_function", function.ID().FullyQualifiedName()), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("%s|USAGE|OnFunction|%s", share.ID().FullyQualifiedName(), function.ID().FullyQualifiedName())), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnFunction"), + ConfigVariables: configVariables, + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func TestAcc_GrantPrivilegesToShare_OnPrivilegeUpdate(t *testing.T) { databaseName := acc.TestClient().Ids.RandomAccountObjectIdentifier() shareName := acc.TestClient().Ids.RandomAccountObjectIdentifier() diff --git a/pkg/resources/grant_privileges_to_share_identifier.go b/pkg/resources/grant_privileges_to_share_identifier.go index 9996c38add..f2422c97ac 100644 --- a/pkg/resources/grant_privileges_to_share_identifier.go +++ b/pkg/resources/grant_privileges_to_share_identifier.go @@ -11,10 +11,9 @@ import ( type ShareGrantKind string const ( - OnDatabaseShareGrantKind ShareGrantKind = "OnDatabase" - OnSchemaShareGrantKind ShareGrantKind = "OnSchema" - // TODO(SNOW-1021686): Because function identifier contains arguments which are not supported right now - // OnFunctionShareGrantKind ShareGrantKind = "OnFunction" + OnDatabaseShareGrantKind ShareGrantKind = "OnDatabase" + OnSchemaShareGrantKind ShareGrantKind = "OnSchema" + OnFunctionShareGrantKind ShareGrantKind = "OnFunction" OnTableShareGrantKind ShareGrantKind = "OnTable" OnAllTablesInSchemaShareGrantKind ShareGrantKind = "OnAllTablesInSchema" OnTagShareGrantKind ShareGrantKind = "OnTag" @@ -53,42 +52,31 @@ func ParseGrantPrivilegesToShareId(idString string) (GrantPrivilegesToShareId, e grantPrivilegesToShareId.Privileges = privileges grantPrivilegesToShareId.Kind = ShareGrantKind(parts[2]) - id, err := helpers.DecodeSnowflakeParameterID(parts[3]) - if err != nil { - return grantPrivilegesToShareId, err - } - switch grantPrivilegesToShareId.Kind { case OnDatabaseShareGrantKind: - if typedIdentifier, ok := id.(sdk.AccountObjectIdentifier); ok { - grantPrivilegesToShareId.Identifier = typedIdentifier - } else { - return grantPrivilegesToShareId, fmt.Errorf( - "invalid identifier, expected fully qualified name of account object: %s, but instead got: %s", - getExpectedIdentifierRepresentationFromGeneric[sdk.AccountObjectIdentifier](), - getExpectedIdentifierRepresentationFromParam(id), - ) + id, err := sdk.ParseAccountObjectIdentifier(parts[3]) + if err != nil { + return grantPrivilegesToShareId, sdk.NewError(fmt.Sprintf("invalid identifier, expected fully qualified name of database object%s: ", parts[3]), err) } + grantPrivilegesToShareId.Identifier = id case OnSchemaShareGrantKind, OnAllTablesInSchemaShareGrantKind: - if typedIdentifier, ok := id.(sdk.DatabaseObjectIdentifier); ok { - grantPrivilegesToShareId.Identifier = typedIdentifier - } else { - return grantPrivilegesToShareId, fmt.Errorf( - "invalid identifier, expected fully qualified name of database object: %s, but instead got: %s", - getExpectedIdentifierRepresentationFromGeneric[sdk.DatabaseObjectIdentifier](), - getExpectedIdentifierRepresentationFromParam(id), - ) + id, err := sdk.ParseDatabaseObjectIdentifier(parts[3]) + if err != nil { + return grantPrivilegesToShareId, sdk.NewError(fmt.Sprintf("could not parse database object identifier %s: ", parts[3]), err) + } + grantPrivilegesToShareId.Identifier = id + case OnTableShareGrantKind, OnViewShareGrantKind, OnTagShareGrantKind: + id, err := sdk.ParseSchemaObjectIdentifier(parts[3]) + if err != nil { + return grantPrivilegesToShareId, sdk.NewError(fmt.Sprintf("could not parse schema object identifier %s: ", parts[3]), err) } - case OnTableShareGrantKind, OnViewShareGrantKind, OnTagShareGrantKind: // TODO(SNOW-1021686) , OnFunctionShareGrantKind: - if typedIdentifier, ok := id.(sdk.SchemaObjectIdentifier); ok { - grantPrivilegesToShareId.Identifier = typedIdentifier - } else { - return grantPrivilegesToShareId, fmt.Errorf( - "invalid identifier, expected fully qualified name of schema object: %s, but instead got: %s", - getExpectedIdentifierRepresentationFromGeneric[sdk.SchemaObjectIdentifier](), - getExpectedIdentifierRepresentationFromParam(id), - ) + grantPrivilegesToShareId.Identifier = id + case OnFunctionShareGrantKind: + id, err := sdk.ParseSchemaObjectIdentifierWithArguments(parts[3]) + if err != nil { + return grantPrivilegesToShareId, sdk.NewError(fmt.Sprintf("could not parse schema object identifier with arguments %s: ", parts[3]), err) } + grantPrivilegesToShareId.Identifier = id default: return grantPrivilegesToShareId, fmt.Errorf("unexpected share grant kind: %v", grantPrivilegesToShareId.Kind) } diff --git a/pkg/resources/grant_privileges_to_share_identifier_test.go b/pkg/resources/grant_privileges_to_share_identifier_test.go index d892d2800d..2c64f3f847 100644 --- a/pkg/resources/grant_privileges_to_share_identifier_test.go +++ b/pkg/resources/grant_privileges_to_share_identifier_test.go @@ -34,17 +34,26 @@ func TestParseGrantPrivilegesToShareId(t *testing.T) { Identifier: sdk.NewDatabaseObjectIdentifier("on-database-name", "on-schema-name"), }, }, - // TODO(SNOW-1021686): This is wrong and should be fixed (function's last part of identifier cannot be enclosed with quotes like that) - //{ - // Name: "grant privileges on function to share", - // Identifier: `"share-name"|USAGE|OnFunction|"on-database-name"."on-schema-name".on-function-name(INT, VARCHAR)`, - // Expected: GrantPrivilegesToShareId{ - // ShareName: sdk.NewExternalObjectIdentifierFromFullyQualifiedName("share-name"), - // Privileges: []string{"USAGE"}, - // Kind: OnFunctionShareGrantKind, - // Identifier: sdk.NewSchemaObjectIdentifier("on-database-name", "on-schema-name", "on-function-name(INT, VARCHAR)"), - // }, - // }, + { + Name: "grant privileges on function to share", + Identifier: `"share-name"|USAGE|OnFunction|"on-database-name"."on-schema-name".on-function-name(INT, VARCHAR)`, + Expected: GrantPrivilegesToShareId{ + ShareName: sdk.NewAccountObjectIdentifier("share-name"), + Privileges: []string{"USAGE"}, + Kind: OnFunctionShareGrantKind, + Identifier: sdk.NewSchemaObjectIdentifierWithArguments("on-database-name", "on-schema-name", "on-function-name", sdk.DataTypeInt, sdk.DataTypeVARCHAR), + }, + }, + { + Name: "grant privileges on function without arguments to share", + Identifier: `"share-name"|READ|OnFunction|"on-database-name"."on-schema-name"."on-view-name"()`, + Expected: GrantPrivilegesToShareId{ + ShareName: sdk.NewAccountObjectIdentifier("share-name"), + Privileges: []string{"READ"}, + Kind: OnFunctionShareGrantKind, + Identifier: sdk.NewSchemaObjectIdentifierWithArguments("on-database-name", "on-schema-name", "on-view-name", []sdk.DataType{}...), + }, + }, { Name: "grant privileges on table to share", Identifier: `"share-name"|EVOLVE SCHEMA|OnTable|"on-database-name"."on-schema-name"."on-table-name"`, @@ -103,22 +112,22 @@ func TestParseGrantPrivilegesToShareId(t *testing.T) { { Name: "validation: invalid identifier", Identifier: `"share-name"|SELECT|OnDatabase|one.two.three.four.five.six.seven.eight.nine.ten`, - Error: `unable to classify identifier: one.two.three.four.five.six.seven.eight.nine.ten`, + Error: `unexpected number of parts 10 in identifier one.two.three.four.five.six.seven.eight.nine.ten, expected 1 in a form of ""`, }, { Name: "validation: invalid account object identifier", Identifier: `"share-name"|SELECT|OnDatabase|one.two`, - Error: `invalid identifier, expected fully qualified name of account object: , but instead got: .`, + Error: `unexpected number of parts 2 in identifier one.two, expected 1 in a form of ""`, }, { Name: "validation: invalid database object identifier", Identifier: `"share-name"|SELECT|OnSchema|one.two.three`, - Error: `invalid identifier, expected fully qualified name of database object: ., but instead got: ..`, + Error: `unexpected number of parts 3 in identifier one.two.three, expected 2 in a form of ".`, }, { Name: "validation: invalid schema object identifier", Identifier: `"share-name"|SELECT|OnTable|one`, - Error: `invalid identifier, expected fully qualified name of schema object: .., but instead got: `, + Error: `unexpected number of parts 1 in identifier one, expected 3 in a form of ".."`, }, } @@ -163,17 +172,6 @@ func TestGrantPrivilegesToShareIdString(t *testing.T) { }, Expected: `"share-name"|USAGE|OnSchema|"database-name"."schema-name"`, }, - // TODO(SNOW-1021686): This is wrong and should be fixed (function's last part of identifier cannot be enclosed with quotes like that) - //{ - // Name: "grant privileges on function to share", - // Identifier: GrantPrivilegesToShareId{ - // ShareName: sdk.NewExternalObjectIdentifierFromFullyQualifiedName("share-name"), - // Privileges: []string{"USAGE"}, - // Kind: OnFunctionShareGrantKind, - // Identifier: sdk.NewSchemaObjectIdentifier("database-name", "schema-name", "function-name(INT, VARCHAR)"), - // }, - // Expected: `"share-name"|USAGE|OnFunction|"database-name"."schema-name".\"function-name(INT, VARCHAR)\"`, - // }, { Name: "grant privileges on table to share", Identifier: GrantPrivilegesToShareId{ diff --git a/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf b/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf new file mode 100644 index 0000000000..4022ad0b83 --- /dev/null +++ b/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf @@ -0,0 +1,39 @@ +resource "snowflake_account_role" "test" { + name = var.account_role_name +} + +resource "snowflake_database" "test" { + name = var.database_name +} + +resource "snowflake_schema" "test" { + name = var.schema_name + database = snowflake_database.test.name +} + +resource "snowflake_procedure" "test" { + name = var.procedure_name + database = snowflake_database.test.name + schema = snowflake_schema.test.name + language = "JAVASCRIPT" + arguments { + name = "ARG1" + type = "FLOAT" + } + return_type = "FLOAT" + execute_as = "CALLER" + return_behavior = "VOLATILE" + null_input_behavior = "RETURNS NULL ON NULL INPUT" + statement = <.."( ...):" +// Return type is not part of an identifier. +// TODO(SNOW-1625030): Remove and use ParseSchemaObjectIdentifierWithArguments instead +func ParseSchemaObjectIdentifierWithArgumentsAndReturnType(fullyQualifiedName string) (SchemaObjectIdentifierWithArguments, error) { + parts, err := ParseIdentifierStringWithOpts(fullyQualifiedName, func(r *csv.Reader) { + r.Comma = IdDelimiter + }) + if err != nil { + return SchemaObjectIdentifierWithArguments{}, err + } + functionHeader := parts[2] + leftParenthesisIndex := strings.IndexRune(functionHeader, '(') + functionName := functionHeader[:leftParenthesisIndex] + argsRaw := functionHeader[leftParenthesisIndex:] + returnTypeIndex := strings.LastIndex(argsRaw, ":") + if returnTypeIndex != -1 { + argsRaw = argsRaw[:returnTypeIndex] + } + dataTypes, err := ParseFunctionArgumentsFromString(argsRaw) + if err != nil { + return SchemaObjectIdentifierWithArguments{}, err + } + return NewSchemaObjectIdentifierWithArguments( + parts[0], + parts[1], + functionName, + dataTypes..., + ), nil +} + // ParseFunctionArgumentsFromString parses function argument from arguments string with optional argument names. // Varying types are not supported (e.g. VARCHAR(200)), because Snowflake outputs them in a shortened version // (VARCHAR in this case). The only exception is newly added type VECTOR that has the following structure @@ -239,10 +269,11 @@ func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { } dataTypes = append(dataTypes, DataType(vectorDataType)) default: - dataType, err := stringBuffer.ReadString(',') + argument, err := stringBuffer.ReadString(',') if err == nil { - dataType = dataType[:len(dataType)-1] + argument = argument[:len(argument)-1] } + dataType := argument[strings.IndexRune(argument, ' ')+1:] dataTypes = append(dataTypes, DataType(dataType)) } } diff --git a/pkg/sdk/identifier_parsers_test.go b/pkg/sdk/identifier_parsers_test.go index cf67c499ec..ab0887dad1 100644 --- a/pkg/sdk/identifier_parsers_test.go +++ b/pkg/sdk/identifier_parsers_test.go @@ -301,6 +301,7 @@ func Test_ParseFunctionArgumentsFromString(t *testing.T) { {Arguments: `(FLOAT, NUMBER(10, 2), TIME)`, Expected: []DataType{DataTypeFloat, DataType("NUMBER(10"), DataType("2)"), DataTypeTime}}, {Arguments: `(FLOAT, NUMBER(10, 2))`, Expected: []DataType{DataTypeFloat, DataType("NUMBER(10"), DataType("2)")}}, {Arguments: `(NUMBER(10, 2), FLOAT)`, Expected: []DataType{DataType("NUMBER(10"), DataType("2)"), DataTypeFloat}}, + {Arguments: `(ab NUMBER(10, 2), x FLOAT, FLOAT)`, Expected: []DataType{DataType("NUMBER(10"), DataType("2)"), DataTypeFloat, DataTypeFloat}}, } for _, testCase := range testCases { @@ -371,3 +372,30 @@ func TestNewSchemaObjectIdentifierWithArgumentsFromFullyQualifiedName_WithRawInp }) } } + +func TestNewSchemaObjectIdentifierWithArgumentsAndReturnTypeFromFullyQualifiedName_WithRawInput(t *testing.T) { + testCases := []struct { + RawInput string + ExpectedIdentifierStructure SchemaObjectIdentifierWithArguments + Error string + }{ + {RawInput: `abc.def.ghi()`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, + {RawInput: `abc.def.ghi(FLOAT, VECTOR(INT, 20))`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + {RawInput: `abc.def.ghi():FLOAT`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, + {RawInput: `abc.def."ghi(FLOAT, VECTOR(INT, 20)):NUMBER(10,2)"`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + {RawInput: `abc.def."ghi(FLOAT, VECTOR(INT, 20)):NUMBER"`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + } + + for _, testCase := range testCases { + t.Run(fmt.Sprintf("processing %s", testCase.ExpectedIdentifierStructure.FullyQualifiedName()), func(t *testing.T) { + id, err := ParseSchemaObjectIdentifierWithArgumentsAndReturnType(testCase.RawInput) + + if testCase.Error != "" { + assert.ErrorContains(t, err, testCase.Error) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.ExpectedIdentifierStructure.FullyQualifiedName(), id.FullyQualifiedName()) + } + }) + } +} diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index 112ad3b7ec..2c85485ba0 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -842,10 +842,10 @@ func TestInt_GrantAndRevokePrivilegesToDatabaseRole(t *testing.T) { func TestInt_GrantPrivilegeToShare(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) + shareTest, shareCleanup := testClientHelper().Share.CreateShareWithIdentifier(t, testClientHelper().Ids.RandomAccountObjectIdentifierContaining(".foo.bar")) t.Cleanup(shareCleanup) - assertGrant := func(t *testing.T, grants []sdk.Grant, onId sdk.ObjectIdentifier, privilege sdk.ObjectPrivilege) { + assertGrant := func(t *testing.T, grants []sdk.Grant, onId sdk.ObjectIdentifier, privilege sdk.ObjectPrivilege, grantedOn sdk.ObjectType, granteeName sdk.ObjectIdentifier) { t.Helper() var shareGrant *sdk.Grant for i, grant := range grants { @@ -855,8 +855,9 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { } } assert.NotNil(t, shareGrant) - assert.Equal(t, sdk.ObjectTypeTable, shareGrant.GrantedOn) + assert.Equal(t, grantedOn, shareGrant.GrantedOn) assert.Equal(t, sdk.ObjectTypeShare, shareGrant.GrantedTo) + assert.Equal(t, granteeName.FullyQualifiedName(), shareGrant.GranteeName.FullyQualifiedName()) assert.Equal(t, onId.FullyQualifiedName(), shareGrant.Name.FullyQualifiedName()) } @@ -892,7 +893,35 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }, }) require.NoError(t, err) - assertGrant(t, grants, table.ID(), sdk.ObjectPrivilegeSelect) + assertGrant(t, grants, table.ID(), sdk.ObjectPrivilegeSelect, sdk.ObjectTypeTable, shareTest.ID()) + + _, err = client.Grants.Show(ctx, &sdk.ShowGrantOptions{ + To: &sdk.ShowGrantsTo{ + Share: &sdk.ShowGrantsToShare{ + Name: shareTest.ID(), + }, + }, + }) + require.NoError(t, err) + + function, functionCleanup := testClientHelper().Function.CreateFunction(t) + t.Cleanup(functionCleanup) + + err = client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ + Function: function.ID(), + }, shareTest.ID()) + require.NoError(t, err) + + grants, err = client.Grants.Show(ctx, &sdk.ShowGrantOptions{ + On: &sdk.ShowGrantsOn{ + Object: &sdk.Object{ + ObjectType: sdk.ObjectTypeFunction, + Name: function.ID(), + }, + }, + }) + require.NoError(t, err) + assertGrant(t, grants, function.ID(), sdk.ObjectPrivilegeUsage, sdk.ObjectTypeFunction, shareTest.ID()) _, err = client.Grants.Show(ctx, &sdk.ShowGrantOptions{ To: &sdk.ShowGrantsTo{ diff --git a/v1-preparations/ESSENTIAL_GA_OBJECTS.MD b/v1-preparations/ESSENTIAL_GA_OBJECTS.MD index b2536e25f9..ac0cb168e4 100644 --- a/v1-preparations/ESSENTIAL_GA_OBJECTS.MD +++ b/v1-preparations/ESSENTIAL_GA_OBJECTS.MD @@ -28,12 +28,12 @@ newer provider versions. We will address these while working on the given object | PROCEDURE | ❌ | [#2735](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2735), [#2623](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2623), [#2257](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2257), [#2146](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2146), [#1855](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1855), [#1695](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1695), [#1640](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1640), [#1195](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1195), [#1189](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1189), [#1178](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1178), [#1050](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1050) | | ROW ACCESS POLICY | ❌ | [#2053](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2053), [#1600](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1600), [#1151](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1151) | | SCHEMA | 🚀 | [#2826](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2826), [#2211](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2211), [#1243](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1243), [#506](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/506) | -| STAGE | ❌ | [#2818](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2818), [#2505](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2505), [#1911](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1911), [#1903](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1903), [#1795](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1795), [#1705](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1705), [#1544](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1544), [#1491](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1491), [#1087](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1087), [#265](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/265) | +| STAGE | ❌ | [#2995](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2995), [#2818](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2818), [#2505](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2505), [#1911](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1911), [#1903](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1903), [#1795](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1795), [#1705](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1705), [#1544](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1544), [#1491](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1491), [#1087](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1087), [#265](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/265) | | STREAM | ❌ | [#2975](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2975), [#2413](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2413), [#2201](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2201), [#1150](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1150) | | STREAMLIT | 🚀 | [#1933](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1933) | -| TABLE | ❌ | [#2844](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2844), [#2839](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2839), [#2735](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2735), [#2733](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2733), [#2683](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2683), [#2676](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2676), [#2674](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2674), [#2629](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2629), [#2418](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2418), [#2415](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2415), [#2406](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2406), [#2236](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2236), [#2035](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2035), [#1823](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1823), [#1799](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1799), [#1764](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1764), [#1600](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1600), [#1387](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1387), [#1272](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1272), [#1271](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1271), [#1248](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1248), [#1241](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1241), [#1146](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1146), [#1032](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1032), [#420](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/420) | +| TABLE | ❌ | [#2997](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2997), [#2844](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2844), [#2839](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2839), [#2735](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2735), [#2733](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2733), [#2683](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2683), [#2676](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2676), [#2674](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2674), [#2629](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2629), [#2418](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2418), [#2415](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2415), [#2406](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2406), [#2236](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2236), [#2035](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2035), [#1823](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1823), [#1799](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1799), [#1764](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1764), [#1600](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1600), [#1387](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1387), [#1272](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1272), [#1271](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1271), [#1248](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1248), [#1241](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1241), [#1146](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1146), [#1032](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1032), [#420](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/420) | | TAG | ❌ | [#2943](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2902), [#2598](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2598), [#1910](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1910), [#1909](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1909), [#1862](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1862), [#1806](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1806), [#1657](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1657), [#1496](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1496), [#1443](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1443), [#1394](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1394), [#1372](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1372), [#1074](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1074) | | TASK | ❌ | [#1419](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1419), [#1250](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1250), [#1194](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1194), [#1088](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1088) | -| VIEW | 👨‍💻 | [#2430](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2430), [#2085](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2085), [#2055](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2055), [#2031](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2031), [#1526](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1526), [#1253](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1253), [#1049](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1049) | +| VIEW | 👨‍💻 | [#3000](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/3000), [#2430](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2430), [#2085](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2085), [#2055](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2055), [#2031](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2031), [#1526](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1526), [#1253](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1253), [#1049](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/1049) | | snowflake_unsafe_execute | ❌ | [#2934](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2934) | From 8544eefa2beb5188b2499e79c9a124a11b865f11 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Fri, 16 Aug 2024 15:00:27 +0200 Subject: [PATCH 12/19] Cleanup --- pkg/acceptance/helpers/function_client.go | 51 +------------------ ...vileges_to_account_role_acceptance_test.go | 6 +-- ...ileges_to_database_role_acceptance_test.go | 6 +-- ...ant_privileges_to_share_acceptance_test.go | 6 +-- 4 files changed, 7 insertions(+), 62 deletions(-) diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index 742640d1d5..05327f633c 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -41,7 +41,7 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI id.SchemaObjectId(), *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), "SELECT 1", - ).WithArguments(argumentRequests), + ).WithArguments(argumentRequests).WithSecure(true), ) require.NoError(t, err) @@ -54,52 +54,3 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI return function } - -func (c *FunctionClient) CreateFunction(t *testing.T) (*sdk.Function, func()) { - t.Helper() - definition := "3.141592654::FLOAT" - id := c.ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeFloat) - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) - return c.CreateFunctionWithRequest(t, id, - sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). - WithSecure(true). - WithArguments([]sdk.FunctionArgumentRequest{*argument}), - ) -} - -func (c *FunctionClient) CreateFunctionWithoutArguments(t *testing.T) (*sdk.Function, func()) { - t.Helper() - definition := "3.141592654::FLOAT" - id := c.ids.RandomSchemaObjectIdentifierWithArguments() - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) - returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - return c.CreateFunctionWithRequest(t, id, - sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). - WithSecure(true), - ) -} - -func (c *FunctionClient) CreateFunctionWithRequest(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments, request *sdk.CreateForSQLFunctionRequest) (*sdk.Function, func()) { - t.Helper() - ctx := context.Background() - - err := c.client().CreateForSQL(ctx, request) - require.NoError(t, err) - - Function, err := c.client().ShowByID(ctx, id) - require.NoError(t, err) - - return Function, c.DropFunctionFunc(t, id) -} - -func (c *FunctionClient) DropFunctionFunc(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) func() { - t.Helper() - ctx := context.Background() - - return func() { - err := c.client().Drop(ctx, sdk.NewDropFunctionRequest(id).WithIfExists(true)) - require.NoError(t, err) - } -} diff --git a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go index 5f8a21e782..2a938f75ae 100644 --- a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go @@ -450,8 +450,7 @@ func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunction(t *testing.T roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() roleFullyQualifiedName := roleId.FullyQualifiedName() - function, functionCleanup := acc.TestClient().Function.CreateFunction(t) - t.Cleanup(functionCleanup) + function := acc.TestClient().Function.Create(t, sdk.DataTypeFloat) configVariables := config.Variables{ "name": config.StringVariable(roleFullyQualifiedName), "function_name": config.StringVariable(function.ID().Name()), @@ -506,8 +505,7 @@ func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunctionWithoutArgume roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() roleFullyQualifiedName := roleId.FullyQualifiedName() - function, functionCleanup := acc.TestClient().Function.CreateFunctionWithoutArguments(t) - t.Cleanup(functionCleanup) + function := acc.TestClient().Function.Create(t) configVariables := config.Variables{ "name": config.StringVariable(roleFullyQualifiedName), "function_name": config.StringVariable(function.ID().Name()), diff --git a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go index e2c49c0cf6..897b42090f 100644 --- a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go @@ -629,8 +629,7 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunction(t *testing. acc.TestAccPreCheck(t) databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() - function, functionCleanup := acc.TestClient().Function.CreateFunction(t) - t.Cleanup(functionCleanup) + function := acc.TestClient().Function.Create(t, sdk.DataTypeFloat) configVariables := config.Variables{ "name": config.StringVariable(databaseRoleId.FullyQualifiedName()), "function_name": config.StringVariable(function.ID().Name()), @@ -684,8 +683,7 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunctionWithoutArgum acc.TestAccPreCheck(t) databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() - function, functionCleanup := acc.TestClient().Function.CreateFunctionWithoutArguments(t) - t.Cleanup(functionCleanup) + function := acc.TestClient().Function.Create(t) configVariables := config.Variables{ "name": config.StringVariable(databaseRoleId.FullyQualifiedName()), "function_name": config.StringVariable(function.ID().Name()), diff --git a/pkg/resources/grant_privileges_to_share_acceptance_test.go b/pkg/resources/grant_privileges_to_share_acceptance_test.go index 4833733c36..a0e494a65f 100644 --- a/pkg/resources/grant_privileges_to_share_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_share_acceptance_test.go @@ -342,8 +342,7 @@ func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunction(t *testing.T) { share, shareCleanup := acc.TestClient().Share.CreateShare(t) t.Cleanup(shareCleanup) - function, functionCleanup := acc.TestClient().Function.CreateFunction(t) - t.Cleanup(functionCleanup) + function := acc.TestClient().Function.Create(t, sdk.DataTypeFloat) configVariables := config.Variables{ "name": config.StringVariable(share.ID().Name()), "function_name": config.StringVariable(function.ID().Name()), @@ -390,8 +389,7 @@ func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunctionWithoutArguments(t share, shareCleanup := acc.TestClient().Share.CreateShare(t) t.Cleanup(shareCleanup) - function, functionCleanup := acc.TestClient().Function.CreateFunctionWithoutArguments(t) - t.Cleanup(functionCleanup) + function := acc.TestClient().Function.Create(t) configVariables := config.Variables{ "name": config.StringVariable(share.ID().Name()), "function_name": config.StringVariable(function.ID().Name()), From 2ff65e8a75926148b469dd5eddd1a7a22b9c25b7 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Fri, 16 Aug 2024 15:36:19 +0200 Subject: [PATCH 13/19] Cleanup --- docs/resources/grant_privileges_to_share.md | 1 + pkg/resources/grant_privileges_to_account_role.go | 2 +- pkg/resources/grant_privileges_to_account_role_identifier.go | 3 +-- pkg/resources/grant_privileges_to_database_role.go | 2 +- pkg/resources/grant_privileges_to_database_role_identifier.go | 3 +-- .../OnObject_Procedure_ToAccountRole/test.tf | 2 +- .../OnObject_Procedure_ToDatabaseRole/test.tf | 2 +- pkg/sdk/grants.go | 4 +--- pkg/sdk/object_types.go | 4 ++++ pkg/sdk/testint/grants_integration_test.go | 3 +-- 10 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/resources/grant_privileges_to_share.md b/docs/resources/grant_privileges_to_share.md index cece122b27..10a66027cb 100644 --- a/docs/resources/grant_privileges_to_share.md +++ b/docs/resources/grant_privileges_to_share.md @@ -108,6 +108,7 @@ resource "snowflake_grant_privileges_to_share" "example" { - `on_all_tables_in_schema` (String) The fully qualified identifier for the schema for which the specified privilege will be granted for all tables. - `on_database` (String) The fully qualified name of the database on which privileges will be granted. +- `on_function` (String) The fully qualified name of the function on which privileges will be granted. - `on_schema` (String) The fully qualified name of the schema on which privileges will be granted. - `on_table` (String) The fully qualified name of the table on which privileges will be granted. - `on_tag` (String) The fully qualified name of the tag on which privileges will be granted. diff --git a/pkg/resources/grant_privileges_to_account_role.go b/pkg/resources/grant_privileges_to_account_role.go index f014746ccb..ae00859732 100644 --- a/pkg/resources/grant_privileges_to_account_role.go +++ b/pkg/resources/grant_privileges_to_account_role.go @@ -1076,7 +1076,7 @@ func getAccountRoleGrantOn(d *schema.ResourceData) (*sdk.AccountRoleGrantOn, err case objectTypeOk && objectNameOk: objectType := sdk.ObjectType(objectType) var id sdk.ObjectIdentifier - if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(objectName) if err != nil { diff --git a/pkg/resources/grant_privileges_to_account_role_identifier.go b/pkg/resources/grant_privileges_to_account_role_identifier.go index 6612f5b3c1..6c16423c33 100644 --- a/pkg/resources/grant_privileges_to_account_role_identifier.go +++ b/pkg/resources/grant_privileges_to_account_role_identifier.go @@ -2,7 +2,6 @@ package resources import ( "fmt" - "slices" "strconv" "strings" @@ -140,7 +139,7 @@ func ParseGrantPrivilegesToAccountRoleId(id string) (GrantPrivilegesToAccountRol } objectType := sdk.ObjectType(parts[6]) var id sdk.ObjectIdentifier - if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(parts[7]) if err != nil { diff --git a/pkg/resources/grant_privileges_to_database_role.go b/pkg/resources/grant_privileges_to_database_role.go index 07f1dd8788..34052cf66d 100644 --- a/pkg/resources/grant_privileges_to_database_role.go +++ b/pkg/resources/grant_privileges_to_database_role.go @@ -905,7 +905,7 @@ func getDatabaseRoleGrantOn(d *schema.ResourceData) (*sdk.DatabaseRoleGrantOn, e case objectTypeOk && objectNameOk: objectType := sdk.ObjectType(objectType) var id sdk.ObjectIdentifier - if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(objectName) if err != nil { diff --git a/pkg/resources/grant_privileges_to_database_role_identifier.go b/pkg/resources/grant_privileges_to_database_role_identifier.go index 4c45cfa5b9..1b7db1c86a 100644 --- a/pkg/resources/grant_privileges_to_database_role_identifier.go +++ b/pkg/resources/grant_privileges_to_database_role_identifier.go @@ -2,7 +2,6 @@ package resources import ( "fmt" - "slices" "strconv" "strings" @@ -130,7 +129,7 @@ func ParseGrantPrivilegesToDatabaseRoleId(id string) (GrantPrivilegesToDatabaseR } objectType := sdk.ObjectType(parts[6]) var id sdk.ObjectIdentifier - if slices.Contains([]sdk.ObjectType{sdk.ObjectTypeFunction, sdk.ObjectTypeProcedure, sdk.ObjectTypeExternalFunction}, objectType) { + if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(parts[7]) if err != nil { diff --git a/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf b/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf index 4022ad0b83..d33faa2af3 100644 --- a/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf +++ b/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole/test.tf @@ -34,6 +34,6 @@ resource "snowflake_grant_ownership" "test" { account_role_name = snowflake_account_role.test.name on { object_type = "PROCEDURE" - object_name = "\"${snowflake_database.test.name}\".\"${snowflake_schema.test.name}\".\"${snowflake_procedure.test.name}\"(FLOAT)" + object_name = snowflake_procedure.test.fully_qualified_name } } diff --git a/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole/test.tf b/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole/test.tf index 91a696da17..bd1ae0af07 100644 --- a/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole/test.tf +++ b/pkg/resources/testdata/TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole/test.tf @@ -31,6 +31,6 @@ resource "snowflake_grant_ownership" "test" { database_role_name = "\"${snowflake_database.test.name}\".\"${snowflake_database_role.test.name}\"" on { object_type = "PROCEDURE" - object_name = "\"${snowflake_database.test.name}\".\"${snowflake_schema.test.name}\".\"${snowflake_procedure.test.name}\"()" + object_name = snowflake_procedure.test.fully_qualified_name } } diff --git a/pkg/sdk/grants.go b/pkg/sdk/grants.go index 138d2ea5ec..e4128a4d27 100644 --- a/pkg/sdk/grants.go +++ b/pkg/sdk/grants.go @@ -3,7 +3,6 @@ package sdk import ( "context" "log" - "slices" "strings" "time" ) @@ -262,8 +261,7 @@ func (row grantRow) convert() *Grant { var name ObjectIdentifier var err error - // external function is represented as FUNCTION - if slices.Contains([]string{"FUNCTION", "PROCEDURE"}, row.GrantedOn) { + if ObjectType(row.GrantedOn).IsWithArguments() { name, err = ParseSchemaObjectIdentifierWithArgumentsAndReturnType(row.Name) } else { name, err = ParseObjectIdentifierString(row.Name) diff --git a/pkg/sdk/object_types.go b/pkg/sdk/object_types.go index e22da6e9b2..d54b3337fd 100644 --- a/pkg/sdk/object_types.go +++ b/pkg/sdk/object_types.go @@ -81,6 +81,10 @@ func (o ObjectType) String() string { return string(o) } +func (o ObjectType) IsWithArguments() bool { + return slices.Contains([]ObjectType{ObjectTypeExternalFunction, ObjectTypeFunction, ObjectTypeProcedure}, o) +} + func objectTypeSingularToPluralMap() map[ObjectType]PluralObjectType { return map[ObjectType]PluralObjectType{ ObjectTypeAccount: PluralObjectTypeAccounts, diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index 2c85485ba0..33d344b226 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -904,8 +904,7 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }) require.NoError(t, err) - function, functionCleanup := testClientHelper().Function.CreateFunction(t) - t.Cleanup(functionCleanup) + function := testClientHelper().Function.Create(t) err = client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Function: function.ID(), From fddfa48bb17aac2fb95c1a91638b6798170565b9 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Fri, 16 Aug 2024 15:39:25 +0200 Subject: [PATCH 14/19] Fix error msg --- pkg/resources/grant_privileges_to_share_identifier.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/resources/grant_privileges_to_share_identifier.go b/pkg/resources/grant_privileges_to_share_identifier.go index f2422c97ac..9c7bf09c22 100644 --- a/pkg/resources/grant_privileges_to_share_identifier.go +++ b/pkg/resources/grant_privileges_to_share_identifier.go @@ -56,7 +56,7 @@ func ParseGrantPrivilegesToShareId(idString string) (GrantPrivilegesToShareId, e case OnDatabaseShareGrantKind: id, err := sdk.ParseAccountObjectIdentifier(parts[3]) if err != nil { - return grantPrivilegesToShareId, sdk.NewError(fmt.Sprintf("invalid identifier, expected fully qualified name of database object%s: ", parts[3]), err) + return grantPrivilegesToShareId, sdk.NewError(fmt.Sprintf("invalid identifier, expected fully qualified name of account object%s: ", parts[3]), err) } grantPrivilegesToShareId.Identifier = id case OnSchemaShareGrantKind, OnAllTablesInSchemaShareGrantKind: From d3585f1f119edcc11ac043c12766fdd1ab9ca944 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Mon, 19 Aug 2024 11:16:49 +0200 Subject: [PATCH 15/19] Fix tests --- pkg/resources/external_function_acceptance_test.go | 2 +- pkg/resources/procedure_acceptance_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/resources/external_function_acceptance_test.go b/pkg/resources/external_function_acceptance_test.go index 1599b6cfe7..1ce5b27cf1 100644 --- a/pkg/resources/external_function_acceptance_test.go +++ b/pkg/resources/external_function_acceptance_test.go @@ -152,7 +152,7 @@ func TestAcc_ExternalFunction_no_arguments(t *testing.T) { } func TestAcc_ExternalFunction_complete(t *testing.T) { - id := acc.TestClient().Ids.RandomSchemaObjectIdentifier() + id := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments() m := func() map[string]config.Variable { return map[string]config.Variable{ diff --git a/pkg/resources/procedure_acceptance_test.go b/pkg/resources/procedure_acceptance_test.go index 1ecb7c3b85..6d0ec93b58 100644 --- a/pkg/resources/procedure_acceptance_test.go +++ b/pkg/resources/procedure_acceptance_test.go @@ -17,8 +17,8 @@ import ( func testAccProcedure(t *testing.T, configDirectory string, args ...sdk.DataType) { t.Helper() - oldId := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(args...) - newId := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArgumentsOld(args...) + oldId := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments(args...) + newId := acc.TestClient().Ids.RandomSchemaObjectIdentifierWithArguments(args...) resourceName := "snowflake_procedure.p" m := func() map[string]config.Variable { From bfe34b402d120f6c8c0a0023d7410e51457ff905 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Mon, 19 Aug 2024 15:46:27 +0200 Subject: [PATCH 16/19] Review suggestions --- pkg/acceptance/helpers/function_client.go | 33 +++++++-- .../grant_ownership_acceptance_test.go | 31 +++----- pkg/resources/grant_ownership_test.go | 10 +-- ...vileges_to_account_role_acceptance_test.go | 9 ++- ...ileges_to_database_role_acceptance_test.go | 9 ++- ...ant_privileges_to_share_acceptance_test.go | 64 +++++++++++++++- pkg/sdk/testint/grants_integration_test.go | 74 ++++++++++++++----- 7 files changed, 170 insertions(+), 60 deletions(-) diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index 05327f633c..421623b17d 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -31,18 +31,37 @@ func (c *FunctionClient) Create(t *testing.T, arguments ...sdk.DataType) *sdk.Fu func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments) *sdk.Function { t.Helper() - ctx := context.Background() - argumentRequests := make([]sdk.FunctionArgumentRequest, len(id.ArgumentDataTypes())) - for i, argumentDataType := range id.ArgumentDataTypes() { - argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), argumentDataType) - } - err := c.client().CreateForSQL(ctx, + + return c.CreateWithRequest(t, id, sdk.NewCreateForSQLFunctionRequest( id.SchemaObjectId(), *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), "SELECT 1", - ).WithArguments(argumentRequests).WithSecure(true), + ), ) +} + +func (c *FunctionClient) CreateSecure(t *testing.T, arguments ...sdk.DataType) *sdk.Function { + t.Helper() + id := c.ids.RandomSchemaObjectIdentifierWithArguments(arguments...) + + return c.CreateWithRequest(t, id, + sdk.NewCreateForSQLFunctionRequest( + id.SchemaObjectId(), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), + "SELECT 1", + ).WithSecure(true), + ) +} + +func (c *FunctionClient) CreateWithRequest(t *testing.T, id sdk.SchemaObjectIdentifierWithArguments, req *sdk.CreateForSQLFunctionRequest) *sdk.Function { + t.Helper() + ctx := context.Background() + argumentRequests := make([]sdk.FunctionArgumentRequest, len(id.ArgumentDataTypes())) + for i, argumentDataType := range id.ArgumentDataTypes() { + argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), argumentDataType) + } + err := c.client().CreateForSQL(ctx, req.WithArguments(argumentRequests)) require.NoError(t, err) t.Cleanup(func() { diff --git a/pkg/resources/grant_ownership_acceptance_test.go b/pkg/resources/grant_ownership_acceptance_test.go index e3c3dac8a4..1d2357b00d 100644 --- a/pkg/resources/grant_ownership_acceptance_test.go +++ b/pkg/resources/grant_ownership_acceptance_test.go @@ -323,19 +323,16 @@ func TestAcc_GrantOwnership_OnObject_Table_ToDatabaseRole(t *testing.T) { }) } -func TestAcc_GrantOwnership_OnObject_Procedure_ToAccountRole(t *testing.T) { +func TestAcc_GrantOwnership_OnObject_ProcedureWithArguments_ToAccountRole(t *testing.T) { databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() - databaseName := databaseId.Name() schemaId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) - schemaName := schemaId.Name() procedureId := acc.TestClient().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(acc.TestClient().Ids.Alpha(), schemaId, sdk.DataTypeFloat) accountRoleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() - accountRoleName := accountRoleId.Name() configVariables := config.Variables{ - "account_role_name": config.StringVariable(accountRoleName), - "database_name": config.StringVariable(databaseName), - "schema_name": config.StringVariable(schemaName), + "account_role_name": config.StringVariable(accountRoleId.Name()), + "database_name": config.StringVariable(databaseId.Name()), + "schema_name": config.StringVariable(schemaId.Name()), "procedure_name": config.StringVariable(procedureId.Name()), } resourceName := "snowflake_grant_ownership.test" @@ -351,7 +348,7 @@ func TestAcc_GrantOwnership_OnObject_Procedure_ToAccountRole(t *testing.T) { ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantOwnership/OnObject_Procedure_ToAccountRole"), ConfigVariables: configVariables, Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "account_role_name", accountRoleName), + resource.TestCheckResourceAttr(resourceName, "account_role_name", accountRoleId.Name()), resource.TestCheckResourceAttr(resourceName, "on.0.object_type", "PROCEDURE"), resource.TestCheckResourceAttr(resourceName, "on.0.object_name", procedureId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("ToAccountRole|%s||OnObject|PROCEDURE|%s", accountRoleId.FullyQualifiedName(), procedureId.FullyQualifiedName())), @@ -359,7 +356,7 @@ func TestAcc_GrantOwnership_OnObject_Procedure_ToAccountRole(t *testing.T) { To: &sdk.ShowGrantsTo{ Role: accountRoleId, }, - }, sdk.ObjectTypeProcedure, accountRoleName, procedureId.FullyQualifiedName()), + }, sdk.ObjectTypeProcedure, accountRoleId.Name(), procedureId.FullyQualifiedName()), ), }, { @@ -375,19 +372,15 @@ func TestAcc_GrantOwnership_OnObject_Procedure_ToAccountRole(t *testing.T) { func TestAcc_GrantOwnership_OnObject_ProcedureWithoutArguments_ToDatabaseRole(t *testing.T) { databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() - databaseName := databaseId.Name() schemaId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) - schemaName := schemaId.Name() procedureId := acc.TestClient().Ids.NewSchemaObjectIdentifierWithArgumentsInSchema(acc.TestClient().Ids.Alpha(), schemaId) databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) - databaseRoleName := databaseRoleId.Name() - databaseRoleFullyQualifiedName := databaseRoleId.FullyQualifiedName() configVariables := config.Variables{ - "database_role_name": config.StringVariable(databaseRoleName), - "database_name": config.StringVariable(databaseName), - "schema_name": config.StringVariable(schemaName), + "database_role_name": config.StringVariable(databaseRoleId.Name()), + "database_name": config.StringVariable(databaseId.Name()), + "schema_name": config.StringVariable(schemaId.Name()), "procedure_name": config.StringVariable(procedureId.Name()), } resourceName := "snowflake_grant_ownership.test" @@ -403,15 +396,15 @@ func TestAcc_GrantOwnership_OnObject_ProcedureWithoutArguments_ToDatabaseRole(t ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantOwnership/OnObject_Procedure_ToDatabaseRole"), ConfigVariables: configVariables, Check: resource.ComposeTestCheckFunc( - resource.TestCheckResourceAttr(resourceName, "database_role_name", databaseRoleFullyQualifiedName), + resource.TestCheckResourceAttr(resourceName, "database_role_name", databaseRoleId.FullyQualifiedName()), resource.TestCheckResourceAttr(resourceName, "on.0.object_type", "PROCEDURE"), resource.TestCheckResourceAttr(resourceName, "on.0.object_name", procedureId.FullyQualifiedName()), - resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("ToDatabaseRole|%s||OnObject|PROCEDURE|%s", databaseRoleFullyQualifiedName, procedureId.FullyQualifiedName())), + resource.TestCheckResourceAttr(resourceName, "id", fmt.Sprintf("ToDatabaseRole|%s||OnObject|PROCEDURE|%s", databaseRoleId.FullyQualifiedName(), procedureId.FullyQualifiedName())), checkResourceOwnershipIsGranted(&sdk.ShowGrantOptions{ To: &sdk.ShowGrantsTo{ DatabaseRole: databaseRoleId, }, - }, sdk.ObjectTypeProcedure, databaseRoleName, procedureId.FullyQualifiedName()), + }, sdk.ObjectTypeProcedure, databaseRoleId.Name(), procedureId.FullyQualifiedName()), ), }, { diff --git a/pkg/resources/grant_ownership_test.go b/pkg/resources/grant_ownership_test.go index 25d898b9a0..6c0cd938e3 100644 --- a/pkg/resources/grant_ownership_test.go +++ b/pkg/resources/grant_ownership_test.go @@ -55,14 +55,14 @@ func TestGetOnObjectIdentifier(t *testing.T) { { Name: "account object identifier with dots", ObjectType: sdk.ObjectTypeDatabase, - ObjectName: "database.name.with.dots", + ObjectName: "\"database.name.with.dots\"", Expected: sdk.NewAccountObjectIdentifier("database.name.with.dots"), }, { Name: "validation - valid identifier", ObjectType: sdk.ObjectTypeDatabase, ObjectName: "to.many.parts.in.this.identifier", - Error: "unable to classify identifier", + Error: "unexpected number of parts 6 in identifier to.many.parts.in.this.identifier, expected 1 in a form of \"\"", }, { Name: "validation - unsupported type", @@ -74,13 +74,13 @@ func TestGetOnObjectIdentifier(t *testing.T) { Name: "validation - invalid database object identifier", ObjectType: sdk.ObjectTypeSchema, ObjectName: "test_database.test_schema.test_table", - Error: "invalid object_name test_database.test_schema.test_table, expected database object identifier", + Error: "unexpected number of parts 3 in identifier test_database.test_schema.test_table, expected 2 in a form of \".\"", }, { Name: "validation - invalid schema object identifier", ObjectType: sdk.ObjectTypeTable, ObjectName: "test_database.test_schema.test_table.column_name", - Error: "invalid object_name test_database.test_schema.test_table.column_name, expected schema object identifier", + Error: "unexpected number of parts 4 in identifier test_database.test_schema.test_table.column_name, expected 3 in a form of \"..\"", }, } @@ -248,7 +248,7 @@ func TestGetOwnershipGrantOn(t *testing.T) { "object_type": "SCHEMA", "object_name": "test_database.test_schema.test_table", }, - Error: "invalid object_name test_database.test_schema.test_table, expected database object identifier", + Error: "unexpected number of parts 3 in identifier test_database.test_schema.test_table, expected 2 in a form of \".\"", }, } diff --git a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go index 2a938f75ae..6386860639 100644 --- a/pkg/resources/grant_privileges_to_account_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_account_role_acceptance_test.go @@ -7,6 +7,7 @@ import ( "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/testenvs" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/config" @@ -445,12 +446,13 @@ func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnObject(t *testing.T) }) } -func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunction(t *testing.T) { +func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunctionWithArguments(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) acc.TestAccPreCheck(t) roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() roleFullyQualifiedName := roleId.FullyQualifiedName() - function := acc.TestClient().Function.Create(t, sdk.DataTypeFloat) + function := acc.TestClient().Function.CreateSecure(t, sdk.DataTypeFloat) configVariables := config.Variables{ "name": config.StringVariable(roleFullyQualifiedName), "function_name": config.StringVariable(function.ID().Name()), @@ -501,11 +503,12 @@ func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunction(t *testing.T } func TestAcc_GrantPrivilegesToAccountRole_OnSchemaObject_OnFunctionWithoutArguments(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) acc.TestAccPreCheck(t) roleId := acc.TestClient().Ids.RandomAccountObjectIdentifier() roleFullyQualifiedName := roleId.FullyQualifiedName() - function := acc.TestClient().Function.Create(t) + function := acc.TestClient().Function.CreateSecure(t) configVariables := config.Variables{ "name": config.StringVariable(roleFullyQualifiedName), "function_name": config.StringVariable(function.ID().Name()), diff --git a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go index 897b42090f..b98d23578a 100644 --- a/pkg/resources/grant_privileges_to_database_role_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_database_role_acceptance_test.go @@ -7,6 +7,7 @@ import ( "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/testenvs" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/config" @@ -625,11 +626,12 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnAll_Streamlits_InDat }) } -func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunction(t *testing.T) { +func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunctionWithArguments(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) acc.TestAccPreCheck(t) databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() - function := acc.TestClient().Function.Create(t, sdk.DataTypeFloat) + function := acc.TestClient().Function.CreateSecure(t, sdk.DataTypeFloat) configVariables := config.Variables{ "name": config.StringVariable(databaseRoleId.FullyQualifiedName()), "function_name": config.StringVariable(function.ID().Name()), @@ -680,10 +682,11 @@ func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunction(t *testing. } func TestAcc_GrantPrivilegesToDatabaseRole_OnSchemaObject_OnFunctionWithoutArguments(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) acc.TestAccPreCheck(t) databaseRoleId := acc.TestClient().Ids.RandomDatabaseObjectIdentifier() - function := acc.TestClient().Function.Create(t) + function := acc.TestClient().Function.CreateSecure(t) configVariables := config.Variables{ "name": config.StringVariable(databaseRoleId.FullyQualifiedName()), "function_name": config.StringVariable(function.ID().Name()), diff --git a/pkg/resources/grant_privileges_to_share_acceptance_test.go b/pkg/resources/grant_privileges_to_share_acceptance_test.go index a0e494a65f..3920ad3983 100644 --- a/pkg/resources/grant_privileges_to_share_acceptance_test.go +++ b/pkg/resources/grant_privileges_to_share_acceptance_test.go @@ -6,6 +6,7 @@ import ( "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/testenvs" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/config" @@ -337,12 +338,13 @@ func TestAcc_GrantPrivilegesToShare_OnTag(t *testing.T) { }) } -func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunction(t *testing.T) { +func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunctionWithArguments(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) acc.TestAccPreCheck(t) share, shareCleanup := acc.TestClient().Share.CreateShare(t) t.Cleanup(shareCleanup) - function := acc.TestClient().Function.Create(t, sdk.DataTypeFloat) + function := acc.TestClient().Function.CreateSecure(t, sdk.DataTypeFloat) configVariables := config.Variables{ "name": config.StringVariable(share.ID().Name()), "function_name": config.StringVariable(function.ID().Name()), @@ -385,11 +387,12 @@ func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunction(t *testing.T) { } func TestAcc_GrantPrivilegesToShare_OnSchemaObject_OnFunctionWithoutArguments(t *testing.T) { + _ = testenvs.GetOrSkipTest(t, testenvs.EnableAcceptance) acc.TestAccPreCheck(t) share, shareCleanup := acc.TestClient().Share.CreateShare(t) t.Cleanup(shareCleanup) - function := acc.TestClient().Function.Create(t) + function := acc.TestClient().Function.CreateSecure(t) configVariables := config.Variables{ "name": config.StringVariable(share.ID().Name()), "function_name": config.StringVariable(function.ID().Name()), @@ -648,3 +651,58 @@ func TestAcc_GrantPrivilegesToShare_RemoveShareOutsideTerraform(t *testing.T) { }, }) } + +func TestAcc_GrantPrivilegesToShareWithNameContainingDots_OnTable(t *testing.T) { + databaseId := acc.TestClient().Ids.RandomAccountObjectIdentifier() + schemaId := acc.TestClient().Ids.RandomDatabaseObjectIdentifierInDatabase(databaseId) + tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifierInSchema(schemaId) + shareId := acc.TestClient().Ids.RandomAccountObjectIdentifierContaining(".foo.bar") + + configVariables := func(withGrant bool) config.Variables { + variables := config.Variables{ + "to_share": config.StringVariable(shareId.Name()), + "database": config.StringVariable(databaseId.Name()), + "schema": config.StringVariable(schemaId.Name()), + "on_table": config.StringVariable(tableId.Name()), + } + if withGrant { + variables["privileges"] = config.ListVariable( + config.StringVariable(sdk.ObjectPrivilegeSelect.String()), + ) + } + return variables + } + resourceName := "snowflake_grant_privileges_to_share.test" + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + Steps: []resource.TestStep{ + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnTable"), + ConfigVariables: configVariables(true), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "to_share", shareId.Name()), + resource.TestCheckResourceAttr(resourceName, "privileges.#", "1"), + resource.TestCheckResourceAttr(resourceName, "privileges.0", sdk.ObjectPrivilegeSelect.String()), + resource.TestCheckResourceAttr(resourceName, "on_table", tableId.FullyQualifiedName()), + ), + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnTable"), + ConfigVariables: configVariables(true), + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + ConfigDirectory: acc.ConfigurationDirectory("TestAcc_GrantPrivilegesToShare/OnTable_NoGrant"), + ConfigVariables: configVariables(false), + Check: acc.CheckSharePrivilegesRevoked(t), + }, + }, + }) +} diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index 33d344b226..bde1ef78d3 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -842,42 +842,42 @@ func TestInt_GrantAndRevokePrivilegesToDatabaseRole(t *testing.T) { func TestInt_GrantPrivilegeToShare(t *testing.T) { client := testClient(t) ctx := testContext(t) - shareTest, shareCleanup := testClientHelper().Share.CreateShareWithIdentifier(t, testClientHelper().Ids.RandomAccountObjectIdentifierContaining(".foo.bar")) + shareTest, shareCleanup := testClientHelper().Share.CreateShare(t) t.Cleanup(shareCleanup) - assertGrant := func(t *testing.T, grants []sdk.Grant, onId sdk.ObjectIdentifier, privilege sdk.ObjectPrivilege, grantedOn sdk.ObjectType, granteeName sdk.ObjectIdentifier) { + assertGrant := func(t *testing.T, grants []sdk.Grant, onId sdk.ObjectIdentifier, privilege sdk.ObjectPrivilege, grantedOn sdk.ObjectType, granteeName sdk.ObjectIdentifier, shareName string) { t.Helper() - var shareGrant *sdk.Grant - for i, grant := range grants { - if grant.GranteeName.Name() == shareTest.ID().Name() && grant.Privilege == string(privilege) { - shareGrant = &grants[i] - break - } - } - assert.NotNil(t, shareGrant) - assert.Equal(t, grantedOn, shareGrant.GrantedOn) - assert.Equal(t, sdk.ObjectTypeShare, shareGrant.GrantedTo) - assert.Equal(t, granteeName.FullyQualifiedName(), shareGrant.GranteeName.FullyQualifiedName()) - assert.Equal(t, onId.FullyQualifiedName(), shareGrant.Name.FullyQualifiedName()) + actualGrant, err := collections.FindOne(grants, func(grant sdk.Grant) bool { + return grant.GranteeName.Name() == shareName && grant.Privilege == string(privilege) + }) + require.NoError(t, err) + assert.Equal(t, grantedOn, actualGrant.GrantedOn) + assert.Equal(t, sdk.ObjectTypeShare, actualGrant.GrantedTo) + assert.Equal(t, granteeName.FullyQualifiedName(), actualGrant.GranteeName.FullyQualifiedName()) + assert.Equal(t, onId.FullyQualifiedName(), actualGrant.Name.FullyQualifiedName()) } - t.Run("with options", func(t *testing.T) { + grantShareOnDatabase := func(t *testing.T, share *sdk.Share) { + t.Helper() err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Database: testDb(t).ID(), - }, shareTest.ID()) + }, share.ID()) require.NoError(t, err) t.Cleanup(func() { err := client.Grants.RevokePrivilegeFromShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Database: testDb(t).ID(), - }, shareTest.ID()) + }, share.ID()) assert.NoError(t, err) }) + } + t.Run("with options", func(t *testing.T) { + grantShareOnDatabase(t, shareTest) table, tableCleanup := testClientHelper().Table.CreateTable(t) t.Cleanup(tableCleanup) - err = client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeSelect}, &sdk.ShareGrantOn{ + err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeSelect}, &sdk.ShareGrantOn{ Table: &sdk.OnTable{ AllInSchema: testSchema(t).ID(), }, @@ -893,7 +893,7 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }, }) require.NoError(t, err) - assertGrant(t, grants, table.ID(), sdk.ObjectPrivilegeSelect, sdk.ObjectTypeTable, shareTest.ID()) + assertGrant(t, grants, table.ID(), sdk.ObjectPrivilegeSelect, sdk.ObjectTypeTable, shareTest.ID(), shareTest.ID().Name()) _, err = client.Grants.Show(ctx, &sdk.ShowGrantOptions{ To: &sdk.ShowGrantsTo{ @@ -920,7 +920,7 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }, }) require.NoError(t, err) - assertGrant(t, grants, function.ID(), sdk.ObjectPrivilegeUsage, sdk.ObjectTypeFunction, shareTest.ID()) + assertGrant(t, grants, function.ID(), sdk.ObjectPrivilegeUsage, sdk.ObjectTypeFunction, shareTest.ID(), shareTest.ID().Name()) _, err = client.Grants.Show(ctx, &sdk.ShowGrantOptions{ To: &sdk.ShowGrantsTo{ @@ -953,6 +953,40 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }, shareTest.ID()) require.NoError(t, err) }) + t.Run("with a name containing dots", func(t *testing.T) { + shareTest, shareCleanup := testClientHelper().Share.CreateShareWithIdentifier(t, testClientHelper().Ids.RandomAccountObjectIdentifierContaining(".foo.bar")) + t.Cleanup(shareCleanup) + grantShareOnDatabase(t, shareTest) + table, tableCleanup := testClientHelper().Table.CreateTable(t) + t.Cleanup(tableCleanup) + + err := client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeSelect}, &sdk.ShareGrantOn{ + Table: &sdk.OnTable{ + AllInSchema: testSchema(t).ID(), + }, + }, shareTest.ID()) + require.NoError(t, err) + + grants, err := client.Grants.Show(ctx, &sdk.ShowGrantOptions{ + On: &sdk.ShowGrantsOn{ + Object: &sdk.Object{ + ObjectType: sdk.ObjectTypeTable, + Name: table.ID(), + }, + }, + }) + require.NoError(t, err) + assertGrant(t, grants, table.ID(), sdk.ObjectPrivilegeSelect, sdk.ObjectTypeTable, shareTest.ID(), shareTest.ID().Name()) + + _, err = client.Grants.Show(ctx, &sdk.ShowGrantOptions{ + To: &sdk.ShowGrantsTo{ + Share: &sdk.ShowGrantsToShare{ + Name: shareTest.ID(), + }, + }, + }) + require.NoError(t, err) + }) } func TestInt_RevokePrivilegeToShare(t *testing.T) { From c5ad5b28c29a91822ee1d7739d32b7318596bf47 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Mon, 19 Aug 2024 16:54:59 +0200 Subject: [PATCH 17/19] Fix tests --- pkg/sdk/grants.go | 12 ++++++++---- pkg/sdk/testint/grants_integration_test.go | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pkg/sdk/grants.go b/pkg/sdk/grants.go index e4128a4d27..d584429fa7 100644 --- a/pkg/sdk/grants.go +++ b/pkg/sdk/grants.go @@ -231,11 +231,15 @@ func (row grantRow) convert() *Grant { grantTo := ObjectType(strings.ReplaceAll(row.GrantTo, "_", " ")) var granteeName AccountObjectIdentifier if grantedTo == ObjectTypeShare { - index := strings.IndexRune(row.GranteeName, '.') - if index == -1 { + // TODO(SNOW-1058419): Change this logic during identifiers rework + parts := strings.Split(row.GranteeName, ".") + switch { + case len(parts) == 1: + granteeName = NewAccountObjectIdentifier(parts[0]) + case len(parts) == 2: + granteeName = NewAccountObjectIdentifier(parts[1]) + default: log.Printf("unsupported case for share's grantee name: %s", row.GranteeName) - } else { - granteeName = NewAccountObjectIdentifier(row.GranteeName[index+1:]) } } else { granteeName = NewAccountObjectIdentifier(row.GranteeName) diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index bde1ef78d3..1bdba32fa8 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -904,7 +904,7 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }) require.NoError(t, err) - function := testClientHelper().Function.Create(t) + function := testClientHelper().Function.CreateSecure(t) err = client.Grants.GrantPrivilegeToShare(ctx, []sdk.ObjectPrivilege{sdk.ObjectPrivilegeUsage}, &sdk.ShareGrantOn{ Function: function.ID(), From 148b73ace18e07d6cd68766e1c3777a8eff36444 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Tue, 20 Aug 2024 11:24:11 +0200 Subject: [PATCH 18/19] Review suggestions --- pkg/resources/grant_privileges_to_account_role.go | 1 + .../grant_privileges_to_account_role_identifier.go | 1 + pkg/resources/grant_privileges_to_database_role.go | 1 + .../grant_privileges_to_database_role_identifier.go | 1 + pkg/sdk/grants.go | 1 + pkg/sdk/identifier_parsers.go | 10 +++++++--- pkg/sdk/identifier_parsers_test.go | 1 + 7 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pkg/resources/grant_privileges_to_account_role.go b/pkg/resources/grant_privileges_to_account_role.go index ae00859732..d08207c9fb 100644 --- a/pkg/resources/grant_privileges_to_account_role.go +++ b/pkg/resources/grant_privileges_to_account_role.go @@ -1076,6 +1076,7 @@ func getAccountRoleGrantOn(d *schema.ResourceData) (*sdk.AccountRoleGrantOn, err case objectTypeOk && objectNameOk: objectType := sdk.ObjectType(objectType) var id sdk.ObjectIdentifier + // TODO(SNOW-1569535): use a mapper from object type to parsing function if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(objectName) diff --git a/pkg/resources/grant_privileges_to_account_role_identifier.go b/pkg/resources/grant_privileges_to_account_role_identifier.go index 6c16423c33..4f6f45ce4b 100644 --- a/pkg/resources/grant_privileges_to_account_role_identifier.go +++ b/pkg/resources/grant_privileges_to_account_role_identifier.go @@ -139,6 +139,7 @@ func ParseGrantPrivilegesToAccountRoleId(id string) (GrantPrivilegesToAccountRol } objectType := sdk.ObjectType(parts[6]) var id sdk.ObjectIdentifier + // TODO(SNOW-1569535): use a mapper from object type to parsing function if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(parts[7]) diff --git a/pkg/resources/grant_privileges_to_database_role.go b/pkg/resources/grant_privileges_to_database_role.go index 34052cf66d..411f127e3f 100644 --- a/pkg/resources/grant_privileges_to_database_role.go +++ b/pkg/resources/grant_privileges_to_database_role.go @@ -905,6 +905,7 @@ func getDatabaseRoleGrantOn(d *schema.ResourceData) (*sdk.DatabaseRoleGrantOn, e case objectTypeOk && objectNameOk: objectType := sdk.ObjectType(objectType) var id sdk.ObjectIdentifier + // TODO(SNOW-1569535): use a mapper from object type to parsing function if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(objectName) diff --git a/pkg/resources/grant_privileges_to_database_role_identifier.go b/pkg/resources/grant_privileges_to_database_role_identifier.go index 1b7db1c86a..45ad11a432 100644 --- a/pkg/resources/grant_privileges_to_database_role_identifier.go +++ b/pkg/resources/grant_privileges_to_database_role_identifier.go @@ -129,6 +129,7 @@ func ParseGrantPrivilegesToDatabaseRoleId(id string) (GrantPrivilegesToDatabaseR } objectType := sdk.ObjectType(parts[6]) var id sdk.ObjectIdentifier + // TODO(SNOW-1569535): use a mapper from object type to parsing function if objectType.IsWithArguments() { var err error id, err = sdk.ParseSchemaObjectIdentifierWithArguments(parts[7]) diff --git a/pkg/sdk/grants.go b/pkg/sdk/grants.go index d584429fa7..0ed2c7b0ac 100644 --- a/pkg/sdk/grants.go +++ b/pkg/sdk/grants.go @@ -265,6 +265,7 @@ func (row grantRow) convert() *Grant { var name ObjectIdentifier var err error + // TODO(SNOW-1569535): use a mapper from object type to parsing function if ObjectType(row.GrantedOn).IsWithArguments() { name, err = ParseSchemaObjectIdentifierWithArgumentsAndReturnType(row.Name) } else { diff --git a/pkg/sdk/identifier_parsers.go b/pkg/sdk/identifier_parsers.go index a6852f4572..3b1dc6fc28 100644 --- a/pkg/sdk/identifier_parsers.go +++ b/pkg/sdk/identifier_parsers.go @@ -222,7 +222,12 @@ func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { // We use another buffer to peek into next data type (needed for vector parsing) peekDataType, _ := bytes.NewBufferString(stringBuffer.String()).ReadString(',') - + // Skip argument name, if present + if strings.ContainsRune(peekDataType, ' ') && !strings.HasPrefix(peekDataType, "VECTOR(") { + // Ignore err, because stringBuffer always contains ' ' here + _, _ = stringBuffer.ReadString(' ') + peekDataType, _ = bytes.NewBufferString(stringBuffer.String()).ReadString(',') + } switch { // For now, only vectors need special parsing behavior case strings.HasPrefix(peekDataType, "VECTOR"): @@ -273,8 +278,7 @@ func ParseFunctionArgumentsFromString(arguments string) ([]DataType, error) { if err == nil { argument = argument[:len(argument)-1] } - dataType := argument[strings.IndexRune(argument, ' ')+1:] - dataTypes = append(dataTypes, DataType(dataType)) + dataTypes = append(dataTypes, DataType(argument)) } } diff --git a/pkg/sdk/identifier_parsers_test.go b/pkg/sdk/identifier_parsers_test.go index ab0887dad1..58aeabb2f3 100644 --- a/pkg/sdk/identifier_parsers_test.go +++ b/pkg/sdk/identifier_parsers_test.go @@ -384,6 +384,7 @@ func TestNewSchemaObjectIdentifierWithArgumentsAndReturnTypeFromFullyQualifiedNa {RawInput: `abc.def.ghi():FLOAT`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`)}, {RawInput: `abc.def."ghi(FLOAT, VECTOR(INT, 20)):NUMBER(10,2)"`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, {RawInput: `abc.def."ghi(FLOAT, VECTOR(INT, 20)):NUMBER"`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)")}, + {RawInput: `abc.def."ghi(ab FLOAT, VECTOR VECTOR(INT, 20), FLOAT):NUMBER"`, ExpectedIdentifierStructure: NewSchemaObjectIdentifierWithArguments(`abc`, `def`, `ghi`, DataTypeFloat, "VECTOR(INT, 20)", DataTypeFloat)}, } for _, testCase := range testCases { From 9d71f59eaedd12fbe9362829a8d9cd6f3d5a14c1 Mon Sep 17 00:00:00 2001 From: Jakub Michalak Date: Tue, 20 Aug 2024 11:49:59 +0200 Subject: [PATCH 19/19] Fix tests --- pkg/sdk/grants.go | 4 +++- pkg/sdk/testint/grants_integration_test.go | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/sdk/grants.go b/pkg/sdk/grants.go index 0ed2c7b0ac..0a6b3cc948 100644 --- a/pkg/sdk/grants.go +++ b/pkg/sdk/grants.go @@ -239,7 +239,9 @@ func (row grantRow) convert() *Grant { case len(parts) == 2: granteeName = NewAccountObjectIdentifier(parts[1]) default: - log.Printf("unsupported case for share's grantee name: %s", row.GranteeName) + fallback := row.GranteeName[strings.IndexRune(row.GranteeName, '.')+1:] + log.Printf("unsupported case for share's grantee name: %s Falling back to account object identifier: %s", row.GranteeName, fallback) + granteeName = NewAccountObjectIdentifier(fallback) } } else { granteeName = NewAccountObjectIdentifier(row.GranteeName) diff --git a/pkg/sdk/testint/grants_integration_test.go b/pkg/sdk/testint/grants_integration_test.go index 1bdba32fa8..d009e32142 100644 --- a/pkg/sdk/testint/grants_integration_test.go +++ b/pkg/sdk/testint/grants_integration_test.go @@ -953,6 +953,7 @@ func TestInt_GrantPrivilegeToShare(t *testing.T) { }, shareTest.ID()) require.NoError(t, err) }) + t.Run("with a name containing dots", func(t *testing.T) { shareTest, shareCleanup := testClientHelper().Share.CreateShareWithIdentifier(t, testClientHelper().Ids.RandomAccountObjectIdentifierContaining(".foo.bar")) t.Cleanup(shareCleanup)