Skip to content

Commit

Permalink
Make GetHexEncoded return an error and improve error message format…
Browse files Browse the repository at this point in the history
…ting
  • Loading branch information
hslatman committed Apr 16, 2024
1 parent fac7036 commit 2fd7f84
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 36 deletions.
40 changes: 23 additions & 17 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,10 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.GetHexEncoded(HashArg)
sha1Hash, err := u.GetHexEncoded(HashArg)
if err != nil {
return nil, fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
}
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)
Expand All @@ -521,7 +524,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return nil, fmt.Errorf("invalid cert store location %v", storeLocation)
return nil, fmt.Errorf("invalid cert store location %q", storeLocation)
}

var storeName string
Expand All @@ -538,7 +541,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return nil, fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
}

var certHandle *windows.CertContext
Expand All @@ -564,7 +567,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", HashArg, keyID)}
}
defer windows.CertFreeCertificateContext(certHandle)
return certContextToX509(certHandle)
Expand All @@ -573,7 +576,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert

keyIDBytes, err := hex.DecodeString(keyID)
if err != nil {
return nil, fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err)
return nil, fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err)
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_KEY_IDENTIFIER,
Expand All @@ -591,7 +594,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)}
}
defer windows.CertFreeCertificateContext(certHandle)
return certContextToX509(certHandle)
Expand All @@ -605,13 +608,13 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
serialBytes, err = hex.DecodeString(serialNumber)
if err != nil {
return nil, fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err)
return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
}
} else {
bi := new(big.Int)
bi, ok := bi.SetString(serialNumber, 10)
if !ok {
return nil, fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg)
return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
}
serialBytes = bi.Bytes()
}
Expand All @@ -628,7 +631,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
}

if certHandle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q and %s=%q not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)}
}

x509Cert, err := certContextToX509(certHandle)
Expand All @@ -645,7 +648,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
prevCert = certHandle
}
default:
return nil, fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
return nil, fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
}
}

Expand All @@ -667,7 +670,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return fmt.Errorf("invalid cert store location %v", storeLocation)
return fmt.Errorf("invalid cert store location %q", storeLocation)
}

var storeName string
Expand Down Expand Up @@ -700,7 +703,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
}

// Add the cert context to the system certificate store
Expand All @@ -725,7 +728,10 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
return fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.GetHexEncoded(HashArg)
sha1Hash, err := u.GetHexEncoded(HashArg)
if err != nil {
return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
}
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)
Expand All @@ -742,7 +748,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return fmt.Errorf("invalid cert store location %v", storeLocation)
return fmt.Errorf("invalid cert store location %q", storeLocation)
}

var storeName string
Expand All @@ -757,7 +763,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
}

var certHandle *windows.CertContext
Expand Down Expand Up @@ -832,13 +838,13 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
serialBytes, err = hex.DecodeString(serialNumber)
if err != nil {
return fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err)
return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
}
} else {
bi := new(big.Int)
bi, ok := bi.SetString(serialNumber, 10)
if !ok {
return fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg)
return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
}
serialBytes = bi.Bytes()
}
Expand Down
21 changes: 12 additions & 9 deletions kms/uri/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package uri
import (
"bytes"
"encoding/hex"
"fmt"
"net/url"
"os"
"strconv"
Expand Down Expand Up @@ -157,19 +158,21 @@ func (u *URI) GetEncoded(key string) []byte {
return []byte(v)
}

// GetHexEncoded returns the first value in the uri with the given key, it will
// return nil if the field is not present, is empty, or is not hex encoded.
func (u *URI) GetHexEncoded(key string) []byte {
// GetHexEncoded returns the first value in the uri with the given key. It
// returns nil if the field is not present or is empty. It will return an
// error if the the value is not properly hex encoded.
func (u *URI) GetHexEncoded(key string) ([]byte, error) {
v := u.Get(key)
if v == "" {
return nil
return nil, nil
}
if len(v)%2 == 0 {
if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil {
return b
}

b, err := hex.DecodeString(strings.TrimPrefix(v, "0x"))
if err != nil {
return nil, fmt.Errorf("failed decoding %q: %w", v, err)
}
return nil

return b, nil
}

// Pin returns the pin encoded in the url. It will read the pin from the
Expand Down
28 changes: 18 additions & 10 deletions kms/uri/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,28 @@ func TestURI_GetHexEncoded(t *testing.T) {
key string
}
tests := []struct {
name string
uri *URI
args args
want []byte
name string
uri *URI
args args
want []byte
wantErr bool
}{
{"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}},
{"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}},
{"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}},
{"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil},
{"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil},
{"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}, false},
{"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false},
{"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false},
{"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil, false},
{"fail odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil, true},
{"fail invalid hex", mustParse(t, "capi:sha1=9z?bar=zar"), args{"sha1"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.uri.GetHexEncoded(tt.args.key)
got, err := tt.uri.GetHexEncoded(tt.args.key)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}

assert.Equal(t, tt.want, got)
})
}
Expand Down

0 comments on commit 2fd7f84

Please sign in to comment.