Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add internal/reflect package #30

Merged
merged 1 commit into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions internal/reflect/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package reflect

import (
"context"
"errors"
"reflect"
"regexp"
"strings"

"github.com/hashicorp/terraform-plugin-go/tftypes"
)

// trueReflectValue returns the reflect.Value for `in` after derefencing all
// the pointers and unwrapping all the interfaces. It's the concrete value
// beneath it all.
func trueReflectValue(val reflect.Value) reflect.Value {
kind := val.Type().Kind()
for kind == reflect.Interface || kind == reflect.Ptr {
innerVal := val.Elem()
if !innerVal.IsValid() {
break
}
val = innerVal
kind = val.Type().Kind()
}
return val
}

// commaSeparatedString returns an English joining of the strings in `in`,
// using "and" and commas as appropriate.
func commaSeparatedString(in []string) string {
switch len(in) {
case 0:
return ""
case 1:
return in[0]
case 2:
return strings.Join(in, " and ")
default:
in[len(in)-1] = "and " + in[len(in)-1]
return strings.Join(in, ", ")
}
}

// getStructTags returns a map of Terraform field names to their position in
// the tags of the struct `in`. `in` must be a struct.
func getStructTags(ctx context.Context, in reflect.Value, path *tftypes.AttributePath) (map[string]int, error) {
tags := map[string]int{}
typ := trueReflectValue(in).Type()
if typ.Kind() != reflect.Struct {
return nil, path.NewErrorf("can't get struct tags of %s, is not a struct", in.Type())
}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if field.PkgPath != "" {
// skip unexported fields
continue
}
tag := field.Tag.Get(`tfsdk`)
if tag == "-" {
// skip explicitly excluded fields
continue
}
if tag == "" {
return nil, path.NewErrorf(`need a struct tag for "tfsdk" on %s`, field.Name)
}
path := path.WithAttributeName(tag)
if !isValidFieldName(tag) {
return nil, path.NewError(errors.New("invalid field name, must only use lowercase letters, underscores, and numbers, and must start with a letter"))
}
if other, ok := tags[tag]; ok {
return nil, path.NewErrorf("can't use field name for both %s and %s", typ.Field(other).Name, field.Name)
}
tags[tag] = i
}
return tags, nil
}

// isValidFieldName returns true if `name` can be used as a field name in a
// Terraform resource or data source.
func isValidFieldName(name string) bool {
re := regexp.MustCompile("^[a-z][a-z0-9_]*$")
return re.MatchString(name)
}

// canBeNil returns true if `target`'s type can hold a nil value
func canBeNil(target reflect.Value) bool {
switch target.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface:
// these types can all hold nils
return true
default:
// nothing else can be set to nil
return false
}
}
258 changes: 258 additions & 0 deletions internal/reflect/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
package reflect

import (
"context"
"fmt"
"reflect"
"testing"

"github.com/hashicorp/terraform-plugin-go/tftypes"
)

func TestTrueReflectValue(t *testing.T) {
t.Parallel()

var iface, otherIface interface{}
var stru struct{}

// test that when nothing needs unwrapped, we get the right answer
if got := trueReflectValue(reflect.ValueOf(stru)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}

// test that we can unwrap pointers
if got := trueReflectValue(reflect.ValueOf(&stru)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}

// test that we can unwrap interfaces
iface = stru
if got := trueReflectValue(reflect.ValueOf(iface)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}

// test that we can unwrap pointers inside interfaces, and pointers to
// interfaces with pointers inside them
iface = &stru
if got := trueReflectValue(reflect.ValueOf(iface)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}
if got := trueReflectValue(reflect.ValueOf(&iface)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}

// test that we can unwrap pointers to interfaces inside other
// interfaces, and pointers to interfaces inside pointers to
// interfaces.
otherIface = &iface
if got := trueReflectValue(reflect.ValueOf(otherIface)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}
if got := trueReflectValue(reflect.ValueOf(&otherIface)).Kind(); got != reflect.Struct {
t.Errorf("Expected %s, got %s", reflect.Struct, got)
}
}

func TestCommaSeparatedString(t *testing.T) {
t.Parallel()
type testCase struct {
input []string
expected string
}
tests := map[string]testCase{
"empty": {
input: []string{},
expected: "",
},
"oneWord": {
input: []string{"red"},
expected: "red",
},
"twoWords": {
input: []string{"red", "blue"},
expected: "red and blue",
},
"threeWords": {
input: []string{"red", "blue", "green"},
expected: "red, blue, and green",
},
"fourWords": {
input: []string{"red", "blue", "green", "purple"},
expected: "red, blue, green, and purple",
},
}
for name, test := range tests {
name, test := name, test
t.Run(name, func(t *testing.T) {
t.Parallel()
got := commaSeparatedString(test.input)
if got != test.expected {
t.Errorf("Expected %q, got %q", test.expected, got)
}
})
}
}

