diff --git a/src/cryptography/hazmat/backends/openssl/aead.py b/src/cryptography/hazmat/backends/openssl/aead.py index b7fef7a526349..5674cb1932cae 100644 --- a/src/cryptography/hazmat/backends/openssl/aead.py +++ b/src/cryptography/hazmat/backends/openssl/aead.py @@ -14,36 +14,16 @@ AESCCM, AESGCM, AESOCB3, - ChaCha20Poly1305, ) - _AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3, ChaCha20Poly1305] - - -def _is_evp_aead_supported_cipher( - backend: Backend, cipher: _AEADTypes -) -> bool: - """ - Checks whether the given cipher is supported through - EVP_AEAD rather than the normal OpenSSL EVP_CIPHER API. - """ - from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 - - return backend._lib.Cryptography_HAS_EVP_AEAD and isinstance( - cipher, ChaCha20Poly1305 - ) + _AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3] def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool: - if _is_evp_aead_supported_cipher(backend, cipher): - return True - else: - cipher_name = _evp_cipher_cipher_name(cipher) - if backend._fips_enabled and cipher_name not in backend._fips_aead: - return False - return ( - backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL - ) + cipher_name = _evp_cipher_cipher_name(cipher) + if backend._fips_enabled and cipher_name not in backend._fips_aead: + return False + return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL def _aead_create_ctx( @@ -51,10 +31,7 @@ def _aead_create_ctx( cipher: _AEADTypes, key: bytes, ): - if _is_evp_aead_supported_cipher(backend, cipher): - return _evp_aead_create_ctx(backend, cipher, key) - else: - return _evp_cipher_create_ctx(backend, cipher, key) + return _evp_cipher_create_ctx(backend, cipher, key) def _encrypt( @@ -66,14 +43,9 @@ def _encrypt( tag_length: int, ctx: typing.Any = None, ) -> bytes: - if _is_evp_aead_supported_cipher(backend, cipher): - return _evp_aead_encrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) - else: - return _evp_cipher_encrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) + return _evp_cipher_encrypt( + backend, cipher, nonce, data, associated_data, tag_length, ctx + ) def _decrypt( @@ -85,132 +57,10 @@ def _decrypt( tag_length: int, ctx: typing.Any = None, ) -> bytes: - if _is_evp_aead_supported_cipher(backend, cipher): - return _evp_aead_decrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) - else: - return _evp_cipher_decrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) - - -def _evp_aead_create_ctx( - backend: Backend, - cipher: _AEADTypes, - key: bytes, - tag_len: typing.Optional[int] = None, -): - aead_cipher = _evp_aead_get_cipher(backend, cipher) - assert aead_cipher is not None - key_ptr = backend._ffi.from_buffer(key) - tag_len = ( - backend._lib.EVP_AEAD_DEFAULT_TAG_LENGTH - if tag_len is None - else tag_len - ) - ctx = backend._lib.Cryptography_EVP_AEAD_CTX_new( - aead_cipher, key_ptr, len(key), tag_len - ) - backend.openssl_assert(ctx != backend._ffi.NULL) - ctx = backend._ffi.gc(ctx, backend._lib.EVP_AEAD_CTX_free) - return ctx - - -def _evp_aead_get_cipher(backend: Backend, cipher: _AEADTypes): - from cryptography.hazmat.primitives.ciphers.aead import ( - ChaCha20Poly1305, + return _evp_cipher_decrypt( + backend, cipher, nonce, data, associated_data, tag_length, ctx ) - # Currently only ChaCha20-Poly1305 is supported using this API - assert isinstance(cipher, ChaCha20Poly1305) - return backend._lib.EVP_aead_chacha20_poly1305() - - -def _evp_aead_encrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: typing.List[bytes], - tag_length: int, - ctx: typing.Any, -) -> bytes: - assert ctx is not None - - aead_cipher = _evp_aead_get_cipher(backend, cipher) - assert aead_cipher is not None - - out_len = backend._ffi.new("size_t *") - # max_out_len should be in_len plus the result of - # EVP_AEAD_max_overhead. - max_out_len = len(data) + backend._lib.EVP_AEAD_max_overhead(aead_cipher) - out_buf = backend._ffi.new("uint8_t[]", max_out_len) - data_ptr = backend._ffi.from_buffer(data) - nonce_ptr = backend._ffi.from_buffer(nonce) - aad = b"".join(associated_data) - aad_ptr = backend._ffi.from_buffer(aad) - - res = backend._lib.EVP_AEAD_CTX_seal( - ctx, - out_buf, - out_len, - max_out_len, - nonce_ptr, - len(nonce), - data_ptr, - len(data), - aad_ptr, - len(aad), - ) - backend.openssl_assert(res == 1) - encrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:] - return encrypted_data - - -def _evp_aead_decrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: typing.List[bytes], - tag_length: int, - ctx: typing.Any, -) -> bytes: - if len(data) < tag_length: - raise InvalidTag - - assert ctx is not None - - out_len = backend._ffi.new("size_t *") - # max_out_len should at least in_len - max_out_len = len(data) - out_buf = backend._ffi.new("uint8_t[]", max_out_len) - data_ptr = backend._ffi.from_buffer(data) - nonce_ptr = backend._ffi.from_buffer(nonce) - aad = b"".join(associated_data) - aad_ptr = backend._ffi.from_buffer(aad) - - res = backend._lib.EVP_AEAD_CTX_open( - ctx, - out_buf, - out_len, - max_out_len, - nonce_ptr, - len(nonce), - data_ptr, - len(data), - aad_ptr, - len(aad), - ) - - if res == 0: - backend._consume_errors() - raise InvalidTag - - decrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:] - return decrypted_data - _ENCRYPT = 1 _DECRYPT = 0 @@ -221,12 +71,9 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: AESCCM, AESGCM, AESOCB3, - ChaCha20Poly1305, ) - if isinstance(cipher, ChaCha20Poly1305): - return b"chacha20-poly1305" - elif isinstance(cipher, AESCCM): + if isinstance(cipher, AESCCM): return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") elif isinstance(cipher, AESOCB3): return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii") diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi index a3f722cde86a5..5594b77cd203f 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -4,6 +4,23 @@ import typing +class ChaCha20Poly1305: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key() -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: typing.Optional[bytes], + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: typing.Optional[bytes], + ) -> bytes: ... + class AESSIV: def __init__(self, key: bytes) -> None: ... @staticmethod diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py index 944060c0b3dd4..e68e13601db32 100644 --- a/src/cryptography/hazmat/primitives/ciphers/aead.py +++ b/src/cryptography/hazmat/primitives/ciphers/aead.py @@ -10,7 +10,6 @@ from cryptography import exceptions, utils from cryptography.hazmat.backends.openssl import aead from cryptography.hazmat.backends.openssl.backend import backend -from cryptography.hazmat.bindings._rust import FixedPool from cryptography.hazmat.bindings._rust import openssl as rust_openssl __all__ = [ @@ -21,82 +20,10 @@ "AESSIV", ] +ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305 AESSIV = rust_openssl.aead.AESSIV -class ChaCha20Poly1305: - _MAX_SIZE = 2**31 - 1 - - def __init__(self, key: bytes): - if not backend.aead_cipher_supported(self): - raise exceptions.UnsupportedAlgorithm( - "ChaCha20Poly1305 is not supported by this version of OpenSSL", - exceptions._Reasons.UNSUPPORTED_CIPHER, - ) - utils._check_byteslike("key", key) - - if len(key) != 32: - raise ValueError("ChaCha20Poly1305 key must be 32 bytes.") - - self._key = key - self._pool = FixedPool(self._create_fn) - - @classmethod - def generate_key(cls) -> bytes: - return os.urandom(32) - - def _create_fn(self): - return aead._aead_create_ctx(backend, self, self._key) - - def encrypt( - self, - nonce: bytes, - data: bytes, - associated_data: typing.Optional[bytes], - ) -> bytes: - if associated_data is None: - associated_data = b"" - - if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: - # This is OverflowError to match what cffi would raise - raise OverflowError( - "Data or associated data too long. Max 2**31 - 1 bytes" - ) - - self._check_params(nonce, data, associated_data) - with self._pool.acquire() as ctx: - return aead._encrypt( - backend, self, nonce, data, [associated_data], 16, ctx - ) - - def decrypt( - self, - nonce: bytes, - data: bytes, - associated_data: typing.Optional[bytes], - ) -> bytes: - if associated_data is None: - associated_data = b"" - - self._check_params(nonce, data, associated_data) - with self._pool.acquire() as ctx: - return aead._decrypt( - backend, self, nonce, data, [associated_data], 16, ctx - ) - - def _check_params( - self, - nonce: bytes, - data: bytes, - associated_data: bytes, - ) -> None: - utils._check_byteslike("nonce", nonce) - utils._check_byteslike("data", data) - utils._check_byteslike("associated_data", associated_data) - if len(nonce) != 12: - raise ValueError("Nonce must be 12 bytes") - - class AESCCM: _MAX_SIZE = 2**31 - 1 diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index 2a6641afa371d..d6a63ec52b880 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -32,9 +32,11 @@ fn encrypt_value<'p>( |b| { let ciphertext; let tag; - // TODO: remove once we have a second AEAD implemented here. - assert!(tag_first); - (tag, ciphertext) = b.split_at_mut(tag_len); + if tag_first { + (tag, ciphertext) = b.split_at_mut(tag_len); + } else { + (ciphertext, tag) = b.split_at_mut(plaintext.len()); + }; let n = ctx .cipher_update(plaintext, Some(ciphertext)) @@ -76,6 +78,117 @@ fn decrypt_value<'p>( })?) } +#[pyo3::prelude::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead")] +struct ChaCha20Poly1305 { + key: pyo3::Py, + cipher: &'static openssl::cipher::CipherRef, +} + +#[pyo3::prelude::pymethods] +impl ChaCha20Poly1305 { + #[new] + fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + let key_buf = key.extract::>(py)?; + if key_buf.as_bytes().len() != 32 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("ChaCha20Poly1305 key must be 32 bytes."), + )); + } + + // TODO: Handle if ChaChaPoly1305 isn't supported by this OpenSSL + // TODO: FixedPool? + + Ok(ChaCha20Poly1305 { + key, + cipher: openssl::cipher::Cipher::chacha20_poly1305(), + }) + } + + #[staticmethod] + fn generate_key(py: pyo3::Python<'_>) -> CryptographyResult<&pyo3::PyAny> { + Ok(py + .import(pyo3::intern!(py, "os"))? + .call_method1(pyo3::intern!(py, "urandom"), (32,))?) + } + + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_buf = self.key.extract::>(py)?; + let data_bytes = data.as_bytes(); + let nonce_bytes = nonce.as_bytes(); + let ad = associated_data + .as_ref() + .map(|ad| ad.as_bytes()) + .unwrap_or(b""); + + check_length(data_bytes)?; + check_length(ad)?; + + if nonce_bytes.len() != 12 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be 12 bytes"), + )); + } + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.encrypt_init( + Some(&self.cipher), + Some(key_buf.as_bytes()), + Some(nonce_bytes), + )?; + + ctx.cipher_update(ad, None)?; + + encrypt_value(py, ctx, data_bytes, 16, false) + } + + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_buf = self.key.extract::>(py)?; + let data_bytes = data.as_bytes(); + let nonce_bytes = nonce.as_bytes(); + let ad = associated_data + .as_ref() + .map(|ad| ad.as_bytes()) + .unwrap_or(b""); + + check_length(data_bytes)?; + check_length(ad)?; + if nonce_bytes.len() != 12 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be 12 bytes"), + )); + } + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.decrypt_init( + Some(&self.cipher), + Some(key_buf.as_bytes()), + Some(nonce_bytes), + )?; + + if data_bytes.len() < 16 { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + let (ciphertext, tag) = data_bytes.split_at(data_bytes.len() - 16); + ctx.set_tag(tag)?; + + ctx.cipher_update(ad, None)?; + + decrypt_value(py, ctx, ciphertext) + } +} + #[pyo3::prelude::pyclass( frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead", @@ -215,6 +328,7 @@ impl AesSiv { pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { let m = pyo3::prelude::PyModule::new(py, "aead")?; + m.add_class::()?; m.add_class::()?; Ok(m)