diff --git a/README.md b/README.md index feaf8c4..74a8f77 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ go get github.com/veraison/go-cose@main import "github.com/veraison/go-cose" ``` -Construct a new COSE_Sign1 message, then sign it using ECDSA w/ SHA-256 and finally marshal it. For example: +Construct a new COSE_Sign1_Tagged message, then sign it using ECDSA w/ SHA-256 and finally marshal it. For example: ```go package main @@ -102,7 +102,7 @@ func SignP256(data []byte) ([]byte, error) { } ``` -Verify a raw COSE_Sign1 message. For example: +Verify a raw COSE_Sign1_Tagged message. For example: ```go package main @@ -132,6 +132,11 @@ func VerifyP256(publicKey crypto.PublicKey, sig []byte) error { See [example_test.go](./example_test.go) for more examples. +#### Untagged Signing and Verification + +Untagged COSE_Sign1 messages can be signed and verified as above, using +`cose.UntaggedSign1Message` instead of `cose.Sign1Message`. + ### About hashing `go-cose` does not import any hash package by its own to avoid linking unnecessary algorithms to the final binary. diff --git a/example_test.go b/example_test.go index d00a383..07ea624 100644 --- a/example_test.go +++ b/example_test.go @@ -139,7 +139,7 @@ func ExampleSign1Message() { // verification error as expected } -// This example demonstrates signing COSE_Sign1 signatures using Sign1(). +// This example demonstrates signing COSE_Sign1_Tagged signatures using Sign1(). func ExampleSign1() { // create a signer privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) @@ -170,3 +170,35 @@ func ExampleSign1() { // Output: // message signed } + +// This example demonstrates signing COSE_Sign1 signatures using Sign1Untagged(). +func ExampleSign1Untagged() { + // create a signer + privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + panic(err) + } + signer, err := cose.NewSigner(cose.AlgorithmES512, privateKey) + if err != nil { + panic(err) + } + + // sign message + headers := cose.Headers{ + Protected: cose.ProtectedHeader{ + cose.HeaderLabelAlgorithm: cose.AlgorithmES512, + }, + Unprotected: cose.UnprotectedHeader{ + cose.HeaderLabelKeyID: []byte("1"), + }, + } + sig, err := cose.Sign1Untagged(rand.Reader, signer, headers, []byte("hello world"), nil) + if err != nil { + panic(err) + } + + fmt.Println("message signed") + _ = sig // further process on sig + // Output: + // message signed +} diff --git a/sign1.go b/sign1.go index 56e884c..d85420c 100644 --- a/sign1.go +++ b/sign1.go @@ -52,22 +52,11 @@ func NewSign1Message() *Sign1Message { // MarshalCBOR encodes Sign1Message into a COSE_Sign1_Tagged object. func (m *Sign1Message) MarshalCBOR() ([]byte, error) { - if m == nil { - return nil, errors.New("cbor: MarshalCBOR on nil Sign1Message pointer") - } - if len(m.Signature) == 0 { - return nil, ErrEmptySignature - } - protected, unprotected, err := m.Headers.marshal() + content, err := m.getContent() if err != nil { return nil, err } - content := sign1Message{ - Protected: protected, - Unprotected: unprotected, - Payload: m.Payload, - Signature: m.Signature, - } + return encMode.Marshal(cbor.Tag{ Number: CBORTagSign1Message, Content: content, @@ -85,28 +74,7 @@ func (m *Sign1Message) UnmarshalCBOR(data []byte) error { return errors.New("cbor: invalid COSE_Sign1_Tagged object") } - // decode to sign1Message and parse - var raw sign1Message - if err := decModeWithTagsForbidden.Unmarshal(data[1:], &raw); err != nil { - return err - } - if len(raw.Signature) == 0 { - return ErrEmptySignature - } - msg := Sign1Message{ - Headers: Headers{ - RawProtected: raw.Protected, - RawUnprotected: raw.Unprotected, - }, - Payload: raw.Payload, - Signature: raw.Signature, - } - if err := msg.Headers.UnmarshalFromRaw(); err != nil { - return err - } - - *m = msg - return nil + return m.doUnmarshal(data[1:]) } // Sign signs a Sign1Message using the provided Signer. @@ -218,6 +186,53 @@ func (m *Sign1Message) toBeSigned(external []byte) ([]byte, error) { return encMode.Marshal(sigStructure) } +func (m *Sign1Message) getContent() (sign1Message, error) { + if m == nil { + return sign1Message{}, errors.New("cbor: MarshalCBOR on nil Sign1Message pointer") + } + if len(m.Signature) == 0 { + return sign1Message{}, ErrEmptySignature + } + protected, unprotected, err := m.Headers.marshal() + if err != nil { + return sign1Message{}, err + } + + content := sign1Message{ + Protected: protected, + Unprotected: unprotected, + Payload: m.Payload, + Signature: m.Signature, + } + + return content, nil +} + +func (m *Sign1Message) doUnmarshal(data []byte) error { + // decode to sign1Message and parse + var raw sign1Message + if err := decModeWithTagsForbidden.Unmarshal(data, &raw); err != nil { + return err + } + if len(raw.Signature) == 0 { + return ErrEmptySignature + } + msg := Sign1Message{ + Headers: Headers{ + RawProtected: raw.Protected, + RawUnprotected: raw.Unprotected, + }, + Payload: raw.Payload, + Signature: raw.Signature, + } + if err := msg.Headers.UnmarshalFromRaw(); err != nil { + return err + } + + *m = msg + return nil +} + // Sign1 signs a Sign1Message using the provided Signer. // // This method is a wrapper of `Sign1Message.Sign()`. @@ -234,3 +249,67 @@ func Sign1(rand io.Reader, signer Signer, headers Headers, payload []byte, exter } return msg.MarshalCBOR() } + +type UntaggedSign1Message Sign1Message + +// MarshalCBOR encodes UntaggedSign1Message into a COSE_Sign1 object. +func (m *UntaggedSign1Message) MarshalCBOR() ([]byte, error) { + content, err := (*Sign1Message)(m).getContent() + if err != nil { + return nil, err + } + + return encMode.Marshal(content) +} + +// UnmarshalCBOR decodes a COSE_Sign1 object into an UnataggedSign1Message. +func (m *UntaggedSign1Message) UnmarshalCBOR(data []byte) error { + if m == nil { + return errors.New("cbor: UnmarshalCBOR on nil UntaggedSign1Message pointer") + } + + // fast message check - ensure the frist byte indicates a four-element array + if data[0] != sign1MessagePrefix[1] { + return errors.New("cbor: invalid COSE_Sign1 object") + } + + return (*Sign1Message)(m).doUnmarshal(data) +} + +// Sign signs an UnttaggedSign1Message using the provided Signer. +// The signature is stored in m.Signature. +// +// Note that m.Signature is only valid as long as m.Headers.Protected and +// m.Payload remain unchanged after calling this method. +// It is possible to modify m.Headers.Unprotected after signing, +// i.e., add counter signatures or timestamps. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func (m *UntaggedSign1Message) Sign(rand io.Reader, external []byte, signer Signer) error { + return (*Sign1Message)(m).Sign(rand, external, signer) +} + +// Verify verifies the signature on the UntaggedSign1Message returning nil on success or +// a suitable error if verification fails. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func (m *UntaggedSign1Message) Verify(external []byte, verifier Verifier) error { + return (*Sign1Message)(m).Verify(external, verifier) +} + +// Sign1Untagged signs an UntaggedSign1Message using the provided Signer. +// +// This method is a wrapper of `UntaggedSign1Message.Sign()`. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func Sign1Untagged(rand io.Reader, signer Signer, headers Headers, payload []byte, external []byte) ([]byte, error) { + msg := UntaggedSign1Message{ + Headers: headers, + Payload: payload, + } + err := msg.Sign(rand, external, signer) + if err != nil { + return nil, err + } + return msg.MarshalCBOR() +} diff --git a/sign1_test.go b/sign1_test.go index c876f49..d7063e0 100644 --- a/sign1_test.go +++ b/sign1_test.go @@ -978,3 +978,114 @@ func TestSign1Message_toBeSigned(t *testing.T) { }) } } + +func TestUntaggedSign1Message_MarshalCBOR(t *testing.T) { + tests := []struct { + name string + m *UntaggedSign1Message + want []byte + wantErr string + }{ + { + name: "valid message", + m: &UntaggedSign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelContentType: 42, + }, + }, + Payload: []byte("foo"), + Signature: []byte("bar"), + }, + want: []byte{ + 0x84, + 0x43, 0xa1, 0x01, 0x26, // protected + 0xa1, 0x03, 0x18, 0x2a, // unprotected + 0x43, 0x66, 0x6f, 0x6f, // payload + 0x43, 0x62, 0x61, 0x72, // signature + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.MarshalCBOR() + + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("UntaggedSign1Message.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { + t.Errorf("UntaggedSign1Message.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UntaggedSign1Message.MarshalCBOR() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUntaggedSign1Message_UnmarshalCBOR(t *testing.T) { + // test others + tests := []struct { + name string + data []byte + want UntaggedSign1Message + wantErr string + }{ + { + name: "valid message", + data: []byte{ + 0x84, + 0x43, 0xa1, 0x01, 0x26, // protected + 0xa1, 0x03, 0x18, 0x2a, // unprotected + 0x43, 0x66, 0x6f, 0x6f, // payload + 0x43, 0x62, 0x61, 0x72, // signature + }, + want: UntaggedSign1Message{ + Headers: Headers{ + RawProtected: []byte{0x43, 0xa1, 0x01, 0x26}, + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + RawUnprotected: []byte{0xa1, 0x03, 0x18, 0x2a}, + Unprotected: UnprotectedHeader{ + HeaderLabelContentType: int64(42), + }, + }, + Payload: []byte("foo"), + Signature: []byte("bar"), + }, + }, + { + name: "tagged message", + data: []byte{ + 0xd2, // tag + 0x84, + 0x43, 0xa1, 0x01, 0x26, // protected + 0xa1, 0x03, 0x18, 0x2a, // unprotected + 0x43, 0x66, 0x6f, 0x6f, // payload + 0x43, 0x62, 0x61, 0x72, // signature + }, + wantErr: "cbor: invalid COSE_Sign1 object", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got UntaggedSign1Message + err := got.UnmarshalCBOR(tt.data) + if (err != nil) && (err.Error() != tt.wantErr) { + t.Errorf("Sign1Message.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { + t.Errorf("Sign1Message.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Sign1Message.UnmarshalCBOR() = %v, want %v", got, tt.want) + } + }) + } +}