diff --git a/azurerm/resource_arm_role_assignment.go b/azurerm/resource_arm_role_assignment.go index 89b4c7a528f8..aaa77fb127cf 100644 --- a/azurerm/resource_arm_role_assignment.go +++ b/azurerm/resource_arm_role_assignment.go @@ -182,7 +182,6 @@ func validateRoleDefinitionName(i interface{}, k string) ([]string, []error) { } func retryRoleAssignmentsClient(scope string, name string, properties authorization.RoleAssignmentCreateParameters, meta interface{}) func() *resource.RetryError { - return func() *resource.RetryError { roleAssignmentsClient := meta.(*ArmClient).roleAssignmentsClient ctx := meta.(*ArmClient).StopContext @@ -190,7 +189,11 @@ func retryRoleAssignmentsClient(scope string, name string, properties authorizat _, err := roleAssignmentsClient.Create(ctx, scope, name, properties) if err != nil { - return resource.RetryableError(err) + if utils.ResponseErrorIsRetryable(err) { + return resource.RetryableError(err) + } else { + return resource.NonRetryableError(err) + } } return nil diff --git a/azurerm/utils/response.go b/azurerm/utils/response.go index 287cec4071fa..98d7ec20ef9e 100644 --- a/azurerm/utils/response.go +++ b/azurerm/utils/response.go @@ -1,6 +1,7 @@ package utils import ( + "net" "net/http" "github.com/Azure/go-autorest/autorest" @@ -14,6 +15,21 @@ func ResponseWasNotFound(resp autorest.Response) bool { return responseWasStatusCode(resp, http.StatusNotFound) } +func ResponseErrorIsRetryable(err error) bool { + if arerr, ok := err.(autorest.DetailedError); ok { + err = arerr.Original + } + + switch e := err.(type) { + case net.Error: + if e.Temporary() || e.Timeout() { + return true + } + } + + return false +} + func responseWasStatusCode(resp autorest.Response, statusCode int) bool { if r := resp.Response; r != nil { if r.StatusCode == statusCode { diff --git a/azurerm/utils/response_test.go b/azurerm/utils/response_test.go index 809189f512bc..afd4de17a7f2 100644 --- a/azurerm/utils/response_test.go +++ b/azurerm/utils/response_test.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "net/http" "testing" @@ -37,3 +38,40 @@ func TestResponseNotFound_StatusCodes(t *testing.T) { } } } + +type testNetError struct { + timeout bool + temporary bool +} + +// testNetError fulfills net.Error interface +func (e testNetError) Error() string { return "testError" } +func (e testNetError) Timeout() bool { return e.timeout } +func (e testNetError) Temporary() bool { return e.temporary } + +func TestResponseErrorIsRetryable(t *testing.T) { + testCases := []struct { + desc string + err error + expectedResult bool + }{ + {"Unhandled error types are not retryable", fmt.Errorf("Some other error"), false}, + {"Temporary AND timeout errors are retryable", testNetError{true, true}, true}, + {"Timeout errors are retryable", testNetError{true, false}, true}, + {"Temporary errors are retryable", testNetError{false, true}, true}, + {"net.Errors that are neither temporary nor timeouts are not retryable", testNetError{false, false}, false}, + {"Retryable error nested in autorest.DetailedError is retryable", autorest.DetailedError{ + Original: testNetError{true, true}}, true}, + {"Unhandled error nested in autorest.DetailedError is not retryable", autorest.DetailedError{ + Original: fmt.Errorf("Some other error")}, false}, + {"nil is handled as non-retryable", nil, false}, + } + + for _, test := range testCases { + result := ResponseErrorIsRetryable(test.err) + if test.expectedResult != result { + t.Errorf("Expected '%v' for case '%s' - got '%v'", + test.expectedResult, test.desc, result) + } + } +}