Skip to content

Commit

Permalink
Add TPU Tensorflow datasource (hashicorp#561)
Browse files Browse the repository at this point in the history
Signed-off-by: Modular Magician <magic-modules@google.com>
  • Loading branch information
modular-magician authored and emilymye committed Mar 27, 2019
1 parent 52f49b5 commit 9a12bc6
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 9 deletions.
82 changes: 82 additions & 0 deletions google-beta/data_source_tpu_tensorflow_versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package google

import (
"fmt"
"log"
"sort"
"time"

"github.com/hashicorp/terraform/helper/schema"
)

func dataSourceTpuTensorflowVersions() *schema.Resource {
return &schema.Resource{
Read: dataSourceTpuTensorFlowVersionsRead,
Schema: map[string]*schema.Schema{
"project": {
Type: schema.TypeString,
Optional: true,
Computed: true,
},
"zone": {
Type: schema.TypeString,
Optional: true,
Computed: true,
},
"versions": {
Type: schema.TypeList,
Computed: true,
Elem: &schema.Schema{Type: schema.TypeString},
},
},
}
}

func dataSourceTpuTensorFlowVersionsRead(d *schema.ResourceData, meta interface{}) error {
config := meta.(*Config)

project, err := getProject(d, config)
if err != nil {
return err
}

zone, err := getZone(d, config)
if err != nil {
return err
}

url, err := replaceVars(d, config, "https://tpu.googleapis.com/v1/projects/{{project}}/locations/{{zone}}/tensorflowVersions")
if err != nil {
return err
}

versionsRaw, err := paginatedListRequest(url, config, flattenTpuTensorflowVersions)
if err != nil {
return fmt.Errorf("Error listing TPU Tensorflow versions: %s", err)
}

versions := make([]string, len(versionsRaw))
for i, ver := range versionsRaw {
versions[i] = ver.(string)
}
sort.Strings(versions)

log.Printf("[DEBUG] Received Google TPU Tensorflow Versions: %q", versions)

d.Set("versions", versions)
d.Set("zone", zone)
d.Set("project", project)
d.SetId(time.Now().UTC().String())

return nil
}

func flattenTpuTensorflowVersions(resp map[string]interface{}) []interface{} {
verObjList := resp["tensorflowVersions"].([]interface{})
versions := make([]interface{}, len(verObjList))
for i, v := range verObjList {
verObj := v.(map[string]interface{})
versions[i] = verObj["version"]
}
return versions
}
72 changes: 72 additions & 0 deletions google-beta/data_source_tpu_tensorflow_versions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package google

import (
"errors"
"fmt"
"strconv"
"testing"

"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/terraform"
"regexp"
)

func TestAccTpuTensorflowVersions_basic(t *testing.T) {
t.Parallel()

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
Steps: []resource.TestStep{
{
Config: testAccTpuTensorFlowVersionsConfig,
Check: resource.ComposeTestCheckFunc(
testAccCheckGoogleTpuTensorflowVersions("data.google_tpu_tensorflow_versions.available"),
),
},
},
})
}

func testAccCheckGoogleTpuTensorflowVersions(n string) resource.TestCheckFunc {
return func(s *terraform.State) error {
rs, ok := s.RootModule().Resources[n]
if !ok {
return fmt.Errorf("Can't find TPU Tensorflow versions data source: %s", n)
}

if rs.Primary.ID == "" {
return errors.New("data source ID not set.")
}

count, ok := rs.Primary.Attributes["versions.#"]
if !ok {
return errors.New("can't find 'names' attribute")
}

cnt, err := strconv.Atoi(count)
if err != nil {
return errors.New("failed to read number of version")
}
if cnt < 2 {
return fmt.Errorf("expected at least 2 versions, received %d, this is most likely a bug", cnt)
}

for i := 0; i < cnt; i++ {
idx := fmt.Sprintf("versions.%d", i)
v, ok := rs.Primary.Attributes[idx]
if !ok {
return fmt.Errorf("expected %q, version not found", idx)
}

if !regexp.MustCompile(`^([0-9]+\.)+[0-9]+$`).MatchString(v) {
return fmt.Errorf("unexpected version format for %q, value is %v", idx, v)
}
}
return nil
}
}

