From a63c3621cbeb5d4a3f5c05f393b27285afb76a78 Mon Sep 17 00:00:00 2001 From: Firas Qutishat Date: Fri, 13 Jan 2023 10:27:58 -0500 Subject: [PATCH] chore: update aws client to v2 Signed-off-by: Firas Qutishat --- go.mod | 7 +- go.sum | 14 +- pkg/aws/service.go | 71 +++++----- pkg/aws/service_test.go | 307 ++++++++++++++++------------------------ 4 files changed, 178 insertions(+), 221 deletions(-) diff --git a/go.mod b/go.mod index 609d6bd..82b1b91 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ go 1.18 require ( github.com/aws/aws-sdk-go v1.43.9 + github.com/aws/aws-sdk-go-v2/service/kms v1.20.0 github.com/btcsuite/btcd v0.22.1 github.com/golang/mock v1.6.0 github.com/google/tink/go v1.7.0 @@ -16,7 +17,6 @@ require ( github.com/hyperledger/aries-framework-go/component/storageutil v0.0.0-20220610133818-119077b0ec85 github.com/hyperledger/aries-framework-go/spi v0.0.0-20221025204933-b807371b6f1e github.com/igor-pavlenko/httpsignatures-go v0.0.23 - github.com/minio/sha256-simd v0.1.1 github.com/piprate/json-gold v0.4.2 github.com/prometheus/client_golang v1.11.0 github.com/rs/xid v1.3.0 @@ -29,6 +29,10 @@ require ( require ( github.com/VictoriaMetrics/fastcache v1.5.7 // indirect + github.com/aws/aws-sdk-go-v2 v1.17.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 // indirect + github.com/aws/smithy-go v1.13.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bluele/gcache v0.0.2 // indirect github.com/btcsuite/btcutil v1.0.3-0.20201208143702-a53e38424cce // indirect @@ -45,6 +49,7 @@ require ( github.com/kilic/bls12-381 v0.1.1-0.20210503002446-7b7597926c69 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect + github.com/minio/sha256-simd v0.1.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect github.com/multiformats/go-base32 v0.1.0 // indirect diff --git a/go.sum b/go.sum index ddab9b3..b3e0042 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,16 @@ github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax github.com/aws/aws-sdk-go v1.35.1/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/aws/aws-sdk-go v1.43.9 h1:k1S/29Bp2QD5ZopnGzIn0Sp63yyt3WH1JRE2OOU3Aig= github.com/aws/aws-sdk-go v1.43.9/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= +github.com/aws/aws-sdk-go-v2 v1.17.3 h1:shN7NlnVzvDUgPQ+1rLMSxY8OWRNDRYtiqe0p/PgrhY= +github.com/aws/aws-sdk-go-v2 v1.17.3/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 h1:I3cakv2Uy1vNmmhRQmFptYDxOvBnwCdNwyw63N0RaRU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27/go.mod h1:a1/UpzeyBBerajpnP5nGZa9mGzsBn5cOKxm6NWQsvoI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 h1:5NbbMrIzmUn/TXFqAle6mgrH5m9cOvMLRGL7pnG8tRE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21/go.mod h1:+Gxn8jYn5k9ebfHEqlhrMirFjSW0v0C9fI+KN5vk2kE= +github.com/aws/aws-sdk-go-v2/service/kms v1.20.0 h1:1mEQ1BVRfxU2KzcUUIzqDQ8p6yPkhzHrHT++sjtLJts= +github.com/aws/aws-sdk-go-v2/service/kms v1.20.0/go.mod h1:13sjgMH7Xu4e46+0BEDhSnNh+cImHSYS5PpBjV3oXcU= +github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= +github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -77,7 +87,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/tink/go v1.7.0 h1:6Eox8zONGebBFcCBqkVmt60LaWZa6xg1cl/DwAh/J1w= github.com/google/tink/go v1.7.0/go.mod h1:GAUOd+QE3pgj9q8VKIGTCP33c/B7eb4NhxLcgTJZStM= @@ -291,7 +302,6 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/pkg/aws/service.go b/pkg/aws/service.go index 745d9ca..9bb53e1 100644 --- a/pkg/aws/service.go +++ b/pkg/aws/service.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0 package aws import ( + "context" "crypto/elliptic" "crypto/sha256" "crypto/sha512" @@ -18,20 +19,24 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/btcsuite/btcd/btcec" arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" ) type awsClient interface { //nolint:dupl - Sign(input *kms.SignInput) (*kms.SignOutput, error) - GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) - Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error) - DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) - CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) - CreateAlias(input *kms.CreateAliasInput) (*kms.CreateAliasOutput, error) + Sign(ctx context.Context, params *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error) + GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, + optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) + Verify(ctx context.Context, params *kms.VerifyInput, optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) + DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) + CreateKey(ctx context.Context, params *kms.CreateKeyInput, + optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) + CreateAlias(ctx context.Context, params *kms.CreateAliasInput, + optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) } type metricsProvider interface { @@ -63,19 +68,19 @@ const ( ) // nolint: gochecknoglobals -var kmsKeyTypes = map[string]arieskms.KeyType{ +var kmsKeyTypes = map[types.SigningAlgorithmSpec]arieskms.KeyType{ signingAlgorithmEcdsaSha256: arieskms.ECDSAP256DER, signingAlgorithmEcdsaSha384: arieskms.ECDSAP384DER, signingAlgorithmEcdsaSha512: arieskms.ECDSAP521DER, } // nolint: gochecknoglobals -var keySpecToCurve = map[string]elliptic.Curve{ - kms.KeySpecEccSecgP256k1: btcec.S256(), +var keySpecToCurve = map[types.KeySpec]elliptic.Curve{ + types.KeySpecEccSecgP256k1: btcec.S256(), } // New return aws service. -func New(awsSession *session.Session, awsConfig *aws.Config, metrics metricsProvider, +func New(awsConfig aws.Config, metrics metricsProvider, healthCheckKeyID string, opts ...Opts) *Service { options := newOpts() @@ -85,7 +90,7 @@ func New(awsSession *session.Session, awsConfig *aws.Config, metrics metricsProv return &Service{ options: options, - client: kms.New(awsSession, awsConfig), + client: kms.NewFromConfig(awsConfig), metrics: metrics, healthCheckKeyID: healthCheckKeyID, } @@ -110,12 +115,12 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { //nolint: f return nil, err } - describeKey, err := s.client.DescribeKey(&kms.DescribeKeyInput{KeyId: &keyID}) + describeKey, err := s.client.DescribeKey(context.Background(), &kms.DescribeKeyInput{KeyId: &keyID}) if err != nil { return nil, err } - digest, err := hashMessage(msg, *describeKey.KeyMetadata.SigningAlgorithms[0]) + digest, err := hashMessage(msg, describeKey.KeyMetadata.SigningAlgorithms[0]) if err != nil { return nil, err } @@ -123,16 +128,16 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { //nolint: f input := &kms.SignInput{ KeyId: aws.String(keyID), Message: digest, - MessageType: aws.String("DIGEST"), + MessageType: types.MessageTypeDigest, SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0], } - result, err := s.client.Sign(input) + result, err := s.client.Sign(context.Background(), input) if err != nil { return nil, err } - if *describeKey.KeyMetadata.KeySpec == kms.KeySpecEccSecgP256k1 { + if describeKey.KeyMetadata.KeySpec == types.KeySpecEccSecgP256k1 { signature := ecdsaSignature{} _, err = asn1.Unmarshal(result.Signature, &signature) @@ -140,7 +145,7 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { //nolint: f return nil, err } - curveBits := keySpecToCurve[*describeKey.KeyMetadata.KeySpec].Params().BitSize + curveBits := keySpecToCurve[describeKey.KeyMetadata.KeySpec].Params().BitSize keyBytes := curveBits / bitSize if curveBits%bitSize > 0 { @@ -172,7 +177,7 @@ func (s *Service) HealthCheck() error { return err } - _, err = s.client.DescribeKey(&kms.DescribeKeyInput{KeyId: &keyID}) + _, err = s.client.DescribeKey(context.Background(), &kms.DescribeKeyInput{KeyId: &keyID}) if err != nil { return err } @@ -203,12 +208,12 @@ func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, er KeyId: aws.String(keyID), } - result, err := s.client.GetPublicKey(input) + result, err := s.client.GetPublicKey(context.Background(), input) if err != nil { return nil, "", err } - return result.PublicKey, kmsKeyTypes[*result.SigningAlgorithms[0]], nil + return result.PublicKey, kmsKeyTypes[result.SigningAlgorithms[0]], nil } // Verify signature. @@ -218,24 +223,25 @@ func (s *Service) Verify(signature, msg []byte, kh interface{}) error { // Create key. func (s *Service) Create(kt arieskms.KeyType) (string, interface{}, error) { - keyUsage := kms.KeyUsageTypeSignVerify + keyUsage := types.KeyUsageTypeSignVerify - keySpec := "" + var keySpec types.KeySpec switch string(kt) { case arieskms.ECDSAP256DER: - keySpec = kms.KeySpecEccNistP256 + keySpec = types.KeySpecEccNistP256 case arieskms.ECDSAP384DER: - keySpec = kms.KeySpecEccNistP384 + keySpec = types.KeySpecEccNistP384 case arieskms.ECDSAP521DER: - keySpec = kms.KeySpecEccNistP521 + keySpec = types.KeySpecEccNistP521 case arieskms.ECDSASecp256k1DER: - keySpec = kms.KeySpecEccSecgP256k1 + keySpec = types.KeySpecEccSecgP256k1 default: return "", nil, fmt.Errorf("key not supported %s", kt) } - result, err := s.client.CreateKey(&kms.CreateKeyInput{KeySpec: &keySpec, KeyUsage: &keyUsage}) + result, err := s.client.CreateKey(context.Background(), + &kms.CreateKeyInput{KeySpec: keySpec, KeyUsage: keyUsage}) if err != nil { return "", nil, err } @@ -244,7 +250,8 @@ func (s *Service) Create(kt arieskms.KeyType) (string, interface{}, error) { if strings.TrimSpace(aliasPrefix) != "" { aliasName := fmt.Sprintf("alias/%s-%s", aliasPrefix, *result.KeyMetadata.KeyId) - _, err = s.client.CreateAlias(&kms.CreateAliasInput{AliasName: &aliasName, TargetKeyId: result.KeyMetadata.KeyId}) + _, err = s.client.CreateAlias(context.Background(), + &kms.CreateAliasInput{AliasName: &aliasName, TargetKeyId: result.KeyMetadata.KeyId}) if err != nil { return "", nil, err } @@ -308,7 +315,7 @@ func (s *Service) getKeyID(keyURI string) (string, error) { return r[4], nil } -func hashMessage(message []byte, algorithm string) ([]byte, error) { +func hashMessage(message []byte, algorithm types.SigningAlgorithmSpec) ([]byte, error) { var digest hash.Hash switch algorithm { diff --git a/pkg/aws/service_test.go b/pkg/aws/service_test.go index 52afa24..f287936 100644 --- a/pkg/aws/service_test.go +++ b/pkg/aws/service_test.go @@ -7,43 +7,39 @@ SPDX-License-Identifier: Apache-2.0 package aws //nolint:testpackage import ( + "context" "fmt" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" - "github.com/stretchr/testify/require" -) -const ( - localhost = "http://localhost" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/stretchr/testify/require" ) func TestSign(t *testing.T) { - t.Run("success", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - - require.NoError(t, err) + awsConfig := aws.Config{ + Region: "ca", + } - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + t.Run("success", func(t *testing.T) { + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - svc.client = &mockAWSClient{signFunc: func(input *kms.SignInput) (*kms.SignOutput, error) { + svc.client = &mockAWSClient{signFunc: func(ctx context.Context, params *kms.SignInput, + optFns ...func(*kms.Options)) (*kms.SignOutput, error) { return &kms.SignOutput{ Signature: []byte("data"), }, nil - }, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { + }, describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { return &kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")}, - KeySpec: aws.String(kms.KeySpecEccNistP256), + KeyMetadata: &types.KeyMetadata{ + SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, + KeySpec: types.KeySpecEccNistP256, }, }, nil }} @@ -55,110 +51,83 @@ func TestSign(t *testing.T) { }) t.Run("failed to sign", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - svc.client = &mockAWSClient{signFunc: func(input *kms.SignInput) (*kms.SignOutput, error) { + svc.client = &mockAWSClient{signFunc: func(ctx context.Context, params *kms.SignInput, + optFns ...func(*kms.Options)) (*kms.SignOutput, error) { return nil, fmt.Errorf("failed to sign") - }, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { + }, describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { return &kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")}, + KeyMetadata: &types.KeyMetadata{ + SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, }, }, nil }} - _, err = svc.Sign([]byte("msg"), + _, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") require.Error(t, err) require.Contains(t, err.Error(), "failed to sign") }) t.Run("failed to parse key id", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) - - _, err = svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:key1") + _, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") }) } func TestHealthCheck(t *testing.T) { - t.Run("success", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) + awsConfig := aws.Config{ + Region: "ca", + } - svc := New(awsSession, nil, &mockMetrics{}, + t.Run("success", func(t *testing.T) { + svc := New(awsConfig, &mockMetrics{}, "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", []Opts{}...) - svc.client = &mockAWSClient{describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { + svc.client = &mockAWSClient{describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { return &kms.DescribeKeyOutput{}, nil }} - err = svc.HealthCheck() + err := svc.HealthCheck() require.NoError(t, err) }) t.Run("failed to list keys", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, + svc := New(awsConfig, &mockMetrics{}, "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", []Opts{}...) - svc.client = &mockAWSClient{describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { + svc.client = &mockAWSClient{describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { return nil, fmt.Errorf("failed to list keys") }} - err = svc.HealthCheck() + err := svc.HealthCheck() require.Error(t, err) require.Contains(t, err.Error(), "failed to list keys") }) } func TestCreate(t *testing.T) { - t.Run("success", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) + awsConfig := aws.Config{ + Region: "ca", + } - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + t.Run("success", func(t *testing.T) { + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) keyID := "key1" - svc.client = &mockAWSClient{createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil + svc.client = &mockAWSClient{createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, + optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil }} result, _, err := svc.Create(arieskms.ECDSAP256DER) @@ -167,23 +136,17 @@ func TestCreate(t *testing.T) { }) t.Run("success: with key alias prefix", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, "", WithKeyAliasPrefix("dummyKeyAlias")) + svc := New(awsConfig, &mockMetrics{}, "", WithKeyAliasPrefix("dummyKeyAlias")) keyID := "key1" svc.client = &mockAWSClient{ - createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil + createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, + optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil }, - createAliasFunc: func(input *kms.CreateAliasInput) (*kms.CreateAliasOutput, error) { + createAliasFunc: func(ctx context.Context, params *kms.CreateAliasInput, + optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { return &kms.CreateAliasOutput{}, nil }, } @@ -194,33 +157,21 @@ func TestCreate(t *testing.T) { }) t.Run("key not supported", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - _, _, err = svc.Create(arieskms.ED25519) + _, _, err := svc.Create(arieskms.ED25519) require.Error(t, err) require.Contains(t, err.Error(), "key not supported ED25519") }) } func TestGet(t *testing.T) { - t.Run("success", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) + awsConfig := aws.Config{ + Region: "ca", + } - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + t.Run("success", func(t *testing.T) { + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) keyID, err := svc.Get("key1") require.NoError(t, err) @@ -229,30 +180,26 @@ func TestGet(t *testing.T) { } func TestCreateAndPubKeyBytes(t *testing.T) { - t.Run("success", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) + awsConfig := aws.Config{ + Region: "ca", + } + t.Run("success", func(t *testing.T) { keyID := "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147" - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) svc.client = &mockAWSClient{ - getPublicKeyFunc: func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) { - signingAlgo := "ECDSA_SHA_256" - + getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, + optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return &kms.GetPublicKeyOutput{ PublicKey: []byte("publickey"), - SigningAlgorithms: []*string{&signingAlgo}, + SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, }, nil }, - createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil + createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, + optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil }, } @@ -264,39 +211,30 @@ func TestCreateAndPubKeyBytes(t *testing.T) { } func TestSignMulti(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) + awsConfig := aws.Config{ + Region: "ca", + } - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - _, err = svc.SignMulti(nil, nil) + _, err := svc.SignMulti(nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "not implemented") } func TestPubKeyBytes(t *testing.T) { - t.Run("success", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + awsConfig := aws.Config{ + Region: "ca", + } - svc.client = &mockAWSClient{getPublicKeyFunc: func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) { - signingAlgo := "ECDSA_SHA_256" + t.Run("success", func(t *testing.T) { + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, + optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return &kms.GetPublicKeyOutput{ PublicKey: []byte("publickey"), - SigningAlgorithms: []*string{&signingAlgo}, + SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, }, nil }} @@ -308,95 +246,92 @@ func TestPubKeyBytes(t *testing.T) { }) t.Run("failed to export public key", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - svc.client = &mockAWSClient{getPublicKeyFunc: func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) { + svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, + optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return nil, fmt.Errorf("failed to export public key") }} - _, _, err = svc.ExportPubKeyBytes( + _, _, err := svc.ExportPubKeyBytes( "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") require.Error(t, err) require.Contains(t, err.Error(), "failed to export public key") }) t.Run("failed to parse key id", func(t *testing.T) { - endpoint := localhost - awsSession, err := session.NewSession(&aws.Config{ - Endpoint: &endpoint, - Region: aws.String("ca"), - CredentialsChainVerboseErrors: aws.Bool(true), - }) - require.NoError(t, err) - - svc := New(awsSession, nil, &mockMetrics{}, "", []Opts{}...) + svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - _, _, err = svc.ExportPubKeyBytes("aws-kms://arn:aws:kms:key1") + _, _, err := svc.ExportPubKeyBytes("aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") }) } type mockAWSClient struct { - signFunc func(input *kms.SignInput) (*kms.SignOutput, error) - getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) - verifyFunc func(input *kms.VerifyInput) (*kms.VerifyOutput, error) - describeKeyFunc func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) - createKeyFunc func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) - createAliasFunc func(input *kms.CreateAliasInput) (*kms.CreateAliasOutput, error) + signFunc func(ctx context.Context, params *kms.SignInput, + optFns ...func(*kms.Options)) (*kms.SignOutput, error) + getPublicKeyFunc func(ctx context.Context, params *kms.GetPublicKeyInput, + optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) + verifyFunc func(ctx context.Context, params *kms.VerifyInput, + optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) + describeKeyFunc func(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) + createKeyFunc func(ctx context.Context, params *kms.CreateKeyInput, + optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) + createAliasFunc func(ctx context.Context, params *kms.CreateAliasInput, + optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) } -func (m *mockAWSClient) Sign(input *kms.SignInput) (*kms.SignOutput, error) { +func (m *mockAWSClient) Sign(ctx context.Context, params *kms.SignInput, + optFns ...func(*kms.Options)) (*kms.SignOutput, error) { if m.signFunc != nil { - return m.signFunc(input) + return m.signFunc(ctx, params, optFns...) } return nil, nil //nolint:nilnil } -func (m *mockAWSClient) GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) { +func (m *mockAWSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, + optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { if m.getPublicKeyFunc != nil { - return m.getPublicKeyFunc(input) + return m.getPublicKeyFunc(ctx, params, optFns...) } return nil, nil //nolint:nilnil } -func (m *mockAWSClient) Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error) { +func (m *mockAWSClient) Verify(ctx context.Context, params *kms.VerifyInput, + optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) { if m.verifyFunc != nil { - return m.verifyFunc(input) + return m.verifyFunc(ctx, params, optFns...) } return nil, nil //nolint:nilnil } -func (m *mockAWSClient) DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) { +func (m *mockAWSClient) DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, + optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { if m.describeKeyFunc != nil { - return m.describeKeyFunc(input) + return m.describeKeyFunc(ctx, params, optFns...) } return nil, nil //nolint:nilnil } -func (m *mockAWSClient) CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { +func (m *mockAWSClient) CreateKey(ctx context.Context, params *kms.CreateKeyInput, + optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { if m.createKeyFunc != nil { - return m.createKeyFunc(input) + return m.createKeyFunc(ctx, params, optFns...) } return nil, nil //nolint:nilnil } -func (m *mockAWSClient) CreateAlias(input *kms.CreateAliasInput) (*kms.CreateAliasOutput, error) { +func (m *mockAWSClient) CreateAlias(ctx context.Context, params *kms.CreateAliasInput, + optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { if m.createAliasFunc != nil { - return m.createAliasFunc(input) + return m.createAliasFunc(ctx, params, optFns...) } return nil, nil //nolint:nilnil