Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert AESGCM AEAD to Rust #9181

Merged
merged 1 commit into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 19 additions & 36 deletions src/cryptography/hazmat/backends/openssl/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from cryptography.hazmat.backends.openssl.backend import Backend
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
)

_AEADTypes = typing.Union[AESCCM, AESGCM]
_AEADTypes = typing.Union[AESCCM]


def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool:
Expand Down Expand Up @@ -55,16 +54,10 @@ def _decrypt(


def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
)
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

if isinstance(cipher, AESCCM):
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
else:
assert isinstance(cipher, AESGCM)
return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii")
assert isinstance(cipher, AESCCM)
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")


def _evp_cipher(cipher_name: bytes, backend: Backend):
Expand Down Expand Up @@ -105,7 +98,8 @@ def _evp_cipher_aead_setup(
if operation == _DECRYPT:
assert tag is not None
_evp_cipher_set_tag(backend, ctx, tag)
elif cipher_name.endswith(b"-ccm"):
else:
assert cipher_name.endswith(b"-ccm")
res = backend._lib.EVP_CIPHER_CTX_ctrl(
ctx,
backend._lib.EVP_CTRL_AEAD_SET_TAG,
Expand Down Expand Up @@ -188,8 +182,8 @@ def _evp_cipher_encrypt(
# CCM requires us to pass the length of the data before processing
# anything.
# However calling this with any other AEAD results in an error
if isinstance(cipher, AESCCM):
_evp_cipher_set_length(backend, ctx, len(data))
assert isinstance(cipher, AESCCM)
_evp_cipher_set_length(backend, ctx, len(data))

for ad in associated_data:
_evp_cipher_process_aad(backend, ctx, ad)
Expand Down Expand Up @@ -241,32 +235,21 @@ def _evp_cipher_decrypt(
# CCM requires us to pass the length of the data before processing
# anything.
# However calling this with any other AEAD results in an error
if isinstance(cipher, AESCCM):
_evp_cipher_set_length(backend, ctx, len(data))
assert isinstance(cipher, AESCCM)
_evp_cipher_set_length(backend, ctx, len(data))

for ad in associated_data:
_evp_cipher_process_aad(backend, ctx, ad)
# CCM has a different error path if the tag doesn't match. Errors are
# raised in Update and Final is irrelevant.
if isinstance(cipher, AESCCM):
outlen = backend._ffi.new("int *")
buf = backend._ffi.new("unsigned char[]", len(data))
d_ptr = backend._ffi.from_buffer(data)
res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, d_ptr, len(data))
if res != 1:
backend._consume_errors()
raise InvalidTag

processed_data = backend._ffi.buffer(buf, outlen[0])[:]
else:
processed_data = _evp_cipher_process_data(backend, ctx, data)
outlen = backend._ffi.new("int *")
# OCB can return up to 15 bytes (16 byte block - 1) in finalization
buf = backend._ffi.new("unsigned char[]", 16)
res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen)
processed_data += backend._ffi.buffer(buf, outlen[0])[:]
if res == 0:
backend._consume_errors()
raise InvalidTag
outlen = backend._ffi.new("int *")
buf = backend._ffi.new("unsigned char[]", len(data))
d_ptr = backend._ffi.from_buffer(data)
res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, d_ptr, len(data))
if res != 1:
backend._consume_errors()
raise InvalidTag

processed_data = backend._ffi.buffer(buf, outlen[0])[:]

return processed_data
17 changes: 17 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

class AESGCM:
def __init__(self, key: bytes) -> None: ...
@staticmethod
def generate_key(key_size: int) -> bytes: ...
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...

class ChaCha20Poly1305:
def __init__(self, key: bytes) -> None: ...
@staticmethod
Expand Down
64 changes: 1 addition & 63 deletions src/cryptography/hazmat/primitives/ciphers/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"AESSIV",
]