func TestGetStructTags_success(t *testing.T) {
t.Parallel()

type testStruct struct {
ExportedAndTagged string `tfsdk:"exported_and_tagged"`
unexported string //nolint:structcheck,unused
unexportedAndTagged string `tfsdk:"unexported_and_tagged"`
ExportedAndExcluded string `tfsdk:"-"`
}

res, err := getStructTags(context.Background(), reflect.ValueOf(testStruct{}), tftypes.NewAttributePath())
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
if len(res) != 1 {
t.Errorf("Unexpected result: %v", res)
}
if res["exported_and_tagged"] != 0 {
t.Errorf("Unexpected result: %v", res)
}
}

func TestGetStructTags_untagged(t *testing.T) {
t.Parallel()
type testStruct struct {
ExportedAndUntagged string
}
_, err := getStructTags(context.Background(), reflect.ValueOf(testStruct{}), tftypes.NewAttributePath())
if err == nil {
t.Error("Expected error, got nil")
}
expected := `: need a struct tag for "tfsdk" on ExportedAndUntagged`
if err.Error() != expected {
t.Errorf("Expected error to be %q, got %q", expected, err.Error())
}
}

func TestGetStructTags_invalidTag(t *testing.T) {
t.Parallel()
type testStruct struct {
InvalidTag string `tfsdk:"invalidTag"`
}
_, err := getStructTags(context.Background(), reflect.ValueOf(testStruct{}), tftypes.NewAttributePath())
if err == nil {
t.Errorf("Expected error, got nil")
}
expected := `AttributeName("invalidTag"): invalid field name, must only use lowercase letters, underscores, and numbers, and must start with a letter`
if err.Error() != expected {
t.Errorf("Expected error to be %q, got %q", expected, err.Error())
}
}

func TestGetStructTags_duplicateTag(t *testing.T) {
t.Parallel()
type testStruct struct {
Field1 string `tfsdk:"my_field"`
Field2 string `tfsdk:"my_field"`
}
_, err := getStructTags(context.Background(), reflect.ValueOf(testStruct{}), tftypes.NewAttributePath())
if err == nil {
t.Errorf("Expected error, got nil")
}
expected := `AttributeName("my_field"): can't use field name for both Field1 and Field2`
if err.Error() != expected {
t.Errorf("Expected error to be %q, got %q", expected, err.Error())
}
}

func TestGetStructTags_notAStruct(t *testing.T) {
t.Parallel()
var testStruct string

_, err := getStructTags(context.Background(), reflect.ValueOf(testStruct), tftypes.NewAttributePath())
if err == nil {
t.Errorf("Expected error, got nil")
}
expected := `: can't get struct tags of string, is not a struct`
if err.Error() != expected {
t.Errorf("Expected error to be %q, got %q", expected, err.Error())
}
}

func TestIsValidFieldName(t *testing.T) {
t.Parallel()
tests := map[string]bool{
"": false,
"a": true,
"1": false,
"1a": false,
"a1": true,
"A": false,
"a-b": false,
"a_b": true,
}
for in, expected := range tests {
in, expected := in, expected
t.Run(fmt.Sprintf("input=%q", in), func(t *testing.T) {
t.Parallel()

result := isValidFieldName(in)
if result != expected {
t.Errorf("Expected %v, got %v", expected, result)
}
})
}
}

func TestCanBeNil_struct(t *testing.T) {
t.Parallel()

var stru struct{}

got := canBeNil(reflect.ValueOf(stru))
if got {
t.Error("Expected structs to not be nillable, but canBeNil said they were")
}
}

func TestCanBeNil_structPointer(t *testing.T) {
t.Parallel()

var stru struct{}
struPtr := &stru

got := canBeNil(reflect.ValueOf(struPtr))
if !got {
t.Error("Expected pointers to structs to be nillable, but canBeNil said they weren't")
}
}

func TestCanBeNil_slice(t *testing.T) {
t.Parallel()

slice := []string{}
got := canBeNil(reflect.ValueOf(slice))
if !got {
t.Errorf("Expected slices to be nillable, but canBeNil said they weren't")
}
}

func TestCanBeNil_map(t *testing.T) {
t.Parallel()

m := map[string]string{}
got := canBeNil(reflect.ValueOf(m))
if !got {
t.Errorf("Expected maps to be nillable, but canBeNil said they weren't")
}
}

func TestCanBeNil_interface(t *testing.T) {
t.Parallel()

type myStruct struct {
Value interface{}
}

var s myStruct
got := canBeNil(reflect.ValueOf(s).FieldByName("Value"))
if !got {
t.Errorf("Expected interfaces to be nillable, but canBeNil said they weren't")
}
}
Loading