From 73030794f7aaf4f614486b511908841852807936 Mon Sep 17 00:00:00 2001 From: Nick Harper Date: Tue, 6 Aug 2024 21:00:01 +0000 Subject: [PATCH] Add DTLS 1.3 sequence number encryption Bug: 715 Change-Id: I87f8a08e9a2258dede21cffb1cfde5802608d30d Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/70667 Reviewed-by: Bob Beck Commit-Queue: Bob Beck --- ssl/dtls_record.cc | 75 ++++++++++++++++++++++----------- ssl/internal.h | 64 +++++++++++++++++++++++++++++ ssl/ssl_aead_ctx.cc | 85 ++++++++++++++++++++++++++++++++++++++ ssl/test/runner/conn.go | 91 +++++++++++++++++++++++++++++++++++++---- ssl/test/runner/dtls.go | 35 +++++++++++++--- ssl/tls13_enc.cc | 31 ++++++++++---- 6 files changed, 336 insertions(+), 45 deletions(-) diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc index 783e950af9..a83d6b19ca 100644 --- a/ssl/dtls_record.cc +++ b/ssl/dtls_record.cc @@ -195,7 +195,7 @@ uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask, return seqnum; } -static bool parse_dtls13_record_header(SSL *ssl, CBS *in, size_t packet_size, +static bool parse_dtls13_record_header(SSL *ssl, CBS *in, Span packet, uint8_t type, CBS *out_body, uint64_t *out_sequence, uint16_t *out_epoch, @@ -206,29 +206,23 @@ static bool parse_dtls13_record_header(SSL *ssl, CBS *in, size_t packet_size, // Connection ID bit set, which we didn't negotiate. return false; } + // TODO(crbug.com/boringssl/715): Add a runner test that performs many // key updates to verify epoch reconstruction works for epochs larger than // 3. *out_epoch = reconstruct_epoch(type, ssl->d1->r_epoch); + size_t seqlen = 1; if ((type & 0x08) == 0x08) { - // 16-bit sequence number. - uint16_t seq; - if (!CBS_get_u16(in, &seq)) { - // The record header was incomplete or malformed. - return false; - } - *out_sequence = - reconstruct_seqnum(seq, 0xffff, ssl->d1->bitmap.max_seq_num); - } else { - // 8-bit sequence number. - uint8_t seq; - if (!CBS_get_u8(in, &seq)) { - // The record header was incomplete or malformed. - return false; - } - *out_sequence = reconstruct_seqnum(seq, 0xff, ssl->d1->bitmap.max_seq_num); + // If this bit is set, the sequence number is 16 bits long, otherwise it is + // 8 bits. The seqlen variable tracks the length of the sequence number in + // bytes. + seqlen = 2; + } + if (!CBS_skip(in, seqlen)) { + // The record header was incomplete or malformed. + return false; } - *out_header_len = packet_size - CBS_len(in); + *out_header_len = packet.size() - CBS_len(in); if ((type & 0x04) == 0x04) { *out_header_len += 2; // 16-bit length present @@ -244,6 +238,26 @@ static bool parse_dtls13_record_header(SSL *ssl, CBS *in, size_t packet_size, return false; } } + + // Decrypt and reconstruct the sequence number: + uint8_t mask[AES_BLOCK_SIZE]; + SSLAEADContext *aead = ssl->s3->aead_read_ctx.get(); + if (!aead->GenerateRecordNumberMask(mask, *out_body)) { + // GenerateRecordNumberMask most likely failed because the record body was + // not long enough. + return false; + } + // Apply the mask to the sequence number as it exists in the header. The + // header (with the decrypted sequence number bytes) is used as the + // additional data for the AEAD function. Since we don't support Connection + // ID, the sequence number starts immediately after the type byte. + uint64_t seq = 0; + for (size_t i = 0; i < seqlen; i++) { + packet[i + 1] ^= mask[i]; + seq = (seq << 8) | packet[i + 1]; + } + *out_sequence = reconstruct_seqnum(seq, (1 << (seqlen * 8)) - 1, + ssl->d1->bitmap.max_seq_num); return true; } @@ -321,9 +335,8 @@ enum ssl_open_record_t dtls_open_record(SSL *ssl, uint8_t *out_type, // records use the old record header format. if ((type & 0xe0) == 0x20 && !aead->is_null_cipher() && aead->ProtocolVersion() >= TLS1_3_VERSION) { - valid_record_header = - parse_dtls13_record_header(ssl, &cbs, in.size(), type, &body, &sequence, - &epoch, &record_header_len); + valid_record_header = parse_dtls13_record_header( + ssl, &cbs, in, type, &body, &sequence, &epoch, &record_header_len); } else { valid_record_header = parse_dtls_plaintext_record_header( ssl, &cbs, in.size(), type, &body, &sequence, &epoch, @@ -539,8 +552,24 @@ bool dtls_seal_record(SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out, return false; } - // TODO(crbug.com/boringssl/715): Perform record number encryption (RFC 9147 - // section 4.2.3). + // Perform record number encryption (RFC 9147 section 4.2.3). + if (dtls13_header) { + // Record number encryption uses bytes from the ciphertext as a sample to + // generate the mask used for encryption. For simplicity, pass in the whole + // ciphertext as the sample - GenerateRecordNumberMask will read only what + // it needs (and error if |sample| is too short). + Span sample = + MakeConstSpan(out + record_header_len, ciphertext_len); + // AES cipher suites require the mask be exactly AES_BLOCK_SIZE; ChaCha20 + // cipher suites have no requirements on the mask size. We only need the + // first two bytes from the mask. + uint8_t mask[AES_BLOCK_SIZE]; + if (!aead->GenerateRecordNumberMask(mask, sample)) { + return false; + } + out[1] ^= mask[0]; + out[2] ^= mask[1]; + } (*seq)++; *out_len = record_header_len + ciphertext_len; diff --git a/ssl/internal.h b/ssl/internal.h index febb676c90..e6518286f1 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -155,6 +155,7 @@ #include #include +#include #include #include #include @@ -811,6 +812,16 @@ bool tls1_prf(const EVP_MD *digest, Span out, // Encryption layer. +class RecordNumberEncrypter { + public: + virtual ~RecordNumberEncrypter() = default; + static constexpr bool kAllowUniquePtr = true; + + virtual size_t KeySize() = 0; + virtual bool SetKey(Span key) = 0; + virtual bool GenerateMask(Span out, Span sample) = 0; +}; + // SSLAEADContext contains information about an AEAD that is being used to // encrypt an SSL connection. class SSLAEADContext { @@ -916,6 +927,17 @@ class SSLAEADContext { bool GetIV(const uint8_t **out_iv, size_t *out_iv_len) const; + RecordNumberEncrypter *GetRecordNumberEncrypter() { + return rn_encrypter_.get(); + } + + // GenerateRecordNumberMask computes the mask used for DTLS 1.3 record number + // encryption (RFC 9147 section 4.2.3), writing it to |out|. The |out| buffer + // must be sized to AES_BLOCK_SIZE. The |sample| buffer must be at least 16 + // bytes, as required by the AES and ChaCha20 cipher suites in RFC 9147. Extra + // bytes in |sample| will be ignored. + bool GenerateRecordNumberMask(Span out, Span sample); + private: // GetAdditionalData returns the additional data, writing into |storage| if // necessary. @@ -924,6 +946,8 @@ class SSLAEADContext { uint64_t seqnum, size_t plaintext_len, Span header); + void CreateRecordNumberEncrypter(); + const SSL_CIPHER *cipher_; ScopedEVP_AEAD_CTX ctx_; // fixed_nonce_ contains any bytes of the nonce that are fixed for all @@ -932,6 +956,7 @@ class SSLAEADContext { uint8_t fixed_nonce_len_ = 0, variable_nonce_len_ = 0; // version_ is the wire version that should be used with this AEAD. uint16_t version_; + UniquePtr rn_encrypter_; // is_dtls_ is whether DTLS is being used with this AEAD. bool is_dtls_; // variable_nonce_included_in_record_ is true if the variable nonce @@ -951,6 +976,45 @@ class SSLAEADContext { bool ad_is_header_ : 1; }; +class AESRecordNumberEncrypter : public RecordNumberEncrypter { + public: + bool SetKey(Span key) override; + bool GenerateMask(Span out, Span sample) override; + + private: + AES_KEY key_; +}; + +class AES128RecordNumberEncrypter : public AESRecordNumberEncrypter { + public: + size_t KeySize() override; +}; + +class AES256RecordNumberEncrypter : public AESRecordNumberEncrypter { + public: + size_t KeySize() override; +}; + +class ChaChaRecordNumberEncrypter : public RecordNumberEncrypter { + public: + size_t KeySize() override; + bool SetKey(Span key) override; + bool GenerateMask(Span out, Span sample) override; + + private: + static const size_t kKeySize = 32; + uint8_t key_[kKeySize]; +}; + +#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) +class NullRecordNumberEncrypter : public RecordNumberEncrypter { + public: + size_t KeySize() override; + bool SetKey(Span key) override; + bool GenerateMask(Span out, Span sample) override; +}; +#endif // BORINGSSL_UNSAFE_FUZZER_MODE + // DTLS replay bitmap. diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc index 85617a4c5a..4f532e90db 100644 --- a/ssl/ssl_aead_ctx.cc +++ b/ssl/ssl_aead_ctx.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -44,6 +45,7 @@ SSLAEADContext::SSLAEADContext(uint16_t version_arg, bool is_dtls_arg, omit_length_in_ad_(false), ad_is_header_(false) { OPENSSL_memset(fixed_nonce_, 0, sizeof(fixed_nonce_)); + CreateRecordNumberEncrypter(); } SSLAEADContext::~SSLAEADContext() {} @@ -145,6 +147,23 @@ UniquePtr SSLAEADContext::Create( return aead_ctx; } +void SSLAEADContext::CreateRecordNumberEncrypter() { + if (!cipher_) { + return; + } +#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) + rn_encrypter_ = MakeUnique(); +#else + if (cipher_->algorithm_enc == SSL_AES128GCM) { + rn_encrypter_ = MakeUnique(); + } else if (cipher_->algorithm_enc == SSL_AES256GCM) { + rn_encrypter_ = MakeUnique(); + } else if (cipher_->algorithm_enc == SSL_CHACHA20POLY1305) { + rn_encrypter_ = MakeUnique(); + } +#endif // BORINGSSL_UNSAFE_FUZZER_MODE +} + UniquePtr SSLAEADContext::CreatePlaceholderForQUIC( uint16_t version, const SSL_CIPHER *cipher) { return MakeUnique(version, false, cipher); @@ -427,4 +446,70 @@ bool SSLAEADContext::GetIV(const uint8_t **out_iv, size_t *out_iv_len) const { EVP_AEAD_CTX_get_iv(ctx_.get(), out_iv, out_iv_len); } +bool SSLAEADContext::GenerateRecordNumberMask(Span out, + Span sample) { + if (!rn_encrypter_) { + return false; + } + return rn_encrypter_->GenerateMask(out, sample); +} + +size_t AES128RecordNumberEncrypter::KeySize() { return 16; } + +size_t AES256RecordNumberEncrypter::KeySize() { return 32; } + +bool AESRecordNumberEncrypter::SetKey(Span key) { + return AES_set_encrypt_key(key.data(), key.size() * 8, &key_) == 0; +} + +bool AESRecordNumberEncrypter::GenerateMask(Span out, + Span sample) { + if (sample.size() < AES_BLOCK_SIZE || out.size() != AES_BLOCK_SIZE) { + return false; + } + AES_encrypt(sample.data(), out.data(), &key_); + return true; +} + +size_t ChaChaRecordNumberEncrypter::KeySize() { return kKeySize; } + +bool ChaChaRecordNumberEncrypter::SetKey(Span key) { + if (key.size() != kKeySize) { + return false; + } + OPENSSL_memcpy(key_, key.data(), key.size()); + return true; +} + +bool ChaChaRecordNumberEncrypter::GenerateMask(Span out, + Span sample) { + Array zeroes; + if (!zeroes.Init(out.size())) { + return false; + } + OPENSSL_memset(zeroes.data(), 0, zeroes.size()); + // RFC 9147 section 4.2.3 uses the first 4 bytes of the sample as the counter + // and the next 12 bytes as the nonce. If we have less than 4+12=16 bytes in + // the sample, then we'll read past the end of the |sample| buffer. + if (sample.size() < 16) { + return false; + } + uint32_t counter = CRYPTO_load_u32_be(sample.data()); + Span nonce = sample.subspan(4); + CRYPTO_chacha_20(out.data(), zeroes.data(), zeroes.size(), key_, nonce.data(), + counter); + return true; +} + +#if defined(BORINGSSL_UNSAFE_FUZZER_MODE) +size_t NullRecordNumberEncrypter::KeySize() { return 0; } +bool NullRecordNumberEncrypter::SetKey(Span key) { return true; } + +bool NullRecordNumberEncrypter::GenerateMask(Span out, + Span sample) { + OPENSSL_memset(out.data(), 0, out.size()); + return true; +} +#endif // BORINGSSL_UNSAFE_FUZZER_MODE + BSSL_NAMESPACE_END diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index ce425a064b..1988ab2a0d 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -8,6 +8,7 @@ package runner import ( "bytes" + "crypto/aes" "crypto/cipher" "crypto/ecdsa" "crypto/subtle" @@ -19,6 +20,9 @@ import ( "net" "sync" "time" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/cryptobyte" ) // A Conn represents a secured connection. @@ -175,15 +179,16 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { type halfConn struct { sync.Mutex - err error // first permanent error - version uint16 // protocol version - wireVersion uint16 // wire version - isDTLS bool - cipher any // cipher algorithm - mac macFunction - seq [8]byte // 64-bit sequence number - outSeq [8]byte // Mapped sequence number - bfree *block // list of free blocks + err error // first permanent error + version uint16 // protocol version + wireVersion uint16 // wire version + isDTLS bool + cipher any // cipher algorithm + recordNumberEncrypter recordNumberEncrypter + mac macFunction + seq [8]byte // 64-bit sequence number + outSeq [8]byte // Mapped sequence number + bfree *block // list of free blocks nextCipher any // next encryption state nextMac macFunction // next MAC algorithm @@ -253,6 +258,17 @@ func (hc *halfConn) useTrafficSecret(version uint16, suite *cipherSuite, secret } hc.version = protocolVersion hc.cipher = deriveTrafficAEAD(version, suite, secret, side, hc.isDTLS) + if hc.isDTLS && !hc.config.Bugs.NullAllCiphers { + sn_key := hkdfExpandLabel(suite.hash(), secret, []byte("sn"), nil, suite.keyLen, hc.isDTLS) + switch suite.id { + case TLS_CHACHA20_POLY1305_SHA256: + hc.recordNumberEncrypter = newChachaRecordNumberEncrypter(sn_key) + case TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384: + hc.recordNumberEncrypter = newAESRecordNumberEncrypter(sn_key) + default: + panic("Cipher suite does not support TLS 1.3") + } + } if hc.config.Bugs.NullAllCiphers { hc.cipher = nullCipher{} } @@ -762,6 +778,63 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) { return b, bb } +type recordNumberEncrypter interface { + // GenerateMask takes a sample of the encrypted record and returns the + // mask used to encrypt and decrypt record numbers. + generateMask(sample []byte) []byte +} + +type aesRecordNumberEncrypter struct { + aesCipher cipher.Block +} + +func newAESRecordNumberEncrypter(key []byte) *aesRecordNumberEncrypter { + aesCipher, err := aes.NewCipher(key) + if err != nil { + panic("Incorrect usage of newAESRecordNumberEncrypter") + } + return &aesRecordNumberEncrypter{ + aesCipher: aesCipher, + } +} + +func (a *aesRecordNumberEncrypter) generateMask(sample []byte) []byte { + out := make([]byte, len(sample)) + a.aesCipher.Encrypt(out, sample) + return out +} + +type chachaRecordNumberEncrypter struct { + key []byte +} + +func newChachaRecordNumberEncrypter(key []byte) *chachaRecordNumberEncrypter { + out := &chachaRecordNumberEncrypter{ + key: key, + } + fmt.Printf("new RNE with key %x\n", key) + return out +} + +func (c *chachaRecordNumberEncrypter) generateMask(sample []byte) []byte { + var counter uint32 + nonce := make([]byte, 12) + sampleReader := cryptobyte.String(sample) + if !sampleReader.ReadUint32(&counter) || !sampleReader.CopyBytes(nonce) { + panic("chachaRecordNumberEncrypter.GenerateMask called with wrong size sample") + } + cipher, err := chacha20.NewUnauthenticatedCipher(c.key, nonce) + if err != nil { + panic("Failed to create chacha20 cipher for record number encryption") + } + cipher.SetCounter(counter) + zeroes := make([]byte, 2) + out := make([]byte, 2) + cipher.XORKeyStream(out, zeroes) + fmt.Printf("golang generateMask: sample: %x, key: %x, mask: %x\n", sample[:16], c.key, out) + return out +} + func (c *Conn) useInTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) error { if c.hand.Len() != 0 { return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change")) diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go index f4921d4e95..8c723f237d 100644 --- a/ssl/test/runner/dtls.go +++ b/ssl/test/runner/dtls.go @@ -55,7 +55,13 @@ func (c *Conn) readDTLS13RecordHeader(b *block) (headerLen int, recordLen int, r c.sendAlert(alertIllegalParameter) return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch")) } - wireSeq := binary.BigEndian.Uint16(b.data[1:3]) + wireSeq := b.data[1:3] + if !c.config.Bugs.NullAllCiphers { + sample := b.data[recordHeaderLen:] + mask := c.in.recordNumberEncrypter.generateMask(sample) + xorSlice(wireSeq, mask) + } + decWireSeq := binary.BigEndian.Uint16(wireSeq) // Reconstruct the sequence number from the low 16 bits on the wire. // A real implementation would compute the full sequence number that is // closest to the highest successfully decrypted record in the @@ -67,7 +73,7 @@ func (c *Conn) readDTLS13RecordHeader(b *block) (headerLen int, recordLen int, r seqInt := binary.BigEndian.Uint64(c.in.seq[:]) // c.in.seq has the epoch in the upper two bytes - clear those. seqInt = seqInt &^ (0xffff << 48) - newSeq := seqInt&^0xffff | uint64(wireSeq) + newSeq := seqInt&^0xffff | uint64(decWireSeq) if newSeq < seqInt { newSeq += 0x10000 } @@ -500,7 +506,10 @@ func (c *Conn) dtlsPackRecord(typ recordType, data []byte, mustPack bool) (n int } copy(b.data[recordHeaderLen+explicitIVLen:], data) recordLen := c.addTLS13Padding(b, recordHeaderLen, len(data), typ) - if c.out.version < VersionTLS13 || c.out.cipher == nil || (c.config.Bugs.DTLSUsePlaintextRecordHeader && c.handshakeComplete) { + useDTLS13RecordHeader := c.out.version >= VersionTLS13 && c.out.cipher != nil && !(c.config.Bugs.DTLSUsePlaintextRecordHeader && c.handshakeComplete) + if useDTLS13RecordHeader { + c.writeDTLS13RecordHeader(b, recordLen) + } else { b.data[0] = byte(typ) b.data[1] = byte(vers >> 8) b.data[2] = byte(vers) @@ -508,10 +517,26 @@ func (c *Conn) dtlsPackRecord(typ recordType, data []byte, mustPack bool) (n int copy(b.data[3:11], c.out.outSeq[0:]) b.data[11] = byte(recordLen >> 8) b.data[12] = byte(recordLen) - } else { - c.writeDTLS13RecordHeader(b, recordLen) } + // encrypt will increment the sequence number. Copy it here to use when + // performing sequence number encryption. + seqBytes := make([]byte, 2) + copy(seqBytes, c.out.outSeq[6:8]) c.out.encrypt(b, explicitIVLen, typ) + if useDTLS13RecordHeader && !c.config.Bugs.NullAllCiphers { + recordHeaderLen := c.out.writeRecordHeaderLen() + sample := b.data[recordHeaderLen:] + mask := c.out.recordNumberEncrypter.generateMask(sample) + if c.config.DTLSUseShortSeqNums { + seqBytes = seqBytes[1:2] + } + xorSlice(seqBytes, mask) + for i := range seqBytes { + // The sequence number starts at index 1 in the record + // header. + b.data[1+i] = seqBytes[i] + } + } // Flush the current pending packet if necessary. if !mustPack && len(b.data)+len(c.pendingPacket) > c.config.Bugs.PackHandshakeRecords { diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc index 1613a3a158..7c193a3757 100644 --- a/ssl/tls13_enc.cc +++ b/ssl/tls13_enc.cc @@ -184,6 +184,8 @@ bool tls13_set_traffic_key(SSL *ssl, enum ssl_encryption_level_t level, const SSL_SESSION *session, Span traffic_secret) { uint16_t version = ssl_session_protocol_version(session); + const EVP_MD *digest = ssl_session_get_digest(session); + bool is_dtls = SSL_is_dtls(ssl); UniquePtr traffic_aead; Span secret_for_quic; if (ssl->quic_method != nullptr) { @@ -197,18 +199,16 @@ bool tls13_set_traffic_key(SSL *ssl, enum ssl_encryption_level_t level, const EVP_AEAD *aead; size_t discard; if (!ssl_cipher_get_evp_aead(&aead, &discard, &discard, session->cipher, - version, SSL_is_dtls(ssl))) { + version, is_dtls)) { return false; } - const EVP_MD *digest = ssl_session_get_digest(session); - // Derive the key. size_t key_len = EVP_AEAD_key_length(aead); uint8_t key_buf[EVP_AEAD_MAX_KEY_LENGTH]; auto key = MakeSpan(key_buf, key_len); if (!hkdf_expand_label(key, digest, traffic_secret, label_to_span("key"), - {}, SSL_is_dtls(ssl))) { + {}, is_dtls)) { return false; } @@ -217,19 +217,34 @@ bool tls13_set_traffic_key(SSL *ssl, enum ssl_encryption_level_t level, uint8_t iv_buf[EVP_AEAD_MAX_NONCE_LENGTH]; auto iv = MakeSpan(iv_buf, iv_len); if (!hkdf_expand_label(iv, digest, traffic_secret, label_to_span("iv"), {}, - SSL_is_dtls(ssl))) { + is_dtls)) { return false; } - traffic_aead = SSLAEADContext::Create(direction, session->ssl_version, - SSL_is_dtls(ssl), session->cipher, - key, Span(), iv); + traffic_aead = + SSLAEADContext::Create(direction, session->ssl_version, is_dtls, + session->cipher, key, Span(), iv); } if (!traffic_aead) { return false; } + if (is_dtls) { + RecordNumberEncrypter *rn_encrypter = + traffic_aead->GetRecordNumberEncrypter(); + if (!rn_encrypter) { + return false; + } + Array rne_key; + if (!rne_key.Init(rn_encrypter->KeySize()) || + !hkdf_expand_label(MakeSpan(rne_key), digest, traffic_secret, + label_to_span("sn"), {}, is_dtls) || + !rn_encrypter->SetKey(MakeSpan(rne_key))) { + return false; + } + } + if (traffic_secret.size() > OPENSSL_ARRAY_SIZE(ssl->s3->read_traffic_secret) || traffic_secret.size() >