diff --git a/src/cryptography/hazmat/backends/openssl/aead.py b/src/cryptography/hazmat/backends/openssl/aead.py index f1d990106474..dd2485481203 100644 --- a/src/cryptography/hazmat/backends/openssl/aead.py +++ b/src/cryptography/hazmat/backends/openssl/aead.py @@ -12,10 +12,9 @@ from cryptography.hazmat.backends.openssl.backend import Backend from cryptography.hazmat.primitives.ciphers.aead import ( AESCCM, - AESGCM, ) - _AEADTypes = typing.Union[AESCCM, AESGCM] + _AEADTypes = typing.Union[AESCCM] def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool: @@ -55,16 +54,10 @@ def _decrypt( def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: - from cryptography.hazmat.primitives.ciphers.aead import ( - AESCCM, - AESGCM, - ) + from cryptography.hazmat.primitives.ciphers.aead import AESCCM - if isinstance(cipher, AESCCM): - return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") - else: - assert isinstance(cipher, AESGCM) - return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii") + assert isinstance(cipher, AESCCM) + return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") def _evp_cipher(cipher_name: bytes, backend: Backend): @@ -105,7 +98,8 @@ def _evp_cipher_aead_setup( if operation == _DECRYPT: assert tag is not None _evp_cipher_set_tag(backend, ctx, tag) - elif cipher_name.endswith(b"-ccm"): + else: + assert cipher_name.endswith(b"-ccm") res = backend._lib.EVP_CIPHER_CTX_ctrl( ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, @@ -188,8 +182,8 @@ def _evp_cipher_encrypt( # CCM requires us to pass the length of the data before processing # anything. # However calling this with any other AEAD results in an error - if isinstance(cipher, AESCCM): - _evp_cipher_set_length(backend, ctx, len(data)) + assert isinstance(cipher, AESCCM) + _evp_cipher_set_length(backend, ctx, len(data)) for ad in associated_data: _evp_cipher_process_aad(backend, ctx, ad) @@ -241,32 +235,21 @@ def _evp_cipher_decrypt( # CCM requires us to pass the length of the data before processing # anything. # However calling this with any other AEAD results in an error - if isinstance(cipher, AESCCM): - _evp_cipher_set_length(backend, ctx, len(data)) + assert isinstance(cipher, AESCCM) + _evp_cipher_set_length(backend, ctx, len(data)) for ad in associated_data: _evp_cipher_process_aad(backend, ctx, ad) # CCM has a different error path if the tag doesn't match. Errors are # raised in Update and Final is irrelevant. - if isinstance(cipher, AESCCM): - outlen = backend._ffi.new("int *") - buf = backend._ffi.new("unsigned char[]", len(data)) - d_ptr = backend._ffi.from_buffer(data) - res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, d_ptr, len(data)) - if res != 1: - backend._consume_errors() - raise InvalidTag - - processed_data = backend._ffi.buffer(buf, outlen[0])[:] - else: - processed_data = _evp_cipher_process_data(backend, ctx, data) - outlen = backend._ffi.new("int *") - # OCB can return up to 15 bytes (16 byte block - 1) in finalization - buf = backend._ffi.new("unsigned char[]", 16) - res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen) - processed_data += backend._ffi.buffer(buf, outlen[0])[:] - if res == 0: - backend._consume_errors() - raise InvalidTag + outlen = backend._ffi.new("int *") + buf = backend._ffi.new("unsigned char[]", len(data)) + d_ptr = backend._ffi.from_buffer(data) + res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, d_ptr, len(data)) + if res != 1: + backend._consume_errors() + raise InvalidTag + + processed_data = backend._ffi.buffer(buf, outlen[0])[:] return processed_data diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi index 81e801e30bb5..e274073f201e 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -2,6 +2,23 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. +class AESGCM: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + class ChaCha20Poly1305: def __init__(self, key: bytes) -> None: ... @staticmethod diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py index 40f1b9b74459..e96b735b18f9 100644 --- a/src/cryptography/hazmat/primitives/ciphers/aead.py +++ b/src/cryptography/hazmat/primitives/ciphers/aead.py @@ -20,6 +20,7 @@ "AESSIV", ] +AESGCM = rust_openssl.aead.AESGCM ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305 AESSIV = rust_openssl.aead.AESSIV AESOCB3 = rust_openssl.aead.AESOCB3 @@ -109,66 +110,3 @@ def _check_params( utils._check_byteslike("associated_data", associated_data) if not 7 <= len(nonce) <= 13: raise ValueError("Nonce must be between 7 and 13 bytes") - - -class AESGCM: - _MAX_SIZE = 2**31 - 1 - - def __init__(self, key: bytes): - utils._check_byteslike("key", key) - if len(key) not in (16, 24, 32): - raise ValueError("AESGCM key must be 128, 192, or 256 bits.") - - self._key = key - - @classmethod - def generate_key(cls, bit_length: int) -> bytes: - if not isinstance(bit_length, int): - raise TypeError("bit_length must be an integer") - - if bit_length not in (128, 192, 256): - raise ValueError("bit_length must be 128, 192, or 256") - - return os.urandom(bit_length // 8) - - def encrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: - # This is OverflowError to match what cffi would raise - raise OverflowError( - "Data or associated data too long. Max 2**31 - 1 bytes" - ) - - self._check_params(nonce, data, associated_data) - return aead._encrypt(backend, self, nonce, data, [associated_data], 16) - - def decrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - self._check_params(nonce, data, associated_data) - return aead._decrypt(backend, self, nonce, data, [associated_data], 16) - - def _check_params( - self, - nonce: bytes, - data: bytes, - associated_data: bytes, - ) -> None: - utils._check_byteslike("nonce", nonce) - utils._check_byteslike("data", data) - utils._check_byteslike("associated_data", associated_data) - if len(nonce) < 8 or len(nonce) > 128: - raise ValueError("Nonce must be between 8 and 128 bytes") diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index 9fd8a91ceeaf..b13a420c7588 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -486,9 +486,7 @@ impl ChaCha20Poly1305 { #[staticmethod] fn generate_key(py: pyo3::Python<'_>) -> CryptographyResult<&pyo3::PyAny> { - Ok(py - .import(pyo3::intern!(py, "os"))? - .call_method1(pyo3::intern!(py, "urandom"), (32,))?) + Ok(types::OS_URANDOM.get(py)?.call1((32,))?) } fn encrypt<'p>( @@ -532,6 +530,118 @@ impl ChaCha20Poly1305 { } } +#[pyo3::prelude::pyclass( + frozen, + module = "cryptography.hazmat.bindings._rust.openssl.aead", + name = "AESGCM" +)] +struct AesGcm { + #[cfg(any( + CRYPTOGRAPHY_OPENSSL_320_OR_GREATER, + CRYPTOGRAPHY_IS_LIBRESSL, + CRYPTOGRAPHY_IS_BORINGSSL, + not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER), + ))] + ctx: EvpCipherAead, + + #[cfg(not(any( + CRYPTOGRAPHY_OPENSSL_320_OR_GREATER, + CRYPTOGRAPHY_IS_LIBRESSL, + CRYPTOGRAPHY_IS_BORINGSSL, + not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER), + )))] + ctx: LazyEvpCipherAead, +} + +#[pyo3::prelude::pymethods] +impl AesGcm { + #[new] + fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + let key_buf = key.extract::>(py)?; + let cipher = match key_buf.as_bytes().len() { + 16 => openssl::cipher::Cipher::aes_128_gcm(), + 24 => openssl::cipher::Cipher::aes_192_gcm(), + 32 => openssl::cipher::Cipher::aes_256_gcm(), + _ => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "AESGCM key must be 128, 192, or 256 bits.", + ), + )) + } + }; + + cfg_if::cfg_if! { + if #[cfg(any( + CRYPTOGRAPHY_OPENSSL_320_OR_GREATER, + CRYPTOGRAPHY_IS_BORINGSSL, + CRYPTOGRAPHY_IS_LIBRESSL, + not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER, + )))] { + Ok(AesGcm { + ctx: EvpCipherAead::new(cipher, key_buf.as_bytes(), 16, false)?, + }) + } else { + Ok(AesGcm { + ctx: LazyEvpCipherAead::new(cipher, key, 16, false), + }) + + } + } + } + + #[staticmethod] + fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + if bit_length != 128 && bit_length != 192 && bit_length != 256 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), + )); + } + + Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + } + + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let nonce_bytes = nonce.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() < 8 || nonce_bytes.len() > 128 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be between 8 and 128 bytes"), + )); + } + + self.ctx + .encrypt(py, data.as_bytes(), aad, Some(nonce_bytes)) + } + + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let nonce_bytes = nonce.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() < 8 || nonce_bytes.len() > 128 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be between 8 and 128 bytes"), + )); + } + + self.ctx + .decrypt(py, data.as_bytes(), aad, Some(nonce_bytes)) + } +} + #[pyo3::prelude::pyclass( frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead", @@ -845,6 +955,7 @@ impl AesGcmSiv { pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { let m = pyo3::prelude::PyModule::new(py, "aead")?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tests/hazmat/primitives/test_aead.py b/tests/hazmat/primitives/test_aead.py index a4624cefc555..5228edbbd2d3 100644 --- a/tests/hazmat/primitives/test_aead.py +++ b/tests/hazmat/primitives/test_aead.py @@ -451,6 +451,8 @@ def test_invalid_nonce_length(self, length, backend): aesgcm = AESGCM(key) with pytest.raises(ValueError): aesgcm.encrypt(b"\x00" * length, b"hi", None) + with pytest.raises(ValueError): + aesgcm.decrypt(b"\x00" * length, b"hi", None) def test_bad_key(self, backend): with pytest.raises(TypeError):