var testAccTpuTensorFlowVersionsConfig = `
data "google_tpu_tensorflow_versions" "available" {}
`
1 change: 1 addition & 0 deletions google-beta/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func Provider() terraform.ResourceProvider {
"google_storage_object_signed_url": dataSourceGoogleSignedUrl(),
"google_storage_project_service_account": dataSourceGoogleStorageProjectServiceAccount(),
"google_storage_transfer_project_service_account": dataSourceGoogleStorageTransferProjectServiceAccount(),
"google_tpu_tensorflow_versions": dataSourceTpuTensorflowVersions(),
},

ResourcesMap: ResourceMap(),
Expand Down
8 changes: 6 additions & 2 deletions google-beta/resource_tpu_node_generated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ func TestAccTpuNode_tpuNodeBasicExample(t *testing.T) {

func testAccTpuNode_tpuNodeBasicExample(context map[string]interface{}) string {
return Nprintf(`
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu-%{random_suffix}"
zone = "us-central1-b"
accelerator_type = "v3-8"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
cidr_block = "10.2.0.0/29"
}
`, context)
Expand Down Expand Up @@ -94,14 +96,16 @@ resource "google_compute_network" "tpu_network" {
auto_create_subnetworks = false
}
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu-%{random_suffix}"
zone = "us-central1-b"
accelerator_type = "v3-8"
cidr_block = "10.3.0.0/29"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
description = "Terraform Google Provider test TPU"
network = "${google_compute_network.tpu_network.name}"
Expand Down
12 changes: 7 additions & 5 deletions google-beta/resource_tpu_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) {
CheckDestroy: testAccCheckTpuNodeDestroy,
Steps: []resource.TestStep{
{
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, "1.11"),
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, 0),
},
{
ResourceName: "google_tpu_node.tpu",
Expand All @@ -28,7 +28,7 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) {
ImportStateVerifyIgnore: []string{"zone"},
},
{
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, "1.12"),
Config: testAccTpuNode_tpuNodeTensorFlow(nodeId, 1),
},
{
ResourceName: "google_tpu_node.tpu",
Expand All @@ -43,15 +43,17 @@ func TestAccTpuNode_tpuNodeBUpdateTensorFlowVersion(t *testing.T) {
// WARNING: cidr_block must not overlap with other existing TPU blocks
// Make sure if you change this value that it does not overlap with the
// autogenerated examples.
func testAccTpuNode_tpuNodeTensorFlow(nodeId, tensorFlowVer string) string {
func testAccTpuNode_tpuNodeTensorFlow(nodeId string, versionIdx int) string {
return fmt.Sprintf(`
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "%s"
zone = "us-central1-b"
accelerator_type = "v3-8"
tensorflow_version = "%s"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[%d]}"
cidr_block = "10.1.0.0/29"
}
`, nodeId, tensorFlowVer)
`, nodeId, versionIdx)
}
24 changes: 24 additions & 0 deletions google-beta/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,27 @@ func serviceAccountFQN(serviceAccount string, d TerraformResourceData, config *C

return fmt.Sprintf("projects/-/serviceAccounts/%s@%s.iam.gserviceaccount.com", serviceAccount, project), nil
}

func paginatedListRequest(baseUrl string, config *Config, flattener func(map[string]interface{}) []interface{}) ([]interface{}, error) {
res, err := sendRequest(config, "GET", baseUrl, nil)
if err != nil {
return nil, err
}

ls := flattener(res)
pageToken, ok := res["pageToken"]
for ok {
if pageToken.(string) == "" {
break
}
url := fmt.Sprintf("%s?pageToken=%s", baseUrl, pageToken.(string))
res, err = sendRequest(config, "GET", url, nil)
if err != nil {
return nil, err
}
ls = append(ls, flattener(res))
pageToken, ok = res["pageToken"]
}

return ls, nil
}
8 changes: 6 additions & 2 deletions website/docs/r/tpu_node.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ To get more information about Node, see:


```hcl
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu"
zone = "us-central1-b"
accelerator_type = "v3-8"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
cidr_block = "10.2.0.0/29"
}
```
Expand All @@ -62,14 +64,16 @@ resource "google_compute_network" "tpu_network" {
auto_create_subnetworks = false
}
data "google_tpu_tensorflow_versions" "available" { }
resource "google_tpu_node" "tpu" {
name = "test-tpu"
zone = "us-central1-b"
accelerator_type = "v3-8"
cidr_block = "10.3.0.0/29"
tensorflow_version = "1.13"
tensorflow_version = "${data.google_tpu_tensorflow_versions.available.versions[0]}"
description = "Terraform Google Provider test TPU"
network = "${google_compute_network.tpu_network.name}"
Expand Down

0 comments on commit 9a12bc6

Please sign in to comment.