From 2fd7f84ea0917b1492ebb5f2f8a7387b3b1e8856 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 16 Apr 2024 17:11:13 +0200 Subject: [PATCH] Make `GetHexEncoded` return an error and improve error message formatting --- kms/capi/capi.go | 40 +++++++++++++++++++++++----------------- kms/uri/uri.go | 21 ++++++++++++--------- kms/uri/uri_test.go | 28 ++++++++++++++++++---------- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 8d5705c5..f266d8ef 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -503,7 +503,10 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.GetHexEncoded(HashArg) + sha1Hash, err := u.GetHexEncoded(HashArg) + if err != nil { + return nil, fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err) + } keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -521,7 +524,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert case "machine": certStoreLocation = certStoreLocalMachine default: - return nil, fmt.Errorf("invalid cert store location %v", storeLocation) + return nil, fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -538,7 +541,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return nil, fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } var certHandle *windows.CertContext @@ -564,7 +567,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", HashArg, keyID)} } defer windows.CertFreeCertificateContext(certHandle) return certContextToX509(certHandle) @@ -573,7 +576,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert keyIDBytes, err := hex.DecodeString(keyID) if err != nil { - return nil, fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err) + return nil, fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_KEY_IDENTIFIER, @@ -591,7 +594,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)} } defer windows.CertFreeCertificateContext(certHandle) return certContextToX509(certHandle) @@ -605,13 +608,13 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed serialBytes, err = hex.DecodeString(serialNumber) if err != nil { - return nil, fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err) + return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) } } else { bi := new(big.Int) bi, ok := bi.SetString(serialNumber, 10) if !ok { - return nil, fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg) + return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) } serialBytes = bi.Bytes() } @@ -628,7 +631,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q and %s=%q not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} } x509Cert, err := certContextToX509(certHandle) @@ -645,7 +648,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert prevCert = certHandle } default: - return nil, fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) + return nil, fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } } @@ -667,7 +670,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { case "machine": certStoreLocation = certStoreLocalMachine default: - return fmt.Errorf("invalid cert store location %v", storeLocation) + return fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -700,7 +703,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } // Add the cert context to the system certificate store @@ -725,7 +728,10 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.GetHexEncoded(HashArg) + sha1Hash, err := u.GetHexEncoded(HashArg) + if err != nil { + return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err) + } keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -742,7 +748,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { case "machine": certStoreLocation = certStoreLocalMachine default: - return fmt.Errorf("invalid cert store location %v", storeLocation) + return fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -757,7 +763,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } var certHandle *windows.CertContext @@ -832,13 +838,13 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed serialBytes, err = hex.DecodeString(serialNumber) if err != nil { - return fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err) + return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) } } else { bi := new(big.Int) bi, ok := bi.SetString(serialNumber, 10) if !ok { - return fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg) + return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) } serialBytes = bi.Bytes() } diff --git a/kms/uri/uri.go b/kms/uri/uri.go index df4a1f94..a3e325b8 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -3,6 +3,7 @@ package uri import ( "bytes" "encoding/hex" + "fmt" "net/url" "os" "strconv" @@ -157,19 +158,21 @@ func (u *URI) GetEncoded(key string) []byte { return []byte(v) } -// GetHexEncoded returns the first value in the uri with the given key, it will -// return nil if the field is not present, is empty, or is not hex encoded. -func (u *URI) GetHexEncoded(key string) []byte { +// GetHexEncoded returns the first value in the uri with the given key. It +// returns nil if the field is not present or is empty. It will return an +// error if the the value is not properly hex encoded. +func (u *URI) GetHexEncoded(key string) ([]byte, error) { v := u.Get(key) if v == "" { - return nil + return nil, nil } - if len(v)%2 == 0 { - if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { - return b - } + + b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")) + if err != nil { + return nil, fmt.Errorf("failed decoding %q: %w", v, err) } - return nil + + return b, nil } // Pin returns the pin encoded in the url. It will read the pin from the diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index f691516c..6688ee73 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -356,20 +356,28 @@ func TestURI_GetHexEncoded(t *testing.T) { key string } tests := []struct { - name string - uri *URI - args args - want []byte + name string + uri *URI + args args + want []byte + wantErr bool }{ - {"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}}, - {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}}, - {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}}, - {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil}, - {"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil}, + {"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}, false}, + {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, + {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, + {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil, false}, + {"fail odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil, true}, + {"fail invalid hex", mustParse(t, "capi:sha1=9z?bar=zar"), args{"sha1"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.uri.GetHexEncoded(tt.args.key) + got, err := tt.uri.GetHexEncoded(tt.args.key) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + assert.Equal(t, tt.want, got) }) }