Skip to content

Commit

Permalink
Add GetHexEncoded method to uri package
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Apr 15, 2024
1 parent 75696c3 commit fac7036
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
30 changes: 12 additions & 18 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.Get(HashArg)
sha1Hash := u.GetHexEncoded(HashArg)
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)
Expand Down Expand Up @@ -544,18 +544,15 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
var certHandle *windows.CertContext

switch {
case sha1Hash != "":
sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial

sha1Bytes, err := hex.DecodeString(sha1Hash)
if err != nil {
return nil, fmt.Errorf("%s must be in hex format: %w", HashArg, err)
case len(sha1Hash) > 0:
if len(sha1Hash) != 20 {
return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_SHA1_HASH,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(sha1Bytes)),
data: uintptr(unsafe.Pointer(&sha1Bytes[0])),
len: uint32(len(sha1Hash)),
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
},
}
certHandle, err = findCertificateInStore(st,
Expand Down Expand Up @@ -728,7 +725,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
return fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.Get(HashArg)
sha1Hash := u.GetHexEncoded(HashArg)
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)
Expand Down Expand Up @@ -766,18 +763,15 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
var certHandle *windows.CertContext

switch {
case sha1Hash != "":
sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial

sha1Bytes, err := hex.DecodeString(sha1Hash)
if err != nil {
return fmt.Errorf("%s must be in hex format: %w", HashArg, err)
case len(sha1Hash) > 0:
if len(sha1Hash) != 20 {
return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_SHA1_HASH,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(sha1Bytes)),
data: uintptr(unsafe.Pointer(&sha1Bytes[0])),
len: uint32(len(sha1Hash)),
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
},
}
certHandle, err = findCertificateInStore(st,
Expand Down
17 changes: 16 additions & 1 deletion kms/uri/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,28 @@ func (u *URI) GetEncoded(key string) []byte {
return nil
}
if len(v)%2 == 0 {
if b, err := hex.DecodeString(v); err == nil {
if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil {
return b
}
}
return []byte(v)
}

// GetHexEncoded returns the first value in the uri with the given key, it will
// return nil if the field is not present, is empty, or is not hex encoded.
func (u *URI) GetHexEncoded(key string) []byte {
v := u.Get(key)
if v == "" {
return nil
}
if len(v)%2 == 0 {
if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil {
return b
}
}
return nil
}

// Pin returns the pin encoded in the url. It will read the pin from the
// pin-value or the pin-source attributes.
func (u *URI) Pin() string {
Expand Down
32 changes: 32 additions & 0 deletions kms/uri/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNew(t *testing.T) {
Expand Down Expand Up @@ -237,6 +238,7 @@ func TestURI_GetEncoded(t *testing.T) {
want []byte
}{
{"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, []byte{0x9a}},
{"ok prefix", mustParse("yubikey:slot-id=0x9a"), args{"slot-id"}, []byte{0x9a}},
{"ok first", mustParse("yubikey:slot-id=9a9b;slot-id=9b"), args{"slot-id"}, []byte{0x9a, 0x9b}},
{"ok percent", mustParse("yubikey:slot-id=9a;foo=%9a%9b%9c"), args{"foo"}, []byte{0x9a, 0x9b, 0x9c}},
{"ok in query", mustParse("yubikey:slot-id=9a?foo=9a"), args{"foo"}, []byte{0x9a}},
Expand Down Expand Up @@ -342,3 +344,33 @@ func TestURI_GetInt(t *testing.T) {
})
}
}

func TestURI_GetHexEncoded(t *testing.T) {
mustParse := func(t *testing.T, s string) *URI {
t.Helper()
u, err := Parse(s)
require.NoError(t, err)
return u
}
type args struct {
key string
}
tests := []struct {
name string
uri *URI
args args
want []byte
}{
{"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}},
{"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}},
{"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}},
{"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil},
{"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.uri.GetHexEncoded(tt.args.key)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit fac7036

Please sign in to comment.