diff --git a/sdk/tables/aztable/shared_policy_shared_key_credential.go b/sdk/tables/aztable/shared_policy_shared_key_credential.go index c9c2d7beeb1a..ca011f0b5240 100644 --- a/sdk/tables/aztable/shared_policy_shared_key_credential.go +++ b/sdk/tables/aztable/shared_policy_shared_key_credential.go @@ -53,19 +53,15 @@ func (c *SharedKeyCredential) SetAccountKey(accountKey string) error { } // computeHMACSHA256 generates a hash signature for an HTTP request or for a SAS. -func (c *SharedKeyCredential) ComputeHMACSHA256(message string) (base64String string) { +func (c *SharedKeyCredential) ComputeHMACSHA256(message string) (string, error) { h := hmac.New(sha256.New, c.accountKey.Load().([]byte)) - h.Write([]byte(message)) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) + _, err := h.Write([]byte(message)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)), err } func (c *SharedKeyCredential) buildStringToSign(req *http.Request) (string, error) { // https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services headers := req.Header - contentLength := headers.Get(azcore.HeaderContentLength) - if contentLength == "0" { - contentLength = "" - } canonicalizedResource, err := c.buildCanonicalizedResource(req.URL) if err != nil { @@ -79,6 +75,7 @@ func (c *SharedKeyCredential) buildStringToSign(req *http.Request) (string, erro return stringToSign, nil } +//nolint func (c *SharedKeyCredential) buildCanonicalizedHeader(headers http.Header) string { cm := map[string][]string{} for k, v := range headers { @@ -105,7 +102,7 @@ func (c *SharedKeyCredential) buildCanonicalizedHeader(headers http.Header) stri ch.WriteRune(':') ch.WriteString(strings.Join(cm[key], ",")) } - return string(ch.Bytes()) + return ch.String() } func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, error) { @@ -133,7 +130,7 @@ func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, er //do something here cr.WriteString("?" + "comp=" + compVal[0]) } - return string(cr.Bytes()), nil + return cr.String(), nil } // AuthenticationPolicy implements the Credential interface on SharedKeyCredential. @@ -147,7 +144,10 @@ func (c *SharedKeyCredential) AuthenticationPolicy(azcore.AuthenticationPolicyOp if err != nil { return nil, err } - signature := c.ComputeHMACSHA256(stringToSign) + signature, err := c.ComputeHMACSHA256(stringToSign) + if err != nil { + return nil, err + } authHeader := strings.Join([]string{"SharedKeyLite ", c.AccountName(), ":", signature}, "") req.Request.Header.Set(azcore.HeaderAuthorization, authHeader) diff --git a/sdk/tables/aztable/table_client.go b/sdk/tables/aztable/table_client.go index 21f9959f7f79..88f199249a78 100644 --- a/sdk/tables/aztable/table_client.go +++ b/sdk/tables/aztable/table_client.go @@ -74,7 +74,7 @@ func (t *TableClient) GetEntity(ctx context.Context, partitionKey string, rowKey if err != nil { return resp, err } - castAndRemoveAnnotations(&resp.Value) + err = castAndRemoveAnnotations(&resp.Value) return resp, err } diff --git a/sdk/tables/aztable/table_client_test.go b/sdk/tables/aztable/table_client_test.go index d5e297971a67..98b0cd8fa1ca 100644 --- a/sdk/tables/aztable/table_client_test.go +++ b/sdk/tables/aztable/table_client_test.go @@ -6,6 +6,7 @@ package aztable import ( "bytes" "errors" + "fmt" "io/ioutil" "net/http" "testing" @@ -180,7 +181,7 @@ func (s *tableClientLiveTests) TestUpsertEntity() { assert.Equalf(postMerge[mergeProp], val, "%s property should equal %s", mergeProp, val) } -func (s *tableClientLiveTests) _TestGetEntity() { +func (s *tableClientLiveTests) TestGetEntity() { assert := assert.New(s.T()) require := require.New(s.T()) client, delete := s.init(true) @@ -233,7 +234,8 @@ func (s *tableClientLiveTests) TestQuerySimpleEntity() { for pager.NextPage(ctx) { resp = pager.PageResponse() models = make([]simpleEntity, len(resp.TableEntityQueryResponse.Value)) - resp.TableEntityQueryResponse.AsModels(&models) + err := resp.TableEntityQueryResponse.AsModels(&models) + assert.Nil(err) assert.Equal(len(resp.TableEntityQueryResponse.Value), expectedCount) } resp = pager.PageResponse() @@ -442,7 +444,8 @@ func (s *tableClientLiveTests) TestBatchError() { assert.Equal(error_empty_transaction, err.Error()) // Add the last entity to the table prior to adding it as part of the batch to cause a batch failure. - client.AddEntity(ctx, (*entitiesToCreate)[2]) + _, err = client.AddEntity(ctx, (*entitiesToCreate)[2]) + assert.Nil(err) // Add the entities to the batch for i := 0; i < cap(batch); i++ { @@ -498,7 +501,10 @@ func (s *tableClientLiveTests) init(doCreate bool) (*TableClient, func()) { } } return client, func() { - client.Delete(ctx) + _, err := client.Delete(ctx) + if err != nil { + fmt.Printf("Error deleting table. %v\n", err.Error()) + } } } @@ -515,7 +521,7 @@ func getStringFromBody(e *runtime.ResponseError) string { if err != nil { return "" } - b = ioutil.NopCloser(&body) + _ = ioutil.NopCloser(&body) } return body.String() } diff --git a/sdk/tables/aztable/table_pagers.go b/sdk/tables/aztable/table_pagers.go index c5c54d408f4d..1f275ca15f21 100644 --- a/sdk/tables/aztable/table_pagers.go +++ b/sdk/tables/aztable/table_pagers.go @@ -176,7 +176,7 @@ func (p *tableQueryResponsePager) Err() error { func castAndRemoveAnnotationsSlice(entities *[]map[string]interface{}) { for _, e := range *entities { - castAndRemoveAnnotations(&e) + castAndRemoveAnnotations(&e) //nolint:errcheck } } @@ -212,7 +212,7 @@ func castAndRemoveAnnotations(entity *map[string]interface{}) error { } value[valueKey] = i default: - return errors.New(fmt.Sprintf("unsupported annotation found: %s", k)) + return fmt.Errorf("unsupported annotation found: %s", k) } // remove the annotation key delete(value, k) @@ -249,7 +249,7 @@ func toOdataAnnotatedDictionary(entity *map[string]interface{}) error { entMap[k] = time.UTC().Format(ISO8601) continue default: - return errors.New(fmt.Sprintf("Invalid struct for entity field '%s' of type '%s'", k, tn)) + return fmt.Errorf("Invalid struct for entity field '%s' of type '%s'", k, tn) } case reflect.Float32, reflect.Float64: entMap[odataType(k)] = edmDouble @@ -321,7 +321,7 @@ func toMap(ent interface{}) (*map[string]interface{}, error) { entMap[name] = time.UTC().Format(ISO8601) continue default: - return nil, errors.New(fmt.Sprintf("Invalid struct for entity field '%s' of type '%s'", typeOfT.Field(i).Name, tn)) + return nil, fmt.Errorf("Invalid struct for entity field '%s' of type '%s'", typeOfT.Field(i).Name, tn) } case reflect.Float32, reflect.Float64: entMap[odataType(name)] = edmDouble diff --git a/sdk/tables/aztable/table_pagers_test.go b/sdk/tables/aztable/table_pagers_test.go index 23641cb50dec..7bb923567b83 100644 --- a/sdk/tables/aztable/table_pagers_test.go +++ b/sdk/tables/aztable/table_pagers_test.go @@ -21,8 +21,6 @@ import ( "github.com/stretchr/testify/assert" ) -type pagerTests struct{} - func TestCastAndRemoveAnnotations(t *testing.T) { assert := assert.New(t) @@ -75,8 +73,10 @@ func BenchmarkUnMarshal_AsJson_CastAndRemove_Map(b *testing.B) { bt := []byte(complexPayload) for i := 0; i < b.N; i++ { var val = make(map[string]interface{}) - json.Unmarshal(bt, &val) - castAndRemoveAnnotations(&val) + err := json.Unmarshal(bt, &val) + assert.Nil(err) + err = castAndRemoveAnnotations(&val) + assert.Nil(err) assert.Equal("somePartition", val["PartitionKey"]) } } @@ -87,28 +87,41 @@ func BenchmarkUnMarshal_FromMap_Entity(b *testing.B) { bt := []byte(complexPayload) for i := 0; i < b.N; i++ { var val = make(map[string]interface{}) - json.Unmarshal(bt, &val) + err := json.Unmarshal(bt, &val) + if err != nil { + panic(err) + } result := complexEntity{} - err := EntityMapAsModel(val, &result) + err = EntityMapAsModel(val, &result) assert.Nil(err) assert.Equal("somePartition", result.PartitionKey) } } +func check(e error) { + if e != nil { + panic(e) + } +} + func BenchmarkMarshal_Entity_ToMap_ToOdataDict_Map(b *testing.B) { ent := createComplexEntity() for i := 0; i < b.N; i++ { m, _ := toMap(ent) - toOdataAnnotatedDictionary(m) - json.Marshal(m) + err := toOdataAnnotatedDictionary(m) + check(err) + _, err = json.Marshal(m) + check(err) } } func BenchmarkMarshal_Map_ToOdataDict_Map(b *testing.B) { ent := createComplexEntityMap() for i := 0; i < b.N; i++ { - toOdataAnnotatedDictionary(&ent) - json.Marshal(ent) + err := toOdataAnnotatedDictionary(&ent) + check(err) + _, err = json.Marshal(ent) + check(err) } } @@ -180,11 +193,12 @@ func TestDeserializeFromMap(t *testing.T) { expected := createComplexEntity() bt := []byte(complexPayload) var val = make(map[string]interface{}) - json.Unmarshal(bt, &val) + err := json.Unmarshal(bt, &val) + assert.Nil(err) result := complexEntity{} // tt := reflect.TypeOf(complexEntity{}) // err := fromMap(tt, getTypeValueMap(tt), &val, reflect.ValueOf(&result).Elem()) - err := EntityMapAsModel(val, &result) + err = EntityMapAsModel(val, &result) assert.Nil(err) assert.EqualValues(expected, result) } diff --git a/sdk/tables/aztable/table_service_client.go b/sdk/tables/aztable/table_service_client.go index c74dbff76e66..7907eef6963b 100644 --- a/sdk/tables/aztable/table_service_client.go +++ b/sdk/tables/aztable/table_service_client.go @@ -75,8 +75,6 @@ func (t *TableServiceClient) Query(queryOptions *QueryOptions) TableQueryRespons } func isCosmosEndpoint(url string) bool { - isCosmosEmulator := strings.Index(url, "localhost") >= 0 && strings.Index(url, "8902") >= 0 - return isCosmosEmulator || - strings.Index(url, CosmosTableDomain) >= 0 || - strings.Index(url, LegacyCosmosTableDomain) >= 0 + isCosmosEmulator := strings.Contains(url, "localhost") && strings.Contains(url, "8902") + return isCosmosEmulator || strings.Contains(url, CosmosTableDomain) || strings.Contains(url, LegacyCosmosTableDomain) } diff --git a/sdk/tables/aztable/table_service_client_test.go b/sdk/tables/aztable/table_service_client_test.go index 62af3cfd42b0..bfba4026dd94 100644 --- a/sdk/tables/aztable/table_service_client_test.go +++ b/sdk/tables/aztable/table_service_client_test.go @@ -37,10 +37,17 @@ func (s *tableServiceClientLiveTests) TestServiceErrors() { assert := assert.New(s.T()) context := getTestContext(s.T().Name()) tableName, err := getTableName(context) + failIfNotNil(assert, err) _, err = context.client.Create(ctx, tableName) - defer context.client.Delete(ctx, tableName) - assert.Nil(err) + delete := func() { + _, err := context.client.Delete(ctx, tableName) + if err != nil { + fmt.Printf("Error cleaning up test. %v\n", err.Error()) + } + } + defer delete() + failIfNotNil(assert, err) // Create a duplicate table to produce an error _, err = context.client.Create(ctx, tableName) @@ -53,11 +60,18 @@ func (s *tableServiceClientLiveTests) TestCreateTable() { assert := assert.New(s.T()) context := getTestContext(s.T().Name()) tableName, err := getTableName(context) + failIfNotNil(assert, err) resp, err := context.client.Create(ctx, tableName) - defer context.client.Delete(ctx, tableName) + delete := func() { + _, err := context.client.Delete(ctx, tableName) + if err != nil { + fmt.Printf("Error cleaning up test. %v\n", err.Error()) + } + } + defer delete() - assert.Nil(err) + failIfNotNil(assert, err) assert.Equal(*resp.TableResponse.TableName, tableName) } diff --git a/sdk/tables/aztable/table_transactional_batch.go b/sdk/tables/aztable/table_transactional_batch.go index 82801a8ca12f..3336f3b84943 100644 --- a/sdk/tables/aztable/table_transactional_batch.go +++ b/sdk/tables/aztable/table_transactional_batch.go @@ -96,6 +96,7 @@ type TableSubmitTransactionOptions struct { RequestID *string } +//nolint var defaultChangesetHeaders = map[string]string{ "Accept": "application/json;odata=minimalmetadata", "Content-Type": "application/json", @@ -131,17 +132,26 @@ func (t *TableClient) submitTransactionInternal(ctx context.Context, transaction boundary := fmt.Sprintf("batch_%s", batchUuid.String()) body := new(bytes.Buffer) writer := multipart.NewWriter(body) - writer.SetBoundary(boundary) + err = writer.SetBoundary(boundary) + if err != nil { + return TableTransactionResponse{}, err + } h := make(textproto.MIMEHeader) h.Set(headerContentType, fmt.Sprintf("multipart/mixed; boundary=%s", changesetBoundary)) batchWriter, err := writer.CreatePart(h) if err != nil { return TableTransactionResponse{}, err } - batchWriter.Write(changeSetBody.Bytes()) + _, err = batchWriter.Write(changeSetBody.Bytes()) + if err != nil { + return TableTransactionResponse{}, err + } writer.Close() - req.SetBody(azcore.NopCloser(bytes.NewReader(body.Bytes())), fmt.Sprintf("multipart/mixed; boundary=%s", boundary)) + err = req.SetBody(azcore.NopCloser(bytes.NewReader(body.Bytes())), fmt.Sprintf("multipart/mixed; boundary=%s", boundary)) + if err != nil { + return TableTransactionResponse{}, err + } resp, err := t.client.con.Pipeline().Do(req) if err != nil { @@ -198,8 +208,8 @@ func buildTransactionResponse(req *azcore.Request, resp *azcore.Response, itemCo } outerBoundary := getBoundaryName(bytesBody) mpReader := multipart.NewReader(reader, outerBoundary) - outerPart, err := mpReader.NextPart() - innerBytes, err := ioutil.ReadAll(outerPart) + outerPart, err := mpReader.NextPart() //nolint There is an error here + innerBytes, err := ioutil.ReadAll(outerPart) //nolint There is an error here innerBoundary := getBoundaryName(innerBytes) reader = bytes.NewReader(innerBytes) mpReader = multipart.NewReader(reader, innerBoundary) @@ -258,7 +268,10 @@ func (t *TableClient) generateChangesetBody(changesetBoundary string, transactio body := new(bytes.Buffer) writer := multipart.NewWriter(body) - writer.SetBoundary(changesetBoundary) + err := writer.SetBoundary(changesetBoundary) + if err != nil { + return nil, err + } for _, be := range *transactionActions { err := t.generateEntitySubset(&be, writer) @@ -300,47 +313,81 @@ func (t *TableClient) generateEntitySubset(transactionAction *TableTransactionAc switch transactionAction.ActionType { case Delete: req, err = t.client.deleteEntityCreateRequest(ctx, t.Name, entity[partitionKey].(string), entity[rowKey].(string), transactionAction.ETag, &TableDeleteEntityOptions{}, qo) + if err != nil { + return err + } case Add: - toOdataAnnotatedDictionary(&entity) + err = toOdataAnnotatedDictionary(&entity) + if err != nil { + return err + } req, err = t.client.insertEntityCreateRequest(ctx, t.Name, &TableInsertEntityOptions{TableEntityProperties: entity, ResponsePreference: ResponseFormatReturnNoContent.ToPtr()}, qo) + if err != nil { + return err + } case UpdateMerge: fallthrough case UpsertMerge: - toOdataAnnotatedDictionary(&entity) + err = toOdataAnnotatedDictionary(&entity) + if err != nil { + return err + } opts := &TableMergeEntityOptions{TableEntityProperties: entity} if len(transactionAction.ETag) > 0 { opts.IfMatch = &transactionAction.ETag } req, err = t.client.mergeEntityCreateRequest(ctx, t.Name, entity[partitionKey].(string), entity[rowKey].(string), opts, qo) + if err != nil { + return err + } if isCosmosEndpoint(t.client.con.Endpoint()) { transformPatchToCosmosPost(req) } case UpdateReplace: fallthrough case UpsertReplace: - toOdataAnnotatedDictionary(&entity) + err = toOdataAnnotatedDictionary(&entity) + if err != nil { + return err + } req, err = t.client.updateEntityCreateRequest(ctx, t.Name, entity[partitionKey].(string), entity[rowKey].(string), &TableUpdateEntityOptions{TableEntityProperties: entity, IfMatch: &transactionAction.ETag}, qo) + if err != nil { + return err + } } urlAndVerb := fmt.Sprintf("%s %s HTTP/1.1\r\n", req.Method, req.URL) - operationWriter.Write([]byte(urlAndVerb)) - writeHeaders(req.Header, &operationWriter) - operationWriter.Write([]byte("\r\n")) // additional \r\n is needed per changeset separating the "headers" and the body. + _, err = operationWriter.Write([]byte(urlAndVerb)) + if err != nil { + return err + } + err = writeHeaders(req.Header, &operationWriter) + if err != nil { + return err + } + _, err = operationWriter.Write([]byte("\r\n")) // additional \r\n is needed per changeset separating the "headers" and the body. + if err != nil { + return err + } if req.Body != nil { - io.Copy(operationWriter, req.Body) + _, err = io.Copy(operationWriter, req.Body) + } - return nil + return err } -func writeHeaders(h http.Header, writer *io.Writer) { +func writeHeaders(h http.Header, writer *io.Writer) error { // This way it is guaranteed the headers will be written in a sorted order var keys []string for k := range h { keys = append(keys, k) } sort.Strings(keys) + var err error for _, k := range keys { - (*writer).Write([]byte(fmt.Sprintf("%s: %s\r\n", k, h.Get(k)))) + _, err = (*writer).Write([]byte(fmt.Sprintf("%s: %s\r\n", k, h.Get(k)))) + } + return err } diff --git a/sdk/tables/aztable/zc_table_constants.go b/sdk/tables/aztable/zc_table_constants.go index ee29206d8195..cd265edf81fb 100644 --- a/sdk/tables/aztable/zc_table_constants.go +++ b/sdk/tables/aztable/zc_table_constants.go @@ -5,6 +5,7 @@ package aztable import "errors" +//nolint const ( timestamp = "Timestamp" partitionKey = "PartitionKey" diff --git a/sdk/tables/aztable/zt_table_recorded_tests.go b/sdk/tables/aztable/zt_table_recorded_tests.go index 8af8521fec02..d39371813473 100644 --- a/sdk/tables/aztable/zt_table_recorded_tests.go +++ b/sdk/tables/aztable/zt_table_recorded_tests.go @@ -15,8 +15,6 @@ import ( "github.com/stretchr/testify/assert" ) -type tablesRecordedTests struct{} - type testContext struct { recording *recording.Recording client *TableServiceClient @@ -53,6 +51,12 @@ func cosmosURI(accountName string, endpointSuffix string) string { return "https://" + accountName + ".table." + endpointSuffix } +func failIfNotNil(a *assert.Assertions, e error) { + if e != nil { + a.FailNow(e.Error()) + } +} + // create the test specific TableClient and wire it up to recordings func recordedTestSetup(t *testing.T, testName string, endpointType EndpointType, mode recording.RecordMode) { var accountName string @@ -69,15 +73,21 @@ func recordedTestSetup(t *testing.T, testName string, endpointType EndpointType, if endpointType == StorageEndpoint { accountName, err = r.GetRecordedVariable(storageAccountNameEnvVar, recording.Default) + failIfNotNil(assert, err) suffix = r.GetOptionalRecordedVariable(storageEndpointSuffixEnvVar, DefaultStorageSuffix, recording.Default) secret, err = r.GetRecordedVariable(storageAccountKeyEnvVar, recording.Secret_Base64String) - cred, _ = NewSharedKeyCredential(accountName, secret) + failIfNotNil(assert, err) + cred, err = NewSharedKeyCredential(accountName, secret) + failIfNotNil(assert, err) uri = storageURI(accountName, suffix) } else { accountName, err = r.GetRecordedVariable(cosmosAccountNameEnnVar, recording.Default) + failIfNotNil(assert, err) suffix = r.GetOptionalRecordedVariable(cosmosEndpointSuffixEnvVar, DefaultCosmosSuffix, recording.Default) secret, err = r.GetRecordedVariable(cosmosAccountKeyEnvVar, recording.Secret_Base64String) - cred, _ = NewSharedKeyCredential(accountName, secret) + failIfNotNil(assert, err) + cred, err = NewSharedKeyCredential(accountName, secret) + failIfNotNil(assert, err) uri = cosmosURI(accountName, suffix) } @@ -89,7 +99,10 @@ func recordedTestSetup(t *testing.T, testName string, endpointType EndpointType, func recordedTestTeardown(key string) { context, ok := clientsMap[key] if ok && !(*context.context).IsFailed() { - context.recording.Stop() + err := context.recording.Stop() + if err != nil { + fmt.Printf("Error tearing down tests. %v\n", err.Error()) + } } } @@ -100,12 +113,18 @@ func cleanupTables(context *testContext, tables *[]string) { pager := c.Query(nil) for pager.NextPage(ctx) { for _, t := range pager.PageResponse().TableQueryResponse.Value { - c.Delete(ctx, *t.TableName) + _, err := c.Delete(ctx, *t.TableName) + if err != nil { + fmt.Printf("Error cleaning up tables. %v\n", err.Error()) + } } } } else { for _, t := range *tables { - c.Delete(ctx, t) + _, err := c.Delete(ctx, t) + if err != nil { + fmt.Printf("There was an error cleaning up tests. %v\n", err.Error()) + } } } }