Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
fix: send digest msg instead of raw to aws kms
Browse files Browse the repository at this point in the history
Signed-off-by: Firas Qutishat <firas.qutishat@securekey.com>
  • Loading branch information
fqutishat committed Oct 22, 2022
1 parent 412f152 commit bdb2cfd
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 110 deletions.
116 changes: 77 additions & 39 deletions pkg/aws/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@ SPDX-License-Identifier: Apache-2.0
package aws

import (
"crypto/elliptic"
"crypto/sha512"
"encoding/asn1"
"fmt"
"hash"
"math/big"
"regexp"
"strings"
"time"

"github.com/btcsuite/btcd/btcec"
"github.com/minio/sha256-simd"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
Expand All @@ -35,18 +43,36 @@ type metricsProvider interface {
VerifyTime(value time.Duration)
}

type ecdsaSignature struct {
R, S *big.Int
}

// Service aws kms.
type Service struct {
client awsClient
metrics metricsProvider
healthCheckKeyID string
}

const (
signingAlgorithmEcdsaSha256 = "ECDSA_SHA_256"
signingAlgorithmEcdsaSha384 = "ECDSA_SHA_384"
signingAlgorithmEcdsaSha512 = "ECDSA_SHA_512"
)

// nolint: gochecknoglobals
var kmsKeyTypes = map[string]arieskms.KeyType{
"ECDSA_SHA_256": arieskms.ECDSAP256DER,
"ECDSA_SHA_384": arieskms.ECDSAP384DER,
"ECDSA_SHA_521": arieskms.ECDSAP521DER,
signingAlgorithmEcdsaSha256: arieskms.ECDSAP256DER,
signingAlgorithmEcdsaSha384: arieskms.ECDSAP384DER,
signingAlgorithmEcdsaSha512: arieskms.ECDSAP521DER,
}

// nolint: gochecknoglobals
var keySpecToCurve = map[string]elliptic.Curve{
kms.KeySpecEccNistP256: elliptic.P256(),
kms.KeySpecEccNistP384: elliptic.P384(),
kms.KeySpecEccNistP521: elliptic.P521(),
kms.KeySpecEccSecgP256k1: btcec.S256(),
}

// New return aws service.
Expand Down Expand Up @@ -78,10 +104,15 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
return nil, err
}

digest, err := hashMessage(msg, *describeKey.KeyMetadata.SigningAlgorithms[0])
if err != nil {
return nil, err
}

input := &kms.SignInput{
KeyId: aws.String(keyID),
Message: msg,
MessageType: aws.String("RAW"),
Message: digest,
MessageType: aws.String("DIGEST"),
SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0],
}

Expand All @@ -90,7 +121,28 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
return nil, err
}

return result.Signature, nil
ecdsaSignature := ecdsaSignature{}

_, err = asn1.Unmarshal(result.Signature, &ecdsaSignature)
if err != nil {
return result.Signature, nil
}

curveBits := keySpecToCurve[*describeKey.KeyMetadata.KeySpec].Params().BitSize

keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes++
}

copyPadded := func(source []byte, size int) []byte {
dest := make([]byte, size)
copy(dest[size-len(source):], source)

return dest
}

return append(copyPadded(ecdsaSignature.R.Bytes(), keyBytes), copyPadded(ecdsaSignature.S.Bytes(), keyBytes)...), nil
}

// Get key handle.
Expand Down Expand Up @@ -146,39 +198,7 @@ func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, er

// Verify signature.
func (s *Service) Verify(signature, msg []byte, kh interface{}) error {
startTime := time.Now()

defer func() {
if s.metrics != nil {
s.metrics.VerifyTime(time.Since(startTime))
}
}()

if s.metrics != nil {
s.metrics.VerifyCount()
}

keyID, err := getKeyID(kh.(string))
if err != nil {
return err
}

describeKey, err := s.client.DescribeKey(&kms.DescribeKeyInput{KeyId: &keyID})
if err != nil {
return err
}

input := &kms.VerifyInput{
KeyId: aws.String(keyID),
Message: msg,
MessageType: aws.String("RAW"),
Signature: signature,
SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0],
}

_, err = s.client.Verify(input)

return err
return fmt.Errorf("not implemented")
}

// Create key.
Expand Down Expand Up @@ -257,3 +277,21 @@ func getKeyID(keyURI string) (string, error) {

return r[4], nil
}

func hashMessage(message []byte, algorithm string) ([]byte, error) {
var digest hash.Hash

switch algorithm {
case signingAlgorithmEcdsaSha256:
digest = sha256.New()
case signingAlgorithmEcdsaSha384:
digest = sha512.New384()
case signingAlgorithmEcdsaSha512:
digest = sha512.New()
default:
return []byte{}, fmt.Errorf("unknown signing algorithm")
}

digest.Write(message)
return digest.Sum(nil), nil
}
71 changes: 0 additions & 71 deletions pkg/aws/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,77 +314,6 @@ func TestPubKeyBytes(t *testing.T) {
})
}

func TestVerify(t *testing.T) {
t.Run("success", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return &kms.VerifyOutput{}, nil
}, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
},
}, nil
}}

err = svc.Verify([]byte("sign"), []byte("data"),
"aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147")
require.NoError(t, err)
})

t.Run("failed to verify", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return nil, fmt.Errorf("failed to verify")
}, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
},
}, nil
}}

err = svc.Verify([]byte("data"), []byte("msg"),
"aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147")
require.Error(t, err)
require.Contains(t, err.Error(), "failed to verify")
})

t.Run("failed to parse key id", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

err = svc.Verify([]byte("sign"), []byte("msg"), "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
}

type mockAWSClient struct {
signFunc func(input *kms.SignInput) (*kms.SignOutput, error)
getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
Expand Down

0 comments on commit bdb2cfd

Please sign in to comment.