diff --git a/internal/planModifiers/ignoreCaseModifier_test.go b/internal/planModifiers/ignoreCaseModifier_test.go index 9bca6c7..4fddd6d 100644 --- a/internal/planModifiers/ignoreCaseModifier_test.go +++ b/internal/planModifiers/ignoreCaseModifier_test.go @@ -16,38 +16,38 @@ func TestIgnoreCaseModifier(t *testing.T) { }{ "empty state": { request: tfsdk.ModifyAttributePlanRequest{ - AttributeState: types.String{Unknown: true}, - AttributePlan: types.String{Value: "plannedValue"}, + AttributeState: types.StringUnknown(), + AttributePlan: types.StringValue("plannedValue"), }, - expectedValue: types.String{Value: "plannedValue"}, + expectedValue: types.StringValue("plannedValue"), }, "empty plan": { request: tfsdk.ModifyAttributePlanRequest{ - AttributeState: types.String{Value: "stateValue"}, - AttributePlan: types.String{Null: true}, + AttributeState: types.StringValue("stateValue"), + AttributePlan: types.StringNull(), }, - expectedValue: types.String{Null: true}, + expectedValue: types.StringNull(), }, "non string": { request: tfsdk.ModifyAttributePlanRequest{ - AttributeState: types.Int64{Value: 246}, - AttributePlan: types.Int64{Value: 45763}, + AttributeState: types.Int64Value(246), + AttributePlan: types.Int64Value(45763), }, - expectedValue: types.Int64{Value: 45763}, + expectedValue: types.Int64Value(45763), }, "matching case": { request: tfsdk.ModifyAttributePlanRequest{ - AttributeState: types.String{Value: "matchingCase"}, - AttributePlan: types.String{Value: "matchingCase"}, + AttributeState: types.StringValue("matchingCase"), + AttributePlan: types.StringValue("matchingCase"), }, - expectedValue: types.String{Value: "matchingCase"}, + expectedValue: types.StringValue("matchingCase"), }, "not matching case": { request: tfsdk.ModifyAttributePlanRequest{ - AttributeState: types.String{Value: "NotMatchingCase"}, - AttributePlan: types.String{Value: "NOTMATCHINGCASE"}, + AttributeState: types.StringValue("NotMatchingCase"), + AttributePlan: types.StringValue("NOTMATCHINGCASE"), }, - expectedValue: types.String{Value: "NotMatchingCase"}, + expectedValue: types.StringValue("NotMatchingCase"), }, } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 1e0b2e9..b3d0e75 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -176,8 +176,8 @@ func (p *mssqlProvider) ValidateConfig(ctx context.Context, request provider.Val utils.StopOnError(ctx). Then(func() { data = utils.GetData[providerData](ctx, request.Config) }). Then(func() { - if data.AzureAuth.IsNull() && data.SqlAuth.IsNull() { + if data.AzureAuth == nil && data.SqlAuth == nil { utils.AddError(ctx, "Missing SQL authentication config", errors.New("One of authentication methods must be provided: sql_auth, azure_auth")) } }) -} \ No newline at end of file +} diff --git a/internal/provider/providerData.go b/internal/provider/providerData.go index 5b60573..66b74b1 100644 --- a/internal/provider/providerData.go +++ b/internal/provider/providerData.go @@ -23,18 +23,18 @@ type azureAuth struct { type providerData struct { Hostname types.String `tfsdk:"hostname"` Port types.Int64 `tfsdk:"port"` - SqlAuth types.Object `tfsdk:"sql_auth"` - AzureAuth types.Object `tfsdk:"azure_auth"` + SqlAuth *sqlAuth `tfsdk:"sql_auth"` + AzureAuth *azureAuth `tfsdk:"azure_auth"` } -func (pd providerData) asConnectionDetails(ctx context.Context) (sql.ConnectionDetails, diag.Diagnostics) { +func (pd providerData) asConnectionDetails(context.Context) (sql.ConnectionDetails, diag.Diagnostics) { diags := diag.Diagnostics{} var addComputedError = func(summary string) { diags.AddError(summary, "SQL connection details must be known during plan execution") } - if pd.Hostname.Unknown { + if pd.Hostname.IsUnknown() { addComputedError("Hostname cannot be a computed value") } @@ -42,51 +42,45 @@ func (pd providerData) asConnectionDetails(ctx context.Context) (sql.ConnectionD Host: os.Getenv("MSSQL_HOSTNAME"), } - if !pd.Hostname.Null { - connDetails.Host = pd.Hostname.Value + if !pd.Hostname.IsNull() { + connDetails.Host = pd.Hostname.ValueString() } - if !pd.Port.Null { - connDetails.Host = fmt.Sprintf("%s:%d", connDetails.Host, pd.Port.Value) + if !pd.Port.IsNull() { + connDetails.Host = fmt.Sprintf("%s:%d", connDetails.Host, pd.Port.ValueInt64()) } else if envPort := os.Getenv("MSSQL_PORT"); envPort != "" { connDetails.Host = fmt.Sprintf("%s:%s", connDetails.Host, envPort) } - if !pd.SqlAuth.Null { - var auth sqlAuth - diags.Append(pd.SqlAuth.As(ctx, &auth, types.ObjectAsOptions{})...) - - if auth.Username.Unknown { + if pd.SqlAuth != nil { + if pd.SqlAuth.Username.IsUnknown() { addComputedError("SQL username cannot be a computed value") } - if auth.Password.Unknown { + if pd.SqlAuth.Password.IsUnknown() { addComputedError("SQL password cannot be a computed value") } - connDetails.Auth = sql.ConnectionAuthSql{Username: auth.Username.Value, Password: auth.Password.Value} + connDetails.Auth = sql.ConnectionAuthSql{Username: pd.SqlAuth.Username.ValueString(), Password: pd.SqlAuth.Password.ValueString()} } - if !pd.AzureAuth.Null { - var auth azureAuth - diags.Append(pd.AzureAuth.As(ctx, &auth, types.ObjectAsOptions{})...) - - if auth.ClientId.Unknown { + if pd.AzureAuth != nil { + if pd.AzureAuth.ClientId.IsUnknown() { addComputedError("Azure AD Service Principal client_id cannot be a computed value") } - if auth.ClientSecret.Unknown { + if pd.AzureAuth.ClientSecret.IsUnknown() { addComputedError("Azure AD Service Principal client_secret cannot be a computed value") } - if auth.TenantId.Unknown { + if pd.AzureAuth.TenantId.IsUnknown() { addComputedError("Azure AD Service Principal tenant_id cannot be a computed value") } connAuth := sql.ConnectionAuthAzure{ - ClientId: auth.ClientId.Value, - ClientSecret: auth.ClientSecret.Value, - TenantId: auth.TenantId.Value, + ClientId: pd.AzureAuth.ClientId.ValueString(), + ClientSecret: pd.AzureAuth.ClientSecret.ValueString(), + TenantId: pd.AzureAuth.TenantId.ValueString(), } if connAuth.ClientId == "" { diff --git a/internal/provider/providerData_test.go b/internal/provider/providerData_test.go index af86d82..eee74c4 100644 --- a/internal/provider/providerData_test.go +++ b/internal/provider/providerData_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "github.com/PGSSoft/terraform-provider-mssql/internal/sql" - "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/stretchr/testify/assert" @@ -21,83 +20,62 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { errSummary string }{ "Hostname": { - pd: providerData{Hostname: types.String{Unknown: true}}, + pd: providerData{ + Hostname: types.StringUnknown(), + }, errSummary: "Hostname cannot be a computed value", }, "SQL Username": { - pd: providerData{SqlAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "username": types.StringType, - "password": types.StringType, - }, - Attrs: map[string]attr.Value{ - "username": types.String{Unknown: true}, - "password": types.String{Null: true}, + pd: providerData{ + SqlAuth: &sqlAuth{ + Username: types.StringUnknown(), + Password: types.StringNull(), }, - }}, + }, errSummary: "SQL username cannot be a computed value", }, "SQL Password": { - pd: providerData{SqlAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "username": types.StringType, - "password": types.StringType, - }, - Attrs: map[string]attr.Value{ - "username": types.String{Null: true}, - "password": types.String{Unknown: true}, + pd: providerData{ + SqlAuth: &sqlAuth{ + Username: types.StringNull(), + Password: types.StringUnknown(), }, - }}, + }, errSummary: "SQL password cannot be a computed value", }, "Azure auth ClientId": { - pd: providerData{AzureAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "client_id": types.StringType, - "client_secret": types.StringType, - "tenant_id": types.StringType, - }, - Attrs: map[string]attr.Value{ - "client_id": types.String{Unknown: true}, - "client_secret": types.String{Null: true}, - "tenant_id": types.String{Null: true}, + pd: providerData{ + AzureAuth: &azureAuth{ + ClientId: types.StringUnknown(), + ClientSecret: types.StringNull(), + TenantId: types.StringNull(), }, - }}, + }, errSummary: "Azure AD Service Principal client_id cannot be a computed value", }, "Azure auth ClientSecret": { - pd: providerData{AzureAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "client_id": types.StringType, - "client_secret": types.StringType, - "tenant_id": types.StringType, - }, - Attrs: map[string]attr.Value{ - "client_id": types.String{Null: true}, - "client_secret": types.String{Unknown: true}, - "tenant_id": types.String{Null: true}, + pd: providerData{ + AzureAuth: &azureAuth{ + ClientId: types.StringNull(), + ClientSecret: types.StringUnknown(), + TenantId: types.StringNull(), }, - }}, + }, errSummary: "Azure AD Service Principal client_secret cannot be a computed value", }, "Azure auth TenantId": { - pd: providerData{AzureAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "client_id": types.StringType, - "client_secret": types.StringType, - "tenant_id": types.StringType, - }, - Attrs: map[string]attr.Value{ - "client_id": types.String{Null: true}, - "client_secret": types.String{Null: true}, - "tenant_id": types.String{Unknown: true}, + pd: providerData{ + AzureAuth: &azureAuth{ + ClientId: types.StringNull(), + ClientSecret: types.StringNull(), + TenantId: types.StringUnknown(), }, - }}, + }, errSummary: "Azure AD Service Principal tenant_id cannot be a computed value", }, } @@ -124,22 +102,22 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { }{ "With port": { pd: providerData{ - Hostname: types.String{Value: "test_hostname"}, - Port: types.Int64{Value: 123}, + Hostname: types.StringValue("test_hostname"), + Port: types.Int64Value(123), }, host: "test_hostname:123", }, "Without port": { pd: providerData{ - Hostname: types.String{Value: "test_hostname2"}, - Port: types.Int64{Null: true}, + Hostname: types.StringValue("test_hostname2"), + Port: types.Int64Null(), }, host: "test_hostname2", }, "Env variable hostname": { pd: providerData{ - Hostname: types.String{Null: true}, - Port: types.Int64{Null: true}, + Hostname: types.StringNull(), + Port: types.Int64Null(), }, env: map[string]string{ "MSSQL_HOSTNAME": "env_test_hostname", @@ -148,8 +126,8 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { }, "Env variable hostname and port": { pd: providerData{ - Hostname: types.String{Null: true}, - Port: types.Int64{Null: true}, + Hostname: types.StringNull(), + Port: types.Int64Null(), }, env: map[string]string{ "MSSQL_HOSTNAME": "env_test_hostname2", @@ -159,8 +137,8 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { }, "Env variables and attributes": { pd: providerData{ - Hostname: types.String{Value: "test_hostname"}, - Port: types.Int64{Value: 123}, + Hostname: types.StringValue("test_hostname"), + Port: types.Int64Value(123), }, env: map[string]string{ "MSSQL_HOSTNAME": "env_test_hostname2", @@ -190,17 +168,10 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { t.Run("SQL auth", func(t *testing.T) { pd := providerData{ - SqlAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "username": types.StringType, - "password": types.StringType, - }, - Attrs: map[string]attr.Value{ - "username": types.String{Value: "test_username"}, - "password": types.String{Value: "test_password"}, - }, + SqlAuth: &sqlAuth{ + Username: types.StringValue("test_username"), + Password: types.StringValue("test_password"), }, - AzureAuth: types.Object{Null: true}, } cd, _ := pd.asConnectionDetails(ctx) @@ -213,18 +184,10 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { t.Run("Azure auth", func(t *testing.T) { pd := providerData{ - SqlAuth: types.Object{Null: true}, - AzureAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "client_id": types.StringType, - "client_secret": types.StringType, - "tenant_id": types.StringType, - }, - Attrs: map[string]attr.Value{ - "client_id": types.String{Value: "test_client_id"}, - "client_secret": types.String{Value: "test_client_secret"}, - "tenant_id": types.String{Value: "test_tenant_id"}, - }, + AzureAuth: &azureAuth{ + ClientId: types.StringValue("test_client_id"), + ClientSecret: types.StringValue("test_client_secret"), + TenantId: types.StringValue("test_tenant_id"), }, } @@ -239,8 +202,7 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { t.Run("Azure auth env variables", func(t *testing.T) { pd := providerData{ - SqlAuth: types.Object{}, - AzureAuth: types.Object{}, + AzureAuth: &azureAuth{}, } os.Setenv("ARM_CLIENT_ID", "env_test_client_id") os.Setenv("ARM_CLIENT_SECRET", "env_test_client_secret") @@ -257,18 +219,10 @@ func TestProviderDataAsConnectionDetails(t *testing.T) { t.Run("Azure auth and env variables", func(t *testing.T) { pd := providerData{ - SqlAuth: types.Object{}, - AzureAuth: types.Object{ - AttrTypes: map[string]attr.Type{ - "client_id": types.StringType, - "client_secret": types.StringType, - "tenant_id": types.StringType, - }, - Attrs: map[string]attr.Value{ - "client_id": types.String{Value: "test_client_id"}, - "client_secret": types.String{Value: "test_client_secret"}, - "tenant_id": types.String{Value: "test_tenant_id"}, - }, + AzureAuth: &azureAuth{ + ClientId: types.StringValue("test_client_id"), + ClientSecret: types.StringValue("test_client_secret"), + TenantId: types.StringValue("test_tenant_id"), }, } os.Setenv("ARM_CLIENT_ID", "env_test_client_id") diff --git a/internal/services/azureADServicePrincipal/base.go b/internal/services/azureADServicePrincipal/base.go index dc6eb4b..32c847f 100644 --- a/internal/services/azureADServicePrincipal/base.go +++ b/internal/services/azureADServicePrincipal/base.go @@ -39,8 +39,8 @@ type resourceData struct { func (d resourceData) toSettings() sql.UserSettings { return sql.UserSettings{ - Name: d.Name.Value, - AADObjectId: sql.AADObjectId(d.ClientId.Value), + Name: d.Name.ValueString(), + AADObjectId: sql.AADObjectId(d.ClientId.ValueString()), Type: sql.USER_TYPE_AZUREAD, } } @@ -51,7 +51,7 @@ func (d resourceData) withSettings(ctx context.Context, settings sql.UserSetting return d } - d.Name = types.String{Value: settings.Name} - d.ClientId = types.String{Value: strings.ToUpper(fmt.Sprint(settings.AADObjectId))} + d.Name = types.StringValue(settings.Name) + d.ClientId = types.StringValue(strings.ToUpper(fmt.Sprint(settings.AADObjectId))) return d } diff --git a/internal/services/azureADServicePrincipal/data.go b/internal/services/azureADServicePrincipal/data.go index 233ce12..280035f 100644 --- a/internal/services/azureADServicePrincipal/data.go +++ b/internal/services/azureADServicePrincipal/data.go @@ -40,26 +40,26 @@ func (d *dataSource) Read(ctx context.Context, request datasource.ReadRequest[re ) request. - Then(func() { db = common2.GetResourceDb(ctx, request.Conn, request.Config.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, request.Conn, request.Config.DatabaseId.ValueString()) }). Then(func() { if !request.Config.Name.IsNull() && !request.Config.Name.IsUnknown() { - user = sql.GetUserByName(ctx, db, request.Config.Name.Value) + user = sql.GetUserByName(ctx, db, request.Config.Name.ValueString()) return } for _, u := range sql.GetUsers(ctx, db) { settings := u.GetSettings(ctx) - if settings.Type == sql.USER_TYPE_AZUREAD && strings.EqualFold(fmt.Sprint(settings.AADObjectId), request.Config.ClientId.Value) { + if settings.Type == sql.USER_TYPE_AZUREAD && strings.EqualFold(fmt.Sprint(settings.AADObjectId), request.Config.ClientId.ValueString()) { user = u return } } - utils.AddError(ctx, "User does not exist", fmt.Errorf("could not find user with name=%q and client_id=%q", request.Config.Name.Value, request.Config.ClientId.Value)) + utils.AddError(ctx, "User does not exist", fmt.Errorf("could not find user with name=%q and client_id=%q", request.Config.Name.ValueString(), request.Config.ClientId.ValueString())) }). Then(func() { state := request.Config - state.Id = types.String{Value: common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()} + state.Id = types.StringValue(common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()) state = state.withSettings(ctx, user.GetSettings(ctx)) response.SetState(state) }) diff --git a/internal/services/azureADServicePrincipal/resource.go b/internal/services/azureADServicePrincipal/resource.go index 5b7b99c..ca86813 100644 --- a/internal/services/azureADServicePrincipal/resource.go +++ b/internal/services/azureADServicePrincipal/resource.go @@ -51,10 +51,10 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { user = sql.CreateUser(ctx, db, req.Plan.toSettings()) }). Then(func() { - req.Plan.Id = types.String{Value: common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()} + req.Plan.Id = types.StringValue(common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()) }). Then(func() { resp.State = req.Plan }) } @@ -67,12 +67,12 @@ func (r *res) Read(ctx context.Context, req resource.ReadRequest[resourceData], ) req. - Then(func() { id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.Value) }). + Then(func() { id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.ValueString()) }). Then(func() { db = sql.GetDatabase(ctx, req.Conn, id.DbId) }). Then(func() { user = sql.GetUser(ctx, db, id.ObjectId) }). Then(func() { state := req.State.withSettings(ctx, user.GetSettings(ctx)) - state.DatabaseId = types.String{Value: fmt.Sprint(id.DbId)} + state.DatabaseId = types.StringValue(fmt.Sprint(id.DbId)) resp.SetState(state) }) } @@ -90,8 +90,8 @@ func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceDat req. Then(func() { - db = common2.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.Value) - id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.Value) + db = common2.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.ValueString()) + id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.ValueString()) }). Then(func() { user = sql.GetUser(ctx, db, id.ObjectId) }). Then(func() { user.Drop(ctx) }) diff --git a/internal/services/azureADUser/base.go b/internal/services/azureADUser/base.go index 8888596..9e9297f 100644 --- a/internal/services/azureADUser/base.go +++ b/internal/services/azureADUser/base.go @@ -38,8 +38,8 @@ type resourceData struct { func (d resourceData) toSettings() sql.UserSettings { return sql.UserSettings{ - Name: d.Name.Value, - AADObjectId: sql.AADObjectId(d.UserObjectId.Value), + Name: d.Name.ValueString(), + AADObjectId: sql.AADObjectId(d.UserObjectId.ValueString()), Type: sql.USER_TYPE_AZUREAD, } } @@ -50,7 +50,7 @@ func (d resourceData) withSettings(ctx context.Context, settings sql.UserSetting return d } - d.Name = types.String{Value: settings.Name} - d.UserObjectId = types.String{Value: strings.ToUpper(fmt.Sprint(settings.AADObjectId))} + d.Name = types.StringValue(settings.Name) + d.UserObjectId = types.StringValue(strings.ToUpper(fmt.Sprint(settings.AADObjectId))) return d } diff --git a/internal/services/azureADUser/data.go b/internal/services/azureADUser/data.go index e38eb1f..718636b 100644 --- a/internal/services/azureADUser/data.go +++ b/internal/services/azureADUser/data.go @@ -41,26 +41,26 @@ func (d *dataSource) Read(ctx context.Context, req datasource.ReadRequest[resour ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) }). Then(func() { if !req.Config.Name.IsNull() && !req.Config.Name.IsUnknown() { - user = sql.GetUserByName(ctx, db, req.Config.Name.Value) + user = sql.GetUserByName(ctx, db, req.Config.Name.ValueString()) return } for _, u := range sql.GetUsers(ctx, db) { settings := u.GetSettings(ctx) - if settings.Type == sql.USER_TYPE_AZUREAD && strings.ToUpper(fmt.Sprint(settings.AADObjectId)) == strings.ToUpper(req.Config.UserObjectId.Value) { + if settings.Type == sql.USER_TYPE_AZUREAD && strings.ToUpper(fmt.Sprint(settings.AADObjectId)) == strings.ToUpper(req.Config.UserObjectId.ValueString()) { user = u return } } - utils.AddError(ctx, "User does not exist", fmt.Errorf("could not find user with name=%q and object_id=%q", req.Config.Name.Value, req.Config.UserObjectId.Value)) + utils.AddError(ctx, "User does not exist", fmt.Errorf("could not find user with name=%q and object_id=%q", req.Config.Name.ValueString(), req.Config.UserObjectId.ValueString())) }). Then(func() { state := req.Config.withSettings(ctx, user.GetSettings(ctx)) - state.Id = types.String{Value: common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()} + state.Id = types.StringValue(common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()) resp.SetState(state) }) } diff --git a/internal/services/azureADUser/resource.go b/internal/services/azureADUser/resource.go index a8905ef..2cb487f 100644 --- a/internal/services/azureADUser/resource.go +++ b/internal/services/azureADUser/resource.go @@ -51,10 +51,10 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { user = sql.CreateUser(ctx, db, req.Plan.toSettings()) }). Then(func() { - req.Plan.Id = types.String{Value: common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()} + req.Plan.Id = types.StringValue(common2.DbObjectId[sql.UserId]{DbId: db.GetId(ctx), ObjectId: user.GetId(ctx)}.String()) }). Then(func() { resp.State = req.Plan }) } @@ -67,12 +67,12 @@ func (r *res) Read(ctx context.Context, req resource.ReadRequest[resourceData], ) req. - Then(func() { id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.Value) }). + Then(func() { id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.ValueString()) }). Then(func() { db = sql.GetDatabase(ctx, req.Conn, id.DbId) }). Then(func() { user = sql.GetUser(ctx, db, id.ObjectId) }). Then(func() { state := req.State.withSettings(ctx, user.GetSettings(ctx)) - state.DatabaseId = types.String{Value: fmt.Sprint(id.DbId)} + state.DatabaseId = types.StringValue(fmt.Sprint(id.DbId)) resp.SetState(state) }) } @@ -90,8 +90,8 @@ func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceDat req. Then(func() { - db = common2.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.Value) - id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.Value) + db = common2.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.ValueString()) + id = common2.ParseDbObjectId[sql.UserId](ctx, req.State.Id.ValueString()) }). Then(func() { user = sql.GetUser(ctx, db, id.ObjectId) }). Then(func() { user.Drop(ctx) }) diff --git a/internal/services/database/base.go b/internal/services/database/base.go index 32d9e00..57d1868 100644 --- a/internal/services/database/base.go +++ b/internal/services/database/base.go @@ -35,14 +35,14 @@ type resourceData struct { } func (d resourceData) getDbId(ctx context.Context) sql.DatabaseId { - if d.Id.Unknown || d.Id.Null { + if !common.IsAttrSet(d.Id) { return sql.NullDatabaseId } - id, err := strconv.Atoi(d.Id.Value) + id, err := strconv.Atoi(d.Id.ValueString()) if err != nil { - utils.AddError(ctx, fmt.Sprintf("Failed to convert resource ID '%s'", d.Id.Value), err) + utils.AddError(ctx, fmt.Sprintf("Failed to convert resource ID '%s'", d.Id.ValueString()), err) } return sql.DatabaseId(id) @@ -50,19 +50,21 @@ func (d resourceData) getDbId(ctx context.Context) sql.DatabaseId { func (d resourceData) toSettings() sql.DatabaseSettings { return sql.DatabaseSettings{ - Name: d.Name.Value, - Collation: d.Collation.Value, + Name: d.Name.ValueString(), + Collation: d.Collation.ValueString(), } } func (d resourceData) withSettings(settings sql.DatabaseSettings) resourceData { - return resourceData{ - Id: d.Id, - Name: types.String{Value: settings.Name}, + resData := resourceData{ + Id: d.Id, + Name: types.StringValue(settings.Name), + Collation: types.StringValue(settings.Collation), + } - Collation: types.String{ - Value: settings.Collation, - Null: settings.Collation == "", - }, + if settings.Collation == "" { + resData.Collation = types.StringNull() } + + return resData } diff --git a/internal/services/database/data.go b/internal/services/database/data.go index fc790fa..3a2d3f2 100644 --- a/internal/services/database/data.go +++ b/internal/services/database/data.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/PGSSoft/terraform-provider-mssql/internal/core/datasource" + "github.com/PGSSoft/terraform-provider-mssql/internal/services/common" "github.com/PGSSoft/terraform-provider-mssql/internal/utils" "github.com/PGSSoft/terraform-provider-mssql/internal/sql" @@ -36,17 +37,17 @@ func (d *dataSource) Read(ctx context.Context, req datasource.ReadRequest[resour req. Then(func() { - db = sql.GetDatabaseByName(ctx, req.Conn, req.Config.Name.Value) + db = sql.GetDatabaseByName(ctx, req.Conn, req.Config.Name.ValueString()) if db == nil || !db.Exists(ctx) { - utils.AddError(ctx, "DB does not exist", fmt.Errorf("could not find DB '%s'", req.Config.Name.Value)) + utils.AddError(ctx, "DB does not exist", fmt.Errorf("could not find DB '%s'", req.Config.Name.ValueString())) } }). Then(func() { state := req.Config.withSettings(db.GetSettings(ctx)) - if state.Id.Unknown || state.Id.Null { - state.Id = types.String{Value: fmt.Sprint(db.GetId(ctx))} + if !common.IsAttrSet(state.Id) { + state.Id = types.StringValue(fmt.Sprint(db.GetId(ctx))) } resp.SetState(state) diff --git a/internal/services/database/list.go b/internal/services/database/list.go index ddb93d0..b0b1dcb 100644 --- a/internal/services/database/list.go +++ b/internal/services/database/list.go @@ -51,13 +51,13 @@ func (l *listDataSource) Read(ctx context.Context, req datasource.ReadRequest[li Then(func() { dbs = sql.GetDatabases(ctx, req.Conn) }). Then(func() { result := listDataSourceData{ - Id: types.String{Value: ""}, + Id: types.StringValue(""), Databases: []resourceData{}, } for id, db := range dbs { r := resourceData{ - Id: types.String{Value: fmt.Sprint(id)}, + Id: types.StringValue(fmt.Sprint(id)), } result.Databases = append(result.Databases, r.withSettings(db.GetSettings(ctx))) } diff --git a/internal/services/database/resource.go b/internal/services/database/resource.go index 8c180ad..d68d0b9 100644 --- a/internal/services/database/resource.go +++ b/internal/services/database/resource.go @@ -39,7 +39,7 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat req. Then(func() { db = sql.CreateDatabase(ctx, req.Conn, req.Plan.toSettings()) }). Then(func() { resp.State = req.Plan.withSettings(db.GetSettings(ctx)) }). - Then(func() { resp.State.Id = types.String{Value: fmt.Sprint(db.GetId(ctx))} }) + Then(func() { resp.State.Id = types.StringValue(fmt.Sprint(db.GetId(ctx))) }) } func (r *res) Read(ctx context.Context, req resource.ReadRequest[resourceData], resp *resource.ReadResponse[resourceData]) { @@ -66,13 +66,13 @@ func (r *res) Update(ctx context.Context, req resource.UpdateRequest[resourceDat req. Then(func() { db = sql.GetDatabase(ctx, req.Conn, req.Plan.getDbId(ctx)) }). Then(func() { - if req.State.Name.Value != req.Plan.Name.Value { - db.Rename(ctx, req.Plan.Name.Value) + if req.State.Name.ValueString() != req.Plan.Name.ValueString() { + db.Rename(ctx, req.Plan.Name.ValueString()) } }). Then(func() { - if req.State.Collation.Value != req.Plan.Collation.Value { - db.SetCollation(ctx, req.Plan.Collation.Value) + if req.State.Collation.ValueString() != req.Plan.Collation.ValueString() { + db.SetCollation(ctx, req.Plan.Collation.ValueString()) } }). Then(func() { diff --git a/internal/services/databaseRole/base.go b/internal/services/databaseRole/base.go index 80bf6ec..1f325c6 100644 --- a/internal/services/databaseRole/base.go +++ b/internal/services/databaseRole/base.go @@ -58,10 +58,10 @@ func (d resourceData) withRoleData(ctx context.Context, role sql.DatabaseRole) r dbId := role.GetDb(ctx).GetId(ctx) return resourceData{ - Id: types.String{Value: common2.DbObjectId[sql.DatabaseRoleId]{DbId: dbId, ObjectId: role.GetId(ctx)}.String()}, - Name: types.String{Value: role.GetName(ctx)}, - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, - OwnerId: types.String{Value: common2.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: role.GetOwnerId(ctx)}.String()}, + Id: types.StringValue(common2.DbObjectId[sql.DatabaseRoleId]{DbId: dbId, ObjectId: role.GetId(ctx)}.String()), + Name: types.StringValue(role.GetName(ctx)), + DatabaseId: types.StringValue(fmt.Sprint(dbId)), + OwnerId: types.StringValue(common2.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: role.GetOwnerId(ctx)}.String()), } } @@ -82,10 +82,10 @@ type dataSourceData struct { func (d dataSourceData) withRoleData(ctx context.Context, role sql.DatabaseRole) dataSourceData { dbId := role.GetDb(ctx).GetId(ctx) data := dataSourceData{ - Id: types.String{Value: common2.DbObjectId[sql.DatabaseRoleId]{DbId: dbId, ObjectId: role.GetId(ctx)}.String()}, - Name: types.String{Value: role.GetName(ctx)}, - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, - OwnerId: types.String{Value: common2.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: role.GetOwnerId(ctx)}.String()}, + Id: types.StringValue(common2.DbObjectId[sql.DatabaseRoleId]{DbId: dbId, ObjectId: role.GetId(ctx)}.String()), + Name: types.StringValue(role.GetName(ctx)), + DatabaseId: types.StringValue(fmt.Sprint(dbId)), + OwnerId: types.StringValue(common2.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: role.GetOwnerId(ctx)}.String()), } mapType := func(typ sql.DatabasePrincipalType) string { @@ -104,9 +104,9 @@ func (d dataSourceData) withRoleData(ctx context.Context, role sql.DatabaseRole) for id, member := range role.GetMembers(ctx) { memberData := resourceRoleMembersData{ - Id: types.String{Value: common2.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: id}.String()}, - Name: types.String{Value: member.Name}, - Type: types.String{Value: mapType(member.Type)}, + Id: types.StringValue(common2.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: id}.String()), + Name: types.StringValue(member.Name), + Type: types.StringValue(mapType(member.Type)), } data.Members = append(data.Members, memberData) } diff --git a/internal/services/databaseRole/data.go b/internal/services/databaseRole/data.go index 168bafc..16604da 100644 --- a/internal/services/databaseRole/data.go +++ b/internal/services/databaseRole/data.go @@ -48,7 +48,7 @@ func (d *dataSource) Read(ctx context.Context, req datasource.ReadRequest[dataSo ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) }). - Then(func() { role = sql.GetDatabaseRoleByName(ctx, db, req.Config.Name.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) }). + Then(func() { role = sql.GetDatabaseRoleByName(ctx, db, req.Config.Name.ValueString()) }). Then(func() { resp.SetState(req.Config.withRoleData(ctx, role)) }) } diff --git a/internal/services/databaseRole/list.go b/internal/services/databaseRole/list.go index 2644ca6..a5749ac 100644 --- a/internal/services/databaseRole/list.go +++ b/internal/services/databaseRole/list.go @@ -56,12 +56,12 @@ func (l *listDataSource) Read(ctx context.Context, req datasource.ReadRequest[li ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) }). Then(func() { dbId = db.GetId(ctx) }). Then(func() { roles = sql.GetDatabaseRoles(ctx, db) }). Then(func() { state := listDataSourceData{ - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, + DatabaseId: types.StringValue(fmt.Sprint(dbId)), } state.Id = state.DatabaseId diff --git a/internal/services/databaseRole/resource.go b/internal/services/databaseRole/resource.go index 0e5bd12..2759720 100644 --- a/internal/services/databaseRole/resource.go +++ b/internal/services/databaseRole/resource.go @@ -44,11 +44,11 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat ownerId := common2.DbObjectId[sql.GenericDatabasePrincipalId]{IsEmpty: true} req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { dbId = db.GetId(ctx) }). Then(func() { - if !req.Plan.OwnerId.Null && !req.Plan.OwnerId.Unknown { - ownerId = common2.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, req.Plan.OwnerId.Value) + if common2.IsAttrSet(req.Plan.OwnerId) { + ownerId = common2.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, req.Plan.OwnerId.ValueString()) if ownerId.DbId != dbId { utils.AddError(ctx, "Role owner must be principal defined in the same DB as the role", errors.New("owner and principal DBs are different")) @@ -57,9 +57,9 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat }). Then(func() { if ownerId.IsEmpty { - role = sql.CreateDatabaseRole(ctx, db, req.Plan.Name.Value, sql.EmptyDatabasePrincipalId) + role = sql.CreateDatabaseRole(ctx, db, req.Plan.Name.ValueString(), sql.EmptyDatabasePrincipalId) } else { - role = sql.CreateDatabaseRole(ctx, db, req.Plan.Name.Value, ownerId.ObjectId) + role = sql.CreateDatabaseRole(ctx, db, req.Plan.Name.ValueString(), ownerId.ObjectId) } }). Then(func() { resp.State = req.Plan.withRoleData(ctx, role) }) @@ -73,7 +73,7 @@ func (r *res) Read(ctx context.Context, req resource.ReadRequest[resourceData], ) req. - Then(func() { roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.State.Id.Value) }). + Then(func() { roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.State.Id.ValueString()) }). Then(func() { db = sql.GetDatabase(ctx, req.Conn, roleId.DbId) }). Then(func() { role = sql.GetDatabaseRole(ctx, db, roleId.ObjectId) }). Then(func() { resp.SetState(req.State.withRoleData(ctx, role)) }) @@ -89,13 +89,13 @@ func (r *res) Update(ctx context.Context, req resource.UpdateRequest[resourceDat ownerId := common2.DbObjectId[sql.GenericDatabasePrincipalId]{IsEmpty: true} req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { dbId = db.GetId(ctx) }). - Then(func() { roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.Plan.Id.Value) }). + Then(func() { roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.Plan.Id.ValueString()) }). Then(func() { role = sql.GetDatabaseRole(ctx, db, roleId.ObjectId) }). Then(func() { - if !req.Plan.OwnerId.Null && !req.Plan.OwnerId.Unknown { - ownerId = common2.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, req.Plan.OwnerId.Value) + if common2.IsAttrSet(req.Plan.OwnerId) { + ownerId = common2.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, req.Plan.OwnerId.ValueString()) if ownerId.DbId != dbId { utils.AddError(ctx, "Role owner must be principal defined in the same DB as the role", errors.New("owner and principal DBs are different")) @@ -103,8 +103,8 @@ func (r *res) Update(ctx context.Context, req resource.UpdateRequest[resourceDat } }). Then(func() { - if role.GetName(ctx) != req.Plan.Name.Value && !utils.HasError(ctx) { - role.Rename(ctx, req.Plan.Name.Value) + if role.GetName(ctx) != req.Plan.Name.ValueString() && !utils.HasError(ctx) { + role.Rename(ctx, req.Plan.Name.ValueString()) } }). Then(func() { @@ -128,8 +128,8 @@ func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceDat ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.Value) }). - Then(func() { roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.State.Id.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.ValueString()) }). + Then(func() { roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.State.Id.ValueString()) }). Then(func() { role = sql.GetDatabaseRole(ctx, db, roleId.ObjectId) }). Then(func() { role.Drop(ctx) }) } diff --git a/internal/services/databaseRoleMember/resource.go b/internal/services/databaseRoleMember/resource.go index 56c8059..bea70f6 100644 --- a/internal/services/databaseRoleMember/resource.go +++ b/internal/services/databaseRoleMember/resource.go @@ -47,8 +47,8 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat req. Then(func() { - roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.Plan.RoleId.Value) - memberId = common2.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, req.Plan.MemberId.Value) + roleId = common2.ParseDbObjectId[sql.DatabaseRoleId](ctx, req.Plan.RoleId.ValueString()) + memberId = common2.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, req.Plan.MemberId.ValueString()) }). Then(func() { if roleId.DbId != memberId.DbId { @@ -59,7 +59,7 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat Then(func() { role = sql.GetDatabaseRole(ctx, db, roleId.ObjectId) }). Then(func() { role.AddMember(ctx, memberId.ObjectId) }). Then(func() { - req.Plan.Id = types.String{Value: common2.DbObjectMemberId[sql.DatabaseRoleId, sql.GenericDatabasePrincipalId]{DbObjectId: roleId, MemberId: memberId.ObjectId}.String()} + req.Plan.Id = types.StringValue(common2.DbObjectMemberId[sql.DatabaseRoleId, sql.GenericDatabasePrincipalId]{DbObjectId: roleId, MemberId: memberId.ObjectId}.String()) resp.State = req.Plan }) } @@ -73,14 +73,14 @@ func (r *res) Read(ctx context.Context, req resource.ReadRequest[resourceData], req. Then(func() { - id = common2.ParseDbObjectMemberId[sql.DatabaseRoleId, sql.GenericDatabasePrincipalId](ctx, req.State.Id.Value) + id = common2.ParseDbObjectMemberId[sql.DatabaseRoleId, sql.GenericDatabasePrincipalId](ctx, req.State.Id.ValueString()) }). Then(func() { db = sql.GetDatabase(ctx, req.Conn, id.DbId) }). Then(func() { role = sql.GetDatabaseRole(ctx, db, id.ObjectId) }). Then(func() { if role.HasMember(ctx, id.MemberId) { - req.State.RoleId = types.String{Value: id.DbObjectId.String()} - req.State.MemberId = types.String{Value: id.GetMemberId().String()} + req.State.RoleId = types.StringValue(id.DbObjectId.String()) + req.State.MemberId = types.StringValue(id.GetMemberId().String()) resp.SetState(req.State) } }) @@ -99,7 +99,7 @@ func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceDat req. Then(func() { - id = common2.ParseDbObjectMemberId[sql.DatabaseRoleId, sql.GenericDatabasePrincipalId](ctx, req.State.Id.Value) + id = common2.ParseDbObjectMemberId[sql.DatabaseRoleId, sql.GenericDatabasePrincipalId](ctx, req.State.Id.ValueString()) }). Then(func() { db = sql.GetDatabase(ctx, req.Conn, id.DbId) }). Then(func() { role = sql.GetDatabaseRole(ctx, db, id.ObjectId) }). diff --git a/internal/services/schema/base.go b/internal/services/schema/base.go index e20de1a..05583f9 100644 --- a/internal/services/schema/base.go +++ b/internal/services/schema/base.go @@ -38,9 +38,9 @@ func (d resourceData) withSchemaData(ctx context.Context, schema sql.Schema) res dbId := schema.GetDb(ctx).GetId(ctx) return resourceData{ - Id: types.String{Value: common.DbObjectId[sql.SchemaId]{DbId: dbId, ObjectId: schema.GetId(ctx)}.String()}, - Name: types.String{Value: schema.GetName(ctx)}, - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, - OwnerId: types.String{Value: common.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: schema.GetOwnerId(ctx)}.String()}, + Id: types.StringValue(common.DbObjectId[sql.SchemaId]{DbId: dbId, ObjectId: schema.GetId(ctx)}.String()), + Name: types.StringValue(schema.GetName(ctx)), + DatabaseId: types.StringValue(fmt.Sprint(dbId)), + OwnerId: types.StringValue(common.DbObjectId[sql.GenericDatabasePrincipalId]{DbId: dbId, ObjectId: schema.GetOwnerId(ctx)}.String()), } } diff --git a/internal/services/schema/data.go b/internal/services/schema/data.go index ed4d6fd..2ab39a9 100644 --- a/internal/services/schema/data.go +++ b/internal/services/schema/data.go @@ -59,19 +59,19 @@ func (d dataSource) Read(ctx context.Context, req datasource.ReadRequest[resourc schema sql.Schema ) - schemaId := common.ParseDbObjectId[sql.SchemaId](ctx, req.Config.Id.Value) + schemaId := common.ParseDbObjectId[sql.SchemaId](ctx, req.Config.Id.ValueString()) req. Then(func() { if schemaId.IsEmpty { - db = common.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) + db = common.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) } else { db = sql.GetDatabase(ctx, req.Conn, schemaId.DbId) } }). Then(func() { if schemaId.IsEmpty { - schema = sql.GetSchemaByName(ctx, db, req.Config.Name.Value) + schema = sql.GetSchemaByName(ctx, db, req.Config.Name.ValueString()) } else { schema = sql.GetSchema(ctx, db, schemaId.ObjectId) } diff --git a/internal/services/schema/list.go b/internal/services/schema/list.go index 17915ab..fcfbca9 100644 --- a/internal/services/schema/list.go +++ b/internal/services/schema/list.go @@ -52,7 +52,7 @@ func (l listDataSource) Read(ctx context.Context, req datasource.ReadRequest[lis var schemas map[sql.SchemaId]sql.Schema var dbId sql.DatabaseId - db := common.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) + db := common.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) req. Then(func() { @@ -61,7 +61,7 @@ func (l listDataSource) Read(ctx context.Context, req datasource.ReadRequest[lis }). Then(func() { data := listDataSourceData{ - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, + DatabaseId: types.StringValue(fmt.Sprint(dbId)), Schemas: []resourceData{}, } data.Id = data.DatabaseId diff --git a/internal/services/schema/resource.go b/internal/services/schema/resource.go index 9785d0d..6aa66c8 100644 --- a/internal/services/schema/resource.go +++ b/internal/services/schema/resource.go @@ -16,7 +16,7 @@ func (r res) GetName() string { return "schema" } -func (r res) GetSchema(ctx context.Context) tfsdk.Schema { +func (r res) GetSchema(context.Context) tfsdk.Schema { return tfsdk.Schema{ MarkdownDescription: "Manages single DB schema.", Attributes: map[string]tfsdk.Attribute{ @@ -42,7 +42,7 @@ func (r res) Read(ctx context.Context, req resource.ReadRequest[resourceData], r ) req. - Then(func() { schemaId = common.ParseDbObjectId[sql.SchemaId](ctx, req.State.Id.Value) }). + Then(func() { schemaId = common.ParseDbObjectId[sql.SchemaId](ctx, req.State.Id.ValueString()) }). Then(func() { db = sql.GetDatabase(ctx, req.Conn, schemaId.DbId) }). Then(func() { schema = sql.GetSchema(ctx, db, schemaId.ObjectId) }). Then(func() { resp.SetState(req.State.withSchemaData(ctx, schema)) }) @@ -56,7 +56,7 @@ func (r res) Create(ctx context.Context, req resource.CreateRequest[resourceData ) req. - Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { ownerId = r.getOwnerId(ctx, req.Plan, db) }). Then(func() { owner := sql.EmptyDatabasePrincipalId @@ -65,7 +65,7 @@ func (r res) Create(ctx context.Context, req resource.CreateRequest[resourceData owner = ownerId.ObjectId } - schema = sql.CreateSchema(ctx, db, req.Plan.Name.Value, owner) + schema = sql.CreateSchema(ctx, db, req.Plan.Name.ValueString(), owner) }). Then(func() { resp.State = req.Plan.withSchemaData(ctx, schema) }) } @@ -79,8 +79,8 @@ func (r res) Update(ctx context.Context, req resource.UpdateRequest[resourceData ) req. - Then(func() { schemaId = common.ParseDbObjectId[sql.SchemaId](ctx, req.Plan.Id.Value) }). - Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { schemaId = common.ParseDbObjectId[sql.SchemaId](ctx, req.Plan.Id.ValueString()) }). + Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { ownerId = r.getOwnerId(ctx, req.Plan, db) }). Then(func() { schema = sql.GetSchema(ctx, db, schemaId.ObjectId) }). Then(func() { @@ -95,7 +95,7 @@ func (r res) Update(ctx context.Context, req resource.UpdateRequest[resourceData Then(func() { resp.State = req.Plan.withSchemaData(ctx, schema) }) } -func (r res) Delete(ctx context.Context, req resource.DeleteRequest[resourceData], resp *resource.DeleteResponse[resourceData]) { +func (r res) Delete(ctx context.Context, req resource.DeleteRequest[resourceData], _ *resource.DeleteResponse[resourceData]) { var ( db sql.Database schemaId common.DbObjectId[sql.SchemaId] @@ -103,8 +103,8 @@ func (r res) Delete(ctx context.Context, req resource.DeleteRequest[resourceData ) req. - Then(func() { schemaId = common.ParseDbObjectId[sql.SchemaId](ctx, req.State.Id.Value) }). - Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.Value) }). + Then(func() { schemaId = common.ParseDbObjectId[sql.SchemaId](ctx, req.State.Id.ValueString()) }). + Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.State.DatabaseId.ValueString()) }). Then(func() { schema = sql.GetSchema(ctx, db, schemaId.ObjectId) }). Then(func() { schema.Drop(ctx) }) } @@ -120,7 +120,9 @@ func (r res) getOwnerId(ctx context.Context, data resourceData, db sql.Database) ) utils.StopOnError(ctx). - Then(func() { ownerId = common.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, data.OwnerId.Value) }). + Then(func() { + ownerId = common.ParseDbObjectId[sql.GenericDatabasePrincipalId](ctx, data.OwnerId.ValueString()) + }). Then(func() { dbId = db.GetId(ctx) }). Then(func() { if ownerId.DbId != dbId { diff --git a/internal/services/script/data.go b/internal/services/script/data.go index 4dbc41b..db1ed63 100644 --- a/internal/services/script/data.go +++ b/internal/services/script/data.go @@ -56,11 +56,11 @@ func (d *dataSource) Read(ctx context.Context, req datasource.ReadRequest[dataSo ) req. - Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) }). - Then(func() { result = db.Query(ctx, req.Config.Query.Value) }). + Then(func() { db = common.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) }). + Then(func() { result = db.Query(ctx, req.Config.Query.ValueString()) }). Then(func() { req.Config.Result = result - req.Config.Id = types.String{Value: "query"} + req.Config.Id = types.StringValue("query") resp.SetState(req.Config) }) } diff --git a/internal/services/script/resource.go b/internal/services/script/resource.go index 5b09238..5bc35f4 100644 --- a/internal/services/script/resource.go +++ b/internal/services/script/resource.go @@ -79,17 +79,17 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat req. Then(func() { r.queryState(ctx, req.Conn, req.Plan) }). // report error if planned read script produces and error Then(func() { - script := req.Plan.UpdateScript.Value + script := req.Plan.UpdateScript.ValueString() if common.IsAttrSet(req.Plan.CreateScript) { - script = req.Plan.CreateScript.Value + script = req.Plan.CreateScript.ValueString() } r.execScript(ctx, req.Conn, script, req.Plan) }). Then(func() { resp.State = req.Plan - resp.State.Id = types.String{Value: "script"} + resp.State.Id = types.StringValue("script") }) } @@ -100,13 +100,13 @@ func (r *res) Update(ctx context.Context, req resource.UpdateRequest[resourceDat r.queryState(ctx, req.Conn, req.Plan) // report error if planned read script produces and error } }). - Then(func() { r.execScript(ctx, req.Conn, req.Plan.UpdateScript.Value, req.Plan) }). + Then(func() { r.execScript(ctx, req.Conn, req.Plan.UpdateScript.ValueString(), req.Plan) }). Then(func() { resp.State = req.Plan }) } func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceData], _ *resource.DeleteResponse[resourceData]) { if common.IsAttrSet(req.State.DeleteScript) { - req.Then(func() { r.execScript(ctx, req.Conn, req.State.DeleteScript.Value, req.State) }) + req.Then(func() { r.execScript(ctx, req.Conn, req.State.DeleteScript.ValueString(), req.State) }) } } @@ -114,7 +114,7 @@ func (r *res) execScript(ctx context.Context, conn sql.Connection, script string var db sql.Database utils.StopOnError(ctx). - Then(func() { db = common.GetResourceDb(ctx, conn, data.DatabaseId.Value) }). + Then(func() { db = common.GetResourceDb(ctx, conn, data.DatabaseId.ValueString()) }). Then(func() { db.Exec(ctx, script) }) } @@ -127,8 +127,8 @@ func (r *res) queryState(ctx context.Context, conn sql.Connection, data resource state := map[string]types.String{} utils.StopOnError(ctx). - Then(func() { db = common.GetResourceDb(ctx, conn, data.DatabaseId.Value) }). - Then(func() { queryRes = db.Query(ctx, data.ReadScript.Value) }). + Then(func() { db = common.GetResourceDb(ctx, conn, data.DatabaseId.ValueString()) }). + Then(func() { queryRes = db.Query(ctx, data.ReadScript.ValueString()) }). Then(func() { if len(queryRes) != 1 { utils.AddError(ctx, "Invalid read_script result", fmt.Errorf("expected 1 row, got %d", len(queryRes))) @@ -136,7 +136,7 @@ func (r *res) queryState(ctx context.Context, conn sql.Connection, data resource }). Then(func() { for name, val := range queryRes[0] { - state[name] = types.String{Value: val} + state[name] = types.StringValue(val) } }) diff --git a/internal/services/sqlLogin/base.go b/internal/services/sqlLogin/base.go index 94abb59..d7025bd 100644 --- a/internal/services/sqlLogin/base.go +++ b/internal/services/sqlLogin/base.go @@ -54,11 +54,11 @@ type dataSourceData struct { func (d dataSourceData) withSettings(settings sql.SqlLoginSettings) dataSourceData { return dataSourceData{ Id: d.Id, - Name: types.String{Value: settings.Name}, - MustChangePassword: types.Bool{Value: settings.MustChangePassword}, - DefaultDatabaseId: types.String{Value: fmt.Sprint(settings.DefaultDatabaseId)}, - DefaultLanguage: types.String{Value: settings.DefaultLanguage}, - CheckPasswordExpiration: types.Bool{Value: settings.CheckPasswordExpiration}, - CheckPasswordPolicy: types.Bool{Value: settings.CheckPasswordPolicy}, + Name: types.StringValue(settings.Name), + MustChangePassword: types.BoolValue(settings.MustChangePassword), + DefaultDatabaseId: types.StringValue(fmt.Sprint(settings.DefaultDatabaseId)), + DefaultLanguage: types.StringValue(settings.DefaultLanguage), + CheckPasswordExpiration: types.BoolValue(settings.CheckPasswordExpiration), + CheckPasswordPolicy: types.BoolValue(settings.CheckPasswordPolicy), } } diff --git a/internal/services/sqlLogin/data.go b/internal/services/sqlLogin/data.go index 4615a9c..17a04c6 100644 --- a/internal/services/sqlLogin/data.go +++ b/internal/services/sqlLogin/data.go @@ -36,15 +36,15 @@ func (d *dataSource) Read(ctx context.Context, req datasource.ReadRequest[dataSo req. Then(func() { - login = sql.GetSqlLoginByName(ctx, req.Conn, req.Config.Name.Value) + login = sql.GetSqlLoginByName(ctx, req.Conn, req.Config.Name.ValueString()) if login == nil || !login.Exists(ctx) { - utils.AddError(ctx, "Login does not exist", fmt.Errorf("could not find SQL Login '%s'", req.Config.Name.Value)) + utils.AddError(ctx, "Login does not exist", fmt.Errorf("could not find SQL Login '%s'", req.Config.Name.ValueString())) } }). Then(func() { state := req.Config.withSettings(login.GetSettings(ctx)) - state.Id = types.String{Value: fmt.Sprint(login.GetId(ctx))} + state.Id = types.StringValue(fmt.Sprint(login.GetId(ctx))) resp.SetState(state) }) diff --git a/internal/services/sqlLogin/list.go b/internal/services/sqlLogin/list.go index e65ef1f..ab07acd 100644 --- a/internal/services/sqlLogin/list.go +++ b/internal/services/sqlLogin/list.go @@ -51,12 +51,12 @@ func (l *listDataSource) Read(ctx context.Context, req datasource.ReadRequest[li Then(func() { logins = sql.GetSqlLogins(ctx, req.Conn) }). Then(func() { result := listDataSourceData{ - Id: types.String{Value: ""}, + Id: types.StringValue(""), } for id, login := range logins { s := login.GetSettings(ctx) - r := dataSourceData{Id: types.String{Value: fmt.Sprint(id)}} + r := dataSourceData{Id: types.StringValue(fmt.Sprint(id))} result.Logins = append(result.Logins, r.withSettings(s)) } diff --git a/internal/services/sqlLogin/resource.go b/internal/services/sqlLogin/resource.go index 38b665c..aceed77 100644 --- a/internal/services/sqlLogin/resource.go +++ b/internal/services/sqlLogin/resource.go @@ -27,8 +27,8 @@ type resourceData struct { func (d resourceData) toSettings(ctx context.Context) sql.SqlLoginSettings { var dbId int - if !d.DefaultDatabaseId.Null && !d.DefaultDatabaseId.Unknown { - if id, err := strconv.Atoi(d.DefaultDatabaseId.Value); err == nil { + if common2.IsAttrSet(d.DefaultDatabaseId) { + if id, err := strconv.Atoi(d.DefaultDatabaseId.ValueString()); err == nil { dbId = id } else { utils.AddError(ctx, "Failed to parse DB id", err) @@ -36,41 +36,41 @@ func (d resourceData) toSettings(ctx context.Context) sql.SqlLoginSettings { } return sql.SqlLoginSettings{ - Name: d.Name.Value, - Password: d.Password.Value, - MustChangePassword: d.MustChangePassword.Value, + Name: d.Name.ValueString(), + Password: d.Password.ValueString(), + MustChangePassword: d.MustChangePassword.ValueBool(), DefaultDatabaseId: sql.DatabaseId(dbId), - DefaultLanguage: d.DefaultLanguage.Value, - CheckPasswordExpiration: d.CheckPasswordExpiration.Value, - CheckPasswordPolicy: d.CheckPasswordPolicy.Value || d.CheckPasswordPolicy.Null || d.CheckPasswordPolicy.Unknown, + DefaultLanguage: d.DefaultLanguage.ValueString(), + CheckPasswordExpiration: d.CheckPasswordExpiration.ValueBool(), + CheckPasswordPolicy: d.CheckPasswordPolicy.ValueBool() || d.CheckPasswordPolicy.IsNull() || d.CheckPasswordPolicy.IsUnknown(), } } func (d resourceData) withSettings(settings sql.SqlLoginSettings, isAzure bool) resourceData { - d.Name = types.String{Value: settings.Name} + d.Name = types.StringValue(settings.Name) if isAzure { return d } if common2.IsAttrSet(d.MustChangePassword) { - d.MustChangePassword.Value = settings.MustChangePassword + d.MustChangePassword = types.BoolValue(settings.MustChangePassword) } if common2.IsAttrSet(d.DefaultDatabaseId) { - d.DefaultDatabaseId = types.String{Value: fmt.Sprint(settings.DefaultDatabaseId)} + d.DefaultDatabaseId = types.StringValue(fmt.Sprint(settings.DefaultDatabaseId)) } if common2.IsAttrSet(d.DefaultLanguage) { - d.DefaultLanguage = types.String{Value: settings.DefaultLanguage} + d.DefaultLanguage = types.StringValue(settings.DefaultLanguage) } if common2.IsAttrSet(d.CheckPasswordExpiration) { - d.CheckPasswordExpiration = types.Bool{Value: settings.CheckPasswordExpiration} + d.CheckPasswordExpiration = types.BoolValue(settings.CheckPasswordExpiration) } if common2.IsAttrSet(d.CheckPasswordPolicy) { - d.CheckPasswordPolicy = types.Bool{Value: settings.CheckPasswordPolicy} + d.CheckPasswordPolicy = types.BoolValue(settings.CheckPasswordPolicy) } return d @@ -141,7 +141,7 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat Then(func() { login = sql.CreateSqlLogin(ctx, req.Conn, req.Plan.toSettings(ctx)) }). Then(func() { resp.State = req.Plan.withSettings(login.GetSettings(ctx), req.Conn.IsAzure(ctx)) - resp.State.Id = types.String{Value: string(login.GetId(ctx))} + resp.State.Id = types.StringValue(string(login.GetId(ctx))) }) } @@ -152,7 +152,7 @@ func (r *res) Read(ctx context.Context, req resource.ReadRequest[resourceData], ) req. - Then(func() { login = sql.GetSqlLogin(ctx, req.Conn, sql.LoginId(req.State.Id.Value)) }). + Then(func() { login = sql.GetSqlLogin(ctx, req.Conn, sql.LoginId(req.State.Id.ValueString())) }). Then(func() { exists = login.Exists(ctx) }). Then(func() { if exists { @@ -165,7 +165,7 @@ func (r *res) Update(ctx context.Context, req resource.UpdateRequest[resourceDat var login sql.SqlLogin req. - Then(func() { login = sql.GetSqlLogin(ctx, req.Conn, sql.LoginId(req.Plan.Id.Value)) }). + Then(func() { login = sql.GetSqlLogin(ctx, req.Conn, sql.LoginId(req.Plan.Id.ValueString())) }). Then(func() { login.UpdateSettings(ctx, req.Plan.toSettings(ctx)) }). Then(func() { resp.State = req.Plan.withSettings(login.GetSettings(ctx), req.Conn.IsAzure(ctx)) }) } @@ -174,6 +174,6 @@ func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceDat var login sql.SqlLogin req. - Then(func() { login = sql.GetSqlLogin(ctx, req.Conn, sql.LoginId(req.State.Id.Value)) }). + Then(func() { login = sql.GetSqlLogin(ctx, req.Conn, sql.LoginId(req.State.Id.ValueString())) }). Then(func() { login.Drop(ctx) }) } diff --git a/internal/services/sqlUser/base.go b/internal/services/sqlUser/base.go index b31d78b..06801d1 100644 --- a/internal/services/sqlUser/base.go +++ b/internal/services/sqlUser/base.go @@ -35,8 +35,8 @@ type resourceData struct { func (d resourceData) toSettings() sql.UserSettings { return sql.UserSettings{ - Name: d.Name.Value, - LoginId: sql.LoginId(d.LoginId.Value), + Name: d.Name.ValueString(), + LoginId: sql.LoginId(d.LoginId.ValueString()), Type: sql.USER_TYPE_SQL, } } @@ -45,15 +45,15 @@ func (d resourceData) withSettings(settings sql.UserSettings) resourceData { return resourceData{ Id: d.Id, DatabaseId: d.DatabaseId, - Name: types.String{Value: settings.Name}, - LoginId: types.String{Value: fmt.Sprint(settings.LoginId)}, + Name: types.StringValue(settings.Name), + LoginId: types.StringValue(fmt.Sprint(settings.LoginId)), } } func (d resourceData) withIds(dbId sql.DatabaseId, userId sql.UserId) resourceData { return resourceData{ - Id: types.String{Value: fmt.Sprintf("%v/%v", dbId, userId)}, - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, + Id: types.StringValue(fmt.Sprintf("%v/%v", dbId, userId)), + DatabaseId: types.StringValue(fmt.Sprint(dbId)), Name: d.Name, LoginId: d.LoginId, } diff --git a/internal/services/sqlUser/data.go b/internal/services/sqlUser/data.go index 3b4432e..d83b556 100644 --- a/internal/services/sqlUser/data.go +++ b/internal/services/sqlUser/data.go @@ -35,8 +35,8 @@ func (d *dataSource) Read(ctx context.Context, req datasource.ReadRequest[resour var user sql.User req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) }). - Then(func() { user = sql.GetUserByName(ctx, db, req.Config.Name.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) }). + Then(func() { user = sql.GetUserByName(ctx, db, req.Config.Name.ValueString()) }). Then(func() { state := req.Config.withIds(db.GetId(ctx), user.GetId(ctx)) resp.SetState(state.withSettings(user.GetSettings(ctx))) diff --git a/internal/services/sqlUser/list.go b/internal/services/sqlUser/list.go index b1054c2..3312ea2 100644 --- a/internal/services/sqlUser/list.go +++ b/internal/services/sqlUser/list.go @@ -58,11 +58,11 @@ func (l *listDataSource) Read(ctx context.Context, req datasource.ReadRequest[li var dbId sql.DatabaseId req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Config.DatabaseId.ValueString()) }). Then(func() { dbId = db.GetId(ctx) }). Then(func() { state := listDataSourceData{ - DatabaseId: types.String{Value: fmt.Sprint(dbId)}, + DatabaseId: types.StringValue(fmt.Sprint(dbId)), } state.Id = state.DatabaseId diff --git a/internal/services/sqlUser/resource.go b/internal/services/sqlUser/resource.go index 17972c9..afcee48 100644 --- a/internal/services/sqlUser/resource.go +++ b/internal/services/sqlUser/resource.go @@ -37,7 +37,7 @@ func (r *res) Create(ctx context.Context, req resource.CreateRequest[resourceDat ) req. - Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.Value) }). + Then(func() { db = common2.GetResourceDb(ctx, req.Conn, req.Plan.DatabaseId.ValueString()) }). Then(func() { user = sql.CreateUser(ctx, db, req.Plan.toSettings()) }). Then(func() { resp.State = req.Plan.withIds(db.GetId(ctx), user.GetId(ctx)) }) } @@ -71,7 +71,7 @@ func (r *res) Delete(ctx context.Context, req resource.DeleteRequest[resourceDat } func getUser(ctx context.Context, conn sql.Connection, data resourceData) sql.User { - idSegments := strings.Split(data.Id.Value, "/") + idSegments := strings.Split(data.Id.ValueString(), "/") id, err := strconv.Atoi(idSegments[1]) if err != nil { utils.AddError(ctx, "Error converting user ID", err) diff --git a/internal/validators/sqlIdentifier.go b/internal/validators/sqlIdentifier.go index 74cb06d..fccfbec 100644 --- a/internal/validators/sqlIdentifier.go +++ b/internal/validators/sqlIdentifier.go @@ -3,6 +3,7 @@ package validators import ( "context" "fmt" + "github.com/PGSSoft/terraform-provider-mssql/internal/services/common" "regexp" "github.com/hashicorp/terraform-plugin-framework/tfsdk" @@ -29,15 +30,15 @@ func (s sqlIdentifierValidator) Validate(ctx context.Context, request tfsdk.Vali return } - if str.Unknown || str.Null { + if !common.IsAttrSet(str) { return } - if match, _ := regexp.Match("^[a-zA-Z_@#][a-zA-Z\\d@$#_-]*$", []byte(str.Value)); !match { + if match, _ := regexp.Match("^[a-zA-Z_@#][a-zA-Z\\d@$#_-]*$", []byte(str.ValueString())); !match { response.Diagnostics.AddAttributeError( request.AttributePath, "Invalid SQL identifier", - fmt.Sprintf("%s, got: %s", s.Description(ctx), str.Value), + fmt.Sprintf("%s, got: %s", s.Description(ctx), str.ValueString()), ) } } diff --git a/internal/validators/sqlIdentifier_test.go b/internal/validators/sqlIdentifier_test.go index 60da6cd..6201e1b 100644 --- a/internal/validators/sqlIdentifier_test.go +++ b/internal/validators/sqlIdentifier_test.go @@ -10,28 +10,28 @@ func TestSqlIdentifierValidate(t *testing.T) { testCases := map[string]validatorTestCase{ "Wrong type": { - val: types.Int64{Value: 2}, + val: types.Int64Value(2), expectedSummary: "Value Conversion Error", }, "Unknown": { - val: types.String{Unknown: true}, + val: types.StringUnknown(), }, "Null": { - val: types.String{Null: true}, + val: types.StringNull(), }, "Valid": { - val: types.String{Value: "_idenTif@$#_er"}, + val: types.StringValue("_idenTif@$#_er"), }, "startingWithDigit": { - val: types.String{Value: "2ndIdentifier"}, + val: types.StringValue("2ndIdentifier"), expectedSummary: validationErrSummary, }, "withSpace": { - val: types.String{Value: "has space"}, + val: types.StringValue("has space"), expectedSummary: validationErrSummary, }, "forbiddenChar": { - val: types.String{Value: "has&inName"}, + val: types.StringValue("has&inName"), expectedSummary: validationErrSummary, }, } diff --git a/internal/validators/stringLength.go b/internal/validators/stringLength.go index c251976..842af17 100644 --- a/internal/validators/stringLength.go +++ b/internal/validators/stringLength.go @@ -3,6 +3,7 @@ package validators import ( "context" "fmt" + "github.com/PGSSoft/terraform-provider-mssql/internal/services/common" "github.com/hashicorp/terraform-plugin-framework/tfsdk" "github.com/hashicorp/terraform-plugin-framework/types" ) @@ -30,17 +31,17 @@ func (s stringLengthValidator) Validate(ctx context.Context, request tfsdk.Valid return } - if str.Unknown || str.Null { + if !common.IsAttrSet(str) { return } - strLen := len(str.Value) + strLen := len(str.ValueString()) if strLen < s.Min || strLen > s.Max { response.Diagnostics.AddAttributeError( request.AttributePath, "Invalid String Length", - fmt.Sprintf("%s, got: %s (%d).", s.Description(ctx), str.Value, strLen), + fmt.Sprintf("%s, got: %s (%d).", s.Description(ctx), str.ValueString(), strLen), ) } } diff --git a/internal/validators/stringLength_test.go b/internal/validators/stringLength_test.go index 9b3af6c..d679043 100644 --- a/internal/validators/stringLength_test.go +++ b/internal/validators/stringLength_test.go @@ -10,24 +10,24 @@ func TestStringLengthValidate(t *testing.T) { testCases := map[string]validatorTestCase{ "Wrong type": { - val: types.Int64{Value: 2}, + val: types.Int64Value(2), expectedSummary: "Value Conversion Error", }, "Unknown": { - val: types.String{Unknown: true}, + val: types.StringUnknown(), }, "Null": { - val: types.String{Null: true}, + val: types.StringNull(), }, "Valid": { - val: types.String{Value: "xxxxx"}, + val: types.StringValue("xxxxx"), }, "TooShort": { - val: types.String{Value: "xx"}, + val: types.StringValue("xx"), expectedSummary: validationErrSummary, }, "TooLong": { - val: types.String{Value: "xxxxx xxxxx"}, + val: types.StringValue("xxxxx xxxxx"), expectedSummary: validationErrSummary, }, }