AESGCM = rust_openssl.aead.AESGCM
ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305
AESSIV = rust_openssl.aead.AESSIV
AESOCB3 = rust_openssl.aead.AESOCB3
Expand Down Expand Up @@ -109,66 +110,3 @@ def _check_params(
utils._check_byteslike("associated_data", associated_data)
if not 7 <= len(nonce) <= 13:
raise ValueError("Nonce must be between 7 and 13 bytes")


class AESGCM:
_MAX_SIZE = 2**31 - 1

def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESGCM key must be 128, 192, or 256 bits.")

self._key = key

@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")

if bit_length not in (128, 192, 256):
raise ValueError("bit_length must be 128, 192, or 256")

return os.urandom(bit_length // 8)

def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes:
if associated_data is None:
associated_data = b""

if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)

self._check_params(nonce, data, associated_data)
return aead._encrypt(backend, self, nonce, data, [associated_data], 16)

def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes:
if associated_data is None:
associated_data = b""

self._check_params(nonce, data, associated_data)
return aead._decrypt(backend, self, nonce, data, [associated_data], 16)

def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if len(nonce) < 8 or len(nonce) > 128:
raise ValueError("Nonce must be between 8 and 128 bytes")
117 changes: 114 additions & 3 deletions src/rust/src/backend/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,7 @@ impl ChaCha20Poly1305 {

#[staticmethod]
fn generate_key(py: pyo3::Python<'_>) -> CryptographyResult<&pyo3::PyAny> {
Ok(py
.import(pyo3::intern!(py, "os"))?
.call_method1(pyo3::intern!(py, "urandom"), (32,))?)
Ok(types::OS_URANDOM.get(py)?.call1((32,))?)
}

fn encrypt<'p>(
Expand Down Expand Up @@ -532,6 +530,118 @@ impl ChaCha20Poly1305 {
}
}

#[pyo3::prelude::pyclass(
frozen,
module = "cryptography.hazmat.bindings._rust.openssl.aead",
name = "AESGCM"
)]
struct AesGcm {
#[cfg(any(
CRYPTOGRAPHY_OPENSSL_320_OR_GREATER,
CRYPTOGRAPHY_IS_LIBRESSL,
CRYPTOGRAPHY_IS_BORINGSSL,
not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER),
))]
ctx: EvpCipherAead,

#[cfg(not(any(
CRYPTOGRAPHY_OPENSSL_320_OR_GREATER,
CRYPTOGRAPHY_IS_LIBRESSL,
CRYPTOGRAPHY_IS_BORINGSSL,
not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER),
)))]
ctx: LazyEvpCipherAead,
}

#[pyo3::prelude::pymethods]
impl AesGcm {
#[new]
fn new(py: pyo3::Python<'_>, key: pyo3::Py<pyo3::PyAny>) -> CryptographyResult<AesGcm> {
let key_buf = key.extract::<CffiBuf<'_>>(py)?;
let cipher = match key_buf.as_bytes().len() {
16 => openssl::cipher::Cipher::aes_128_gcm(),
24 => openssl::cipher::Cipher::aes_192_gcm(),
32 => openssl::cipher::Cipher::aes_256_gcm(),
_ => {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"AESGCM key must be 128, 192, or 256 bits.",
),
))
}
};

cfg_if::cfg_if! {
if #[cfg(any(
CRYPTOGRAPHY_OPENSSL_320_OR_GREATER,
CRYPTOGRAPHY_IS_BORINGSSL,
CRYPTOGRAPHY_IS_LIBRESSL,
not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER,
)))] {
Ok(AesGcm {
ctx: EvpCipherAead::new(cipher, key_buf.as_bytes(), 16, false)?,
})
} else {
Ok(AesGcm {
ctx: LazyEvpCipherAead::new(cipher, key, 16, false),
})

}
}
}

#[staticmethod]
fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> {
if bit_length != 128 && bit_length != 192 && bit_length != 256 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"),
));
}

Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?)
}

fn encrypt<'p>(
&self,
py: pyo3::Python<'p>,
nonce: CffiBuf<'_>,
data: CffiBuf<'_>,
associated_data: Option<CffiBuf<'_>>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
let nonce_bytes = nonce.as_bytes();
let aad = associated_data.map(Aad::Single);

if nonce_bytes.len() < 8 || nonce_bytes.len() > 128 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("Nonce must be between 8 and 128 bytes"),
));
}

self.ctx
.encrypt(py, data.as_bytes(), aad, Some(nonce_bytes))
}

fn decrypt<'p>(
&self,
py: pyo3::Python<'p>,
nonce: CffiBuf<'_>,
data: CffiBuf<'_>,
associated_data: Option<CffiBuf<'_>>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
let nonce_bytes = nonce.as_bytes();
let aad = associated_data.map(Aad::Single);

if nonce_bytes.len() < 8 || nonce_bytes.len() > 128 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("Nonce must be between 8 and 128 bytes"),
));
}

self.ctx
.decrypt(py, data.as_bytes(), aad, Some(nonce_bytes))
}
}

#[pyo3::prelude::pyclass(
frozen,
module = "cryptography.hazmat.bindings._rust.openssl.aead",
Expand Down Expand Up @@ -845,6 +955,7 @@ impl AesGcmSiv {
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "aead")?;

m.add_class::<AesGcm>()?;
m.add_class::<ChaCha20Poly1305>()?;
m.add_class::<AesSiv>()?;
m.add_class::<AesOcb3>()?;
Expand Down
2 changes: 2 additions & 0 deletions tests/hazmat/primitives/test_aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ def test_invalid_nonce_length(self, length, backend):
aesgcm = AESGCM(key)
with pytest.raises(ValueError):
aesgcm.encrypt(b"\x00" * length, b"hi", None)
with pytest.raises(ValueError):
aesgcm.decrypt(b"\x00" * length, b"hi", None)

def test_bad_key(self, backend):
with pytest.raises(TypeError):
Expand Down
Loading