Skip to content

Commit

Permalink
Merge pull request #20711 from lastlee/main
Browse files Browse the repository at this point in the history
feat: add `platform_identifier` to `r/aws_sagemaker_notebook_instance`
  • Loading branch information
ewbankkit authored Aug 30, 2021
2 parents 7481af6 + 150252f commit 3360648
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .changelog/20711.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_notebook_instance: Add `platform_identifier` argument
```
19 changes: 19 additions & 0 deletions aws/resource_aws_sagemaker_notebook_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"regexp"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -56,6 +57,15 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
Required: true,
ValidateFunc: validation.StringInSlice(sagemaker.InstanceType_Values(), false),
},

"platform_identifier": {
Type: schema.TypeString,
Optional: true,
Computed: true,
ForceNew: true,
ValidateFunc: validation.StringMatch(regexp.MustCompile(`^(notebook-al1-v1|notebook-al2-v1)$`), ""),
},

"additional_code_repositories": {
Type: schema.TypeSet,
Optional: true,
Expand Down Expand Up @@ -147,6 +157,10 @@ func resourceAwsSagemakerNotebookInstanceCreate(d *schema.ResourceData, meta int
createOpts.RootAccess = aws.String(v.(string))
}

if v, ok := d.GetOk("platform_identifier"); ok {
createOpts.PlatformIdentifier = aws.String(v.(string))
}

if v, ok := d.GetOk("direct_internet_access"); ok {
createOpts.DirectInternetAccess = aws.String(v.(string))
}
Expand Down Expand Up @@ -226,6 +240,11 @@ func resourceAwsSagemakerNotebookInstanceRead(d *schema.ResourceData, meta inter
if err := d.Set("instance_type", notebookInstance.InstanceType); err != nil {
return fmt.Errorf("error setting instance_type for sagemaker notebook instance (%s): %s", d.Id(), err)
}

if err := d.Set("platform_identifier", notebookInstance.PlatformIdentifier); err != nil {
return fmt.Errorf("error setting platform_identifier for sagemaker notebook instance (%s): %s", d.Id(), err)
}

if err := d.Set("subnet_id", notebookInstance.SubnetId); err != nil {
return fmt.Errorf("error setting subnet_id for sagemaker notebook instance (%s): %s", d.Id(), err)
}
Expand Down
45 changes: 45 additions & 0 deletions aws/resource_aws_sagemaker_notebook_instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func TestAccAWSSagemakerNotebookInstance_basic(t *testing.T) {
testAccCheckAWSSagemakerNotebookInstanceExists(resourceName, &notebook),
resource.TestCheckResourceAttr(resourceName, "name", rName),
resource.TestCheckResourceAttr(resourceName, "instance_type", "ml.t2.medium"),
resource.TestCheckResourceAttr(resourceName, "platform_identifier", "notebook-al1-v1"),
resource.TestCheckResourceAttrPair(resourceName, "role_arn", "aws_iam_role.test", "arn"),
resource.TestCheckResourceAttr(resourceName, "direct_internet_access", "Enabled"),
resource.TestCheckResourceAttr(resourceName, "root_access", "Enabled"),
Expand Down Expand Up @@ -435,6 +436,40 @@ func TestAccAWSSagemakerNotebookInstance_root_access(t *testing.T) {
})
}

func TestAccAWSSagemakerNotebookInstance_platform_identifier(t *testing.T) {
var notebook sagemaker.DescribeNotebookInstanceOutput
rName := acctest.RandomWithPrefix("tf-acc-test")
resourceName := "aws_sagemaker_notebook_instance.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
ErrorCheck: testAccErrorCheck(t, sagemaker.EndpointsID),
Providers: testAccProviders,
CheckDestroy: testAccCheckAWSSagemakerNotebookInstanceDestroy,
Steps: []resource.TestStep{
{
Config: testAccAWSSagemakerNotebookInstanceConfigPlatformIdentifier(rName, "notebook-al2-v1"),
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSagemakerNotebookInstanceExists(resourceName, &notebook),
resource.TestCheckResourceAttr(resourceName, "platform_identifier", "notebook-al2-v1"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
{
Config: testAccAWSSagemakerNotebookInstanceConfigPlatformIdentifier(rName, "notebook-al1-v1"),
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSagemakerNotebookInstanceExists(resourceName, &notebook),
resource.TestCheckResourceAttr(resourceName, "platform_identifier", "notebook-al1-v1"),
),
},
},
})
}

func TestAccAWSSagemakerNotebookInstance_direct_internet_access(t *testing.T) {
var notebook sagemaker.DescribeNotebookInstanceOutput
rName := acctest.RandomWithPrefix("tf-acc-test")
Expand Down Expand Up @@ -701,6 +736,16 @@ resource "aws_sagemaker_notebook_instance" "test" {
`, rName, rootAccess)
}

func testAccAWSSagemakerNotebookInstanceConfigPlatformIdentifier(rName string, platformIdentifier string) string {
return testAccAWSSagemakerNotebookInstanceBaseConfig(rName) + fmt.Sprintf(`
resource "aws_sagemaker_notebook_instance" "test" {
name = %[1]q
role_arn = aws_iam_role.test.arn
instance_type = "ml.t2.medium"
platform_identifier = %[2]q
}
`, rName, platformIdentifier)
}
func testAccAWSSagemakerNotebookInstanceConfigDirectInternetAccess(rName string, directInternetAccess string) string {
return testAccAWSSagemakerNotebookInstanceBaseConfig(rName) +
fmt.Sprintf(`
Expand Down
1 change: 1 addition & 0 deletions website/docs/r/sagemaker_notebook_instance.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The following arguments are supported:
* `name` - (Required) The name of the notebook instance (must be unique).
* `role_arn` - (Required) The ARN of the IAM role to be used by the notebook instance which allows SageMaker to call other services on your behalf.
* `instance_type` - (Required) The name of ML compute instance type.
* `platform_identifier` - (Optional) The platform identifier of the notebook instance runtime environment. This value can be either `notebook-al1-v1` or `notebook-al2-v1`, depending on which version of Amazon Linux you require.
* `volume_size` - (Optional) The size, in GB, of the ML storage volume to attach to the notebook instance. The default value is 5 GB.
* `subnet_id` - (Optional) The VPC subnet ID.
* `security_groups` - (Optional) The associated security groups.
Expand Down

0 comments on commit 3360648

Please sign in to comment.