diff --git a/cbor.go b/cbor.go index 5f2ee91..15bdc54 100644 --- a/cbor.go +++ b/cbor.go @@ -82,3 +82,48 @@ func (s *byteString) UnmarshalCBOR(data []byte) error { } return decModeWithTagsForbidden.Unmarshal(data, (*[]byte)(s)) } + +// deterministicBinaryString converts a bstr into the deterministic encoding. +// +// Reference: https://www.rfc-editor.org/rfc/rfc9052.html#section-9 +func deterministicBinaryString(data cbor.RawMessage) (cbor.RawMessage, error) { + if len(data) == 0 { + return nil, io.EOF + } + if data[0]>>5 != 2 { // major type 2: bstr + return nil, errors.New("cbor: require bstr type") + } + + // fast path: return immediately if bstr is already deterministic + if err := decModeWithTagsForbidden.Valid(data); err != nil { + return nil, err + } + ai := data[0] & 0x1f + if ai < 24 { + return data, nil + } + switch ai { + case 24: + if data[1] >= 24 { + return data, nil + } + case 25: + if data[1] != 0 { + return data, nil + } + case 26: + if data[1] != 0 || data[2] != 0 { + return data, nil + } + case 27: + if data[1] != 0 || data[2] != 0 || data[3] != 0 || data[4] != 0 { + return data, nil + } + } + + // slow path: convert by re-encoding + // error checking is not required since `data` has been validataed + var s []byte + _ = decModeWithTagsForbidden.Unmarshal(data, &s) + return encMode.Marshal(s) +} diff --git a/cbor_test.go b/cbor_test.go index 2fdd6b9..00c8093 100644 --- a/cbor_test.go +++ b/cbor_test.go @@ -2,7 +2,10 @@ package cose import ( "bytes" + "reflect" "testing" + + "github.com/fxamacker/cbor/v2" ) func Test_byteString_UnmarshalCBOR(t *testing.T) { @@ -75,3 +78,130 @@ func Test_byteString_UnmarshalCBOR(t *testing.T) { }) } } + +func Test_deterministicBinaryString(t *testing.T) { + gen := func(initial []byte, size int) []byte { + data := make([]byte, size+len(initial)) + copy(data, initial) + return data + } + tests := []struct { + name string + data cbor.RawMessage + want cbor.RawMessage + wantErr bool + }{ + { + name: "empty input", + data: nil, + wantErr: true, + }, + { + name: "not bstr", + data: []byte{0x00}, + wantErr: true, + }, + { + name: "short length", + data: gen([]byte{0x57}, 23), + want: gen([]byte{0x57}, 23), + }, + { + name: "optimal uint8 length", + data: gen([]byte{0x58, 0x18}, 24), + want: gen([]byte{0x58, 0x18}, 24), + }, + { + name: "non-optimal uint8 length", + data: gen([]byte{0x58, 0x17}, 23), + want: gen([]byte{0x57}, 23), + }, + { + name: "optimal uint16 length", + data: gen([]byte{0x59, 0x01, 0x00}, 256), + want: gen([]byte{0x59, 0x01, 0x00}, 256), + }, + { + name: "non-optimal uint16 length, target short", + data: gen([]byte{0x59, 0x00, 0x17}, 23), + want: gen([]byte{0x57}, 23), + }, + { + name: "non-optimal uint16 length, target uint8", + data: gen([]byte{0x59, 0x00, 0x18}, 24), + want: gen([]byte{0x58, 0x18}, 24), + }, + { + name: "optimal uint32 length", + data: gen([]byte{0x5a, 0x00, 0x01, 0x00, 0x00}, 65536), + want: gen([]byte{0x5a, 0x00, 0x01, 0x00, 0x00}, 65536), + }, + { + name: "non-optimal uint32 length, target short", + data: gen([]byte{0x5a, 0x00, 0x00, 0x00, 0x17}, 23), + want: gen([]byte{0x57}, 23), + }, + { + name: "non-optimal uint32 length, target uint8", + data: gen([]byte{0x5a, 0x00, 0x00, 0x00, 0x18}, 24), + want: gen([]byte{0x58, 0x18}, 24), + }, + { + name: "non-optimal uint32 length, target uint16", + data: gen([]byte{0x5a, 0x00, 0x00, 0x01, 0x00}, 256), + want: gen([]byte{0x59, 0x01, 0x00}, 256), + }, + { + name: "non-optimal uint64 length, target short", + data: gen([]byte{0x5b, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x17, + }, 23), + want: gen([]byte{0x57}, 23), + }, + { + name: "non-optimal uint64 length, target uint8", + data: gen([]byte{0x5b, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x18, + }, 24), + want: gen([]byte{0x58, 0x18}, 24), + }, + { + name: "non-optimal uint64 length, target uint16", + data: gen([]byte{0x5b, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, 0x00, + }, 256), + want: gen([]byte{0x59, 0x01, 0x00}, 256), + }, + { + name: "non-optimal uint64 length, target uint32", + data: gen([]byte{0x5b, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, + }, 65536), + want: gen([]byte{0x5a, 0x00, 0x01, 0x00, 0x00}, 65536), + }, + { + name: "early EOF", + data: gen([]byte{0x5b, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, 42), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := deterministicBinaryString(tt.data) + if (err != nil) != tt.wantErr { + t.Errorf("deterministicBinaryString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("deterministicBinaryString() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sign.go b/sign.go index 54d9ef2..cd86d11 100644 --- a/sign.go +++ b/sign.go @@ -220,8 +220,16 @@ func (s *Signature) toBeSigned(bodyProtected cbor.RawMessage, payload, external // external_aad : bstr, // payload : bstr // ] + bodyProtected, err := deterministicBinaryString(bodyProtected) + if err != nil { + return nil, err + } var signProtected cbor.RawMessage - signProtected, err := s.Headers.MarshalProtected() + signProtected, err = s.Headers.MarshalProtected() + if err != nil { + return nil, err + } + signProtected, err = deterministicBinaryString(signProtected) if err != nil { return nil, err } diff --git a/sign1.go b/sign1.go index 3b2d111..56e884c 100644 --- a/sign1.go +++ b/sign1.go @@ -199,6 +199,10 @@ func (m *Sign1Message) toBeSigned(external []byte) ([]byte, error) { if err != nil { return nil, err } + protected, err = deterministicBinaryString(protected) + if err != nil { + return nil, err + } if external == nil { external = []byte{} } diff --git a/sign1_test.go b/sign1_test.go index 5a45fcf..f2590ca 100644 --- a/sign1_test.go +++ b/sign1_test.go @@ -5,6 +5,8 @@ import ( "crypto/rand" "reflect" "testing" + + "github.com/fxamacker/cbor/v2" ) func TestSign1Message_MarshalCBOR(t *testing.T) { @@ -837,3 +839,129 @@ func TestSign1Message_Verify(t *testing.T) { } }) } + +// TestSign1Message_Verify_issue119: non-minimal protected header length +func TestSign1Message_Verify_issue119(t *testing.T) { + // generate key and set up signer / verifier + alg := AlgorithmES256 + key := generateTestECDSAKey(t) + signer, err := NewSigner(alg, key) + if err != nil { + t.Fatalf("NewSigner() error = %v", err) + } + verifier, err := NewVerifier(alg, key.Public()) + if err != nil { + t.Fatalf("NewVerifier() error = %v", err) + } + + // generate message and sign + msg := &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + }, + Payload: []byte("hello"), + } + if err := msg.Sign(rand.Reader, nil, signer); err != nil { + t.Fatalf("Sign1Message.Sign() error = %v", err) + } + data, err := msg.MarshalCBOR() + if err != nil { + t.Fatalf("Sign1Message.MarshalCBOR() error = %v", err) + } + + // decanonicalize protected header + decanonicalize := func(data []byte) ([]byte, error) { + var content sign1Message + if err := decModeWithTagsForbidden.Unmarshal(data[1:], &content); err != nil { + return nil, err + } + + protected := make([]byte, len(content.Protected)+1) + copy(protected[2:], content.Protected[1:]) + protected[0] = 0x58 + protected[1] = content.Protected[0] & 0x1f + content.Protected = protected + + return encMode.Marshal(cbor.Tag{ + Number: CBORTagSign1Message, + Content: content, + }) + } + if data, err = decanonicalize(data); err != nil { + t.Fatalf("fail to decanonicalize: %v", err) + } + + // verify message + var decoded Sign1Message + if err = decoded.UnmarshalCBOR(data); err != nil { + t.Fatalf("Sign1Message.UnmarshalCBOR() error = %v", err) + } + if err := decoded.Verify(nil, verifier); err != nil { + t.Fatalf("Sign1Message.Verify() error = %v", err) + } +} + +func TestSign1Message_toBeSigned(t *testing.T) { + tests := []struct { + name string + m *Sign1Message + external []byte + want []byte + wantErr bool + }{ + { + name: "valid message", + m: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: algorithmMock, + }, + }, + Payload: []byte("hello world"), + }, + want: []byte{ + 0x84, // array type + 0x6a, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x31, // context + 0x47, 0xa1, 0x01, 0x3a, 0x6d, 0x6f, 0x63, 0x6a, // protected + 0x40, // external + 0x4b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, // payload + }, + }, + { + name: "invalid protected header", + m: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + 1.5: nil, + }, + }, + Payload: []byte{}, + }, + wantErr: true, + }, + { + name: "invalid raw protected header", + m: &Sign1Message{ + Headers: Headers{ + RawProtected: []byte{0x00}, + }, + Payload: []byte{}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.toBeSigned(tt.external) + if (err != nil) != tt.wantErr { + t.Errorf("Sign1Message.toBeSigned() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Sign1Message.toBeSigned() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sign_test.go b/sign_test.go index ee865ec..289ea97 100644 --- a/sign_test.go +++ b/sign_test.go @@ -2189,3 +2189,85 @@ func TestSignMessage_Verify(t *testing.T) { } }) } + +func TestSignature_toBeSigned(t *testing.T) { + tests := []struct { + name string + s *Signature + protected cbor.RawMessage + payload []byte + external []byte + want []byte + wantErr bool + }{ + { + name: "valid signature", + s: &Signature{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: algorithmMock, + }, + }, + }, + protected: []byte{0x40, 0xa1, 0x00, 0x00}, + payload: []byte("hello world"), + want: []byte{ + 0x85, // array type + 0x69, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, // context + 0x40, 0xa1, 0x00, 0x00, // body_protected + 0x47, 0xa1, 0x01, 0x3a, 0x6d, 0x6f, 0x63, 0x6a, // sign_protected + 0x40, // external + 0x4b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, // payload + }, + }, + { + name: "invalid body protected header", + s: &Signature{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES512, + }, + }, + }, + protected: []byte{0x00}, + payload: []byte{}, + wantErr: true, + }, + { + name: "invalid sign protected header", + s: &Signature{ + Headers: Headers{ + Protected: ProtectedHeader{ + 1.5: nil, + }, + }, + }, + protected: []byte{0x40}, + payload: []byte{}, + wantErr: true, + }, + { + name: "invalid raw sign protected header", + s: &Signature{ + Headers: Headers{ + RawProtected: []byte{0x00}, + }, + }, + protected: []byte{0x40}, + payload: []byte{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.s.toBeSigned(tt.protected, tt.payload, tt.external) + if (err != nil) != tt.wantErr { + t.Errorf("Signature.toBeSigned() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signature.toBeSigned() = %v, want %v", got, tt.want) + } + }) + } +}