Skip to content

Commit

Permalink
fb
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes committed Jun 15, 2021
1 parent 921c67f commit 438d7e1
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 50 deletions.
15 changes: 10 additions & 5 deletions sdk/internal/testframework/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ or indicate that the test IsFailed.
***Note**: an instance of TestContext should be initialized for each test.*

```go
type testState struct {
recording *testframework.Recording
client *TableServiceClient
context *testframework.TestContext
}
// a map to store our created test contexts
var clientsMap map[string]*testContext = make(map[string]*testContext)
var clientsMap map[string]*testState = make(map[string]*testState)

// recordedTestSetup is called before each test execution by the test suite's BeforeTest method
func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) {
Expand Down Expand Up @@ -80,11 +85,11 @@ The last step is to instrument your client by replacing its transport with your
assert.Nil(err)

// either return your client instance, or store it somewhere that your test can use it for test execution.
clientsMap[testName] = &testContext{client: client, recording: recording, context: &context}
clientsMap[testName] = &testState{client: client, recording: recording, context: &context}
}


func getTestContext(key string) *testContext {
func getTestState(key string) *testState {
return clientsMap[key]
}
```
Expand Down Expand Up @@ -132,7 +137,7 @@ type tableServiceClientLiveTests struct {

// Hookup to the testing framework
func TestServiceClient_Storage(t *testing.T) {
storage := tableServiceClientLiveTests{endpointType: StorageEndpoint, mode: testframework.Playback /* change to Record to re-record tests */}
storage := tableServiceClientLiveTests{mode: testframework.Playback /* change to Record to re-record tests */}
suite.Run(t, &storage)
}

Expand All @@ -151,7 +156,7 @@ func (s *tableServiceClientLiveTests) TestCreateTable() {

func (s *tableServiceClientLiveTests) BeforeTest(suite string, test string) {
// setup the test environment
recordedTestSetup(s.T(), s.T().Name(), s.endpointType, s.mode)
recordedTestSetup(s.T(), s.T().Name(), s.mode)
}

func (s *tableServiceClientLiveTests) AfterTest(suite string, test string) {
Expand Down
24 changes: 14 additions & 10 deletions sdk/internal/testframework/recording.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ const (
type VariableType string

const (
Default VariableType = "default"
Secret_String VariableType = "secret_string"
// NoSanitization indicates that the recorded value should not be sanitized.
NoSanitization VariableType = "default"
// Secret_String indicates that the recorded value should be replaced with a sanitized value.
Secret_String VariableType = "secret_string"
// Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value.
Secret_Base64String VariableType = "secret_base64String"
)

Expand Down Expand Up @@ -102,18 +105,18 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) {
}

// set the recorder Matcher
recording.Matcher = DefaultMatcher(c)
recording.Matcher = defaultMatcher(c)
rec.SetMatcher(recording.matchRequest)

// wire up the sanitizer
recording.Sanitizer = DefaultSanitizer(rec)
recording.Sanitizer = defaultSanitizer(rec)

return recording, err
}

// GetRecordedVariable returns a recorded variable. If the variable is not found we return an error
// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation.
func (r *Recording) GetRecordedVariable(name string, variableType VariableType) (string, error) {
// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error.
// variableType determines how the recorded variable will be saved.
func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) {
var err error
result, ok := r.previousSessionVariables[name]
if !ok || r.Mode == Live {
Expand All @@ -128,9 +131,10 @@ func (r *Recording) GetRecordedVariable(name string, variableType VariableType)
return *result, err
}

// GetOptionalRecordedVariable returns a recorded variable with a fallback default value
// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation.
func (r *Recording) GetOptionalRecordedVariable(name string, defaultValue string, variableType VariableType) string {
// GetOptionalEnvVar returns a recorded environment variable with a fallback default value.
// default Value configures the fallback value to be returned if the environment variable is not set.
// variableType determines how the recorded variable will be saved.
func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string {
result, ok := r.previousSessionVariables[name]
if !ok || r.Mode == Live {
result = getOptionalEnv(name, defaultValue)
Expand Down
4 changes: 3 additions & 1 deletion sdk/internal/testframework/recording_sanitizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ const SanitizedBase64Value string = "Kg=="

var sanitizedValueSlice = []string{SanitizedValue}

func DefaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer {
// defaultSanitizer returns a new RecordingSanitizer with the default sanitizing behavior.
// To customize sanitization, call AddSanitizedHeaders, AddBodySanitizer, or AddUrlSanitizer.
func defaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer {
// The default sanitizer sanitizes the Authorization header
s := &RecordingSanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer}
recorder.AddSaveFilter(s.applySaveFilter)
Expand Down
6 changes: 3 additions & 3 deletions sdk/internal/testframework/recording_sanitizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() {
rt := NewMockRoundTripper(server)
r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt)

DefaultSanitizer(r)
defaultSanitizer(r)

req, _ := http.NewRequest(http.MethodPost, server.URL(), nil)
req.Header.Add(authHeader, "superSecret")
Expand All @@ -65,7 +65,7 @@ func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() {
rt := NewMockRoundTripper(server)
r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt)

target := DefaultSanitizer(r)
target := defaultSanitizer(r)
target.AddSanitizedHeaders(customHeader1, customHeader2)

req, _ := http.NewRequest(http.MethodPost, server.URL(), nil)
Expand Down Expand Up @@ -103,7 +103,7 @@ func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() {

baseUrl := server.URL() + "/"

target := DefaultSanitizer(r)
target := defaultSanitizer(r)
target.AddUrlSanitizer(func(url *string) {
*url = strings.Replace(*url, secret, SanitizedValue, -1)
})
Expand Down
14 changes: 7 additions & 7 deletions sdk/internal/testframework/recording_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ func (s *recordingTests) TestRecordedVariables() {
assert.Nil(err)

// optional variables always succeed.
assert.Equal(expectedVariableValue, target.GetOptionalRecordedVariable(nonExistingEnvVar, expectedVariableValue, Default))
assert.Equal(expectedVariableValue, target.GetOptionalEnvVar(nonExistingEnvVar, expectedVariableValue, NoSanitization))

// non existent variables return an error
val, err := target.GetRecordedVariable(nonExistingEnvVar, Default)
val, err := target.GetEnvVar(nonExistingEnvVar, NoSanitization)
// mark test as succeeded
assert.Equal(envNotExistsError(nonExistingEnvVar), err.Error())

// now create the env variable and check that it can be fetched
os.Setenv(nonExistingEnvVar, expectedVariableValue)
defer os.Unsetenv(nonExistingEnvVar)
val, err = target.GetRecordedVariable(nonExistingEnvVar, Default)
val, err = target.GetEnvVar(nonExistingEnvVar, NoSanitization)
assert.Equal(expectedVariableValue, val)

err = target.Stop()
Expand All @@ -107,10 +107,10 @@ func (s *recordingTests) TestRecordedVariablesSanitized() {
assert.Nil(err)

// call GetOptionalRecordedVariable with the Secret_String VariableType arg
assert.Equal(secret, target.GetOptionalRecordedVariable(SanitizedStringVar, secret, Secret_String))
assert.Equal(secret, target.GetOptionalEnvVar(SanitizedStringVar, secret, Secret_String))

// call GetOptionalRecordedVariable with the Secret_Base64String VariableType arg
assert.Equal(secretBase64, target.GetOptionalRecordedVariable(SanitizedBase64StrigVar, secretBase64, Secret_Base64String))
assert.Equal(secretBase64, target.GetOptionalEnvVar(SanitizedBase64StrigVar, secretBase64, Secret_Base64String))

// Calling Stop will save the variables and apply the sanitization options
err = target.Stop()
Expand Down Expand Up @@ -143,7 +143,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables(
target, err := NewRecording(context, Playback)
assert.Nil(err)

target.GetOptionalRecordedVariable(expectedVariableName, expectedVariableValue, Default)
target.GetOptionalEnvVar(expectedVariableName, expectedVariableValue, NoSanitization)

err = target.Stop()
assert.Nil(err)
Expand All @@ -159,7 +159,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables(
assert.Nil(err)

// add a new variable to the existing batch
target2.GetOptionalRecordedVariable(addedVariableName, addedVariableValue, Default)
target2.GetOptionalEnvVar(addedVariableName, addedVariableValue, NoSanitization)

err = target2.Stop()
assert.Nil(err)
Expand Down
38 changes: 22 additions & 16 deletions sdk/internal/testframework/request_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ import (
)

type RequestMatcher struct {
context TestContext
IgnoredHeaders map[string]*string
context TestContext
// IgnoredHeaders is a map acting as a hash set of the header names that will be ignored for matching.
// Modifying the keys in the map will affect how headers are matched for recordings.
IgnoredHeaders map[string]struct{}
bodyMatcher StringMatcher
urlMatcher StringMatcher
methodMatcher StringMatcher
Expand All @@ -26,15 +28,15 @@ type RequestMatcher struct {
type StringMatcher func(reqVal string, recVal string) bool
type matcherWrapper func(matcher StringMatcher, testContext TestContext) bool

var ignoredHeaders = map[string]*string{
"Date": nil,
"X-Ms-Date": nil,
"x-ms-date": nil,
"x-ms-client-request-id": nil,
"User-Agent": nil,
"Request-Id": nil,
"traceparent": nil,
"Authorization": nil,
var ignoredHeaders = map[string]struct{}{
"Date": {},
"X-Ms-Date": {},
"x-ms-date": {},
"x-ms-client-request-id": {},
"User-Agent": {},
"Request-Id": {},
"traceparent": {},
"Authorization": {},
}

const (
Expand All @@ -46,25 +48,27 @@ const (
bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s"
)

func DefaultMatcher(testContext TestContext) *RequestMatcher {
// defaultMatcher returns a new RequestMatcher configured with the default matching behavior.
func defaultMatcher(testContext TestContext) *RequestMatcher {
// The default sanitizer sanitizes the Authorization header
matcher := &RequestMatcher{
context: testContext,
IgnoredHeaders: ignoredHeaders,
}
matcher.SetBodyMatcher(func(req string, rec string) bool {
return DefaultStringMatcher(req, rec)
return defaultStringMatcher(req, rec)
})
matcher.SetURLMatcher(func(req string, rec string) bool {
return DefaultStringMatcher(req, rec)
return defaultStringMatcher(req, rec)
})
matcher.SetMethodMatcher(func(req string, rec string) bool {
return DefaultStringMatcher(req, rec)
return defaultStringMatcher(req, rec)
})

return matcher
}

// SetBodyMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request body payload with the string value of the recorded body payload.
func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) {
m.bodyMatcher = func(reqVal string, recVal string) bool {
isMatch := matcher(reqVal, recVal)
Expand All @@ -75,6 +79,7 @@ func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) {
}
}

// SetURLMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request URL with the string value of the recorded URL
func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) {
m.urlMatcher = func(reqVal string, recVal string) bool {
isMatch := matcher(reqVal, recVal)
Expand All @@ -85,6 +90,7 @@ func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) {
}
}

// SetMethodMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request method with the string value of the recorded method
func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) {
m.methodMatcher = func(reqVal string, recVal string) bool {
isMatch := matcher(reqVal, recVal)
Expand All @@ -95,7 +101,7 @@ func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) {
}
}

func DefaultStringMatcher(s1 string, s2 string) bool {
func defaultStringMatcher(s1 string, s2 string) bool {
return s1 == s2
}

Expand Down
16 changes: 8 additions & 8 deletions sdk/internal/testframework/request_matcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const unMatchedBody string = "This body does not match."
func (s *requestMatcherTests) TestCompareBodies() {
assert := assert.New(s.T())
context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() })
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

req := http.Request{Body: closerFromString(matchedBody)}
recReq := cassette.Request{Body: matchedBody}
Expand All @@ -53,7 +53,7 @@ func (s *requestMatcherTests) TestCompareBodies() {
func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() {
assert := assert.New(s.T())
context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() })
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

// populate only ignored headers that do not match
reqHeaders := make(http.Header)
Expand All @@ -73,7 +73,7 @@ func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() {
func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() {
assert := assert.New(s.T())
context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() })
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

// populate only ignored headers that do not match
reqHeaders := make(http.Header)
Expand All @@ -93,7 +93,7 @@ func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() {
func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() {
assert := assert.New(s.T())
context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() })
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

// populate only ignored headers that do not match
reqHeaders := make(http.Header)
Expand All @@ -117,7 +117,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() {
func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() {
assert := assert.New(s.T())
context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() })
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

// populate only ignored headers that do not match
reqHeaders := make(http.Header)
Expand All @@ -141,7 +141,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() {
func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() {
assert := assert.New(s.T())
context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() })
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

// populate only ignored headers that do not match
reqHeaders := make(http.Header)
Expand Down Expand Up @@ -171,7 +171,7 @@ func (s *requestMatcherTests) TestCompareURLs() {
host := "foo.bar"
req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}}
recReq := cassette.Request{URL: scheme + "://" + host}
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

assert.True(matcher.compareURLs(&req, recReq.URL))

Expand All @@ -187,7 +187,7 @@ func (s *requestMatcherTests) TestCompareMethods() {
methodPatch := "PATCH"
req := http.Request{Method: methodPost}
recReq := cassette.Request{Method: methodPost}
matcher := DefaultMatcher(context)
matcher := defaultMatcher(context)

assert.True(matcher.compareMethods(&req, recReq.Method))

Expand Down

0 comments on commit 438d7e1

Please sign in to comment.