diff --git a/sdk/internal/testframework/README.md b/sdk/internal/testframework/README.md index d545700c3535..a3a496f048bd 100644 --- a/sdk/internal/testframework/README.md +++ b/sdk/internal/testframework/README.md @@ -21,8 +21,13 @@ or indicate that the test IsFailed. ***Note**: an instance of TestContext should be initialized for each test.* ```go +type testState struct { + recording *testframework.Recording + client *TableServiceClient + context *testframework.TestContext +} // a map to store our created test contexts -var clientsMap map[string]*testContext = make(map[string]*testContext) +var clientsMap map[string]*testState = make(map[string]*testState) // recordedTestSetup is called before each test execution by the test suite's BeforeTest method func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) { @@ -80,11 +85,11 @@ The last step is to instrument your client by replacing its transport with your assert.Nil(err) // either return your client instance, or store it somewhere that your test can use it for test execution. - clientsMap[testName] = &testContext{client: client, recording: recording, context: &context} + clientsMap[testName] = &testState{client: client, recording: recording, context: &context} } -func getTestContext(key string) *testContext { +func getTestState(key string) *testState { return clientsMap[key] } ``` @@ -132,7 +137,7 @@ type tableServiceClientLiveTests struct { // Hookup to the testing framework func TestServiceClient_Storage(t *testing.T) { - storage := tableServiceClientLiveTests{endpointType: StorageEndpoint, mode: testframework.Playback /* change to Record to re-record tests */} + storage := tableServiceClientLiveTests{mode: testframework.Playback /* change to Record to re-record tests */} suite.Run(t, &storage) } @@ -151,7 +156,7 @@ func (s *tableServiceClientLiveTests) TestCreateTable() { func (s *tableServiceClientLiveTests) BeforeTest(suite string, test string) { // setup the test environment - recordedTestSetup(s.T(), s.T().Name(), s.endpointType, s.mode) + recordedTestSetup(s.T(), s.T().Name(), s.mode) } func (s *tableServiceClientLiveTests) AfterTest(suite string, test string) { diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/testframework/recording.go index 38e4495945be..9a74549a1568 100644 --- a/sdk/internal/testframework/recording.go +++ b/sdk/internal/testframework/recording.go @@ -64,8 +64,11 @@ const ( type VariableType string const ( - Default VariableType = "default" - Secret_String VariableType = "secret_string" + // NoSanitization indicates that the recorded value should not be sanitized. + NoSanitization VariableType = "default" + // Secret_String indicates that the recorded value should be replaced with a sanitized value. + Secret_String VariableType = "secret_string" + // Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value. Secret_Base64String VariableType = "secret_base64String" ) @@ -102,18 +105,18 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) { } // set the recorder Matcher - recording.Matcher = DefaultMatcher(c) + recording.Matcher = defaultMatcher(c) rec.SetMatcher(recording.matchRequest) // wire up the sanitizer - recording.Sanitizer = DefaultSanitizer(rec) + recording.Sanitizer = defaultSanitizer(rec) return recording, err } -// GetRecordedVariable returns a recorded variable. If the variable is not found we return an error -// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation. -func (r *Recording) GetRecordedVariable(name string, variableType VariableType) (string, error) { +// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error. +// variableType determines how the recorded variable will be saved. +func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) { var err error result, ok := r.previousSessionVariables[name] if !ok || r.Mode == Live { @@ -128,9 +131,10 @@ func (r *Recording) GetRecordedVariable(name string, variableType VariableType) return *result, err } -// GetOptionalRecordedVariable returns a recorded variable with a fallback default value -// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation. -func (r *Recording) GetOptionalRecordedVariable(name string, defaultValue string, variableType VariableType) string { +// GetOptionalEnvVar returns a recorded environment variable with a fallback default value. +// default Value configures the fallback value to be returned if the environment variable is not set. +// variableType determines how the recorded variable will be saved. +func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string { result, ok := r.previousSessionVariables[name] if !ok || r.Mode == Live { result = getOptionalEnv(name, defaultValue) diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/recording_sanitizer.go index 873844b7703b..16d728688d83 100644 --- a/sdk/internal/testframework/recording_sanitizer.go +++ b/sdk/internal/testframework/recording_sanitizer.go @@ -26,7 +26,9 @@ const SanitizedBase64Value string = "Kg==" var sanitizedValueSlice = []string{SanitizedValue} -func DefaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { +// defaultSanitizer returns a new RecordingSanitizer with the default sanitizing behavior. +// To customize sanitization, call AddSanitizedHeaders, AddBodySanitizer, or AddUrlSanitizer. +func defaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { // The default sanitizer sanitizes the Authorization header s := &RecordingSanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} recorder.AddSaveFilter(s.applySaveFilter) diff --git a/sdk/internal/testframework/recording_sanitizer_test.go b/sdk/internal/testframework/recording_sanitizer_test.go index 570dfb3b005b..23adad100ea4 100644 --- a/sdk/internal/testframework/recording_sanitizer_test.go +++ b/sdk/internal/testframework/recording_sanitizer_test.go @@ -39,7 +39,7 @@ func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { rt := NewMockRoundTripper(server) r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) - DefaultSanitizer(r) + defaultSanitizer(r) req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) req.Header.Add(authHeader, "superSecret") @@ -65,7 +65,7 @@ func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() { rt := NewMockRoundTripper(server) r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) - target := DefaultSanitizer(r) + target := defaultSanitizer(r) target.AddSanitizedHeaders(customHeader1, customHeader2) req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) @@ -103,7 +103,7 @@ func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() { baseUrl := server.URL() + "/" - target := DefaultSanitizer(r) + target := defaultSanitizer(r) target.AddUrlSanitizer(func(url *string) { *url = strings.Replace(*url, secret, SanitizedValue, -1) }) diff --git a/sdk/internal/testframework/recording_test.go b/sdk/internal/testframework/recording_test.go index 4ed73f5805ce..15e4ea2e4b04 100644 --- a/sdk/internal/testframework/recording_test.go +++ b/sdk/internal/testframework/recording_test.go @@ -70,17 +70,17 @@ func (s *recordingTests) TestRecordedVariables() { assert.Nil(err) // optional variables always succeed. - assert.Equal(expectedVariableValue, target.GetOptionalRecordedVariable(nonExistingEnvVar, expectedVariableValue, Default)) + assert.Equal(expectedVariableValue, target.GetOptionalEnvVar(nonExistingEnvVar, expectedVariableValue, NoSanitization)) // non existent variables return an error - val, err := target.GetRecordedVariable(nonExistingEnvVar, Default) + val, err := target.GetEnvVar(nonExistingEnvVar, NoSanitization) // mark test as succeeded assert.Equal(envNotExistsError(nonExistingEnvVar), err.Error()) // now create the env variable and check that it can be fetched os.Setenv(nonExistingEnvVar, expectedVariableValue) defer os.Unsetenv(nonExistingEnvVar) - val, err = target.GetRecordedVariable(nonExistingEnvVar, Default) + val, err = target.GetEnvVar(nonExistingEnvVar, NoSanitization) assert.Equal(expectedVariableValue, val) err = target.Stop() @@ -107,10 +107,10 @@ func (s *recordingTests) TestRecordedVariablesSanitized() { assert.Nil(err) // call GetOptionalRecordedVariable with the Secret_String VariableType arg - assert.Equal(secret, target.GetOptionalRecordedVariable(SanitizedStringVar, secret, Secret_String)) + assert.Equal(secret, target.GetOptionalEnvVar(SanitizedStringVar, secret, Secret_String)) // call GetOptionalRecordedVariable with the Secret_Base64String VariableType arg - assert.Equal(secretBase64, target.GetOptionalRecordedVariable(SanitizedBase64StrigVar, secretBase64, Secret_Base64String)) + assert.Equal(secretBase64, target.GetOptionalEnvVar(SanitizedBase64StrigVar, secretBase64, Secret_Base64String)) // Calling Stop will save the variables and apply the sanitization options err = target.Stop() @@ -143,7 +143,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables( target, err := NewRecording(context, Playback) assert.Nil(err) - target.GetOptionalRecordedVariable(expectedVariableName, expectedVariableValue, Default) + target.GetOptionalEnvVar(expectedVariableName, expectedVariableValue, NoSanitization) err = target.Stop() assert.Nil(err) @@ -159,7 +159,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables( assert.Nil(err) // add a new variable to the existing batch - target2.GetOptionalRecordedVariable(addedVariableName, addedVariableValue, Default) + target2.GetOptionalEnvVar(addedVariableName, addedVariableValue, NoSanitization) err = target2.Stop() assert.Nil(err) diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index c9405d03593a..0e4fa6be1ec6 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -16,8 +16,10 @@ import ( ) type RequestMatcher struct { - context TestContext - IgnoredHeaders map[string]*string + context TestContext + // IgnoredHeaders is a map acting as a hash set of the header names that will be ignored for matching. + // Modifying the keys in the map will affect how headers are matched for recordings. + IgnoredHeaders map[string]struct{} bodyMatcher StringMatcher urlMatcher StringMatcher methodMatcher StringMatcher @@ -26,15 +28,15 @@ type RequestMatcher struct { type StringMatcher func(reqVal string, recVal string) bool type matcherWrapper func(matcher StringMatcher, testContext TestContext) bool -var ignoredHeaders = map[string]*string{ - "Date": nil, - "X-Ms-Date": nil, - "x-ms-date": nil, - "x-ms-client-request-id": nil, - "User-Agent": nil, - "Request-Id": nil, - "traceparent": nil, - "Authorization": nil, +var ignoredHeaders = map[string]struct{}{ + "Date": {}, + "X-Ms-Date": {}, + "x-ms-date": {}, + "x-ms-client-request-id": {}, + "User-Agent": {}, + "Request-Id": {}, + "traceparent": {}, + "Authorization": {}, } const ( @@ -46,25 +48,27 @@ const ( bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" ) -func DefaultMatcher(testContext TestContext) *RequestMatcher { +// defaultMatcher returns a new RequestMatcher configured with the default matching behavior. +func defaultMatcher(testContext TestContext) *RequestMatcher { // The default sanitizer sanitizes the Authorization header matcher := &RequestMatcher{ context: testContext, IgnoredHeaders: ignoredHeaders, } matcher.SetBodyMatcher(func(req string, rec string) bool { - return DefaultStringMatcher(req, rec) + return defaultStringMatcher(req, rec) }) matcher.SetURLMatcher(func(req string, rec string) bool { - return DefaultStringMatcher(req, rec) + return defaultStringMatcher(req, rec) }) matcher.SetMethodMatcher(func(req string, rec string) bool { - return DefaultStringMatcher(req, rec) + return defaultStringMatcher(req, rec) }) return matcher } +// SetBodyMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request body payload with the string value of the recorded body payload. func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) { m.bodyMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) @@ -75,6 +79,7 @@ func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) { } } +// SetURLMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request URL with the string value of the recorded URL func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) { m.urlMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) @@ -85,6 +90,7 @@ func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) { } } +// SetMethodMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request method with the string value of the recorded method func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { m.methodMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) @@ -95,7 +101,7 @@ func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { } } -func DefaultStringMatcher(s1 string, s2 string) bool { +func defaultStringMatcher(s1 string, s2 string) bool { return s1 == s2 } diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/testframework/request_matcher_test.go index f85e7110ee22..1dd07f5c006a 100644 --- a/sdk/internal/testframework/request_matcher_test.go +++ b/sdk/internal/testframework/request_matcher_test.go @@ -33,7 +33,7 @@ const unMatchedBody string = "This body does not match." func (s *requestMatcherTests) TestCompareBodies() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) req := http.Request{Body: closerFromString(matchedBody)} recReq := cassette.Request{Body: matchedBody} @@ -53,7 +53,7 @@ func (s *requestMatcherTests) TestCompareBodies() { func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -73,7 +73,7 @@ func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -93,7 +93,7 @@ func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -117,7 +117,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -141,7 +141,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -171,7 +171,7 @@ func (s *requestMatcherTests) TestCompareURLs() { host := "foo.bar" req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}} recReq := cassette.Request{URL: scheme + "://" + host} - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) assert.True(matcher.compareURLs(&req, recReq.URL)) @@ -187,7 +187,7 @@ func (s *requestMatcherTests) TestCompareMethods() { methodPatch := "PATCH" req := http.Request{Method: methodPost} recReq := cassette.Request{Method: methodPost} - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) assert.True(matcher.compareMethods(&req, recReq.Method))