Skip to content

Commit

Permalink
Add support for CWT Claims & Type in Protected Headers (#189)
Browse files Browse the repository at this point in the history
Signed-off-by: steve lasker <stevenlasker@hotmail.com>
Co-authored-by: Orie Steele <orie@transmute.industries>
  • Loading branch information
SteveLasker and OR13 authored Jul 15, 2024
1 parent 96ea810 commit 8c458e2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 6 deletions.
20 changes: 20 additions & 0 deletions cwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package cose

// https://www.iana.org/assignments/cwt/cwt.xhtml#claims-registry
const (
CWTClaimIssuer int64 = 1
CWTClaimSubject int64 = 2
CWTClaimAudience int64 = 3
CWTClaimExpirationTime int64 = 4
CWTClaimNotBefore int64 = 5
CWTClaimIssuedAt int64 = 6
CWTClaimCWTID int64 = 7
CWTClaimConfirmation int64 = 8
CWTClaimScope int64 = 9

// TODO: the rest upon request
)

// CWTClaims contains parameters that are to be cryptographically
// protected.
type CWTClaims map[any]any
82 changes: 82 additions & 0 deletions cwt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package cose_test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"

"github.com/veraison/go-cose"
)

// This example demonstrates signing and verifying COSE_Sign1 signatures.
func ExampleCWTMessage() {
// create message to be signed
msgToSign := cose.NewSign1Message()
msgToSign.Payload = []byte("hello world")
msgToSign.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)

msgToSign.Headers.Protected.SetType("application/cwt")
claims := cose.CWTClaims{
cose.CWTClaimIssuer: "issuer.example",
cose.CWTClaimSubject: "subject.example",
}
msgToSign.Headers.Protected.SetCWTClaims(claims)

msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")

// create a signer
privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
panic(err)
}
signer, err := cose.NewSigner(cose.AlgorithmES512, privateKey)
if err != nil {
panic(err)
}

// sign message
err = msgToSign.Sign(rand.Reader, nil, signer)
if err != nil {
panic(err)
}
sig, err := msgToSign.MarshalCBOR()
// uncomment to review EDN
// coseSign1Diagnostic, err := cbor.Diagnose(sig)
// fmt.Println(coseSign1Diagnostic)
if err != nil {
panic(err)
}
fmt.Println("message signed")

// create a verifier from a trusted public key
publicKey := privateKey.Public()
verifier, err := cose.NewVerifier(cose.AlgorithmES512, publicKey)
if err != nil {
panic(err)
}

// verify message
var msgToVerify cose.Sign1Message
err = msgToVerify.UnmarshalCBOR(sig)
if err != nil {
panic(err)
}
err = msgToVerify.Verify(nil, verifier)
if err != nil {
panic(err)
}
fmt.Println("message verified")

// tamper the message and verification should fail
msgToVerify.Payload = []byte("foobar")
err = msgToVerify.Verify(nil, verifier)
if err != cose.ErrVerification {
panic(err)
}
fmt.Println("verification error as expected")
// Output:
// message signed
// message verified
// verification error as expected
}
57 changes: 51 additions & 6 deletions headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ const (
HeaderLabelCounterSignature0 int64 = 9
HeaderLabelCounterSignatureV2 int64 = 11
HeaderLabelCounterSignature0V2 int64 = 12
HeaderLabelCWTClaims int64 = 15
HeaderLabelType int64 = 16
HeaderLabelX5Bag int64 = 32
HeaderLabelX5Chain int64 = 33
HeaderLabelX5T int64 = 34
Expand Down Expand Up @@ -97,11 +99,35 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
return nil
}

// SetAlgorithm sets the algorithm value to the algorithm header.
// SetAlgorithm sets the algorithm value of the protected header.
func (h ProtectedHeader) SetAlgorithm(alg Algorithm) {
h[HeaderLabelAlgorithm] = alg
}

// SetType sets the type of the cose object in the protected header.
func (h ProtectedHeader) SetType(typ any) (any, error) {
if !canTstr(typ) && !canUint(typ) {
return typ, errors.New("header parameter: type: require tstr / uint type")
}
h[HeaderLabelType] = typ
return typ, nil
}

// SetCWTClaims sets the CWT Claims value of the protected header.
func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) (CWTClaims, error) {
iss, hasIss := claims[1]
if hasIss && !canTstr(iss) {
return claims, errors.New("cwt claim: iss: require tstr")
}
sub, hasSub := claims[2]
if hasSub && !canTstr(sub) {
return claims, errors.New("cwt claim: sub: require tstr")
}
// TODO: validate claims, other claims
h[HeaderLabelCWTClaims] = claims
return claims, nil
}

// Algorithm gets the algorithm value from the algorithm header.
func (h ProtectedHeader) Algorithm() (Algorithm, error) {
value, ok := h[HeaderLabelAlgorithm]
Expand Down Expand Up @@ -460,8 +486,8 @@ func validateHeaderParameters(h map[any]any, protected bool) error {
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
switch label {
case HeaderLabelAlgorithm:
_, is_alg := value.(Algorithm)
if !is_alg && !canInt(value) && !canTstr(value) {
_, isAlg := value.(Algorithm)
if !isAlg && !canInt(value) && !canTstr(value) {
return errors.New("header parameter: alg: require int / tstr type")
}
case HeaderLabelCritical:
Expand All @@ -471,12 +497,31 @@ func validateHeaderParameters(h map[any]any, protected bool) error {
if err := ensureCritical(value, h); err != nil {
return fmt.Errorf("header parameter: crit: %w", err)
}
case HeaderLabelType:
isTstr := canTstr(value)
if !isTstr && !canUint(value) {
return errors.New("header parameter: type: require tstr / uint type")
}
if isTstr {
v := value.(string)
if len(v) == 0 {
return errors.New("header parameter: type: require non-empty string")
}
if v[0] == ' ' || v[len(v)-1] == ' ' {
return errors.New("header parameter: type: require no leading/trailing whitespace")
}
// Basic check that the content type is of form type/subtype.
// We don't check the precise definition though (RFC 6838 Section 4.2).
if strings.Count(v, "/") != 1 {
return errors.New("header parameter: type: require text of form type/subtype")
}
}
case HeaderLabelContentType:
is_tstr := canTstr(value)
if !is_tstr && !canUint(value) {
isTstr := canTstr(value)
if !isTstr && !canUint(value) {
return errors.New("header parameter: content type: require tstr / uint type")
}
if is_tstr {
if isTstr {
v := value.(string)
if len(v) == 0 {
return errors.New("header parameter: content type: require non-empty string")
Expand Down

0 comments on commit 8c458e2

Please sign in to comment.