From 387e5094703221de9120fb91b7908461822d434b Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 9 Aug 2023 22:29:55 -0400 Subject: [PATCH] Convert AESSIV AEAD to Rust (#9359) --- .../hazmat/backends/openssl/aead.py | 64 +----- .../bindings/_rust/openssl/__init__.pyi | 2 + .../hazmat/bindings/_rust/openssl/aead.pyi | 20 ++ .../hazmat/primitives/ciphers/aead.py | 86 +------ src/rust/src/backend/aead.rs | 210 ++++++++++++++++++ src/rust/src/backend/mod.rs | 2 + src/rust/src/exceptions.rs | 1 + 7 files changed, 258 insertions(+), 127 deletions(-) create mode 100644 src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi create mode 100644 src/rust/src/backend/aead.rs diff --git a/src/cryptography/hazmat/backends/openssl/aead.py b/src/cryptography/hazmat/backends/openssl/aead.py index b36f535f3f8f..b7fef7a52634 100644 --- a/src/cryptography/hazmat/backends/openssl/aead.py +++ b/src/cryptography/hazmat/backends/openssl/aead.py @@ -14,13 +14,10 @@ AESCCM, AESGCM, AESOCB3, - AESSIV, ChaCha20Poly1305, ) - _AEADTypes = typing.Union[ - AESCCM, AESGCM, AESOCB3, AESSIV, ChaCha20Poly1305 - ] + _AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3, ChaCha20Poly1305] def _is_evp_aead_supported_cipher( @@ -44,16 +41,9 @@ def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool: cipher_name = _evp_cipher_cipher_name(cipher) if backend._fips_enabled and cipher_name not in backend._fips_aead: return False - # SIV isn't loaded through get_cipherbyname but instead a new fetch API - # only available in 3.0+. But if we know we're on 3.0+ then we know - # it's supported. - if cipher_name.endswith(b"-siv"): - return backend._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER == 1 - else: - return ( - backend._lib.EVP_get_cipherbyname(cipher_name) - != backend._ffi.NULL - ) + return ( + backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL + ) def _aead_create_ctx( @@ -231,7 +221,6 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: AESCCM, AESGCM, AESOCB3, - AESSIV, ChaCha20Poly1305, ) @@ -241,26 +230,14 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") elif isinstance(cipher, AESOCB3): return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii") - elif isinstance(cipher, AESSIV): - return f"aes-{len(cipher._key) * 8 // 2}-siv".encode("ascii") else: assert isinstance(cipher, AESGCM) return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii") def _evp_cipher(cipher_name: bytes, backend: Backend): - if cipher_name.endswith(b"-siv"): - evp_cipher = backend._lib.EVP_CIPHER_fetch( - backend._ffi.NULL, - cipher_name, - backend._ffi.NULL, - ) - backend.openssl_assert(evp_cipher != backend._ffi.NULL) - evp_cipher = backend._ffi.gc(evp_cipher, backend._lib.EVP_CIPHER_free) - else: - evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name) - backend.openssl_assert(evp_cipher != backend._ffi.NULL) - + evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name) + backend.openssl_assert(evp_cipher != backend._ffi.NULL) return evp_cipher @@ -389,10 +366,7 @@ def _evp_cipher_process_data(backend: Backend, ctx, data: bytes) -> bytes: buf = backend._ffi.new("unsigned char[]", len(data)) data_ptr = backend._ffi.from_buffer(data) res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data_ptr, len(data)) - if res == 0: - # AES SIV can error here if the data is invalid on decrypt - backend._consume_errors() - raise InvalidTag + backend.openssl_assert(res != 0) return backend._ffi.buffer(buf, outlen[0])[:] @@ -405,7 +379,7 @@ def _evp_cipher_encrypt( tag_length: int, ctx: typing.Any = None, ) -> bytes: - from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV + from cryptography.hazmat.primitives.ciphers.aead import AESCCM if ctx is None: cipher_name = _evp_cipher_cipher_name(cipher) @@ -445,14 +419,7 @@ def _evp_cipher_encrypt( backend.openssl_assert(res != 0) tag = backend._ffi.buffer(tag_buf)[:] - if isinstance(cipher, AESSIV): - # RFC 5297 defines the output as IV || C, where the tag we generate - # is the "IV" and C is the ciphertext. This is the opposite of our - # other AEADs, which are Ciphertext || Tag - backend.openssl_assert(len(tag) == 16) - return tag + processed_data - else: - return processed_data + tag + return processed_data + tag def _evp_cipher_decrypt( @@ -464,20 +431,13 @@ def _evp_cipher_decrypt( tag_length: int, ctx: typing.Any = None, ) -> bytes: - from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV + from cryptography.hazmat.primitives.ciphers.aead import AESCCM if len(data) < tag_length: raise InvalidTag - if isinstance(cipher, AESSIV): - # RFC 5297 defines the output as IV || C, where the tag we generate - # is the "IV" and C is the ciphertext. This is the opposite of our - # other AEADs, which are Ciphertext || Tag - tag = data[:tag_length] - data = data[tag_length:] - else: - tag = data[-tag_length:] - data = data[:-tag_length] + tag = data[-tag_length:] + data = data[:-tag_length] if ctx is None: cipher_name = _evp_cipher_cipher_name(cipher) ctx = _evp_cipher_aead_setup( diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index d0e6ccaed238..1784c5ade9cd 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -5,6 +5,7 @@ import typing from cryptography.hazmat.bindings._rust.openssl import ( + aead, dh, dsa, ec, @@ -21,6 +22,7 @@ from cryptography.hazmat.bindings._rust.openssl import ( __all__ = [ "openssl_version", "raise_openssl_error", + "aead", "dh", "dsa", "ec", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi new file mode 100644 index 000000000000..57cf92ce5e75 --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -0,0 +1,20 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +class AESSIV: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + nonce: bytes, + associated_data: typing.Optional[typing.List[bytes]], + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + associated_data: typing.Optional[typing.List[bytes]], + ) -> bytes: ... diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py index 957b2d221b62..944060c0b3dd 100644 --- a/src/cryptography/hazmat/primitives/ciphers/aead.py +++ b/src/cryptography/hazmat/primitives/ciphers/aead.py @@ -11,6 +11,17 @@ from cryptography.hazmat.backends.openssl import aead from cryptography.hazmat.backends.openssl.backend import backend from cryptography.hazmat.bindings._rust import FixedPool +from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +__all__ = [ + "ChaCha20Poly1305", + "AESCCM", + "AESGCM", + "AESOCB3", + "AESSIV", +] + +AESSIV = rust_openssl.aead.AESSIV class ChaCha20Poly1305: @@ -301,78 +312,3 @@ def _check_params( utils._check_byteslike("associated_data", associated_data) if len(nonce) < 12 or len(nonce) > 15: raise ValueError("Nonce must be between 12 and 15 bytes") - - -class AESSIV: - _MAX_SIZE = 2**31 - 1 - - def __init__(self, key: bytes): - utils._check_byteslike("key", key) - if len(key) not in (32, 48, 64): - raise ValueError("AESSIV key must be 256, 384, or 512 bits.") - - self._key = key - - if not backend.aead_cipher_supported(self): - raise exceptions.UnsupportedAlgorithm( - "AES-SIV is not supported by this version of OpenSSL", - exceptions._Reasons.UNSUPPORTED_CIPHER, - ) - - @classmethod - def generate_key(cls, bit_length: int) -> bytes: - if not isinstance(bit_length, int): - raise TypeError("bit_length must be an integer") - - if bit_length not in (256, 384, 512): - raise ValueError("bit_length must be 256, 384, or 512") - - return os.urandom(bit_length // 8) - - def encrypt( - self, - data: bytes, - associated_data: typing.Optional[typing.List[bytes]], - ) -> bytes: - if associated_data is None: - associated_data = [] - - self._check_params(data, associated_data) - - if len(data) > self._MAX_SIZE or any( - len(ad) > self._MAX_SIZE for ad in associated_data - ): - # This is OverflowError to match what cffi would raise - raise OverflowError( - "Data or associated data too long. Max 2**31 - 1 bytes" - ) - - return aead._encrypt(backend, self, b"", data, associated_data, 16) - - def decrypt( - self, - data: bytes, - associated_data: typing.Optional[typing.List[bytes]], - ) -> bytes: - if associated_data is None: - associated_data = [] - - self._check_params(data, associated_data) - - return aead._decrypt(backend, self, b"", data, associated_data, 16) - - def _check_params( - self, - data: bytes, - associated_data: typing.List[bytes], - ) -> None: - utils._check_byteslike("data", data) - if len(data) == 0: - raise ValueError("data must not be zero length") - - if not isinstance(associated_data, list): - raise TypeError( - "associated_data must be a list of bytes-like objects or None" - ) - for x in associated_data: - utils._check_byteslike("associated_data elements", x) diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs new file mode 100644 index 000000000000..8f9c4829090e --- /dev/null +++ b/src/rust/src/backend/aead.rs @@ -0,0 +1,210 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::buf::CffiBuf; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; + +#[pyo3::prelude::pyclass( + frozen, + module = "cryptography.hazmat.bindings._rust.openssl.aead", + name = "AESSIV" +)] +struct AesSiv { + key: pyo3::Py, + cipher: openssl::cipher::Cipher, +} + +#[pyo3::prelude::pymethods] +impl AesSiv { + #[new] + fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + let key_buf = key.extract::>(py)?; + let cipher_name = match key_buf.as_bytes().len() { + 32 => "aes-128-siv", + 48 => "aes-192-siv", + 64 => "aes-256-siv", + _ => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "AESSIV key must be 256, 384, or 512 bits.", + ), + )) + } + }; + + #[cfg(not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER))] + { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "AES-SIV is not supported by this version of OpenSSL", + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )); + } + #[cfg(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER)] + { + if cryptography_openssl::fips::is_enabled() { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "AES-SIV is not supported by this version of OpenSSL", + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )); + } + + let cipher = openssl::cipher::Cipher::fetch(None, cipher_name, None)?; + Ok(AesSiv { key, cipher }) + } + } + + #[staticmethod] + fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + if bit_length != 256 && bit_length != 384 && bit_length != 512 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("bit_length must be 256, 384, or 512"), + )); + } + + Ok(py + .import(pyo3::intern!(py, "os"))? + .call_method1(pyo3::intern!(py, "urandom"), (bit_length / 8,))?) + } + + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + data: CffiBuf<'_>, + associated_data: Option<&pyo3::types::PyList>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_buf = self.key.extract::>(py)?; + let data_bytes = data.as_bytes(); + + if data_bytes.is_empty() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("data must not be zero length"), + )); + } else if data_bytes.len() > (i32::MAX as usize) { + // This is OverflowError to match what cffi would raise + return Err(CryptographyError::from( + pyo3::exceptions::PyOverflowError::new_err( + "Data or associated data too long. Max 2**31 - 1 bytes", + ), + )); + } + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.encrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?; + + if let Some(ads) = associated_data { + for ad in ads.iter() { + let ad = ad.extract::>()?; + if ad.as_bytes().len() > (i32::MAX as usize) { + // This is OverflowError to match what cffi would raise + return Err(CryptographyError::from( + pyo3::exceptions::PyOverflowError::new_err( + "Data or associated data too long. Max 2**31 - 1 bytes", + ), + )); + } + + ctx.cipher_update(ad.as_bytes(), None)?; + } + } + + Ok(pyo3::types::PyBytes::new_with( + py, + data_bytes.len() + 16, + |b| { + // RFC 5297 defines the output as IV || C, where the tag we + // generate is the "IV" and C is the ciphertext. This is the + // opposite of our other AEADs, which are Ciphertext || Tag. + let (tag, ciphertext) = b.split_at_mut(16); + + let n = ctx + .cipher_update(data_bytes, Some(ciphertext)) + .map_err(CryptographyError::from)?; + assert_eq!(n, ciphertext.len()); + + let mut final_block = [0]; + let n = ctx + .cipher_final(&mut final_block) + .map_err(CryptographyError::from)?; + assert_eq!(n, 0); + + ctx.tag(tag).map_err(CryptographyError::from)?; + + Ok(()) + }, + )?) + } + + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + data: CffiBuf<'_>, + associated_data: Option<&pyo3::types::PyList>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_buf = self.key.extract::>(py)?; + let data_bytes = data.as_bytes(); + + if data_bytes.is_empty() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("data must not be zero length"), + )); + } + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.decrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?; + + if data_bytes.len() < 16 { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + // RFC 5297 defines the output as IV || C, where the tag we generate + // is the "IV" and C is the ciphertext. This is the opposite of our + // other AEADs, which are Ciphertext || Tag. + let (tag, ciphertext) = data_bytes.split_at(16); + ctx.set_tag(tag)?; + + if let Some(ads) = associated_data { + for ad in ads.iter() { + let ad = ad.extract::>()?; + if ad.as_bytes().len() > (i32::MAX as usize) { + // This is OverflowError to match what cffi would raise + return Err(CryptographyError::from( + pyo3::exceptions::PyOverflowError::new_err( + "Data or associated data too long. Max 2**31 - 1 bytes", + ), + )); + } + + ctx.cipher_update(ad.as_bytes(), None)?; + } + } + + Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| { + // AES SIV can error here if the data is invalid on decrypt + let n = ctx + .cipher_update(ciphertext, Some(b)) + .map_err(|_| exceptions::InvalidTag::new_err(()))?; + assert_eq!(n, b.len()); + + let mut final_block = [0]; + let n = ctx + .cipher_final(&mut final_block) + .map_err(|_| exceptions::InvalidTag::new_err(()))?; + assert_eq!(n, 0); + + Ok(()) + })?) + } +} + +pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { + let m = pyo3::prelude::PyModule::new(py, "aead")?; + + m.add_class::()?; + + Ok(m) +} diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index b032aaac4404..717a09af8ad4 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -2,6 +2,7 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. +pub(crate) mod aead; pub(crate) mod dh; pub(crate) mod dsa; pub(crate) mod ec; @@ -20,6 +21,7 @@ pub(crate) mod x25519; pub(crate) mod x448; pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<()> { + module.add_submodule(aead::create_module(module.py())?)?; module.add_submodule(dh::create_module(module.py())?)?; module.add_submodule(dsa::create_module(module.py())?)?; module.add_submodule(ec::create_module(module.py())?)?; diff --git a/src/rust/src/exceptions.rs b/src/rust/src/exceptions.rs index e3feb38d1d8c..c9456513993d 100644 --- a/src/rust/src/exceptions.rs +++ b/src/rust/src/exceptions.rs @@ -26,6 +26,7 @@ pub(crate) enum Reasons { pyo3::import_exception!(cryptography.exceptions, AlreadyFinalized); pyo3::import_exception!(cryptography.exceptions, InternalError); pyo3::import_exception!(cryptography.exceptions, InvalidSignature); +pyo3::import_exception!(cryptography.exceptions, InvalidTag); pyo3::import_exception!(cryptography.exceptions, UnsupportedAlgorithm); pyo3::import_exception!(cryptography.x509, AttributeNotFound); pyo3::import_exception!(cryptography.x509, DuplicateExtension);