From fcd61a36761fb1008465d47092d36cb009290221 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 17 Aug 2022 14:16:24 -0700 Subject: [PATCH 1/6] Adds basic support for Comprehend Document Classifier --- internal/provider/provider.go | 3 +- internal/service/comprehend/acc_test.go | 9 + internal/service/comprehend/common_model.go | 121 ++ internal/service/comprehend/consts.go | 9 +- .../service/comprehend/document_classifier.go | 749 +++++++ .../comprehend/document_classifier_test.go | 1879 +++++++++++++++++ .../service/comprehend/entity_recognizer.go | 155 +- .../comprehend/entity_recognizer_test.go | 8 - internal/service/comprehend/generate.go | 1 + .../document_classifier/documents.csv | 100 + .../generate/document_classifier/main.go | 76 + internal/tfresource/retry.go | 94 +- internal/tfresource/retry_test.go | 58 +- ...mprehend_document_classifier.html.markdown | 123 ++ ...comprehend_entity_recognizer.html.markdown | 8 +- 15 files changed, 3224 insertions(+), 169 deletions(-) create mode 100644 internal/service/comprehend/document_classifier.go create mode 100644 internal/service/comprehend/document_classifier_test.go create mode 100644 internal/service/comprehend/test-fixtures/document_classifier/documents.csv create mode 100644 internal/service/comprehend/test-fixtures/generate/document_classifier/main.go create mode 100644 website/docs/r/comprehend_document_classifier.html.markdown diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 7293a2d3377..bfa06a13248 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1177,7 +1177,8 @@ func New(_ context.Context) (*schema.Provider, error) { "aws_cognito_user_pool_domain": cognitoidp.ResourceUserPoolDomain(), "aws_cognito_user_pool_ui_customization": cognitoidp.ResourceUserPoolUICustomization(), - "aws_comprehend_entity_recognizer": comprehend.ResourceEntityRecognizer(), + "aws_comprehend_document_classifier": comprehend.ResourceDocumentClassifier(), + "aws_comprehend_entity_recognizer": comprehend.ResourceEntityRecognizer(), "aws_config_aggregate_authorization": configservice.ResourceAggregateAuthorization(), "aws_config_config_rule": configservice.ResourceConfigRule(), diff --git a/internal/service/comprehend/acc_test.go b/internal/service/comprehend/acc_test.go index f52bdec3e90..f8f9a36f174 100644 --- a/internal/service/comprehend/acc_test.go +++ b/internal/service/comprehend/acc_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/service/comprehend" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-provider-aws/internal/acctest" "github.com/hashicorp/terraform-provider-aws/internal/conns" ) @@ -56,3 +57,11 @@ resource "aws_subnet" "test" { `, rName, subnetCount), ) } + +func uniqueIDPattern() string { + return prefixedUniqueIDPattern(resource.UniqueIdPrefix) +} + +func prefixedUniqueIDPattern(prefix string) string { + return fmt.Sprintf("%s[[:xdigit:]]{%d}", prefix, resource.UniqueIDSuffixLength) +} diff --git a/internal/service/comprehend/common_model.go b/internal/service/comprehend/common_model.go index 7fa0c1022f9..034314bdcc9 100644 --- a/internal/service/comprehend/common_model.go +++ b/internal/service/comprehend/common_model.go @@ -7,8 +7,11 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/comprehend/types" "github.com/aws/aws-sdk-go/service/ec2" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-provider-aws/internal/flex" tfec2 "github.com/hashicorp/terraform-provider-aws/internal/service/ec2" ) @@ -93,3 +96,121 @@ func statusNetworkInterfaces(ctx context.Context, conn *ec2.EC2, initialENIs map return added, aws.ToString(added.Status), nil } } + +type resourceGetter interface { + Get(key string) any +} + +func flattenVPCConfig(apiObject *types.VpcConfig) []interface{} { + if apiObject == nil { + return nil + } + + m := map[string]interface{}{ + "security_group_ids": flex.FlattenStringValueSet(apiObject.SecurityGroupIds), + "subnets": flex.FlattenStringValueSet(apiObject.Subnets), + } + + return []interface{}{m} +} + +func expandVPCConfig(tfList []interface{}) *types.VpcConfig { + if len(tfList) == 0 { + return nil + } + + tfMap := tfList[0].(map[string]interface{}) + + a := &types.VpcConfig{ + SecurityGroupIds: flex.ExpandStringValueSet(tfMap["security_group_ids"].(*schema.Set)), + Subnets: flex.ExpandStringValueSet(tfMap["subnets"].(*schema.Set)), + } + + return a +} + +func flattenAugmentedManifests(apiObjects []types.AugmentedManifestsListItem) []interface{} { + if len(apiObjects) == 0 { + return nil + } + + var l []interface{} + + for _, apiObject := range apiObjects { + l = append(l, flattenAugmentedManifestsListItem(&apiObject)) + } + + return l +} + +func flattenAugmentedManifestsListItem(apiObject *types.AugmentedManifestsListItem) map[string]interface{} { + if apiObject == nil { + return nil + } + + m := map[string]interface{}{ + "attribute_names": flex.FlattenStringValueList(apiObject.AttributeNames), + "s3_uri": aws.ToString(apiObject.S3Uri), + "document_type": apiObject.DocumentType, + "split": apiObject.Split, + } + + if v := apiObject.AnnotationDataS3Uri; v != nil { + m["annotation_data_s3_uri"] = aws.ToString(v) + } + + if v := apiObject.SourceDocumentsS3Uri; v != nil { + m["source_documents_s3_uri"] = aws.ToString(v) + } + + return m +} + +func expandAugmentedManifests(tfSet *schema.Set) []types.AugmentedManifestsListItem { + if tfSet.Len() == 0 { + return nil + } + + var s []types.AugmentedManifestsListItem + + for _, r := range tfSet.List() { + m, ok := r.(map[string]interface{}) + + if !ok { + continue + } + + a := expandAugmentedManifestsListItem(m) + + if a == nil { + continue + } + + s = append(s, *a) + } + + return s +} + +func expandAugmentedManifestsListItem(tfMap map[string]interface{}) *types.AugmentedManifestsListItem { + if tfMap == nil { + return nil + } + + a := &types.AugmentedManifestsListItem{ + AttributeNames: flex.ExpandStringValueList(tfMap["attribute_names"].([]interface{})), + S3Uri: aws.String(tfMap["s3_uri"].(string)), + DocumentType: types.AugmentedManifestsDocumentTypeFormat(tfMap["document_type"].(string)), + Split: types.Split(tfMap["split"].(string)), + } + + if v, ok := tfMap["annotation_data_s3_uri"].(string); ok && v != "" { + a.AnnotationDataS3Uri = aws.String(v) + } + + if v, ok := tfMap["source_documents_s3_uri"].(string); ok && v != "" { + a.SourceDocumentsS3Uri = aws.String(v) + } + + return a +} diff --git a/internal/service/comprehend/consts.go b/internal/service/comprehend/consts.go index 8c926987a46..e57884a12d2 100644 --- a/internal/service/comprehend/consts.go +++ b/internal/service/comprehend/consts.go @@ -7,5 +7,12 @@ import ( const iamPropagationTimeout = 2 * time.Minute // Avoid service throttling -const entityRegcognizerDelay = 1 * time.Minute +const entityRegcognizerCreatedDelay = 10 * time.Minute +const entityRegcognizerStoppedDelay = 0 +const entityRegcognizerDeletedDelay = 5 * time.Minute const entityRegcognizerPollInterval = 1 * time.Minute + +const documentClassifierCreatedDelay = 15 * time.Minute +const documentClassifierStoppedDelay = 0 +const documentClassifierDeletedDelay = 5 * time.Minute +const documentClassifierPollInterval = 1 * time.Minute diff --git a/internal/service/comprehend/document_classifier.go b/internal/service/comprehend/document_classifier.go new file mode 100644 index 00000000000..4d29a6a5338 --- /dev/null +++ b/internal/service/comprehend/document_classifier.go @@ -0,0 +1,749 @@ +package comprehend + +import ( + "context" + "errors" + "fmt" + "log" + "reflect" + "regexp" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/aws/ratelimit" + "github.com/aws/aws-sdk-go-v2/service/comprehend" + "github.com/aws/aws-sdk-go-v2/service/comprehend/types" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-provider-aws/internal/conns" + "github.com/hashicorp/terraform-provider-aws/internal/create" + awsdiag "github.com/hashicorp/terraform-provider-aws/internal/diag" + "github.com/hashicorp/terraform-provider-aws/internal/enum" + tfec2 "github.com/hashicorp/terraform-provider-aws/internal/service/ec2" + tftags "github.com/hashicorp/terraform-provider-aws/internal/tags" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" + "github.com/hashicorp/terraform-provider-aws/internal/verify" +) + +const ( + documentClassifierTagKey = "tf-aws_comprehend_document_classifier" +) + +func ResourceDocumentClassifier() *schema.Resource { + return &schema.Resource{ + CreateWithoutTimeout: resourceDocumentClassifierCreate, + ReadWithoutTimeout: resourceDocumentClassifierRead, + UpdateWithoutTimeout: resourceDocumentClassifierUpdate, + DeleteWithoutTimeout: resourceDocumentClassifierDelete, + + Importer: &schema.ResourceImporter{ + StateContext: schema.ImportStatePassthroughContext, + }, + + Timeouts: &schema.ResourceTimeout{ + Create: schema.DefaultTimeout(60 * time.Minute), + Update: schema.DefaultTimeout(60 * time.Minute), + Delete: schema.DefaultTimeout(30 * time.Minute), + }, + + Schema: map[string]*schema.Schema{ + "arn": { + Type: schema.TypeString, + Computed: true, + }, + "data_access_role_arn": { + Type: schema.TypeString, + Required: true, + }, + "input_data_config": { + Type: schema.TypeList, + Required: true, + MaxItems: 1, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "augmented_manifests": { + Type: schema.TypeSet, + Optional: true, + ExactlyOneOf: []string{"input_data_config.0.augmented_manifests", "input_data_config.0.s3_uri"}, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "annotation_data_s3_uri": { + Type: schema.TypeString, + Optional: true, + }, + "attribute_names": { + Type: schema.TypeList, + Required: true, + Elem: &schema.Schema{Type: schema.TypeString}, + }, + "document_type": { + Type: schema.TypeString, + Optional: true, + ValidateDiagFunc: enum.Validate[types.AugmentedManifestsDocumentTypeFormat](), + Default: types.AugmentedManifestsDocumentTypeFormatPlainTextDocument, + }, + "s3_uri": { + Type: schema.TypeString, + Required: true, + }, + "source_documents_s3_uri": { + Type: schema.TypeString, + Optional: true, + }, + "split": { + Type: schema.TypeString, + Optional: true, + ValidateDiagFunc: enum.Validate[types.Split](), + Default: types.SplitTrain, + }, + }, + }, + }, + "data_format": { + Type: schema.TypeString, + Optional: true, + ValidateDiagFunc: enum.Validate[types.DocumentClassifierDataFormat](), + Default: types.DocumentClassifierDataFormatComprehendCsv, + }, + // "label_delimiter":{ + // Type:schema.TypeString, + // Optional: true, + // ValidateDiagFunc: , + // Default: "|", + // }, + "s3_uri": { + Type: schema.TypeString, + Optional: true, + }, + "test_s3_uri": { + Type: schema.TypeString, + Optional: true, + }, + }, + }, + }, + "language_code": { + Type: schema.TypeString, + Required: true, + ValidateDiagFunc: enum.Validate[types.SyntaxLanguageCode](), + }, + "model_kms_key_id": { + Type: schema.TypeString, + Optional: true, + DiffSuppressFunc: diffSuppressKMSKeyId, + ValidateFunc: validateKMSKey, + }, + "name": { + Type: schema.TypeString, + Required: true, + ValidateFunc: validModelName, + }, + // "output_data_config" + "tags": tftags.TagsSchema(), + "tags_all": tftags.TagsSchemaComputed(), + "version_name": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ValidateFunc: validModelVersionName, + ConflictsWith: []string{"version_name_prefix"}, + }, + "version_name_prefix": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ValidateFunc: validModelVersionNamePrefix, + ConflictsWith: []string{"version_name"}, + }, + "volume_kms_key_id": { + Type: schema.TypeString, + Optional: true, + DiffSuppressFunc: diffSuppressKMSKeyId, + ValidateFunc: validateKMSKey, + }, + "vpc_config": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "security_group_ids": { + Type: schema.TypeSet, + Required: true, + Elem: &schema.Schema{Type: schema.TypeString}, + }, + "subnets": { + Type: schema.TypeSet, + Required: true, + Elem: &schema.Schema{Type: schema.TypeString}, + }, + }, + }, + }, + }, + + CustomizeDiff: customdiff.All( + verify.SetTagsDiff, + func(_ context.Context, diff *schema.ResourceDiff, _ interface{}) error { + tfMap := getDocumentClassifierInputDataConfig(diff) + if tfMap == nil { + return nil + } + + if format := types.DocumentClassifierDataFormat(tfMap["data_format"].(string)); format == types.DocumentClassifierDataFormatComprehendCsv { + if tfMap["s3_uri"] == nil { + return fmt.Errorf("s3_uri must be set when data_format is %s", format) + } + } else { + if tfMap["augmented_manifests"] == nil { + return fmt.Errorf("augmented_manifests must be set when data_format is %s", format) + } + } + + return nil + }, + ), + } +} + +func resourceDocumentClassifierCreate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + awsClient := meta.(*conns.AWSClient) + conn := awsClient.ComprehendConn + + var versionName *string + raw := d.GetRawConfig().GetAttr("version_name") + if raw.IsNull() { + versionName = aws.String(create.Name("", d.Get("version_name_prefix").(string))) + } else if v := raw.AsString(); v != "" { + versionName = aws.String(v) + } + + diags := documentClassifierPublishVersion(ctx, conn, d, versionName, create.ErrActionCreating, d.Timeout(schema.TimeoutCreate), awsClient) + if diags.HasError() { + return diags + } + + return append(diags, resourceDocumentClassifierRead(ctx, d, meta)...) +} + +func resourceDocumentClassifierRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + conn := meta.(*conns.AWSClient).ComprehendConn + + out, err := FindDocumentClassifierByID(ctx, conn, d.Id()) + + if !d.IsNewResource() && tfresource.NotFound(err) { + log.Printf("[WARN] Comprehend Document Classifier (%s) not found, removing from state", d.Id()) + d.SetId("") + return nil + } + + if err != nil { + return diag.Errorf("reading Comprehend Document Classifier (%s): %s", d.Id(), err) + } + + d.Set("arn", out.DocumentClassifierArn) + d.Set("data_access_role_arn", out.DataAccessRoleArn) + d.Set("language_code", out.LanguageCode) + d.Set("model_kms_key_id", out.ModelKmsKeyId) + d.Set("version_name", out.VersionName) + d.Set("version_name_prefix", create.NamePrefixFromName(aws.ToString(out.VersionName))) + d.Set("volume_kms_key_id", out.VolumeKmsKeyId) + + // DescribeDocumentClassifier() doesn't return the model name + name, err := DocumentClassifierParseARN(aws.ToString(out.DocumentClassifierArn)) + if err != nil { + return diag.Errorf("reading Comprehend Document Classifier (%s): %s", d.Id(), err) + } + d.Set("name", name) + + if err := d.Set("input_data_config", flattenDocumentClassifierInputDataConfig(out.InputDataConfig)); err != nil { + return diag.Errorf("setting input_data_config: %s", err) + } + + if err := d.Set("vpc_config", flattenVPCConfig(out.VpcConfig)); err != nil { + return diag.Errorf("setting vpc_config: %s", err) + } + + tags, err := ListTags(ctx, conn, d.Id()) + if err != nil { + return diag.Errorf("listing tags for Comprehend Document Classifier (%s): %s", d.Id(), err) + } + + defaultTagsConfig := meta.(*conns.AWSClient).DefaultTagsConfig + ignoreTagsConfig := meta.(*conns.AWSClient).IgnoreTagsConfig + tags = tags.IgnoreAWS().IgnoreConfig(ignoreTagsConfig) + + if err := d.Set("tags", tags.RemoveDefaultConfig(defaultTagsConfig).Map()); err != nil { + return diag.Errorf("setting tags: %s", err) + } + + if err := d.Set("tags_all", tags.Map()); err != nil { + return diag.Errorf("setting tags_all: %s", err) + } + + return nil +} + +func resourceDocumentClassifierUpdate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + awsClient := meta.(*conns.AWSClient) + conn := awsClient.ComprehendConn + + var diags diag.Diagnostics + + if d.HasChangesExcept("tags", "tags_all") { + var versionName *string + if d.HasChange("version_name") { + versionName = aws.String(d.Get("version_name").(string)) + } else if v := d.Get("version_name_prefix").(string); v != "" { + versionName = aws.String(create.Name("", d.Get("version_name_prefix").(string))) + } + + diags := documentClassifierPublishVersion(ctx, conn, d, versionName, create.ErrActionUpdating, d.Timeout(schema.TimeoutUpdate), awsClient) + if diags.HasError() { + return diags + } + } else if d.HasChange("tags_all") { + // For a tags-only change. If tag changes are combined with version publishing, the tags are set + // by the CreateDocumentClassifier call + o, n := d.GetChange("tags_all") + + if err := UpdateTags(ctx, conn, d.Id(), o, n); err != nil { + return awsdiag.AppendErrorf(diags, "updating tags for Comprehend Document Classifier (%s): %s", d.Id(), err) + } + } + + return append(diags, resourceDocumentClassifierRead(ctx, d, meta)...) +} + +func resourceDocumentClassifierDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + conn := meta.(*conns.AWSClient).ComprehendConn + + log.Printf("[INFO] Stopping Comprehend Document Classifier (%s)", d.Id()) + + _, err := conn.StopTrainingDocumentClassifier(ctx, &comprehend.StopTrainingDocumentClassifierInput{ + DocumentClassifierArn: aws.String(d.Id()), + }) + if err != nil { + var nfe *types.ResourceNotFoundException + if errors.As(err, &nfe) { + return nil + } + + return diag.Errorf("stopping Comprehend Document Classifier (%s): %s", d.Id(), err) + } + + if _, err := waitDocumentClassifierStopped(ctx, conn, d.Id(), d.Timeout(schema.TimeoutDelete)); err != nil { + var nfe *types.ResourceNotFoundException + if errors.As(err, &nfe) { + return nil + } + + return diag.Errorf("waiting for Comprehend Document Classifier (%s) to be stopped: %s", d.Id(), err) + } + + name, err := DocumentClassifierParseARN(d.Id()) + if err != nil { + return diag.Errorf("deleting Comprehend Document Classifier (%s): %s", d.Id(), err) + } + + log.Printf("[INFO] Deleting Comprehend Document Classifier (%s)", name) + + versions, err := ListDocumentClassifierVersionsByName(ctx, conn, name) + if err != nil { + return diag.Errorf("deleting Comprehend Document Classifier (%s): %s", name, err) + } + + var g multierror.Group + for _, v := range versions { + v := v + g.Go(func() error { + _, err = conn.DeleteDocumentClassifier(ctx, &comprehend.DeleteDocumentClassifierInput{ + DocumentClassifierArn: v.DocumentClassifierArn, + }) + if err != nil { + var nfe *types.ResourceNotFoundException + if !errors.As(err, &nfe) { + return fmt.Errorf("deleting version (%s): %w", aws.ToString(v.VersionName), err) + } + } + + if _, err := waitDocumentClassifierDeleted(ctx, conn, aws.ToString(v.DocumentClassifierArn), d.Timeout(schema.TimeoutDelete)); err != nil { + return fmt.Errorf("waiting for version (%s) to be deleted: %s", aws.ToString(v.VersionName), err) + } + + ec2Conn := meta.(*conns.AWSClient).EC2Conn + networkInterfaces, err := tfec2.FindNetworkInterfacesWithContext(ctx, ec2Conn, &ec2.DescribeNetworkInterfacesInput{ + Filters: []*ec2.Filter{ + tfec2.NewFilter(fmt.Sprintf("tag:%s", documentClassifierTagKey), []string{aws.ToString(v.DocumentClassifierArn)}), + }, + }) + if err != nil { + return fmt.Errorf("finding ENIs for version (%s): %w", aws.ToString(v.VersionName), err) + } + + for _, v := range networkInterfaces { + v := v + g.Go(func() error { + networkInterfaceID := aws.ToString(v.NetworkInterfaceId) + + if v.Attachment != nil { + err = tfec2.DetachNetworkInterfaceWithContext(ctx, ec2Conn, networkInterfaceID, aws.ToString(v.Attachment.AttachmentId), d.Timeout(schema.TimeoutDelete)) + + if err != nil { + return fmt.Errorf("detaching ENI (%s): %w", networkInterfaceID, err) + } + } + + err = tfec2.DeleteNetworkInterfaceWithContext(ctx, ec2Conn, networkInterfaceID) + if err != nil { + return fmt.Errorf("deleting ENI (%s): %w", networkInterfaceID, err) + } + + return nil + }) + } + + return nil + }) + } + + if err := g.Wait(); err != nil { + return diag.Errorf("deleting Comprehend Document Classifier (%s): %s", name, err) + } + + return nil +} + +func fullTypeName(i interface{}) string { + return fullValueTypeName(reflect.ValueOf(i)) +} + +func fullValueTypeName(v reflect.Value) string { + if v.Kind() == reflect.Ptr { + return "*" + fullValueTypeName(reflect.Indirect(v)) + } + + requestType := v.Type() + return fmt.Sprintf("%s.%s", requestType.PkgPath(), requestType.Name()) +} + +func documentClassifierPublishVersion(ctx context.Context, conn *comprehend.Client, d *schema.ResourceData, versionName *string, action string, timeout time.Duration, awsClient *conns.AWSClient) diag.Diagnostics { + in := &comprehend.CreateDocumentClassifierInput{ + DataAccessRoleArn: aws.String(d.Get("data_access_role_arn").(string)), + InputDataConfig: expandDocumentClassifierInputDataConfig(getDocumentClassifierInputDataConfig(d)), + LanguageCode: types.LanguageCode(d.Get("language_code").(string)), + DocumentClassifierName: aws.String(d.Get("name").(string)), + VersionName: versionName, + VpcConfig: expandVPCConfig(d.Get("vpc_config").([]interface{})), + ClientRequestToken: aws.String(resource.UniqueId()), + } + + if v, ok := d.Get("model_kms_key_id").(string); ok && v != "" { + in.ModelKmsKeyId = aws.String(v) + } + + if v, ok := d.Get("volume_kms_key_id").(string); ok && v != "" { + in.VolumeKmsKeyId = aws.String(v) + } + + defaultTagsConfig := awsClient.DefaultTagsConfig + tags := defaultTagsConfig.MergeTags(tftags.New(d.Get("tags").(map[string]interface{}))) + + if len(tags) > 0 { + in.Tags = Tags(tags.IgnoreAWS()) + } + + // Because the IAM credentials aren't evaluated until training time, we need to ensure we wait for the IAM propagation delay + time.Sleep(iamPropagationTimeout) + + if in.VpcConfig != nil { + modelVPCENILock.Lock() + defer modelVPCENILock.Unlock() + } + + var out *comprehend.CreateDocumentClassifierOutput + err := tfresource.RetryContext(ctx, timeout, func() *resource.RetryError { + var err error + out, err = conn.CreateDocumentClassifier(ctx, in) + + if err != nil { + var tmre *types.TooManyRequestsException + var qee ratelimit.QuotaExceededError // This is not a typo: the ratelimit.QuotaExceededError is returned as a struct, not a pointer + if errors.As(err, &tmre) { + return resource.RetryableError(err) + } else if errors.As(err, &qee) { + // Unable to get a rate limit token + return resource.RetryableError(err) + } else { + return resource.NonRetryableError(err) + } + } + + return nil + }, tfresource.WithPollInterval(documentClassifierPollInterval)) + if tfresource.TimedOut(err) { + out, err = conn.CreateDocumentClassifier(ctx, in) + } + if err != nil { + return diag.Errorf("%s Amazon Comprehend Document Classifier (%s): %s", action, d.Get("name").(string), err) + } + + if out == nil || out.DocumentClassifierArn == nil { + return diag.Errorf("%s Amazon Comprehend Document Classifier (%s): empty output", action, d.Get("name").(string)) + } + + d.SetId(aws.ToString(out.DocumentClassifierArn)) + + var g multierror.Group + waitCtx, cancel := context.WithCancel(ctx) + + g.Go(func() error { + _, err := waitDocumentClassifierCreated(waitCtx, conn, d.Id(), timeout) + cancel() + return err + }) + + var diags diag.Diagnostics + var tobe string + if action == create.ErrActionCreating { + tobe = "to be created" + } else if action == create.ErrActionUpdating { + tobe = "to be updated" + } else { + tobe = "to complete action" + } + + if in.VpcConfig != nil { + g.Go(func() error { + ec2Conn := awsClient.EC2Conn + enis, err := findNetworkInterfaces(waitCtx, ec2Conn, in.VpcConfig.SecurityGroupIds, in.VpcConfig.Subnets) + if err != nil { + diags = awsdiag.AppendWarningf(diags, "waiting for Amazon Comprehend Document Classifier (%s) %s: %s", d.Id(), tobe, err) + return nil + } + initialENIIds := make(map[string]bool, len(enis)) + for _, v := range enis { + initialENIIds[aws.ToString(v.NetworkInterfaceId)] = true + } + + newENI, err := waitNetworkInterfaceCreated(waitCtx, ec2Conn, initialENIIds, in.VpcConfig.SecurityGroupIds, in.VpcConfig.Subnets, d.Timeout(schema.TimeoutCreate)) + if errors.Is(err, context.Canceled) { + diags = awsdiag.AppendWarningf(diags, "waiting for Amazon Comprehend Document Classifier (%s) %s: %s", d.Id(), tobe, "ENI not found") + return nil + } + if err != nil { + diags = awsdiag.AppendWarningf(diags, "waiting for Amazon Comprehend Document Classifier (%s) %s: %s", d.Id(), tobe, err) + return nil + } + + modelVPCENILock.Unlock() + + _, err = ec2Conn.CreateTagsWithContext(waitCtx, &ec2.CreateTagsInput{ + Resources: []*string{newENI.NetworkInterfaceId}, + Tags: []*ec2.Tag{ + { + Key: aws.String(documentClassifierTagKey), + Value: aws.String(d.Id()), + }, + }, + }) + if err != nil { + diags = awsdiag.AppendWarningf(diags, "waiting for Amazon Comprehend Document Classifier (%s) %s: %s", d.Id(), tobe, err) + return nil + } + + return nil + }) + } + + err = g.Wait().ErrorOrNil() + if err != nil { + diags = awsdiag.AppendErrorf(diags, "waiting for Amazon Comprehend Document Classifier (%s) %s: %s", d.Id(), tobe, err) + } + + return diags +} + +func FindDocumentClassifierByID(ctx context.Context, conn *comprehend.Client, id string) (*types.DocumentClassifierProperties, error) { + in := &comprehend.DescribeDocumentClassifierInput{ + DocumentClassifierArn: aws.String(id), + } + + out, err := conn.DescribeDocumentClassifier(ctx, in) + if err != nil { + var nfe *types.ResourceNotFoundException + if errors.As(err, &nfe) { + return nil, &resource.NotFoundError{ + LastError: err, + LastRequest: in, + } + } + + return nil, err + } + + if out == nil || out.DocumentClassifierProperties == nil { + return nil, tfresource.NewEmptyResultError(in) + } + + return out.DocumentClassifierProperties, nil +} + +func ListDocumentClassifierVersionsByName(ctx context.Context, conn *comprehend.Client, name string) ([]types.DocumentClassifierProperties, error) { + results := []types.DocumentClassifierProperties{} + + input := &comprehend.ListDocumentClassifiersInput{ + Filter: &types.DocumentClassifierFilter{ + DocumentClassifierName: aws.String(name), + }, + } + paginator := comprehend.NewListDocumentClassifiersPaginator(conn, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return []types.DocumentClassifierProperties{}, err + } + results = append(results, output.DocumentClassifierPropertiesList...) + } + + return results, nil +} + +func waitDocumentClassifierCreated(ctx context.Context, conn *comprehend.Client, id string, timeout time.Duration) (*types.DocumentClassifierProperties, error) { + stateConf := &resource.StateChangeConf{ + Pending: enum.Slice(types.ModelStatusSubmitted, types.ModelStatusTraining), + Target: enum.Slice(types.ModelStatusTrained), + Refresh: statusDocumentClassifier(ctx, conn, id), + Delay: documentClassifierCreatedDelay, + PollInterval: documentClassifierPollInterval, + Timeout: timeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + if out, ok := outputRaw.(*types.DocumentClassifierProperties); ok { + var ues *resource.UnexpectedStateError + if errors.As(err, &ues) { + if ues.State == string(types.ModelStatusInError) { + err = errors.New(aws.ToString(out.Message)) + } + } + return out, err + } + + return nil, err +} + +func waitDocumentClassifierStopped(ctx context.Context, conn *comprehend.Client, id string, timeout time.Duration) (*types.DocumentClassifierProperties, error) { + stateConf := &resource.StateChangeConf{ + Pending: enum.Slice(types.ModelStatusSubmitted, types.ModelStatusTraining, types.ModelStatusStopRequested), + Target: enum.Slice(types.ModelStatusTrained, types.ModelStatusStopped, types.ModelStatusInError, types.ModelStatusDeleting), + Refresh: statusDocumentClassifier(ctx, conn, id), + Delay: documentClassifierStoppedDelay, + PollInterval: documentClassifierPollInterval, + Timeout: timeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + if out, ok := outputRaw.(*types.DocumentClassifierProperties); ok { + return out, err + } + + return nil, err +} + +func waitDocumentClassifierDeleted(ctx context.Context, conn *comprehend.Client, id string, timeout time.Duration) (*types.DocumentClassifierProperties, error) { + stateConf := &resource.StateChangeConf{ + Pending: enum.Slice(types.ModelStatusSubmitted, types.ModelStatusTraining, types.ModelStatusDeleting, types.ModelStatusInError, types.ModelStatusStopRequested), + Target: []string{}, + Refresh: statusDocumentClassifier(ctx, conn, id), + Delay: documentClassifierDeletedDelay, + PollInterval: documentClassifierPollInterval, + NotFoundChecks: 3, + Timeout: timeout, + } + + outputRaw, err := stateConf.WaitForStateContext(ctx) + if out, ok := outputRaw.(*types.DocumentClassifierProperties); ok { + return out, err + } + + return nil, err +} + +func statusDocumentClassifier(ctx context.Context, conn *comprehend.Client, id string) resource.StateRefreshFunc { + return func() (interface{}, string, error) { + out, err := FindDocumentClassifierByID(ctx, conn, id) + if tfresource.NotFound(err) { + return nil, "", nil + } + + if err != nil { + return nil, "", err + } + + return out, string(out.Status), nil + } +} + +func flattenDocumentClassifierInputDataConfig(apiObject *types.DocumentClassifierInputDataConfig) []interface{} { + if apiObject == nil { + return nil + } + + m := map[string]interface{}{ + "augmented_manifests": flattenAugmentedManifests(apiObject.AugmentedManifests), + "data_format": apiObject.DataFormat, + "s3_uri": aws.ToString(apiObject.S3Uri), + } + + if apiObject.TestS3Uri != nil { + m["test_s3_uri"] = aws.ToString(apiObject.TestS3Uri) + } + + return []interface{}{m} +} + +func getDocumentClassifierInputDataConfig(diff resourceGetter) map[string]any { + v := diff.Get("input_data_config").([]any) + if len(v) == 0 { + return nil + } + + return v[0].(map[string]any) +} + +func expandDocumentClassifierInputDataConfig(tfMap map[string]any) *types.DocumentClassifierInputDataConfig { + if len(tfMap) == 0 { + return nil + } + + a := &types.DocumentClassifierInputDataConfig{ + AugmentedManifests: expandAugmentedManifests(tfMap["augmented_manifests"].(*schema.Set)), + DataFormat: types.DocumentClassifierDataFormat(tfMap["data_format"].(string)), + // LabelDelimiter: aws.String(tfMap["label_delimiter"].(string)), + S3Uri: aws.String(tfMap["s3_uri"].(string)), + // TestS3Uri: aws.String(tfMap["test_s3_uri"].(string)), + } + + return a +} + +func DocumentClassifierParseARN(arnString string) (string, error) { + arn, err := arn.Parse(arnString) + if err != nil { + return "", err + } + re := regexp.MustCompile(`^document-classifier/([[:alnum:]-]+)`) + matches := re.FindStringSubmatch(arn.Resource) + if len(matches) != 2 { + return "", fmt.Errorf("unable to parse %q", arnString) + } + name := matches[1] + + return name, nil +} diff --git a/internal/service/comprehend/document_classifier_test.go b/internal/service/comprehend/document_classifier_test.go new file mode 100644 index 00000000000..1a415634e4b --- /dev/null +++ b/internal/service/comprehend/document_classifier_test.go @@ -0,0 +1,1879 @@ +package comprehend_test + +import ( + "context" + "fmt" + "regexp" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/comprehend" + "github.com/aws/aws-sdk-go-v2/service/comprehend/types" + sdkacctest "github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" + "github.com/hashicorp/terraform-provider-aws/internal/acctest" + "github.com/hashicorp/terraform-provider-aws/internal/conns" + tfcomprehend "github.com/hashicorp/terraform-provider-aws/internal/service/comprehend" + "github.com/hashicorp/terraform-provider-aws/names" +) + +func TestAccComprehendDocumentClassifier_basic(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), + resource.TestCheckResourceAttrSet(resourceName, "input_data_config.0.s3_uri"), + resource.TestCheckNoResourceAttr(resourceName, "input_data_config.0.documents.0.test_s3_uri"), + resource.TestCheckResourceAttr(resourceName, "language_code", "en"), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), + acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_disappears(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + acctest.CheckResourceDisappears(acctest.Provider, tfcomprehend.ResourceDocumentClassifier(), resourceName), + ), + ExpectNonEmptyPlan: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_versionName(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + vName1 := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + vName2 := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_versionName(rName, vName1, "key", "value1"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttr(resourceName, "version_name", vName1), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", ""), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, vName1))), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key", "value1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_versionName(rName, vName2, "key", "value2"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttr(resourceName, "version_name", vName2), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", ""), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, vName2))), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key", "value2"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_versionNameEmpty(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_versionNameEmpty(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttr(resourceName, "version_name", ""), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", ""), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s$`, rName))), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_versionNameGenerated(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_versionNameNotSet(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_versionNamePrefix(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_versioNamePrefix(rName, "tf-acc-test-prefix-"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + acctest.CheckResourceAttrNameFromPrefix(resourceName, "version_name", "tf-acc-test-prefix-"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", "tf-acc-test-prefix-"), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, prefixedUniqueIDPattern("tf-acc-test-prefix-")))), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +// func TestAccComprehendDocumentClassifier_documents_testDocuments(t *testing.T) { +// if testing.Short() { +// t.Skip("skipping long-running test in short mode") +// } + +// var documentclassifier types.DocumentClassifierProperties +// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) +// resourceName := "aws_comprehend_document_classifier.test" + +// resource.ParallelTest(t, resource.TestCase{ +// PreCheck: func() { +// acctest.PreCheck(t) +// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) +// testAccPreCheck(t) +// }, +// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), +// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, +// CheckDestroy: testAccCheckDocumentClassifierDestroy, +// Steps: []resource.TestStep{ +// { +// Config: testAccDocumentClassifierConfig_testDocuments(rName), +// Check: resource.ComposeAggregateTestCheckFunc( +// testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), +// testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), +// resource.TestCheckResourceAttr(resourceName, "name", rName), +// acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), +// resource.TestCheckResourceAttr(resourceName, "language_code", "en"), +// resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), +// resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), +// resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), +// acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), +// resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), +// resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), +// resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), +// ), +// }, +// { +// ResourceName: resourceName, +// ImportState: true, +// ImportStateVerify: true, +// }, +// }, +// }) +// } + +// func TestAccComprehendDocumentClassifier_annotations_basic(t *testing.T) { +// if testing.Short() { +// t.Skip("skipping long-running test in short mode") +// } + +// var documentclassifier types.DocumentClassifierProperties +// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) +// resourceName := "aws_comprehend_document_classifier.test" + +// resource.ParallelTest(t, resource.TestCase{ +// PreCheck: func() { +// acctest.PreCheck(t) +// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) +// testAccPreCheck(t) +// }, +// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), +// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, +// CheckDestroy: testAccCheckDocumentClassifierDestroy, +// Steps: []resource.TestStep{ +// { +// Config: testAccDocumentClassifierConfig_annotations_basic(rName), +// Check: resource.ComposeAggregateTestCheckFunc( +// testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), +// testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), +// resource.TestCheckResourceAttr(resourceName, "name", rName), +// resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), +// acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), +// resource.TestCheckResourceAttr(resourceName, "language_code", "en"), +// resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), +// resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), +// resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), +// acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), +// resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), +// resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), +// resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), +// ), +// }, +// { +// ResourceName: resourceName, +// ImportState: true, +// ImportStateVerify: true, +// }, +// }, +// }) +// } + +// func TestAccComprehendDocumentClassifier_annotations_testDocuments(t *testing.T) { +// if testing.Short() { +// t.Skip("skipping long-running test in short mode") +// } + +// var documentclassifier types.DocumentClassifierProperties +// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) +// resourceName := "aws_comprehend_document_classifier.test" + +// resource.ParallelTest(t, resource.TestCase{ +// PreCheck: func() { +// acctest.PreCheck(t) +// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) +// testAccPreCheck(t) +// }, +// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), +// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, +// CheckDestroy: testAccCheckDocumentClassifierDestroy, +// Steps: []resource.TestStep{ +// { +// Config: testAccDocumentClassifierConfig_annotations_testDocuments(rName), +// Check: resource.ComposeAggregateTestCheckFunc( +// testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), +// testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), +// resource.TestCheckResourceAttr(resourceName, "name", rName), +// resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), +// acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), +// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), +// resource.TestCheckResourceAttr(resourceName, "language_code", "en"), +// resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), +// resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), +// resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), +// acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), +// resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), +// resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), +// resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), +// ), +// }, +// { +// ResourceName: resourceName, +// ImportState: true, +// ImportStateVerify: true, +// }, +// }, +// }) +// } + +// func TestAccComprehendDocumentClassifier_annotations_validateNoTestDocuments(t *testing.T) { +// if testing.Short() { +// t.Skip("skipping long-running test in short mode") +// } + +// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + +// resource.ParallelTest(t, resource.TestCase{ +// PreCheck: func() { +// acctest.PreCheck(t) +// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) +// testAccPreCheck(t) +// }, +// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), +// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, +// CheckDestroy: testAccCheckDocumentClassifierDestroy, +// Steps: []resource.TestStep{ +// { +// Config: testAccDocumentClassifierConfig_annotations_noTestDocuments(rName), +// ExpectError: regexp.MustCompile("input_data_config.documents.test_s3_uri must be set when input_data_config.annotations.test_s3_uri is set"), +// }, +// }, +// }) +// } + +// func TestAccComprehendDocumentClassifier_annotations_validateNoTestAnnotations(t *testing.T) { +// if testing.Short() { +// t.Skip("skipping long-running test in short mode") +// } + +// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + +// resource.ParallelTest(t, resource.TestCase{ +// PreCheck: func() { +// acctest.PreCheck(t) +// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) +// testAccPreCheck(t) +// }, +// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), +// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, +// CheckDestroy: testAccCheckDocumentClassifierDestroy, +// Steps: []resource.TestStep{ +// { +// Config: testAccDocumentClassifierConfig_annotations_noTestAnnotations(rName), +// ExpectError: regexp.MustCompile("input_data_config.annotations.test_s3_uri must be set when input_data_config.documents.test_s3_uri is set"), +// }, +// }, +// }) +// } + +func TestAccComprehendDocumentClassifier_KMSKeys_CreateIDs(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_kmsKeyIds(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_kmsKeyARNs(rName), + PlanOnly: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_KMSKeys_CreateARNs(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_kmsKeyARNs(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "arn"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "arn"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_kmsKeyIds(rName), + PlanOnly: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2, v3, v4 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + ), + }, + { + Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_kmsKeys_Update(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v3), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 3), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model2", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume2", "key_id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v4), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 4), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + ), + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var dc1, dc2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_vpcConfig_Update(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.1", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.2", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.3", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_VPCConfig_Update(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var dc1, dc2, dc3 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc3), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 3), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_tags(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2, v3 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_tags1(rName, "key1", "value1"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_tags2(rName, "key1", "value1updated", "key2", "value2"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierNotRecreated(&v1, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "2"), + resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1updated"), + resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + ), + }, + { + Config: testAccDocumentClassifierConfig_tags1(rName, "key2", "value2"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v3), + testAccCheckDocumentClassifierNotRecreated(&v2, &v3), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + ), + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_DefaultTags_providerOnly(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2, v3 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: acctest.ConfigCompose( + acctest.ConfigDefaultTags_Tags1("providerkey1", "providervalue1"), + testAccDocumentClassifierConfig_tags0(rName), + ), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags_all.providerkey1", "providervalue1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: acctest.ConfigCompose( + acctest.ConfigDefaultTags_Tags2("providerkey1", "providervalue1", "providerkey2", "providervalue2"), + testAccDocumentClassifierConfig_tags0(rName), + ), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierNotRecreated(&v1, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "2"), + resource.TestCheckResourceAttr(resourceName, "tags_all.providerkey1", "providervalue1"), + resource.TestCheckResourceAttr(resourceName, "tags_all.providerkey2", "providervalue2"), + ), + }, + { + Config: acctest.ConfigCompose( + acctest.ConfigDefaultTags_Tags1("providerkey1", "value1"), + testAccDocumentClassifierConfig_tags0(rName), + ), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v3), + testAccCheckDocumentClassifierNotRecreated(&v2, &v3), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags_all.providerkey1", "value1"), + ), + }, + }, + }) +} + +func testAccCheckDocumentClassifierDestroy(s *terraform.State) error { + conn := acctest.Provider.Meta().(*conns.AWSClient).ComprehendConn + ctx := context.Background() + + for _, rs := range s.RootModule().Resources { + if rs.Type != "aws_comprehend_document_classifier" { + continue + } + + name, err := tfcomprehend.DocumentClassifierParseARN(rs.Primary.ID) + if err != nil { + return err + } + + input := &comprehend.ListDocumentClassifiersInput{ + Filter: &types.DocumentClassifierFilter{ + DocumentClassifierName: aws.String(name), + }, + } + total := 0 + paginator := comprehend.NewListDocumentClassifiersPaginator(conn, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return err + } + total += len(output.DocumentClassifierPropertiesList) + } + + if total != 0 { + return fmt.Errorf("Expected Comprehend Document Classifier (%s) to be destroyed, found %d versions", rs.Primary.ID, total) + } + return nil + } + + return nil +} + +func testAccCheckDocumentClassifierExists(name string, documentclassifier *types.DocumentClassifierProperties) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[name] + if !ok { + return fmt.Errorf("Not found: %s", name) + } + + if rs.Primary.ID == "" { + return fmt.Errorf("No Comprehend Document Classifier is set") + } + + conn := acctest.Provider.Meta().(*conns.AWSClient).ComprehendConn + ctx := context.Background() + + resp, err := tfcomprehend.FindDocumentClassifierByID(ctx, conn, rs.Primary.ID) + if err != nil { + return fmt.Errorf("Error describing Comprehend Document Classifier: %w", err) + } + + *documentclassifier = *resp + + return nil + } +} + +// func testAccCheckDocumentClassifierRecreated(before, after *types.DocumentClassifierProperties) resource.TestCheckFunc { +// return func(s *terraform.State) error { +// if documentClassifierIdentity(before, after) { +// return fmt.Errorf("Comprehend Document Classifier not recreated") +// } + +// return nil +// } +// } + +func testAccCheckDocumentClassifierNotRecreated(before, after *types.DocumentClassifierProperties) resource.TestCheckFunc { + return func(s *terraform.State) error { + if !documentClassifierIdentity(before, after) { + return fmt.Errorf("Comprehend Document Classifier recreated") + } + + return nil + } +} + +func documentClassifierIdentity(before, after *types.DocumentClassifierProperties) bool { + return aws.ToTime(before.SubmitTime).Equal(aws.ToTime(after.SubmitTime)) +} + +func testAccCheckDocumentClassifierPublishedVersions(name string, expected int) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[name] + if !ok { + return fmt.Errorf("Not found: %s", name) + } + + if rs.Primary.ID == "" { + return fmt.Errorf("No Comprehend Document Classifier is set") + } + + conn := acctest.Provider.Meta().(*conns.AWSClient).ComprehendConn + ctx := context.Background() + + name, err := tfcomprehend.DocumentClassifierParseARN(rs.Primary.ID) + if err != nil { + return err + } + + input := &comprehend.ListDocumentClassifiersInput{ + Filter: &types.DocumentClassifierFilter{ + DocumentClassifierName: aws.String(name), + }, + } + count := 0 + paginator := comprehend.NewListDocumentClassifiersPaginator(conn, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return err + } + count += len(output.DocumentClassifierPropertiesList) + } + + if count != expected { + return fmt.Errorf("expected %d published versions, found %d", expected, count) + } + + return nil + } +} + +func testAccDocumentClassifierConfig_basic(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_versionName(rName, vName, key, value string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name = %[2]q + + data_access_role_arn = aws_iam_role.test.arn + + tags = { + %[3]q = %[4]q + } + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, vName, key, value)) +} + +func testAccDocumentClassifierConfig_versionNameEmpty(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name = "" + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_versionNameNotSet(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_versioNamePrefix(rName, versionNamePrefix string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name_prefix = %[2]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, versionNamePrefix)) +} + +func testAccDocumentClassifierConfig_testDocuments(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeyIds(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + model_kms_key_id = aws_kms_key.model.key_id + volume_kms_key_id = aws_kms_key.volume.key_id + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume.arn, + ] + } +} + +resource "aws_kms_key" "model" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume" { + deletion_window_in_days = 7 +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeyARNs(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + model_kms_key_id = aws_kms_key.model.arn + volume_kms_key_id = aws_kms_key.volume.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume.arn, + ] + } +} + +resource "aws_kms_key" "model" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume" { + deletion_window_in_days = 7 +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeys_None(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeys_Set(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + model_kms_key_id = aws_kms_key.model.key_id + volume_kms_key_id = aws_kms_key.volume.key_id + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume.arn, + ] + } +} + +resource "aws_kms_key" "model" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume" { + deletion_window_in_days = 7 +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeys_Update(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + model_kms_key_id = aws_kms_key.model2.key_id + volume_kms_key_id = aws_kms_key.volume2.key_id + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model2.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume2.arn, + ] + } +} + +resource "aws_kms_key" "model2" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume2" { + deletion_window_in_days = 7 +} +`, rName)) +} + +func testAccDocumentClassifierConfig_tags0(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + tags = {} + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_tags1(rName, tagKey1, tagValue1 string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + tags = { + %[2]q = %[3]q + } + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, tagKey1, tagValue1)) +} + +func testAccDocumentClassifierConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + tags = { + %[2]q = %[3]q + %[4]q = %[5]q + } + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, tagKey1, tagValue1, tagKey2, tagValue2)) +} + +func testAccDocumentClassifierS3BucketConfig(rName string) string { + return fmt.Sprintf(` +resource "aws_s3_bucket" "test" { + bucket = %[1]q +} + +resource "aws_s3_bucket_public_access_block" "test" { + bucket = aws_s3_bucket.test.bucket + + block_public_acls = true + block_public_policy = true + ignore_public_acls = true + restrict_public_buckets = true +} + +resource "aws_s3_bucket_ownership_controls" "test" { + bucket = aws_s3_bucket.test.bucket + + rule { + object_ownership = "BucketOwnerEnforced" + } +} +`, rName) +} + +func testAccDocumentClassifierBasicRoleConfig(rName string) string { + return fmt.Sprintf(` +resource "aws_iam_role" "test" { + name = %[1]q + + assume_role_policy = < 0 { + o.Delay = delay + } + + if delayRand > 0 { + // Hitting the API at exactly the same time on each iteration of the retry is more likely to + // cause Throttling problems. We introduce randomness in order to help AWS be happier. + o.Delay = time.Duration(rand.Int63n(delayRand.Milliseconds())) * time.Millisecond + } + + if minPollInterval > 0 { + o.MinPollInterval = minPollInterval + } + + if pollInterval > 0 { + o.PollInterval = pollInterval + } + }) +} + +type Options struct { + Delay time.Duration // Wait this time before starting checks + MinPollInterval time.Duration // Smallest time to wait before refreshes (MinTimeout in resource.StateChangeConf) + PollInterval time.Duration // Override MinPollInterval/backoff and only poll this often +} + +func (o Options) Apply(c *resource.StateChangeConf) { + if o.Delay > 0 { + c.Delay = o.Delay + } + + if o.MinPollInterval > 0 { + c.MinTimeout = o.MinPollInterval + } + + if o.PollInterval > 0 { + c.PollInterval = o.PollInterval + } +} + +type OptionsFunc func(*Options) + +func WithDelay(delay time.Duration) OptionsFunc { + return func(o *Options) { + o.Delay = delay + } +} + +func WithMinPollInterval(minPollInterval time.Duration) OptionsFunc { + return func(o *Options) { + o.MinPollInterval = minPollInterval + } +} + +func WithPollInterval(pollInterval time.Duration) OptionsFunc { + return func(o *Options) { + o.PollInterval = pollInterval + } +} + +// RetryContext allows configuration of StateChangeConf's various time arguments. +// This is especially useful for AWS services that are prone to throttling, such as Route53, where +// the default durations cause problems. +func RetryContext(ctx context.Context, timeout time.Duration, f resource.RetryFunc, optFns ...OptionsFunc) error { // These are used to pull the error out of the function; need a mutex to // avoid a data race. var resultErr error var resultErrMu sync.Mutex + options := Options{} + for _, fn := range optFns { + fn(&options) + } + c := &resource.StateChangeConf{ Pending: []string{"retryableerror"}, Target: []string{"success"}, @@ -178,25 +248,7 @@ func RetryConfigContext(ctx context.Context, delay time.Duration, delayRand time }, } - if delay.Milliseconds() > 0 { - c.Delay = delay - } - - if delayRand.Milliseconds() > 0 { - // Hitting the API at exactly the same time on each iteration of the retry is more likely to - // cause Throttling problems. We introduce randomness in order to help AWS be happier. - rand.Seed(time.Now().UTC().UnixNano()) - - c.Delay = time.Duration(rand.Int63n(delayRand.Milliseconds())) * time.Millisecond - } - - if minTimeout.Milliseconds() > 0 { - c.MinTimeout = minTimeout - } - - if pollInterval.Milliseconds() > 0 { - c.PollInterval = pollInterval - } + options.Apply(c) _, waitErr := c.WaitForStateContext(ctx) diff --git a/internal/tfresource/retry_test.go b/internal/tfresource/retry_test.go index b8d1b410cd4..2126fbd58d4 100644 --- a/internal/tfresource/retry_test.go +++ b/internal/tfresource/retry_test.go @@ -348,7 +348,7 @@ func TestRetryUntilNotFound(t *testing.T) { } } -func TestRetryConfigContext_error(t *testing.T) { +func TestRetryContext_error(t *testing.T) { t.Parallel() expected := fmt.Errorf("nope") @@ -358,7 +358,7 @@ func TestRetryConfigContext_error(t *testing.T) { errCh := make(chan error) go func() { - errCh <- tfresource.RetryConfigContext(context.Background(), 0*time.Second, 0*time.Second, 0*time.Second, 0*time.Second, 1*time.Second, f) + errCh <- tfresource.RetryContext(context.Background(), 1*time.Second, f) }() select { @@ -370,3 +370,57 @@ func TestRetryConfigContext_error(t *testing.T) { t.Fatal("timeout") } } + +func TestOptionsApply(t *testing.T) { + testCases := map[string]struct { + options tfresource.Options + expected resource.StateChangeConf + }{ + "Nothing": { + options: tfresource.Options{}, + expected: resource.StateChangeConf{}, + }, + "Delay": { + options: tfresource.Options{ + Delay: 1 * time.Minute, + }, + expected: resource.StateChangeConf{ + Delay: 1 * time.Minute, + }, + }, + "MinPollInterval": { + options: tfresource.Options{ + MinPollInterval: 1 * time.Minute, + }, + expected: resource.StateChangeConf{ + MinTimeout: 1 * time.Minute, + }, + }, + "PollInterval": { + options: tfresource.Options{ + PollInterval: 1 * time.Minute, + }, + expected: resource.StateChangeConf{ + PollInterval: 1 * time.Minute, + }, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + conf := resource.StateChangeConf{} + + testCase.options.Apply(&conf) + + if a, e := conf.Delay, testCase.expected.Delay; a != e { + t.Errorf("Delay: expected %s, got %s", e, a) + } + if a, e := conf.MinTimeout, testCase.expected.MinTimeout; a != e { + t.Errorf("MinTimeout: expected %s, got %s", e, a) + } + if a, e := conf.PollInterval, testCase.expected.PollInterval; a != e { + t.Errorf("PollInterval: expected %s, got %s", e, a) + } + }) + } +} diff --git a/website/docs/r/comprehend_document_classifier.html.markdown b/website/docs/r/comprehend_document_classifier.html.markdown new file mode 100644 index 00000000000..dc555c370ad --- /dev/null +++ b/website/docs/r/comprehend_document_classifier.html.markdown @@ -0,0 +1,123 @@ +--- +subcategory: "Comprehend" +layout: "aws" +page_title: "AWS: aws_comprehend_document_classifier" +description: |- + Terraform resource for managing an AWS Comprehend Document Classifier. +--- + +# Resource: aws_comprehend_document_classifier + +Terraform resource for managing an AWS Comprehend Document Classifier. + +## Example Usage + +### Basic Usage + +```terraform +resource "aws_comprehend_document_classifier" "example" { + name = "example" + + data_access_role_arn = aws_iam_role.example.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.example + ] +} + +resource "aws_s3_object" "documents" { + # ... +} + +resource "aws_s3_object" "entities" { + # ... +} +``` + +## Argument Reference + +The following arguments are required: + +* `data_access_role_arn` - (Required) The ARN for an IAM Role which allows Comprehend to read the training and testing data. +* `input_data_config` - (Required) Configuration for the training and testing data. + See the [`input_data_config` Configuration Block](#input_data_config-configuration-block) section below. +* `language_code` - (Required) Two-letter language code for the language. + One of `en`, `es`, `fr`, `it`, `de`, or `pt`. +* `name` - (Required) Name for the Document Classifier. + Has a maximum length of 63 characters. + Can contain upper- and lower-case letters, numbers, and hypen (`-`). + +The following arguments are optional: + +* `model_kms_key_id` - (Optional) The ID or ARN of a KMS Key used to encrypt trained Document Classifiers. +* `tags` - (Optional) A map of tags to assign to the resource. If configured with a provider [`default_tags` Configuration Block](/docs/providers/aws/index.html#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. +* `version_name` - (Optional) Name for the version of the Document Classifier. + Each version must have a unique name within the Document Classifier. + If omitted, Terraform will assign a random, unique version name. + If explicitly set to `""`, no version name will be set. + Has a maximum length of 63 characters. + Can contain upper- and lower-case letters, numbers, and hypen (`-`). + Conflicts with `version_name_prefix`. +* `version_name_prefix` - (Optional) Creates a unique version name beginning with the specified prefix. + Has a maximum length of 37 characters. + Can contain upper- and lower-case letters, numbers, and hypen (`-`). + Conflicts with `version_name`. +* `volume_kms_key_id` - (Optional) ID or ARN of a KMS Key used to encrypt storage volumes during job processing. +* `vpc_config` - (Optional) Configuration parameters for VPC to contain Document Classifier resources. + See the [`vpc_config` Configuration Block](#vpc_config-configuration-block) section below. + +### `input_data_config` Configuration Block + +* `augmented_manifests` - (Optional) List of training datasets produced by Amazon SageMaker Ground Truth. + Used if `data_format` is `AUGMENTED_MANIFEST`. + See the [`augmented_manifests` Configuration Block](#augmented_manifests-configuration-block) section below. +* `data_format` - (Optional, Default: `COMPREHEND_CSV`) The format for the training data. + One of `COMPREHEND_CSV` or `AUGMENTED_MANIFEST`. +* `s3_uri` - (Optional) Location of training documents. + Used if `data_format` is `COMPREHEND_CSV`. +* `test_s3uri` - (Optional) Location of test documents. + +### `augmented_manifests` Configuration Block + +* `annotation_data_s3_uri` - (Optional) Location of annotation files. +* `attribute_names` - (Required) The JSON attribute that contains the annotations for the training documents. +* `document_type` - (Optional, Default: `PLAIN_TEXT_DOCUMENT`) Type of augmented manifest. + One of `PLAIN_TEXT_DOCUMENT` or `SEMI_STRUCTURED_DOCUMENT`. +* `s3_uri` - (Required) Location of augmented manifest file. +* `source_documents_s3_uri` - (Optional) Location of source PDF files. +* `split` - (Optional, Default: `TRAIN`) Purpose of data in augmented manifest. + One of `TRAIN` or `TEST`. + + +### `vpc_config` Configuration Block + +* `security_group_ids` - (Required) List of security group IDs. +* `subnets` - (Required) List of VPC subnets. + +## Attributes Reference + +In addition to all arguments above, the following attributes are exported: + +* `arn` - ARN of the Document Classifier version. +* `tags_all` - A map of tags assigned to the resource, including those inherited from the provider [`default_tags` configuration block](/docs/providers/aws/index.html#default_tags-configuration-block). + +## Timeouts + +`aws_comprehend_document_classifier` provides the following [Timeouts](https://www.terraform.io/docs/configuration/blocks/resources/syntax.html#operation-timeouts) configuration options: + +* `create` - (Optional, Default: `60m`) +* `update` - (Optional, Default: `60m`) +* `delete` - (Optional, Default: `30m`) + +## Import + +Comprehend Document Classifier can be imported using the ARN, e.g., + +``` +$ terraform import aws_comprehend_document_classifier.example arn:aws:comprehend:us-west-2:123456789012:document_classifier/example +``` diff --git a/website/docs/r/comprehend_entity_recognizer.html.markdown b/website/docs/r/comprehend_entity_recognizer.html.markdown index a2c9b8ba1cd..4cdb367781f 100644 --- a/website/docs/r/comprehend_entity_recognizer.html.markdown +++ b/website/docs/r/comprehend_entity_recognizer.html.markdown @@ -8,7 +8,7 @@ description: |- # Resource: aws_comprehend_entity_recognizer -Terraform resource for managing an AWS Comprehend EntityRecognizer. +Terraform resource for managing an AWS Comprehend Entity Recognizer. ## Example Usage @@ -152,9 +152,9 @@ In addition to all arguments above, the following attributes are exported: `aws_comprehend_entity_recognizer` provides the following [Timeouts](https://www.terraform.io/docs/configuration/blocks/resources/syntax.html#operation-timeouts) configuration options: -* `create` - (Optional, Default: `20m`) -* `update` - (Optional, Default: `20m`) -* `delete` - (Optional, Default: `20m`) +* `create` - (Optional, Default: `60m`) +* `update` - (Optional, Default: `60m`) +* `delete` - (Optional, Default: `30m`) ## Import From c55b774a8f8d6f481dc8774668ae0cf0b315d4e1 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 17 Aug 2022 16:42:53 -0700 Subject: [PATCH 2/6] Adds support for test data --- .../service/comprehend/document_classifier.go | 12 ++- .../comprehend/document_classifier_test.go | 96 ++++++++++--------- 2 files changed, 58 insertions(+), 50 deletions(-) diff --git a/internal/service/comprehend/document_classifier.go b/internal/service/comprehend/document_classifier.go index 4d29a6a5338..9d96eeeaba6 100644 --- a/internal/service/comprehend/document_classifier.go +++ b/internal/service/comprehend/document_classifier.go @@ -725,9 +725,15 @@ func expandDocumentClassifierInputDataConfig(tfMap map[string]any) *types.Docume a := &types.DocumentClassifierInputDataConfig{ AugmentedManifests: expandAugmentedManifests(tfMap["augmented_manifests"].(*schema.Set)), DataFormat: types.DocumentClassifierDataFormat(tfMap["data_format"].(string)), - // LabelDelimiter: aws.String(tfMap["label_delimiter"].(string)), - S3Uri: aws.String(tfMap["s3_uri"].(string)), - // TestS3Uri: aws.String(tfMap["test_s3_uri"].(string)), + S3Uri: aws.String(tfMap["s3_uri"].(string)), + } + + // if v, ok := tfMap["label_delimiter"].(string); ok && v != "" { + // a.LabelDelimiter = aws.String(v) + // } + + if v, ok := tfMap["test_s3_uri"].(string); ok && v != "" { + a.TestS3Uri = aws.String(v) } return a diff --git a/internal/service/comprehend/document_classifier_test.go b/internal/service/comprehend/document_classifier_test.go index 1a415634e4b..7c922582ee8 100644 --- a/internal/service/comprehend/document_classifier_test.go +++ b/internal/service/comprehend/document_classifier_test.go @@ -49,7 +49,7 @@ func TestAccComprehendDocumentClassifier_basic(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), resource.TestCheckResourceAttrSet(resourceName, "input_data_config.0.s3_uri"), - resource.TestCheckNoResourceAttr(resourceName, "input_data_config.0.documents.0.test_s3_uri"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.test_s3_uri", ""), resource.TestCheckResourceAttr(resourceName, "language_code", "en"), resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), @@ -276,53 +276,54 @@ func TestAccComprehendDocumentClassifier_versionNamePrefix(t *testing.T) { }) } -// func TestAccComprehendDocumentClassifier_documents_testDocuments(t *testing.T) { -// if testing.Short() { -// t.Skip("skipping long-running test in short mode") -// } +func TestAccComprehendDocumentClassifier_testDocuments(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } -// var documentclassifier types.DocumentClassifierProperties -// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) -// resourceName := "aws_comprehend_document_classifier.test" + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" -// resource.ParallelTest(t, resource.TestCase{ -// PreCheck: func() { -// acctest.PreCheck(t) -// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) -// testAccPreCheck(t) -// }, -// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), -// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, -// CheckDestroy: testAccCheckDocumentClassifierDestroy, -// Steps: []resource.TestStep{ -// { -// Config: testAccDocumentClassifierConfig_testDocuments(rName), -// Check: resource.ComposeAggregateTestCheckFunc( -// testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), -// testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), -// resource.TestCheckResourceAttr(resourceName, "name", rName), -// acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), -// resource.TestCheckResourceAttr(resourceName, "language_code", "en"), -// resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), -// resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), -// resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), -// acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), -// resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), -// resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), -// resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), -// ), -// }, -// { -// ResourceName: resourceName, -// ImportState: true, -// ImportStateVerify: true, -// }, -// }, -// }) -// } + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_testDocuments(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "name", rName), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), + resource.TestCheckResourceAttrSet(resourceName, "input_data_config.0.test_s3_uri"), + resource.TestCheckResourceAttr(resourceName, "language_code", "en"), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), + acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} // func TestAccComprehendDocumentClassifier_annotations_basic(t *testing.T) { // if testing.Short() { @@ -1148,7 +1149,8 @@ resource "aws_comprehend_document_classifier" "test" { language_code = "en" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + test_s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } depends_on = [ From 829160f900e5478bf8f8dc84431730a08261955c Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Mon, 12 Sep 2022 13:44:41 -0700 Subject: [PATCH 3/6] Adds support for Multi-Label mode --- .../service/comprehend/document_classifier.go | 83 ++- .../comprehend/document_classifier_test.go | 571 +++++++++++++----- .../service/comprehend/entity_recognizer.go | 4 +- .../documents.csv | 100 +++ .../generate/document_classifier/main.go | 8 +- .../document_classifier_multilabel/main.go | 90 +++ ...mprehend_document_classifier.html.markdown | 6 + 7 files changed, 693 insertions(+), 169 deletions(-) create mode 100644 internal/service/comprehend/test-fixtures/document_classifier_multilabel/documents.csv create mode 100644 internal/service/comprehend/test-fixtures/generate/document_classifier_multilabel/main.go diff --git a/internal/service/comprehend/document_classifier.go b/internal/service/comprehend/document_classifier.go index 9d96eeeaba6..480a6d8ead7 100644 --- a/internal/service/comprehend/document_classifier.go +++ b/internal/service/comprehend/document_classifier.go @@ -15,11 +15,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/comprehend" "github.com/aws/aws-sdk-go-v2/service/comprehend/types" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/go-multierror" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/hashicorp/terraform-provider-aws/internal/conns" "github.com/hashicorp/terraform-provider-aws/internal/create" awsdiag "github.com/hashicorp/terraform-provider-aws/internal/diag" @@ -110,12 +112,12 @@ func ResourceDocumentClassifier() *schema.Resource { ValidateDiagFunc: enum.Validate[types.DocumentClassifierDataFormat](), Default: types.DocumentClassifierDataFormatComprehendCsv, }, - // "label_delimiter":{ - // Type:schema.TypeString, - // Optional: true, - // ValidateDiagFunc: , - // Default: "|", - // }, + "label_delimiter": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ValidateFunc: validation.StringInSlice(documentClassifierLabelSeparators(), false), + }, "s3_uri": { Type: schema.TypeString, Optional: true, @@ -132,6 +134,12 @@ func ResourceDocumentClassifier() *schema.Resource { Required: true, ValidateDiagFunc: enum.Validate[types.SyntaxLanguageCode](), }, + "mode": { + Type: schema.TypeString, + Optional: true, + ValidateDiagFunc: enum.Validate[types.DocumentClassifierMode](), + Default: types.DocumentClassifierModeMultiClass, + }, "model_kms_key_id": { Type: schema.TypeString, Optional: true, @@ -205,6 +213,20 @@ func ResourceDocumentClassifier() *schema.Resource { } } + return nil + }, + func(_ context.Context, diff *schema.ResourceDiff, _ interface{}) error { + mode := types.DocumentClassifierMode(diff.Get("mode").(string)) + + if mode == types.DocumentClassifierModeMultiClass { + config := diff.GetRawConfig() + inputDataConfig := config.GetAttr("input_data_config").Index(cty.NumberIntVal(0)) + labelDelimiter := inputDataConfig.GetAttr("label_delimiter") + if !labelDelimiter.IsNull() { + return fmt.Errorf("input_data_config.label_delimiter must not be set when mode is %s", types.DocumentClassifierModeMultiClass) + } + } + return nil }, ), @@ -249,6 +271,7 @@ func resourceDocumentClassifierRead(ctx context.Context, d *schema.ResourceData, d.Set("arn", out.DocumentClassifierArn) d.Set("data_access_role_arn", out.DataAccessRoleArn) d.Set("language_code", out.LanguageCode) + d.Set("mode", out.Mode) d.Set("model_kms_key_id", out.ModelKmsKeyId) d.Set("version_name", out.VersionName) d.Set("version_name_prefix", create.NamePrefixFromName(aws.ToString(out.VersionName))) @@ -435,9 +458,10 @@ func fullValueTypeName(v reflect.Value) string { func documentClassifierPublishVersion(ctx context.Context, conn *comprehend.Client, d *schema.ResourceData, versionName *string, action string, timeout time.Duration, awsClient *conns.AWSClient) diag.Diagnostics { in := &comprehend.CreateDocumentClassifierInput{ DataAccessRoleArn: aws.String(d.Get("data_access_role_arn").(string)), - InputDataConfig: expandDocumentClassifierInputDataConfig(getDocumentClassifierInputDataConfig(d)), + InputDataConfig: expandDocumentClassifierInputDataConfig(d), LanguageCode: types.LanguageCode(d.Get("language_code").(string)), DocumentClassifierName: aws.String(d.Get("name").(string)), + Mode: types.DocumentClassifierMode(d.Get("mode").(string)), VersionName: versionName, VpcConfig: expandVPCConfig(d.Get("vpc_config").([]interface{})), ClientRequestToken: aws.String(resource.UniqueId()), @@ -701,6 +725,10 @@ func flattenDocumentClassifierInputDataConfig(apiObject *types.DocumentClassifie "s3_uri": aws.ToString(apiObject.S3Uri), } + if apiObject.LabelDelimiter != nil { + m["label_delimiter"] = aws.ToString(apiObject.LabelDelimiter) + } + if apiObject.TestS3Uri != nil { m["test_s3_uri"] = aws.ToString(apiObject.TestS3Uri) } @@ -708,8 +736,8 @@ func flattenDocumentClassifierInputDataConfig(apiObject *types.DocumentClassifie return []interface{}{m} } -func getDocumentClassifierInputDataConfig(diff resourceGetter) map[string]any { - v := diff.Get("input_data_config").([]any) +func getDocumentClassifierInputDataConfig(d resourceGetter) map[string]any { + v := d.Get("input_data_config").([]any) if len(v) == 0 { return nil } @@ -717,7 +745,8 @@ func getDocumentClassifierInputDataConfig(diff resourceGetter) map[string]any { return v[0].(map[string]any) } -func expandDocumentClassifierInputDataConfig(tfMap map[string]any) *types.DocumentClassifierInputDataConfig { +func expandDocumentClassifierInputDataConfig(d *schema.ResourceData) *types.DocumentClassifierInputDataConfig { + tfMap := getDocumentClassifierInputDataConfig(d) if len(tfMap) == 0 { return nil } @@ -728,9 +757,9 @@ func expandDocumentClassifierInputDataConfig(tfMap map[string]any) *types.Docume S3Uri: aws.String(tfMap["s3_uri"].(string)), } - // if v, ok := tfMap["label_delimiter"].(string); ok && v != "" { - // a.LabelDelimiter = aws.String(v) - // } + if v, ok := tfMap["label_delimiter"].(string); ok && v != "" { + a.LabelDelimiter = aws.String(v) + } if v, ok := tfMap["test_s3_uri"].(string); ok && v != "" { a.TestS3Uri = aws.String(v) @@ -753,3 +782,31 @@ func DocumentClassifierParseARN(arnString string) (string, error) { return name, nil } + +const DocumentClassifierLabelSeparatorDefault = "|" + +func documentClassifierLabelSeparators() []string { + return []string{ + DocumentClassifierLabelSeparatorDefault, + "~", + "!", + "@", + "#", + "$", + "%", + "^", + "*", + "-", + "_", + "+", + "=", + "\\", + ":", + ";", + ">", + "?", + "/", + " ", + "\t", + } +} diff --git a/internal/service/comprehend/document_classifier_test.go b/internal/service/comprehend/document_classifier_test.go index 7c922582ee8..aee95f80e23 100644 --- a/internal/service/comprehend/document_classifier_test.go +++ b/internal/service/comprehend/document_classifier_test.go @@ -48,9 +48,11 @@ func TestAccComprehendDocumentClassifier_basic(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", ""), resource.TestCheckResourceAttrSet(resourceName, "input_data_config.0.s3_uri"), resource.TestCheckResourceAttr(resourceName, "input_data_config.0.test_s3_uri", ""), resource.TestCheckResourceAttr(resourceName, "language_code", "en"), + resource.TestCheckResourceAttr(resourceName, "mode", string(types.DocumentClassifierModeMultiClass)), resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), @@ -65,6 +67,10 @@ func TestAccComprehendDocumentClassifier_basic(t *testing.T) { ImportState: true, ImportStateVerify: true, }, + { + Config: testAccDocumentClassifierConfig_Mode_singleLabel(rName), + PlanOnly: true, + }, }, }) } @@ -325,153 +331,164 @@ func TestAccComprehendDocumentClassifier_testDocuments(t *testing.T) { }) } -// func TestAccComprehendDocumentClassifier_annotations_basic(t *testing.T) { -// if testing.Short() { -// t.Skip("skipping long-running test in short mode") -// } +func TestAccComprehendDocumentClassifier_SingleLabel_ValidateNoDelimiterSet(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } -// var documentclassifier types.DocumentClassifierProperties -// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) -// resourceName := "aws_comprehend_document_classifier.test" - -// resource.ParallelTest(t, resource.TestCase{ -// PreCheck: func() { -// acctest.PreCheck(t) -// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) -// testAccPreCheck(t) -// }, -// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), -// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, -// CheckDestroy: testAccCheckDocumentClassifierDestroy, -// Steps: []resource.TestStep{ -// { -// Config: testAccDocumentClassifierConfig_annotations_basic(rName), -// Check: resource.ComposeAggregateTestCheckFunc( -// testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), -// testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), -// resource.TestCheckResourceAttr(resourceName, "name", rName), -// resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), -// acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), -// resource.TestCheckResourceAttr(resourceName, "language_code", "en"), -// resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), -// resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), -// resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), -// acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), -// resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), -// resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), -// resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), -// ), -// }, -// { -// ResourceName: resourceName, -// ImportState: true, -// ImportStateVerify: true, -// }, -// }, -// }) -// } + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) -// func TestAccComprehendDocumentClassifier_annotations_testDocuments(t *testing.T) { -// if testing.Short() { -// t.Skip("skipping long-running test in short mode") -// } + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_modeDefault_ValidateNoDelimiterSet(rName, tfcomprehend.DocumentClassifierLabelSeparatorDefault), + ExpectError: regexp.MustCompile(fmt.Sprintf(`input_data_config.label_delimiter must not be set when mode is %s`, types.DocumentClassifierModeMultiClass)), + }, + { + Config: testAccDocumentClassifierConfig_modeDefault_ValidateNoDelimiterSet(rName, ">"), + ExpectError: regexp.MustCompile(fmt.Sprintf(`input_data_config.label_delimiter must not be set when mode is %s`, types.DocumentClassifierModeMultiClass)), + }, + { + Config: testAccDocumentClassifierConfig_modeSingleLabel_ValidateNoDelimiterSet(rName, tfcomprehend.DocumentClassifierLabelSeparatorDefault), + ExpectError: regexp.MustCompile(fmt.Sprintf(`input_data_config.label_delimiter must not be set when mode is %s`, types.DocumentClassifierModeMultiClass)), + }, + { + Config: testAccDocumentClassifierConfig_modeSingleLabel_ValidateNoDelimiterSet(rName, ">"), + ExpectError: regexp.MustCompile(fmt.Sprintf(`input_data_config.label_delimiter must not be set when mode is %s`, types.DocumentClassifierModeMultiClass)), + }, + }, + }) +} -// var documentclassifier types.DocumentClassifierProperties -// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) -// resourceName := "aws_comprehend_document_classifier.test" - -// resource.ParallelTest(t, resource.TestCase{ -// PreCheck: func() { -// acctest.PreCheck(t) -// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) -// testAccPreCheck(t) -// }, -// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), -// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, -// CheckDestroy: testAccCheckDocumentClassifierDestroy, -// Steps: []resource.TestStep{ -// { -// Config: testAccDocumentClassifierConfig_annotations_testDocuments(rName), -// Check: resource.ComposeAggregateTestCheckFunc( -// testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), -// testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), -// resource.TestCheckResourceAttr(resourceName, "name", rName), -// resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), -// acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), -// resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), -// resource.TestCheckResourceAttr(resourceName, "language_code", "en"), -// resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), -// resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), -// resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), -// acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), -// resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), -// resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), -// resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), -// ), -// }, -// { -// ResourceName: resourceName, -// ImportState: true, -// ImportStateVerify: true, -// }, -// }, -// }) -// } +func TestAccComprehendDocumentClassifier_multiLabel_basic(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } -// func TestAccComprehendDocumentClassifier_annotations_validateNoTestDocuments(t *testing.T) { -// if testing.Short() { -// t.Skip("skipping long-running test in short mode") -// } + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" -// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) - -// resource.ParallelTest(t, resource.TestCase{ -// PreCheck: func() { -// acctest.PreCheck(t) -// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) -// testAccPreCheck(t) -// }, -// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), -// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, -// CheckDestroy: testAccCheckDocumentClassifierDestroy, -// Steps: []resource.TestStep{ -// { -// Config: testAccDocumentClassifierConfig_annotations_noTestDocuments(rName), -// ExpectError: regexp.MustCompile("input_data_config.documents.test_s3_uri must be set when input_data_config.annotations.test_s3_uri is set"), -// }, -// }, -// }) -// } + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_multiLabel_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", tfcomprehend.DocumentClassifierLabelSeparatorDefault), + resource.TestCheckResourceAttr(resourceName, "language_code", "en"), + resource.TestCheckResourceAttr(resourceName, "mode", string(types.DocumentClassifierModeMultiLabel)), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), + acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_multiLabel_defaultDelimiter(rName), + PlanOnly: true, + }, + }, + }) +} -// func TestAccComprehendDocumentClassifier_annotations_validateNoTestAnnotations(t *testing.T) { -// if testing.Short() { -// t.Skip("skipping long-running test in short mode") -// } +func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } -// rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) - -// resource.ParallelTest(t, resource.TestCase{ -// PreCheck: func() { -// acctest.PreCheck(t) -// acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) -// testAccPreCheck(t) -// }, -// ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), -// ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, -// CheckDestroy: testAccCheckDocumentClassifierDestroy, -// Steps: []resource.TestStep{ -// { -// Config: testAccDocumentClassifierConfig_annotations_noTestAnnotations(rName), -// ExpectError: regexp.MustCompile("input_data_config.annotations.test_s3_uri must be set when input_data_config.documents.test_s3_uri is set"), -// }, -// }, -// }) -// } + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + const delimiter = "~" + const delimiterUpdated = "/" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiter), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", delimiter), + resource.TestCheckResourceAttr(resourceName, "language_code", "en"), + resource.TestCheckResourceAttr(resourceName, "mode", string(types.DocumentClassifierModeMultiLabel)), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), + acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiterUpdated), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", delimiterUpdated), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} func TestAccComprehendDocumentClassifier_KMSKeys_CreateIDs(t *testing.T) { if testing.Short() { @@ -555,12 +572,12 @@ func TestAccComprehendDocumentClassifier_KMSKeys_CreateARNs(t *testing.T) { }) } -func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { +func TestAccComprehendDocumentClassifier_KMSKeys_Add(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var v1, v2, v3, v4 types.DocumentClassifierProperties + var v1, v2 types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" @@ -597,11 +614,43 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { ImportState: true, ImportStateVerify: true, }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + ), + }, { Config: testAccDocumentClassifierConfig_kmsKeys_Update(rName), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &v3), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 3), + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model2", "key_id"), resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume2", "key_id"), ), @@ -611,15 +660,52 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { ImportState: true, ImportStateVerify: true, }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_KMSKeys_Remove(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + ), + }, { Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &v4), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 4), + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), ), }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, }, }) } @@ -683,12 +769,12 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { }) } -func TestAccComprehendDocumentClassifier_VPCConfig_Update(t *testing.T) { +func TestAccComprehendDocumentClassifier_VPCConfig_Add(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var dc1, dc2, dc3 types.DocumentClassifierProperties + var dc1, dc2 types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" @@ -728,11 +814,52 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Update(t *testing.T) { ImportState: true, ImportStateVerify: true, }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_VPCConfig_Remove(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var dc1, dc2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, { Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc3), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 3), + testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), ), }, @@ -1027,6 +1154,32 @@ resource "aws_comprehend_document_classifier" "test" { `, rName)) } +func testAccDocumentClassifierConfig_Mode_singleLabel(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + mode = "MULTI_CLASS" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + func testAccDocumentClassifierConfig_versionName(rName, vName, key, value string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), @@ -1160,6 +1313,116 @@ resource "aws_comprehend_document_classifier" "test" { `, rName)) } +func testAccDocumentClassifierConfig_modeDefault_ValidateNoDelimiterSet(rName, delimiter string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + label_delimiter = %q + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, delimiter)) +} + +func testAccDocumentClassifierConfig_modeSingleLabel_ValidateNoDelimiterSet(rName, delimiter string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + mode = "MULTI_CLASS" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + label_delimiter = %q + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, delimiter)) +} + +func testAccDocumentClassifierConfig_multiLabel_basic(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_multilabel, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + mode = "MULTI_LABEL" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_multiLabel_defaultDelimiter(rName string) string { + return testAccDocumentClassifierConfig_multiLabel_delimiter(rName, tfcomprehend.DocumentClassifierLabelSeparatorDefault) +} + +func testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiter string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_multilabel, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + mode = "MULTI_LABEL" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" + label_delimiter = %[2]q + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, delimiter)) +} + func testAccDocumentClassifierConfig_kmsKeyIds(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), @@ -1879,3 +2142,11 @@ resource "aws_s3_object" "documents" { source = "test-fixtures/document_classifier/documents.csv" } ` + +const testAccDocumentClassifierConfig_S3_multilabel = ` +resource "aws_s3_object" "multilabel" { + bucket = aws_s3_bucket.test.bucket + key = "documents.csv" + source = "test-fixtures/document_classifier_multilabel/documents.csv" +} +` diff --git a/internal/service/comprehend/entity_recognizer.go b/internal/service/comprehend/entity_recognizer.go index 05f9d38ae28..f01278c13be 100644 --- a/internal/service/comprehend/entity_recognizer.go +++ b/internal/service/comprehend/entity_recognizer.go @@ -851,8 +851,8 @@ func flattenEntityList(apiObject *types.EntityRecognizerEntityList) []interface{ return []interface{}{m} } -func getEntityRecognizerInputDataConfig(diff resourceGetter) map[string]any { - v := diff.Get("input_data_config").([]any) +func getEntityRecognizerInputDataConfig(d resourceGetter) map[string]any { + v := d.Get("input_data_config").([]any) if len(v) == 0 { return nil } diff --git a/internal/service/comprehend/test-fixtures/document_classifier_multilabel/documents.csv b/internal/service/comprehend/test-fixtures/document_classifier_multilabel/documents.csv new file mode 100644 index 00000000000..ce65a2536dc --- /dev/null +++ b/internal/service/comprehend/test-fixtures/document_classifier_multilabel/documents.csv @@ -0,0 +1,100 @@ +DRAMA|COMEDY,minus is dramatic and hilarious +DRAMA|COMEDY,reiciendis is gripping and funny +DRAMA|COMEDY,sed is poingnant and comedic +DRAMA,eum is gripping +COMEDY,ex is comedic +COMEDY,aut is hilarious +DRAMA|COMEDY,doloremque is poingnant and hilarious +DRAMA,quisquam is gripping +DRAMA|COMEDY,illo is gripping and comedic +DRAMA,reiciendis is dramatic +DRAMA|COMEDY,dignissimos is dramatic and hilarious +DRAMA|COMEDY,omnis is dramatic and hilarious +DRAMA|COMEDY,ut is gripping and funny +COMEDY,dolorem is comedic +COMEDY,velit is comedic +DRAMA|COMEDY,minus is gripping and comedic +DRAMA|COMEDY,aut is dramatic and funny +DRAMA|COMEDY,magni is dramatic and funny +COMEDY,eum is hilarious +COMEDY,aut is hilarious +DRAMA|COMEDY,totam is dramatic and funny +DRAMA,aut is poingnant +DRAMA|COMEDY,quae is poingnant and hilarious +COMEDY,doloremque is funny +DRAMA,delectus is dramatic +COMEDY,laudantium is comedic +DRAMA,non is dramatic +DRAMA|COMEDY,amet is gripping and comedic +DRAMA|COMEDY,aliquam is gripping and hilarious +DRAMA|COMEDY,nostrum is poingnant and comedic +DRAMA,omnis is dramatic +DRAMA|COMEDY,accusantium is dramatic and comedic +COMEDY,sint is comedic +COMEDY,consectetur is comedic +DRAMA|COMEDY,laboriosam is poingnant and hilarious +DRAMA,qui is gripping +DRAMA,dolores is dramatic +DRAMA|COMEDY,illo is gripping and hilarious +COMEDY,eius is hilarious +DRAMA|COMEDY,tenetur is dramatic and funny +DRAMA,rerum is dramatic +DRAMA|COMEDY,beatae is gripping and funny +DRAMA|COMEDY,voluptates is gripping and comedic +DRAMA|COMEDY,dolores is dramatic and funny +DRAMA|COMEDY,cupiditate is poingnant and funny +DRAMA|COMEDY,deserunt is poingnant and hilarious +COMEDY,rerum is funny +DRAMA|COMEDY,qui is gripping and comedic +COMEDY,porro is hilarious +COMEDY,aut is hilarious +COMEDY,similique is comedic +COMEDY,qui is comedic +DRAMA|COMEDY,quia is dramatic and comedic +COMEDY,omnis is funny +DRAMA,esse is dramatic +DRAMA|COMEDY,eligendi is dramatic and funny +DRAMA|COMEDY,aperiam is gripping and funny +COMEDY,consequatur is funny +DRAMA,sed is gripping +DRAMA|COMEDY,labore is dramatic and hilarious +COMEDY,necessitatibus is comedic +DRAMA|COMEDY,ratione is poingnant and funny +DRAMA,alias is gripping +COMEDY,tenetur is hilarious +DRAMA|COMEDY,natus is dramatic and funny +COMEDY,distinctio is funny +DRAMA,et is dramatic +COMEDY,aut is comedic +DRAMA,aut is poingnant +COMEDY,numquam is funny +COMEDY,recusandae is hilarious +DRAMA,repellendus is dramatic +DRAMA|COMEDY,hic is poingnant and comedic +COMEDY,quia is comedic +DRAMA|COMEDY,velit is gripping and hilarious +DRAMA|COMEDY,placeat is gripping and comedic +DRAMA,asperiores is gripping +DRAMA,ut is dramatic +DRAMA|COMEDY,eveniet is gripping and comedic +DRAMA|COMEDY,quia is dramatic and funny +COMEDY,est is funny +DRAMA|COMEDY,aut is dramatic and hilarious +DRAMA|COMEDY,porro is gripping and comedic +DRAMA|COMEDY,beatae is gripping and comedic +DRAMA|COMEDY,quisquam is gripping and hilarious +DRAMA,ea is poingnant +DRAMA,explicabo is poingnant +DRAMA|COMEDY,minus is poingnant and funny +DRAMA|COMEDY,eum is dramatic and hilarious +DRAMA|COMEDY,quo is poingnant and comedic +DRAMA,tempora is gripping +DRAMA|COMEDY,voluptates is dramatic and funny +COMEDY,accusantium is funny +DRAMA,odio is gripping +COMEDY,voluptas is hilarious +DRAMA|COMEDY,voluptatum is poingnant and comedic +DRAMA|COMEDY,ipsum is poingnant and comedic +DRAMA|COMEDY,veritatis is poingnant and funny +DRAMA|COMEDY,beatae is gripping and hilarious +COMEDY,animi is hilarious diff --git a/internal/service/comprehend/test-fixtures/generate/document_classifier/main.go b/internal/service/comprehend/test-fixtures/generate/document_classifier/main.go index d315da77f5e..76bf105df6f 100644 --- a/internal/service/comprehend/test-fixtures/generate/document_classifier/main.go +++ b/internal/service/comprehend/test-fixtures/generate/document_classifier/main.go @@ -42,7 +42,7 @@ func main() { log.Fatalf("error opening file %q: %s", "documents.csv", err) } defer closeFile(documentFile, "documents.csv") - annotationsWriter := csv.NewWriter(documentFile) + documentsWriter := csv.NewWriter(documentFile) for i := 0; i < 100; i++ { name := faker.Name().Name() @@ -61,12 +61,12 @@ func main() { line = fmt.Sprintf(doc, name, product, company) } - if err := annotationsWriter.Write([]string{doctype, line}); err != nil { - log.Fatalf("error writing to file %q: %s", "annotations.csv", err) + if err := documentsWriter.Write([]string{doctype, line}); err != nil { + log.Fatalf("error writing to file %q: %s", "documents.csv", err) } } - annotationsWriter.Flush() + documentsWriter.Flush() } func closeFile(f *os.File, name string) { diff --git a/internal/service/comprehend/test-fixtures/generate/document_classifier_multilabel/main.go b/internal/service/comprehend/test-fixtures/generate/document_classifier_multilabel/main.go new file mode 100644 index 00000000000..2c28768fa6d --- /dev/null +++ b/internal/service/comprehend/test-fixtures/generate/document_classifier_multilabel/main.go @@ -0,0 +1,90 @@ +//go:build generate +// +build generate + +package main + +import ( + "encoding/csv" + "fmt" + "log" + "math/rand" + "os" + "strings" + + "syreclabs.com/go/faker" +) + +const ( + defaultSeparator = "|" +) + +var doctypes = []string{ + "DRAMA", + "COMEDY", +} + +var dramaWords = []string{ + "dramatic", + "gripping", + "poingnant", +} + +var comedyWords = []string{ + "funny", + "comedic", + "hilarious", +} + +func main() { + log.SetFlags(0) + + seed := int64(1) // Default rand seed + rand.Seed(seed) + faker.Seed(seed) + + // documentFile, err := os.OpenFile("./test-fixtures/document_classifier_multilabel/documents.csv", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0600) + documentFile, err := os.OpenFile("../../../test-fixtures/document_classifier_multilabel/documents.csv", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + log.Fatalf("error opening file %q: %s", "documents.csv", err) + } + defer closeFile(documentFile, "documents.csv") + documentsWriter := csv.NewWriter(documentFile) + + for i := 0; i < 100; i++ { + f := rand.Intn(2) + var doctype string + if f == 0 { + doctype = doctypes[rand.Intn(len(doctypes))] + } else { + doctype = strings.Join(doctypes, defaultSeparator) + } + + title := faker.Lorem().Word() + + var desc string + if doctype == "DRAMA" { + desc = dramaWords[rand.Intn(len(dramaWords))] + } else if doctype == "COMEDY" { + desc = comedyWords[rand.Intn(len(comedyWords))] + } else { + desc = fmt.Sprintf("%s and %s", + dramaWords[rand.Intn(len(dramaWords))], + comedyWords[rand.Intn(len(comedyWords))], + ) + } + + line := fmt.Sprintf("%s is %s", title, desc) + + if err := documentsWriter.Write([]string{doctype, line}); err != nil { + log.Fatalf("error writing to file %q: %s", "documents.csv", err) + } + } + + documentsWriter.Flush() +} + +func closeFile(f *os.File, name string) { + if err := f.Close(); err != nil { + log.Fatalf("error closing file %q: %s", name, err) + } +} diff --git a/website/docs/r/comprehend_document_classifier.html.markdown b/website/docs/r/comprehend_document_classifier.html.markdown index dc555c370ad..fb4d3c994eb 100644 --- a/website/docs/r/comprehend_document_classifier.html.markdown +++ b/website/docs/r/comprehend_document_classifier.html.markdown @@ -54,6 +54,9 @@ The following arguments are required: The following arguments are optional: +* `mode` - (Optional, Default: `MULTI_CLASS`) The document classification mode. + One of `MULTI_CLASS` or `MULTI_LABEL`. + `MULTI_CLASS` is also known as "Single Label" in the AWS Console. * `model_kms_key_id` - (Optional) The ID or ARN of a KMS Key used to encrypt trained Document Classifiers. * `tags` - (Optional) A map of tags to assign to the resource. If configured with a provider [`default_tags` Configuration Block](/docs/providers/aws/index.html#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. * `version_name` - (Optional) Name for the version of the Document Classifier. @@ -78,6 +81,9 @@ The following arguments are optional: See the [`augmented_manifests` Configuration Block](#augmented_manifests-configuration-block) section below. * `data_format` - (Optional, Default: `COMPREHEND_CSV`) The format for the training data. One of `COMPREHEND_CSV` or `AUGMENTED_MANIFEST`. +* `label_delimiter` - (Optional) Delimiter between labels when training a multi-label classifier. + Valid values are `|`, `~`, `!`, `@`, `#`, `$`, `%`, `^`, `*`, `-`, `_`, `+`, `=`, `\`, `:`, `;`, `>`, `?`, `/`, ``, and ``. + Default is `|`. * `s3_uri` - (Optional) Location of training documents. Used if `data_format` is `COMPREHEND_CSV`. * `test_s3uri` - (Optional) Location of test documents. From 5f95fbe75a472351bb3dfb9a40d8905bb28ef011 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 15 Sep 2022 10:08:39 -0700 Subject: [PATCH 4/6] Adds support for `output_data_config` --- internal/service/comprehend/diff_test.go | 1 - .../service/comprehend/document_classifier.go | 65 ++++++-- .../comprehend/document_classifier_test.go | 142 +++++++++++++++++- .../service/s3/bucket_ownership_controls.go | 10 +- ...mprehend_document_classifier.html.markdown | 7 + 5 files changed, 205 insertions(+), 20 deletions(-) diff --git a/internal/service/comprehend/diff_test.go b/internal/service/comprehend/diff_test.go index 5bbad568065..7d74b19f431 100644 --- a/internal/service/comprehend/diff_test.go +++ b/internal/service/comprehend/diff_test.go @@ -64,5 +64,4 @@ func TestDiffSuppressKMSKeyId(t *testing.T) { } }) } - } diff --git a/internal/service/comprehend/document_classifier.go b/internal/service/comprehend/document_classifier.go index 480a6d8ead7..b816760dd09 100644 --- a/internal/service/comprehend/document_classifier.go +++ b/internal/service/comprehend/document_classifier.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log" - "reflect" "regexp" "time" @@ -151,7 +150,24 @@ func ResourceDocumentClassifier() *schema.Resource { Required: true, ValidateFunc: validModelName, }, - // "output_data_config" + "output_data_config": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + DiffSuppressFunc: verify.SuppressMissingOptionalConfigurationBlock, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "s3_uri": { + Type: schema.TypeString, + Required: true, + }, + "output_s3_uri": { + Type: schema.TypeString, + Computed: true, + }, + }, + }, + }, "tags": tftags.TagsSchema(), "tags_all": tftags.TagsSchemaComputed(), "version_name": { @@ -288,6 +304,10 @@ func resourceDocumentClassifierRead(ctx context.Context, d *schema.ResourceData, return diag.Errorf("setting input_data_config: %s", err) } + if err := d.Set("output_data_config", flattenDocumentClassifierOutputDataConfig(d, out.OutputDataConfig)); err != nil { + return diag.Errorf("setting output_data_config: %s", err) + } + if err := d.Set("vpc_config", flattenVPCConfig(out.VpcConfig)); err != nil { return diag.Errorf("setting vpc_config: %s", err) } @@ -442,19 +462,6 @@ func resourceDocumentClassifierDelete(ctx context.Context, d *schema.ResourceDat return nil } -func fullTypeName(i interface{}) string { - return fullValueTypeName(reflect.ValueOf(i)) -} - -func fullValueTypeName(v reflect.Value) string { - if v.Kind() == reflect.Ptr { - return "*" + fullValueTypeName(reflect.Indirect(v)) - } - - requestType := v.Type() - return fmt.Sprintf("%s.%s", requestType.PkgPath(), requestType.Name()) -} - func documentClassifierPublishVersion(ctx context.Context, conn *comprehend.Client, d *schema.ResourceData, versionName *string, action string, timeout time.Duration, awsClient *conns.AWSClient) diag.Diagnostics { in := &comprehend.CreateDocumentClassifierInput{ DataAccessRoleArn: aws.String(d.Get("data_access_role_arn").(string)), @@ -462,6 +469,7 @@ func documentClassifierPublishVersion(ctx context.Context, conn *comprehend.Clie LanguageCode: types.LanguageCode(d.Get("language_code").(string)), DocumentClassifierName: aws.String(d.Get("name").(string)), Mode: types.DocumentClassifierMode(d.Get("mode").(string)), + OutputDataConfig: expandDocumentClassifierOutputDataConfig(d.Get("output_data_config").([]interface{})), VersionName: versionName, VpcConfig: expandVPCConfig(d.Get("vpc_config").([]interface{})), ClientRequestToken: aws.String(resource.UniqueId()), @@ -736,6 +744,19 @@ func flattenDocumentClassifierInputDataConfig(apiObject *types.DocumentClassifie return []interface{}{m} } +func flattenDocumentClassifierOutputDataConfig(d *schema.ResourceData, apiObject *types.DocumentClassifierOutputDataConfig) []interface{} { + if apiObject == nil || apiObject.S3Uri == nil { + return nil + } + + m := map[string]interface{}{ + "s3_uri": d.Get("output_data_config.0.s3_uri"), + "output_s3_uri": aws.ToString(apiObject.S3Uri), + } + + return []interface{}{m} +} + func getDocumentClassifierInputDataConfig(d resourceGetter) map[string]any { v := d.Get("input_data_config").([]any) if len(v) == 0 { @@ -768,6 +789,20 @@ func expandDocumentClassifierInputDataConfig(d *schema.ResourceData) *types.Docu return a } +func expandDocumentClassifierOutputDataConfig(tfList []interface{}) *types.DocumentClassifierOutputDataConfig { + if len(tfList) == 0 { + return nil + } + + tfMap := tfList[0].(map[string]interface{}) + + a := &types.DocumentClassifierOutputDataConfig{ + S3Uri: aws.String(tfMap["s3_uri"].(string)), + } + + return a +} + func DocumentClassifierParseARN(arnString string) (string, error) { arn, err := arn.Parse(arnString) if err != nil { diff --git a/internal/service/comprehend/document_classifier_test.go b/internal/service/comprehend/document_classifier_test.go index aee95f80e23..6f5443ca147 100644 --- a/internal/service/comprehend/document_classifier_test.go +++ b/internal/service/comprehend/document_classifier_test.go @@ -54,6 +54,7 @@ func TestAccComprehendDocumentClassifier_basic(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "language_code", "en"), resource.TestCheckResourceAttr(resourceName, "mode", string(types.DocumentClassifierModeMultiClass)), resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "0"), resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), @@ -423,6 +424,54 @@ func TestAccComprehendDocumentClassifier_multiLabel_basic(t *testing.T) { }) } +func TestAccComprehendDocumentClassifier_outputDataConfig_basic(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_outputDataConfig_basic(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestMatchResourceAttr(resourceName, "output_data_config.0.s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/`)), + resource.TestMatchResourceAttr(resourceName, "output_data_config.0.output_s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/output/output.tar.gz`)), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_outputDataConfig_basic2(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestMatchResourceAttr(resourceName, "output_data_config.0.s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/`)), + resource.TestMatchResourceAttr(resourceName, "output_data_config.0.output_s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/output/output.tar.gz`)), + ), + }, + }, + }) +} + func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") @@ -1329,7 +1378,7 @@ resource "aws_comprehend_document_classifier" "test" { language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" - label_delimiter = %q + label_delimiter = %q } depends_on = [ @@ -1356,7 +1405,7 @@ resource "aws_comprehend_document_classifier" "test" { mode = "MULTI_CLASS" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" - label_delimiter = %q + label_delimiter = %q } depends_on = [ @@ -1413,7 +1462,7 @@ resource "aws_comprehend_document_classifier" "test" { mode = "MULTI_LABEL" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" - label_delimiter = %[2]q + label_delimiter = %[2]q } depends_on = [ @@ -1786,10 +1835,75 @@ resource "aws_comprehend_document_classifier" "test" { `, rName, tagKey1, tagValue1, tagKey2, tagValue2)) } +func testAccDocumentClassifierConfig_outputDataConfig_basic(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + } + + depends_on = [ + aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_outputDataConfig_basic2(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name = "2" + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + } + + depends_on = [ + aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + ] +} +`, rName)) +} + func testAccDocumentClassifierS3BucketConfig(rName string) string { return fmt.Sprintf(` resource "aws_s3_bucket" "test" { bucket = %[1]q + + force_destroy = true } resource "aws_s3_bucket_public_access_block" "test" { @@ -1891,6 +2005,28 @@ data "aws_iam_policy_document" "vpc_access" { ` } +func testAccDocumentClassifierConfig_s3OutputRole() string { + return ` +resource "aws_iam_role_policy" "s3_output" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.s3_output.json +} + +data "aws_iam_policy_document" "s3_output" { + statement { + actions = [ + "s3:PutObject", + ] + + resources = [ + "${aws_s3_bucket.test.arn}/*", + ] + } +} +` +} + func testAccDocumentClassifierConfig_vpcConfig(rName string) string { const subnetCount = 2 return acctest.ConfigCompose( diff --git a/internal/service/s3/bucket_ownership_controls.go b/internal/service/s3/bucket_ownership_controls.go index 00a72f8715a..71db426e488 100644 --- a/internal/service/s3/bucket_ownership_controls.go +++ b/internal/service/s3/bucket_ownership_controls.go @@ -3,6 +3,7 @@ package s3 import ( "fmt" "log" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" @@ -10,6 +11,7 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/hashicorp/terraform-provider-aws/internal/conns" + "github.com/hashicorp/terraform-provider-aws/internal/tfresource" ) func ResourceBucketOwnershipControls() *schema.Resource { @@ -140,7 +142,13 @@ func resourceBucketOwnershipControlsDelete(d *schema.ResourceData, meta interfac Bucket: aws.String(d.Id()), } - _, err := conn.DeleteBucketOwnershipControls(input) + _, err := tfresource.RetryWhenAWSErrCodeEquals( + 5*time.Minute, + func() (any, error) { + return conn.DeleteBucketOwnershipControls(input) + }, + "OperationAborted", + ) if tfawserr.ErrCodeEquals(err, s3.ErrCodeNoSuchBucket) { return nil diff --git a/website/docs/r/comprehend_document_classifier.html.markdown b/website/docs/r/comprehend_document_classifier.html.markdown index fb4d3c994eb..ece21b67f6a 100644 --- a/website/docs/r/comprehend_document_classifier.html.markdown +++ b/website/docs/r/comprehend_document_classifier.html.markdown @@ -58,6 +58,8 @@ The following arguments are optional: One of `MULTI_CLASS` or `MULTI_LABEL`. `MULTI_CLASS` is also known as "Single Label" in the AWS Console. * `model_kms_key_id` - (Optional) The ID or ARN of a KMS Key used to encrypt trained Document Classifiers. +* `output_data_config` - (Optional) Configuration for the output results of training. + See the [`output_data_config` Configuration Block](#output_data_config-configuration-block) section below. * `tags` - (Optional) A map of tags to assign to the resource. If configured with a provider [`default_tags` Configuration Block](/docs/providers/aws/index.html#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. * `version_name` - (Optional) Name for the version of the Document Classifier. Each version must have a unique name within the Document Classifier. @@ -99,6 +101,11 @@ The following arguments are optional: * `split` - (Optional, Default: `TRAIN`) Purpose of data in augmented manifest. One of `TRAIN` or `TEST`. +### `output_data_config` Configuration Block + +* `output_s3_uri` - (Computed) Full path for the output documents. +* `s3_uri` - (Required) Destination path for the output documents. + The full path to the output file will be returned in `output_s3_uri`. ### `vpc_config` Configuration Block From 21a6370e9c1fc6c3d5037adf72e0284058524af6 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Thu, 22 Sep 2022 15:14:10 -0700 Subject: [PATCH 5/6] Adds support for encrypting output --- internal/service/comprehend/diff.go | 78 +- .../service/comprehend/document_classifier.go | 36 +- .../comprehend/document_classifier_test.go | 1402 ++++++++++++----- internal/service/comprehend/validate.go | 40 +- internal/service/kms/validate.go | 13 +- ...mprehend_document_classifier.html.markdown | 8 +- 6 files changed, 1175 insertions(+), 402 deletions(-) diff --git a/internal/service/comprehend/diff.go b/internal/service/comprehend/diff.go index 014395c9c39..0a69b1ad077 100644 --- a/internal/service/comprehend/diff.go +++ b/internal/service/comprehend/diff.go @@ -2,13 +2,15 @@ package comprehend import ( "regexp" + "strings" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + tfkms "github.com/hashicorp/terraform-provider-aws/internal/service/kms" "github.com/hashicorp/terraform-provider-aws/internal/verify" ) -func diffSuppressKMSKeyId(k, oldValue, newValue string, d *schema.ResourceData) bool { +func diffSuppressKMSKeyId(_, oldValue, newValue string, _ *schema.ResourceData) bool { if oldValue == newValue { return true } @@ -30,6 +32,41 @@ func diffSuppressKMSKeyId(k, oldValue, newValue string, d *schema.ResourceData) return false } +func diffSuppressKMSAlias(_, oldValue, newValue string, _ *schema.ResourceData) bool { + if oldValue == newValue { + return true + } + + oldAlias := oldValue + if arn.IsARN(oldValue) { + oldAlias = kmsKeyAliasFromARN(oldValue) + } + + newAlias := newValue + if arn.IsARN(newValue) { + newAlias = kmsKeyAliasFromARN(newValue) + } + + if oldAlias == newAlias { + return true + } + + return false +} + +func diffSuppressKMSKeyOrAlias(k, oldValue, newValue string, d *schema.ResourceData) bool { + if arn.IsARN(newValue) { + if isKMSKeyARN(newValue) { + return diffSuppressKMSKeyId(k, oldValue, newValue, d) + } else { + return diffSuppressKMSAlias(k, oldValue, newValue, d) + } + } else if isKMSAliasName(newValue) { + return diffSuppressKMSAlias(k, oldValue, newValue, d) + } + return diffSuppressKMSKeyId(k, oldValue, newValue, d) +} + func kmsKeyIdFromARN(s string) string { arn, err := arn.Parse(s) if err != nil { @@ -47,5 +84,44 @@ func kmsKeyIdFromARNResource(s string) string { } return matches[1] +} + +func kmsKeyAliasFromARN(s string) string { + arn, err := arn.Parse(s) + if err != nil { + return "" + } + + return kmsKeyAliasNameFromARNResource(arn.Resource) +} + +func kmsKeyAliasNameFromARNResource(s string) string { + re := regexp.MustCompile("^" + tfkms.AliasNameRegexPattern + "$") + if re.MatchString(s) { + return s + } + + return "" +} + +func isKMSKeyARN(s string) bool { + parsedARN, err := arn.Parse(s) + if err != nil { + return false + } + + return kmsKeyIdFromARNResource(parsedARN.Resource) != "" +} + +func isKMSAliasName(s string) bool { + return strings.HasPrefix(s, "alias/") +} + +func isKMSAliasARN(s string) bool { + parsedARN, err := arn.Parse(s) + if err != nil { + return false + } + return isKMSAliasName(parsedARN.Resource) } diff --git a/internal/service/comprehend/document_classifier.go b/internal/service/comprehend/document_classifier.go index b816760dd09..cf1a54bcdb2 100644 --- a/internal/service/comprehend/document_classifier.go +++ b/internal/service/comprehend/document_classifier.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "regexp" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -153,13 +154,25 @@ func ResourceDocumentClassifier() *schema.Resource { "output_data_config": { Type: schema.TypeList, Optional: true, + Computed: true, MaxItems: 1, DiffSuppressFunc: verify.SuppressMissingOptionalConfigurationBlock, Elem: &schema.Resource{ Schema: map[string]*schema.Schema{ + "kms_key_id": { + Type: schema.TypeString, + Optional: true, + DiffSuppressFunc: diffSuppressKMSKeyOrAlias, + ValidateFunc: validateKMSKeyOrAlias, + }, "s3_uri": { Type: schema.TypeString, Required: true, + DiffSuppressFunc: func(k, oldValue, newValue string, d *schema.ResourceData) bool { + o := strings.TrimRight(oldValue, "/") + n := strings.TrimRight(newValue, "/") + return o == n + }, }, "output_s3_uri": { Type: schema.TypeString, @@ -304,7 +317,7 @@ func resourceDocumentClassifierRead(ctx context.Context, d *schema.ResourceData, return diag.Errorf("setting input_data_config: %s", err) } - if err := d.Set("output_data_config", flattenDocumentClassifierOutputDataConfig(d, out.OutputDataConfig)); err != nil { + if err := d.Set("output_data_config", flattenDocumentClassifierOutputDataConfig(out.OutputDataConfig)); err != nil { return diag.Errorf("setting output_data_config: %s", err) } @@ -744,14 +757,25 @@ func flattenDocumentClassifierInputDataConfig(apiObject *types.DocumentClassifie return []interface{}{m} } -func flattenDocumentClassifierOutputDataConfig(d *schema.ResourceData, apiObject *types.DocumentClassifierOutputDataConfig) []interface{} { +func flattenDocumentClassifierOutputDataConfig(apiObject *types.DocumentClassifierOutputDataConfig) []interface{} { if apiObject == nil || apiObject.S3Uri == nil { return nil } + // On return, `S3Uri` contains the full path of the output documents, not the storage location + s3Uri := aws.ToString(apiObject.S3Uri) m := map[string]interface{}{ - "s3_uri": d.Get("output_data_config.0.s3_uri"), - "output_s3_uri": aws.ToString(apiObject.S3Uri), + "output_s3_uri": s3Uri, + } + + re := regexp.MustCompile(`^(s3://[-a-z0-9.]{3,63}(/.+)?/)[-a-zA-Z0-9]+/output/output\.tar\.gz`) + match := re.FindStringSubmatch(s3Uri) + if match != nil && match[1] != "" { + m["s3_uri"] = match[1] + } + + if apiObject.KmsKeyId != nil { + m["kms_key_id"] = aws.ToString(apiObject.KmsKeyId) } return []interface{}{m} @@ -800,6 +824,10 @@ func expandDocumentClassifierOutputDataConfig(tfList []interface{}) *types.Docum S3Uri: aws.String(tfMap["s3_uri"].(string)), } + if v, ok := tfMap["kms_key_id"].(string); ok && v != "" { + a.KmsKeyId = aws.String(v) + } + return a } diff --git a/internal/service/comprehend/document_classifier_test.go b/internal/service/comprehend/document_classifier_test.go index 6f5443ca147..b4aa61f182b 100644 --- a/internal/service/comprehend/document_classifier_test.go +++ b/internal/service/comprehend/document_classifier_test.go @@ -444,13 +444,13 @@ func TestAccComprehendDocumentClassifier_outputDataConfig_basic(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_outputDataConfig_basic(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_basic(rName, "outputs"), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), - resource.TestMatchResourceAttr(resourceName, "output_data_config.0.s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/`)), - resource.TestMatchResourceAttr(resourceName, "output_data_config.0.output_s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/output/output.tar.gz`)), + resource.TestMatchResourceAttr(resourceName, "output_data_config.0.s3_uri", regexp.MustCompile(`s3://.+/outputs`)), + resource.TestMatchResourceAttr(resourceName, "output_data_config.0.output_s3_uri", regexp.MustCompile(`s3://.+/outputs/[-A-Za-z0-9]+/output/output.tar.gz`)), ), }, { @@ -459,20 +459,14 @@ func TestAccComprehendDocumentClassifier_outputDataConfig_basic(t *testing.T) { ImportStateVerify: true, }, { - Config: testAccDocumentClassifierConfig_outputDataConfig_basic2(rName), - Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), - resource.TestMatchResourceAttr(resourceName, "output_data_config.0.s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/`)), - resource.TestMatchResourceAttr(resourceName, "output_data_config.0.output_s3_uri", regexp.MustCompile(`s3:.+/[-A-Za-z0-9]+/output/output.tar.gz`)), - ), + Config: testAccDocumentClassifierConfig_outputDataConfig_basic(rName, "outputs/"), + PlanOnly: true, }, }, }) } -func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) { +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyCreateID(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -480,8 +474,6 @@ func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) var documentclassifier types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" - const delimiter = "~" - const delimiterUpdated = "/" resource.ParallelTest(t, resource.TestCase{ PreCheck: func() { @@ -494,26 +486,12 @@ func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiter), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyId(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "name", rName), - resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), - acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), - resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), - resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), - resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), - resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", delimiter), - resource.TestCheckResourceAttr(resourceName, "language_code", "en"), - resource.TestCheckResourceAttr(resourceName, "mode", string(types.DocumentClassifierModeMultiLabel)), - resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), - resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), - resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), - acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), - resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), - resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_key.output", "key_id"), ), }, { @@ -522,12 +500,39 @@ func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) ImportStateVerify: true, }, { - Config: testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiterUpdated), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyARN(rName), + PlanOnly: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyCreateARN(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var documentclassifier types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyARN(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), - resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", delimiterUpdated), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_key.output", "arn"), ), }, { @@ -535,11 +540,15 @@ func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) ImportState: true, ImportStateVerify: true, }, + { + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyId(rName), + PlanOnly: true, + }, }, }) } -func TestAccComprehendDocumentClassifier_KMSKeys_CreateIDs(t *testing.T) { +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyCreateAliasName(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -559,12 +568,12 @@ func TestAccComprehendDocumentClassifier_KMSKeys_CreateIDs(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_kmsKeyIds(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyAliasName(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), - resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_alias.output", "name"), ), }, { @@ -573,14 +582,14 @@ func TestAccComprehendDocumentClassifier_KMSKeys_CreateIDs(t *testing.T) { ImportStateVerify: true, }, { - Config: testAccDocumentClassifierConfig_kmsKeyARNs(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyAliasARN(rName), PlanOnly: true, }, }, }) } -func TestAccComprehendDocumentClassifier_KMSKeys_CreateARNs(t *testing.T) { +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyCreateAliasARN(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -600,12 +609,12 @@ func TestAccComprehendDocumentClassifier_KMSKeys_CreateARNs(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_kmsKeyARNs(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyAliasARN(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "arn"), - resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "arn"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_alias.output", "arn"), ), }, { @@ -614,14 +623,14 @@ func TestAccComprehendDocumentClassifier_KMSKeys_CreateARNs(t *testing.T) { ImportStateVerify: true, }, { - Config: testAccDocumentClassifierConfig_kmsKeyIds(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyAliasName(rName), PlanOnly: true, }, }, }) } -func TestAccComprehendDocumentClassifier_KMSKeys_Add(t *testing.T) { +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyAdd(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -641,21 +650,21 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Add(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyNone(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v1), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), - resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.0.kms_key_id", ""), ), }, { - Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeySet(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v2), testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), - resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_key.output", "key_id"), ), }, { @@ -667,7 +676,7 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Add(t *testing.T) { }) } -func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyUpdate(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -687,21 +696,21 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeySet(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v1), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), - resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_key.output", "key_id"), ), }, { - Config: testAccDocumentClassifierConfig_kmsKeys_Update(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyUpdate(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v2), testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model2", "key_id"), - resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume2", "key_id"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_key.output2", "key_id"), ), }, { @@ -713,7 +722,7 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { }) } -func TestAccComprehendDocumentClassifier_KMSKeys_Remove(t *testing.T) { +func TestAccComprehendDocumentClassifier_outputDataConfig_kmsKeyRemove(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -733,21 +742,21 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Remove(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeySet(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v1), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), - resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "output_data_config.0.kms_key_id", "aws_kms_key.output", "key_id"), ), }, { - Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), + Config: testAccDocumentClassifierConfig_outputDataConfig_kmsKeyNone(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v2), testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), - resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "output_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "output_data_config.0.kms_key_id", ""), ), }, { @@ -759,14 +768,16 @@ func TestAccComprehendDocumentClassifier_KMSKeys_Remove(t *testing.T) { }) } -func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { +func TestAccComprehendDocumentClassifier_multiLabel_labelDelimiter(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var dc1, dc2 types.DocumentClassifierProperties + var documentclassifier types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" + const delimiter = "~" + const delimiterUpdated = "/" resource.ParallelTest(t, resource.TestCase{ PreCheck: func() { @@ -779,16 +790,26 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Config: testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiter), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttrPair(resourceName, "data_access_role_arn", "aws_iam_role.test", "arn"), + acctest.MatchResourceAttrRegionalARN(resourceName, "arn", "comprehend", regexp.MustCompile(fmt.Sprintf(`document-classifier/%s/version/%s$`, rName, uniqueIDPattern()))), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.augmented_manifests.#", "0"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.data_format", string(types.DocumentClassifierDataFormatComprehendCsv)), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", delimiter), + resource.TestCheckResourceAttr(resourceName, "language_code", "en"), + resource.TestCheckResourceAttr(resourceName, "mode", string(types.DocumentClassifierModeMultiLabel)), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "0"), + acctest.CheckResourceAttrNameGenerated(resourceName, "version_name"), + resource.TestCheckResourceAttr(resourceName, "version_name_prefix", resource.UniqueIdPrefix), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), ), }, { @@ -797,16 +818,12 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { ImportStateVerify: true, }, { - Config: testAccDocumentClassifierConfig_vpcConfig_Update(rName), + Config: testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiterUpdated), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.1", "id"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.2", "id"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.3", "id"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "input_data_config.0.label_delimiter", delimiterUpdated), ), }, { @@ -818,12 +835,12 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { }) } -func TestAccComprehendDocumentClassifier_VPCConfig_Add(t *testing.T) { +func TestAccComprehendDocumentClassifier_KMSKeys_CreateIDs(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var dc1, dc2 types.DocumentClassifierProperties + var documentclassifier types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" @@ -838,24 +855,12 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Add(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), + Config: testAccDocumentClassifierConfig_kmsKeyIds(rName), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), - ), - }, - { - Config: testAccDocumentClassifierConfig_vpcConfig(rName), - Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc2), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), ), }, { @@ -863,16 +868,20 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Add(t *testing.T) { ImportState: true, ImportStateVerify: true, }, + { + Config: testAccDocumentClassifierConfig_kmsKeyARNs(rName), + PlanOnly: true, + }, }, }) } -func TestAccComprehendDocumentClassifier_VPCConfig_Remove(t *testing.T) { +func TestAccComprehendDocumentClassifier_KMSKeys_CreateARNs(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var dc1, dc2 types.DocumentClassifierProperties + var documentclassifier types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" @@ -887,16 +896,12 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Remove(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Config: testAccDocumentClassifierConfig_kmsKeyARNs(rName), Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierExists(resourceName, &documentclassifier), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), - resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), - resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "arn"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "arn"), ), }, { @@ -905,28 +910,19 @@ func TestAccComprehendDocumentClassifier_VPCConfig_Remove(t *testing.T) { ImportStateVerify: true, }, { - Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), - Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &dc2), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), - resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), - ), - }, - { - ResourceName: resourceName, - ImportState: true, - ImportStateVerify: true, + Config: testAccDocumentClassifierConfig_kmsKeyIds(rName), + PlanOnly: true, }, }, }) } -func TestAccComprehendDocumentClassifier_tags(t *testing.T) { +func TestAccComprehendDocumentClassifier_KMSKeys_Add(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var v1, v2, v3 types.DocumentClassifierProperties + var v1, v2 types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" @@ -941,50 +937,38 @@ func TestAccComprehendDocumentClassifier_tags(t *testing.T) { CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: testAccDocumentClassifierConfig_tags1(rName, "key1", "value1"), + Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v1), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), - resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1"), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), ), }, { - ResourceName: resourceName, - ImportState: true, - ImportStateVerify: true, - }, - { - Config: testAccDocumentClassifierConfig_tags2(rName, "key1", "value1updated", "key2", "value2"), + Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v2), - testAccCheckDocumentClassifierNotRecreated(&v1, &v2), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "tags.%", "2"), - resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1updated"), - resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), ), }, { - Config: testAccDocumentClassifierConfig_tags1(rName, "key2", "value2"), - Check: resource.ComposeAggregateTestCheckFunc( - testAccCheckDocumentClassifierExists(resourceName, &v3), - testAccCheckDocumentClassifierNotRecreated(&v2, &v3), - testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), - resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), - ), + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, }, }, }) } -func TestAccComprehendDocumentClassifier_DefaultTags_providerOnly(t *testing.T) { +func TestAccComprehendDocumentClassifier_KMSKeys_Update(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } - var v1, v2, v3 types.DocumentClassifierProperties + var v1, v2 types.DocumentClassifierProperties rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_comprehend_document_classifier.test" @@ -999,24 +983,336 @@ func TestAccComprehendDocumentClassifier_DefaultTags_providerOnly(t *testing.T) CheckDestroy: testAccCheckDocumentClassifierDestroy, Steps: []resource.TestStep{ { - Config: acctest.ConfigCompose( - acctest.ConfigDefaultTags_Tags1("providerkey1", "providervalue1"), - testAccDocumentClassifierConfig_tags0(rName), - ), + Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), Check: resource.ComposeAggregateTestCheckFunc( testAccCheckDocumentClassifierExists(resourceName, &v1), testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), - resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), - resource.TestCheckResourceAttr(resourceName, "tags_all.%", "1"), - resource.TestCheckResourceAttr(resourceName, "tags_all.providerkey1", "providervalue1"), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), ), }, { - ResourceName: resourceName, - ImportState: true, - ImportStateVerify: true, - }, - { + Config: testAccDocumentClassifierConfig_kmsKeys_Update(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model2", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume2", "key_id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_KMSKeys_Remove(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_kmsKeys_Set(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttrPair(resourceName, "model_kms_key_id", "aws_kms_key.model", "key_id"), + resource.TestCheckResourceAttrPair(resourceName, "volume_kms_key_id", "aws_kms_key.volume", "key_id"), + ), + }, + { + Config: testAccDocumentClassifierConfig_kmsKeys_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "model_kms_key_id", ""), + resource.TestCheckResourceAttr(resourceName, "volume_kms_key_id", ""), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_VPCConfig_Create(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var dc1, dc2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_vpcConfig_Update(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.1", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.2", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.3", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_VPCConfig_Add(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var dc1, dc2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_VPCConfig_Remove(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var dc1, dc2 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_vpcConfig(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.security_group_ids.#", "1"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.security_group_ids.*", "aws_security_group.test.0", "id"), + resource.TestCheckResourceAttr(resourceName, "vpc_config.0.subnets.#", "2"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.0", "id"), + resource.TestCheckTypeSetElemAttrPair(resourceName, "vpc_config.0.subnets.*", "aws_subnet.test.1", "id"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_vpcConfig_None(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &dc2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 2), + resource.TestCheckResourceAttr(resourceName, "vpc_config.#", "0"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_tags(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2, v3 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: testAccDocumentClassifierConfig_tags1(rName, "key1", "value1"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { + Config: testAccDocumentClassifierConfig_tags2(rName, "key1", "value1updated", "key2", "value2"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v2), + testAccCheckDocumentClassifierNotRecreated(&v1, &v2), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "2"), + resource.TestCheckResourceAttr(resourceName, "tags.key1", "value1updated"), + resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + ), + }, + { + Config: testAccDocumentClassifierConfig_tags1(rName, "key2", "value2"), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v3), + testAccCheckDocumentClassifierNotRecreated(&v2, &v3), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags.key2", "value2"), + ), + }, + }, + }) +} + +func TestAccComprehendDocumentClassifier_DefaultTags_providerOnly(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + var v1, v2, v3 types.DocumentClassifierProperties + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_comprehend_document_classifier.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { + acctest.PreCheck(t) + acctest.PreCheckPartitionHasService(names.ComprehendEndpointID, t) + testAccPreCheck(t) + }, + ErrorCheck: acctest.ErrorCheck(t, names.ComprehendEndpointID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckDocumentClassifierDestroy, + Steps: []resource.TestStep{ + { + Config: acctest.ConfigCompose( + acctest.ConfigDefaultTags_Tags1("providerkey1", "providervalue1"), + testAccDocumentClassifierConfig_tags0(rName), + ), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckDocumentClassifierExists(resourceName, &v1), + testAccCheckDocumentClassifierPublishedVersions(resourceName, 1), + resource.TestCheckResourceAttr(resourceName, "tags.%", "0"), + resource.TestCheckResourceAttr(resourceName, "tags_all.%", "1"), + resource.TestCheckResourceAttr(resourceName, "tags_all.providerkey1", "providervalue1"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + { Config: acctest.ConfigCompose( acctest.ConfigDefaultTags_Tags2("providerkey1", "providervalue1", "providerkey2", "providervalue2"), testAccDocumentClassifierConfig_tags0(rName), @@ -1170,19 +1466,256 @@ func testAccCheckDocumentClassifierPublishedVersions(name string, expected int) count += len(output.DocumentClassifierPropertiesList) } - if count != expected { - return fmt.Errorf("expected %d published versions, found %d", expected, count) - } + if count != expected { + return fmt.Errorf("expected %d published versions, found %d", expected, count) + } + + return nil + } +} + +func testAccDocumentClassifierConfig_basic(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_Mode_singleLabel(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + mode = "MULTI_CLASS" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_versionName(rName, vName, key, value string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name = %[2]q + + data_access_role_arn = aws_iam_role.test.arn + + tags = { + %[3]q = %[4]q + } + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, vName, key, value)) +} + +func testAccDocumentClassifierConfig_versionNameEmpty(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name = "" + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_versionNameNotSet(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_versioNamePrefix(rName, versionNamePrefix string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + version_name_prefix = %[2]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, versionNamePrefix)) +} + +func testAccDocumentClassifierConfig_testDocuments(rName string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + test_s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName)) +} + +func testAccDocumentClassifierConfig_modeDefault_ValidateNoDelimiterSet(rName, delimiter string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn + + language_code = "en" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + label_delimiter = %q + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, delimiter)) +} + +func testAccDocumentClassifierConfig_modeSingleLabel_ValidateNoDelimiterSet(rName, delimiter string) string { + return acctest.ConfigCompose( + testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierS3BucketConfig(rName), + testAccDocumentClassifierConfig_S3_documents, + fmt.Sprintf(` +data "aws_partition" "current" {} + +resource "aws_comprehend_document_classifier" "test" { + name = %[1]q + + data_access_role_arn = aws_iam_role.test.arn - return nil - } + language_code = "en" + mode = "MULTI_CLASS" + input_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + label_delimiter = %q + } + + depends_on = [ + aws_iam_role_policy.test, + ] +} +`, rName, delimiter)) } -func testAccDocumentClassifierConfig_basic(rName string) string { +func testAccDocumentClassifierConfig_multiLabel_basic(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_documents, + testAccDocumentClassifierConfig_S3_multilabel, fmt.Sprintf(` data "aws_partition" "current" {} @@ -1192,8 +1725,9 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn language_code = "en" + mode = "MULTI_LABEL" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" } depends_on = [ @@ -1203,11 +1737,15 @@ resource "aws_comprehend_document_classifier" "test" { `, rName)) } -func testAccDocumentClassifierConfig_Mode_singleLabel(rName string) string { +func testAccDocumentClassifierConfig_multiLabel_defaultDelimiter(rName string) string { + return testAccDocumentClassifierConfig_multiLabel_delimiter(rName, tfcomprehend.DocumentClassifierLabelSeparatorDefault) +} + +func testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiter string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_documents, + testAccDocumentClassifierConfig_S3_multilabel, fmt.Sprintf(` data "aws_partition" "current" {} @@ -1217,19 +1755,20 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn language_code = "en" - mode = "MULTI_CLASS" + mode = "MULTI_LABEL" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" + label_delimiter = %[2]q } depends_on = [ aws_iam_role_policy.test, ] } -`, rName)) +`, rName, delimiter)) } -func testAccDocumentClassifierConfig_versionName(rName, vName, key, value string) string { +func testAccDocumentClassifierConfig_kmsKeyIds(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1238,14 +1777,12 @@ func testAccDocumentClassifierConfig_versionName(rName, vName, key, value string data "aws_partition" "current" {} resource "aws_comprehend_document_classifier" "test" { - name = %[1]q - version_name = %[2]q + name = %[1]q data_access_role_arn = aws_iam_role.test.arn - tags = { - %[3]q = %[4]q - } + model_kms_key_id = aws_kms_key.model.key_id + volume_kms_key_id = aws_kms_key.volume.key_id language_code = "en" input_data_config { @@ -1254,12 +1791,48 @@ resource "aws_comprehend_document_classifier" "test" { depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.kms_keys, ] } -`, rName, vName, key, value)) + +resource "aws_kms_key" "model" { + deletion_window_in_days = 7 } -func testAccDocumentClassifierConfig_versionNameEmpty(rName string) string { +resource "aws_kms_key" "volume" { + deletion_window_in_days = 7 +} + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume.arn, + ] + } +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeyARNs(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1268,11 +1841,13 @@ func testAccDocumentClassifierConfig_versionNameEmpty(rName string) string { data "aws_partition" "current" {} resource "aws_comprehend_document_classifier" "test" { - name = %[1]q - version_name = "" + name = %[1]q data_access_role_arn = aws_iam_role.test.arn + model_kms_key_id = aws_kms_key.model.arn + volume_kms_key_id = aws_kms_key.volume.arn + language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" @@ -1282,10 +1857,45 @@ resource "aws_comprehend_document_classifier" "test" { aws_iam_role_policy.test, ] } + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume.arn, + ] + } +} + +resource "aws_kms_key" "model" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume" { + deletion_window_in_days = 7 +} `, rName)) } -func testAccDocumentClassifierConfig_versionNameNotSet(rName string) string { +func testAccDocumentClassifierConfig_kmsKeys_None(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1310,7 +1920,7 @@ resource "aws_comprehend_document_classifier" "test" { `, rName)) } -func testAccDocumentClassifierConfig_versioNamePrefix(rName, versionNamePrefix string) string { +func testAccDocumentClassifierConfig_kmsKeys_Set(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1319,11 +1929,13 @@ func testAccDocumentClassifierConfig_versioNamePrefix(rName, versionNamePrefix s data "aws_partition" "current" {} resource "aws_comprehend_document_classifier" "test" { - name = %[1]q - version_name_prefix = %[2]q + name = %[1]q data_access_role_arn = aws_iam_role.test.arn + model_kms_key_id = aws_kms_key.model.key_id + volume_kms_key_id = aws_kms_key.volume.key_id + language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" @@ -1333,10 +1945,45 @@ resource "aws_comprehend_document_classifier" "test" { aws_iam_role_policy.test, ] } -`, rName, versionNamePrefix)) + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json } -func testAccDocumentClassifierConfig_testDocuments(rName string) string { +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume.arn, + ] + } +} + +resource "aws_kms_key" "model" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume" { + deletion_window_in_days = 7 +} +`, rName)) +} + +func testAccDocumentClassifierConfig_kmsKeys_Update(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1349,20 +1996,57 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn + model_kms_key_id = aws_kms_key.model2.key_id + volume_kms_key_id = aws_kms_key.volume2.key_id + language_code = "en" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" - test_s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } depends_on = [ aws_iam_role_policy.test, ] } + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.model2.arn, + ] + } + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.volume2.arn, + ] + } +} + +resource "aws_kms_key" "model2" { + deletion_window_in_days = 7 +} + +resource "aws_kms_key" "volume2" { + deletion_window_in_days = 7 +} `, rName)) } -func testAccDocumentClassifierConfig_modeDefault_ValidateNoDelimiterSet(rName, delimiter string) string { +func testAccDocumentClassifierConfig_tags0(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1375,20 +2059,21 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn + tags = {} + language_code = "en" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" - label_delimiter = %q + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } depends_on = [ aws_iam_role_policy.test, ] } -`, rName, delimiter)) +`, rName)) } -func testAccDocumentClassifierConfig_modeSingleLabel_ValidateNoDelimiterSet(rName, delimiter string) string { +func testAccDocumentClassifierConfig_tags1(rName, tagKey1, tagValue1 string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), @@ -1401,25 +2086,27 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn + tags = { + %[2]q = %[3]q + } + language_code = "en" - mode = "MULTI_CLASS" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" - label_delimiter = %q + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } depends_on = [ aws_iam_role_policy.test, ] } -`, rName, delimiter)) +`, rName, tagKey1, tagValue1)) } -func testAccDocumentClassifierConfig_multiLabel_basic(rName string) string { +func testAccDocumentClassifierConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_multilabel, + testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` data "aws_partition" "current" {} @@ -1428,28 +2115,29 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn + tags = { + %[2]q = %[3]q + %[4]q = %[5]q + } + language_code = "en" - mode = "MULTI_LABEL" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } depends_on = [ aws_iam_role_policy.test, ] } -`, rName)) -} - -func testAccDocumentClassifierConfig_multiLabel_defaultDelimiter(rName string) string { - return testAccDocumentClassifierConfig_multiLabel_delimiter(rName, tfcomprehend.DocumentClassifierLabelSeparatorDefault) +`, rName, tagKey1, tagValue1, tagKey2, tagValue2)) } -func testAccDocumentClassifierConfig_multiLabel_delimiter(rName, delimiter string) string { +func testAccDocumentClassifierConfig_outputDataConfig_basic(rName, outputPath string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_multilabel, + testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` data "aws_partition" "current" {} @@ -1459,22 +2147,26 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn language_code = "en" - mode = "MULTI_LABEL" input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.multilabel.id}" - label_delimiter = %[2]q + s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + } + + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/%[2]s" } depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, ] } -`, rName, delimiter)) +`, rName, outputPath)) } -func testAccDocumentClassifierConfig_kmsKeyIds(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeyId(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), testAccDocumentClassifierS3BucketConfig(rName), testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` @@ -1485,19 +2177,27 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn - model_kms_key_id = aws_kms_key.model.key_id - volume_kms_key_id = aws_kms_key.volume.key_id - language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output.key_id + } + depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + aws_iam_role_policy.kms_keys, ] } +resource "aws_kms_key" "output" { + deletion_window_in_days = 7 +} + resource "aws_iam_role_policy" "kms_keys" { role = aws_iam_role.test.name @@ -1511,33 +2211,17 @@ data "aws_iam_policy_document" "kms_keys" { ] resources = [ - aws_kms_key.model.arn, - ] - } - statement { - actions = [ - "*", - ] - - resources = [ - aws_kms_key.volume.arn, + aws_kms_key.output.arn, ] } } - -resource "aws_kms_key" "model" { - deletion_window_in_days = 7 -} - -resource "aws_kms_key" "volume" { - deletion_window_in_days = 7 -} `, rName)) } -func testAccDocumentClassifierConfig_kmsKeyARNs(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeyARN(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), testAccDocumentClassifierS3BucketConfig(rName), testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` @@ -1548,19 +2232,27 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn - model_kms_key_id = aws_kms_key.model.arn - volume_kms_key_id = aws_kms_key.volume.arn - language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output.arn + } + depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + aws_iam_role_policy.kms_keys, ] } +resource "aws_kms_key" "output" { + deletion_window_in_days = 7 +} + resource "aws_iam_role_policy" "kms_keys" { role = aws_iam_role.test.name @@ -1574,33 +2266,17 @@ data "aws_iam_policy_document" "kms_keys" { ] resources = [ - aws_kms_key.model.arn, - ] - } - statement { - actions = [ - "*", - ] - - resources = [ - aws_kms_key.volume.arn, + aws_kms_key.output.arn, ] } } - -resource "aws_kms_key" "model" { - deletion_window_in_days = 7 -} - -resource "aws_kms_key" "volume" { - deletion_window_in_days = 7 -} `, rName)) } -func testAccDocumentClassifierConfig_kmsKeys_None(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeyAliasName(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), testAccDocumentClassifierS3BucketConfig(rName), testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` @@ -1616,37 +2292,25 @@ resource "aws_comprehend_document_classifier" "test" { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_alias.output.name + } + depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + aws_iam_role_policy.kms_keys, ] } -`, rName)) -} - -func testAccDocumentClassifierConfig_kmsKeys_Set(rName string) string { - return acctest.ConfigCompose( - testAccDocumentClassifierBasicRoleConfig(rName), - testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_documents, - fmt.Sprintf(` -data "aws_partition" "current" {} -resource "aws_comprehend_document_classifier" "test" { - name = %[1]q - - data_access_role_arn = aws_iam_role.test.arn - - model_kms_key_id = aws_kms_key.model.key_id - volume_kms_key_id = aws_kms_key.volume.key_id - - language_code = "en" - input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" - } +resource "aws_kms_alias" "output" { + name = "alias/%[1]s" + target_key_id = aws_kms_key.output.key_id +} - depends_on = [ - aws_iam_role_policy.test, - ] +resource "aws_kms_key" "output" { + deletion_window_in_days = 7 } resource "aws_iam_role_policy" "kms_keys" { @@ -1662,33 +2326,17 @@ data "aws_iam_policy_document" "kms_keys" { ] resources = [ - aws_kms_key.model.arn, - ] - } - statement { - actions = [ - "*", - ] - - resources = [ - aws_kms_key.volume.arn, + aws_kms_key.output.arn, ] } } - -resource "aws_kms_key" "model" { - deletion_window_in_days = 7 -} - -resource "aws_kms_key" "volume" { - deletion_window_in_days = 7 -} `, rName)) } -func testAccDocumentClassifierConfig_kmsKeys_Update(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeyAliasARN(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), testAccDocumentClassifierS3BucketConfig(rName), testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` @@ -1699,19 +2347,32 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn - model_kms_key_id = aws_kms_key.model2.key_id - volume_kms_key_id = aws_kms_key.volume2.key_id - language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_alias.output.arn + } + depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + aws_iam_role_policy.kms_keys, ] } +resource "aws_kms_alias" "output" { + name = "alias/%[1]s" + target_key_id = aws_kms_key.output.key_id +} + +resource "aws_kms_key" "output" { + deletion_window_in_days = 7 +} + resource "aws_iam_role_policy" "kms_keys" { role = aws_iam_role.test.name @@ -1725,33 +2386,17 @@ data "aws_iam_policy_document" "kms_keys" { ] resources = [ - aws_kms_key.model2.arn, - ] - } - statement { - actions = [ - "*", - ] - - resources = [ - aws_kms_key.volume2.arn, + aws_kms_key.output.arn, ] } } - -resource "aws_kms_key" "model2" { - deletion_window_in_days = 7 -} - -resource "aws_kms_key" "volume2" { - deletion_window_in_days = 7 -} `, rName)) } -func testAccDocumentClassifierConfig_tags0(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeySet(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), + testAccDocumentClassifierConfig_s3OutputRole(), testAccDocumentClassifierS3BucketConfig(rName), testAccDocumentClassifierConfig_S3_documents, fmt.Sprintf(` @@ -1762,80 +2407,48 @@ resource "aws_comprehend_document_classifier" "test" { data_access_role_arn = aws_iam_role.test.arn - tags = {} - language_code = "en" input_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" } - depends_on = [ - aws_iam_role_policy.test, - ] -} -`, rName)) -} - -func testAccDocumentClassifierConfig_tags1(rName, tagKey1, tagValue1 string) string { - return acctest.ConfigCompose( - testAccDocumentClassifierBasicRoleConfig(rName), - testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_documents, - fmt.Sprintf(` -data "aws_partition" "current" {} - -resource "aws_comprehend_document_classifier" "test" { - name = %[1]q - - data_access_role_arn = aws_iam_role.test.arn - - tags = { - %[2]q = %[3]q - } - - language_code = "en" - input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + output_data_config { + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output.key_id } depends_on = [ aws_iam_role_policy.test, + aws_iam_role_policy.s3_output, + aws_iam_role_policy.kms_keys, ] } -`, rName, tagKey1, tagValue1)) -} -func testAccDocumentClassifierConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string { - return acctest.ConfigCompose( - testAccDocumentClassifierBasicRoleConfig(rName), - testAccDocumentClassifierS3BucketConfig(rName), - testAccDocumentClassifierConfig_S3_documents, - fmt.Sprintf(` -data "aws_partition" "current" {} +resource "aws_kms_key" "output" { + deletion_window_in_days = 7 +} -resource "aws_comprehend_document_classifier" "test" { - name = %[1]q +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name - data_access_role_arn = aws_iam_role.test.arn + policy = data.aws_iam_policy_document.kms_keys.json +} - tags = { - %[2]q = %[3]q - %[4]q = %[5]q - } +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] - language_code = "en" - input_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/${aws_s3_object.documents.id}" + resources = [ + aws_kms_key.output.arn, + ] } - - depends_on = [ - aws_iam_role_policy.test, - ] } -`, rName, tagKey1, tagValue1, tagKey2, tagValue2)) +`, rName)) } -func testAccDocumentClassifierConfig_outputDataConfig_basic(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeyNone(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierConfig_s3OutputRole(), @@ -1866,7 +2479,7 @@ resource "aws_comprehend_document_classifier" "test" { `, rName)) } -func testAccDocumentClassifierConfig_outputDataConfig_basic2(rName string) string { +func testAccDocumentClassifierConfig_outputDataConfig_kmsKeyUpdate(rName string) string { return acctest.ConfigCompose( testAccDocumentClassifierBasicRoleConfig(rName), testAccDocumentClassifierConfig_s3OutputRole(), @@ -1877,7 +2490,6 @@ data "aws_partition" "current" {} resource "aws_comprehend_document_classifier" "test" { name = %[1]q - version_name = "2" data_access_role_arn = aws_iam_role.test.arn @@ -1888,13 +2500,37 @@ resource "aws_comprehend_document_classifier" "test" { output_data_config { s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output2.key_id } depends_on = [ aws_iam_role_policy.test, aws_iam_role_policy.s3_output, + aws_iam_role_policy.kms_keys, ] } + +resource "aws_kms_key" "output2" { + deletion_window_in_days = 7 +} + +resource "aws_iam_role_policy" "kms_keys" { + role = aws_iam_role.test.name + + policy = data.aws_iam_policy_document.kms_keys.json +} + +data "aws_iam_policy_document" "kms_keys" { + statement { + actions = [ + "*", + ] + + resources = [ + aws_kms_key.output2.arn, + ] + } +} `, rName)) } diff --git a/internal/service/comprehend/validate.go b/internal/service/comprehend/validate.go index 947a8de8b4b..38725858662 100644 --- a/internal/service/comprehend/validate.go +++ b/internal/service/comprehend/validate.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" + tfkms "github.com/hashicorp/terraform-provider-aws/internal/service/kms" "github.com/hashicorp/terraform-provider-aws/internal/verify" ) @@ -39,6 +40,13 @@ var validateKMSKey = validation.Any( validateKMSKeyARN, ) +var validateKMSKeyOrAlias = validation.Any( + validateKMSKeyId, + validateKMSKeyARN, + validateKMSKeyAliasName, + validateKMSKeyAliasARN, +) + var validateKMSKeyId = validation.StringMatch(regexp.MustCompile("^"+verify.UUIDRegexPattern+"$"), "must be a KMS Key ID") func validateKMSKeyARN(v any, k string) (ws []string, errors []error) { @@ -52,19 +60,39 @@ func validateKMSKeyARN(v any, k string) (ws []string, errors []error) { return } - parsedARN, err := arn.Parse(value) - if err != nil { + if _, err := arn.Parse(value); err != nil { errors = append(errors, fmt.Errorf("%q (%s) is an invalid ARN: %s", k, value, err)) return } - if parsedARN.Service != "kms" { - errors = append(errors, fmt.Errorf("%q (%s) is not a valid KMS Key ARN: %s", k, value, err)) + if !isKMSKeyARN(value) { + errors = append(errors, fmt.Errorf("%q (%s) is not a valid KMS Key ARN", k, value)) + return + } + + return +} + +var validateKMSKeyAliasName = validation.StringMatch(regexp.MustCompile("^"+tfkms.AliasNameRegexPattern+"$"), "must be a KMS Key Alias") + +func validateKMSKeyAliasARN(v any, k string) (ws []string, errors []error) { + value, ok := v.(string) + if !ok { + errors = append(errors, fmt.Errorf("expected type of %s to be string", k)) + return + } + + if value == "" { + return + } + + if _, err := arn.Parse(value); err != nil { + errors = append(errors, fmt.Errorf("%q (%s) is an invalid ARN: %s", k, value, err)) return } - if id := kmsKeyIdFromARNResource(parsedARN.Resource); id == "" { - errors = append(errors, fmt.Errorf("%q (%s) is not a valid KMS Key ARN: %s", k, value, err)) + if !isKMSAliasARN(value) { + errors = append(errors, fmt.Errorf("%q (%s) is not a valid KMS Key Alias ARN", k, value)) return } diff --git a/internal/service/kms/validate.go b/internal/service/kms/validate.go index ef30634fa10..3dc9b33218d 100644 --- a/internal/service/kms/validate.go +++ b/internal/service/kms/validate.go @@ -5,6 +5,8 @@ import ( "regexp" ) +const AliasNameRegexPattern = `alias/[a-zA-Z0-9/_-]+` + func validGrantName(v interface{}, k string) (ws []string, es []error) { value := v.(string) @@ -22,7 +24,7 @@ func validGrantName(v interface{}, k string) (ws []string, es []error) { func validNameForDataSource(v interface{}, k string) (ws []string, es []error) { value := v.(string) - if !regexp.MustCompile(`^(alias/)[a-zA-Z0-9/_-]+$`).MatchString(value) { + if !regexp.MustCompile("^" + AliasNameRegexPattern + "$").MatchString(value) { es = append(es, fmt.Errorf( "%q must begin with 'alias/' and be comprised of only [a-zA-Z0-9/_-]", k)) } @@ -36,7 +38,7 @@ func validNameForResource(v interface{}, k string) (ws []string, es []error) { es = append(es, fmt.Errorf("%q cannot begin with reserved AWS CMK prefix 'alias/aws/'", k)) } - if !regexp.MustCompile(`^(alias/)[a-zA-Z0-9/_-]+$`).MatchString(value) { + if !regexp.MustCompile("^" + AliasNameRegexPattern + "$").MatchString(value) { es = append(es, fmt.Errorf( "%q must begin with 'alias/' and be comprised of only [a-zA-Z0-9/_-]", k)) } @@ -48,13 +50,12 @@ func validKey(v interface{}, k string) (ws []string, errors []error) { arnPrefixPattern := `arn:[^:]+:kms:[^:]+:[^:]+:` keyIdPattern := "[A-Za-z0-9-]+" keyArnPattern := arnPrefixPattern + "key/" + keyIdPattern - aliasNamePattern := "alias/[a-zA-Z0-9:/_-]+" - aliasArnPattern := arnPrefixPattern + aliasNamePattern + aliasArnPattern := arnPrefixPattern + AliasNameRegexPattern if !regexp.MustCompile(fmt.Sprintf("^%s$", keyIdPattern)).MatchString(value) && !regexp.MustCompile(fmt.Sprintf("^%s$", keyArnPattern)).MatchString(value) && - !regexp.MustCompile(fmt.Sprintf("^%s$", aliasNamePattern)).MatchString(value) && + !regexp.MustCompile(fmt.Sprintf("^%s$", AliasNameRegexPattern)).MatchString(value) && !regexp.MustCompile(fmt.Sprintf("^%s$", aliasArnPattern)).MatchString(value) { - errors = append(errors, fmt.Errorf("%q must be one of the following patterns: %s, %s, %s or %s", k, keyIdPattern, keyArnPattern, aliasNamePattern, aliasArnPattern)) + errors = append(errors, fmt.Errorf("%q must be one of the following patterns: %s, %s, %s or %s", k, keyIdPattern, keyArnPattern, AliasNameRegexPattern, aliasArnPattern)) } return } diff --git a/website/docs/r/comprehend_document_classifier.html.markdown b/website/docs/r/comprehend_document_classifier.html.markdown index ece21b67f6a..86b10bc86ba 100644 --- a/website/docs/r/comprehend_document_classifier.html.markdown +++ b/website/docs/r/comprehend_document_classifier.html.markdown @@ -57,7 +57,8 @@ The following arguments are optional: * `mode` - (Optional, Default: `MULTI_CLASS`) The document classification mode. One of `MULTI_CLASS` or `MULTI_LABEL`. `MULTI_CLASS` is also known as "Single Label" in the AWS Console. -* `model_kms_key_id` - (Optional) The ID or ARN of a KMS Key used to encrypt trained Document Classifiers. +* `model_kms_key_id` - (Optional) KMS Key used to encrypt trained Document Classifiers. + Can be a KMS Key ID or a KMS Key ARN. * `output_data_config` - (Optional) Configuration for the output results of training. See the [`output_data_config` Configuration Block](#output_data_config-configuration-block) section below. * `tags` - (Optional) A map of tags to assign to the resource. If configured with a provider [`default_tags` Configuration Block](/docs/providers/aws/index.html#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. @@ -72,7 +73,8 @@ The following arguments are optional: Has a maximum length of 37 characters. Can contain upper- and lower-case letters, numbers, and hypen (`-`). Conflicts with `version_name`. -* `volume_kms_key_id` - (Optional) ID or ARN of a KMS Key used to encrypt storage volumes during job processing. +* `volume_kms_key_id` - (Optional) KMS Key used to encrypt storage volumes during job processing. + Can be a KMS Key ID or a KMS Key ARN. * `vpc_config` - (Optional) Configuration parameters for VPC to contain Document Classifier resources. See the [`vpc_config` Configuration Block](#vpc_config-configuration-block) section below. @@ -103,6 +105,8 @@ The following arguments are optional: ### `output_data_config` Configuration Block +* `kms_key_id` - (Optional) KMS Key used to encrypt the output documents. + Can be a KMS Key ID, a KMS Key ARN, a KMS Alias name, or a KMS Alias ARN. * `output_s3_uri` - (Computed) Full path for the output documents. * `s3_uri` - (Required) Destination path for the output documents. The full path to the output file will be returned in `output_s3_uri`. From 7d10b0ce764a768cbec94349b7df0e8fe9ad3953 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Fri, 23 Sep 2022 17:25:08 -0700 Subject: [PATCH 6/6] terrafmt --- .../comprehend/document_classifier_test.go | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/service/comprehend/document_classifier_test.go b/internal/service/comprehend/document_classifier_test.go index b4aa61f182b..8d6423c08c5 100644 --- a/internal/service/comprehend/document_classifier_test.go +++ b/internal/service/comprehend/document_classifier_test.go @@ -2183,8 +2183,8 @@ resource "aws_comprehend_document_classifier" "test" { } output_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" - kms_key_id = aws_kms_key.output.key_id + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output.key_id } depends_on = [ @@ -2238,8 +2238,8 @@ resource "aws_comprehend_document_classifier" "test" { } output_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" - kms_key_id = aws_kms_key.output.arn + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output.arn } depends_on = [ @@ -2293,8 +2293,8 @@ resource "aws_comprehend_document_classifier" "test" { } output_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" - kms_key_id = aws_kms_alias.output.name + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_alias.output.name } depends_on = [ @@ -2353,8 +2353,8 @@ resource "aws_comprehend_document_classifier" "test" { } output_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" - kms_key_id = aws_kms_alias.output.arn + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_alias.output.arn } depends_on = [ @@ -2413,8 +2413,8 @@ resource "aws_comprehend_document_classifier" "test" { } output_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" - kms_key_id = aws_kms_key.output.key_id + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output.key_id } depends_on = [ @@ -2499,8 +2499,8 @@ resource "aws_comprehend_document_classifier" "test" { } output_data_config { - s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" - kms_key_id = aws_kms_key.output2.key_id + s3_uri = "s3://${aws_s3_bucket.test.bucket}/outputs" + kms_key_id = aws_kms_key.output2.key_id } depends_on = [