Skip to content

Commit

Permalink
Migrate ChaCha20Poly1305 AEAD to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Oct 28, 2023
1 parent 3b39f65 commit 035e0ab
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 239 deletions.
177 changes: 12 additions & 165 deletions src/cryptography/hazmat/backends/openssl/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,24 @@
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

_AEADTypes = typing.Union[AESCCM, AESGCM, ChaCha20Poly1305]


def _is_evp_aead_supported_cipher(
backend: Backend, cipher: _AEADTypes
) -> bool:
"""
Checks whether the given cipher is supported through
EVP_AEAD rather than the normal OpenSSL EVP_CIPHER API.
"""
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305

return backend._lib.Cryptography_HAS_EVP_AEAD and isinstance(
cipher, ChaCha20Poly1305
)
_AEADTypes = typing.Union[AESCCM, AESGCM]


def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool:
if _is_evp_aead_supported_cipher(backend, cipher):
return True
else:
cipher_name = _evp_cipher_cipher_name(cipher)
if backend._fips_enabled and cipher_name not in backend._fips_aead:
return False
return (
backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL
)
cipher_name = _evp_cipher_cipher_name(cipher)
if backend._fips_enabled and cipher_name not in backend._fips_aead:
return False
return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL


def _aead_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
):
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_create_ctx(backend, cipher, key)
else:
return _evp_cipher_create_ctx(backend, cipher, key)
return _evp_cipher_create_ctx(backend, cipher, key)


def _encrypt(
Expand All @@ -65,14 +42,9 @@ def _encrypt(
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
else:
return _evp_cipher_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
return _evp_cipher_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)


def _decrypt(
Expand All @@ -84,132 +56,10 @@ def _decrypt(
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
else:
return _evp_cipher_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)


def _evp_aead_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
tag_len: int | None = None,
):
aead_cipher = _evp_aead_get_cipher(backend, cipher)
assert aead_cipher is not None
key_ptr = backend._ffi.from_buffer(key)
tag_len = (
backend._lib.EVP_AEAD_DEFAULT_TAG_LENGTH
if tag_len is None
else tag_len
)
ctx = backend._lib.Cryptography_EVP_AEAD_CTX_new(
aead_cipher, key_ptr, len(key), tag_len
)
backend.openssl_assert(ctx != backend._ffi.NULL)
ctx = backend._ffi.gc(ctx, backend._lib.EVP_AEAD_CTX_free)
return ctx


def _evp_aead_get_cipher(backend: Backend, cipher: _AEADTypes):
from cryptography.hazmat.primitives.ciphers.aead import (
ChaCha20Poly1305,
return _evp_cipher_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)

# Currently only ChaCha20-Poly1305 is supported using this API
assert isinstance(cipher, ChaCha20Poly1305)
return backend._lib.EVP_aead_chacha20_poly1305()


def _evp_aead_encrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any,
) -> bytes:
assert ctx is not None

aead_cipher = _evp_aead_get_cipher(backend, cipher)
assert aead_cipher is not None

out_len = backend._ffi.new("size_t *")
# max_out_len should be in_len plus the result of
# EVP_AEAD_max_overhead.
max_out_len = len(data) + backend._lib.EVP_AEAD_max_overhead(aead_cipher)
out_buf = backend._ffi.new("uint8_t[]", max_out_len)
data_ptr = backend._ffi.from_buffer(data)
nonce_ptr = backend._ffi.from_buffer(nonce)
aad = b"".join(associated_data)
aad_ptr = backend._ffi.from_buffer(aad)

res = backend._lib.EVP_AEAD_CTX_seal(
ctx,
out_buf,
out_len,
max_out_len,
nonce_ptr,
len(nonce),
data_ptr,
len(data),
aad_ptr,
len(aad),
)
backend.openssl_assert(res == 1)
encrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:]
return encrypted_data


def _evp_aead_decrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any,
) -> bytes:
if len(data) < tag_length:
raise InvalidTag

assert ctx is not None

out_len = backend._ffi.new("size_t *")
# max_out_len should at least in_len
max_out_len = len(data)
out_buf = backend._ffi.new("uint8_t[]", max_out_len)
data_ptr = backend._ffi.from_buffer(data)
nonce_ptr = backend._ffi.from_buffer(nonce)
aad = b"".join(associated_data)
aad_ptr = backend._ffi.from_buffer(aad)

res = backend._lib.EVP_AEAD_CTX_open(
ctx,
out_buf,
out_len,
max_out_len,
nonce_ptr,
len(nonce),
data_ptr,
len(data),
aad_ptr,
len(aad),
)

if res == 0:
backend._consume_errors()
raise InvalidTag

decrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:]
return decrypted_data


_ENCRYPT = 1
_DECRYPT = 0
Expand All @@ -219,12 +69,9 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

if isinstance(cipher, ChaCha20Poly1305):
return b"chacha20-poly1305"
elif isinstance(cipher, AESCCM):
if isinstance(cipher, AESCCM):
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
else:
assert isinstance(cipher, AESGCM)
Expand Down
17 changes: 17 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

class ChaCha20Poly1305:
def __init__(self, key: bytes) -> None: ...
@staticmethod
def generate_key() -> bytes: ...
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...

class AESSIV:
def __init__(self, key: bytes) -> None: ...
@staticmethod
Expand Down
75 changes: 1 addition & 74 deletions src/cryptography/hazmat/primitives/ciphers/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from cryptography import exceptions, utils
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.backend import backend
from cryptography.hazmat.bindings._rust import FixedPool
from cryptography.hazmat.bindings._rust import openssl as rust_openssl

__all__ = [
Expand All @@ -20,83 +19,11 @@
"AESSIV",
]

ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305
AESSIV = rust_openssl.aead.AESSIV
AESOCB3 = rust_openssl.aead.AESOCB3


class ChaCha20Poly1305:
_MAX_SIZE = 2**31 - 1

def __init__(self, key: bytes):
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"ChaCha20Poly1305 is not supported by this version of OpenSSL",
exceptions._Reasons.UNSUPPORTED_CIPHER,
)
utils._check_byteslike("key", key)

if len(key) != 32:
raise ValueError("ChaCha20Poly1305 key must be 32 bytes.")

self._key = key
self._pool = FixedPool(self._create_fn)

@classmethod
def generate_key(cls) -> bytes:
return os.urandom(32)

def _create_fn(self):
return aead._aead_create_ctx(backend, self, self._key)

def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes:
if associated_data is None:
associated_data = b""

if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)

self._check_params(nonce, data, associated_data)
with self._pool.acquire() as ctx:
return aead._encrypt(
backend, self, nonce, data, [associated_data], 16, ctx
)

def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes:
if associated_data is None:
associated_data = b""

self._check_params(nonce, data, associated_data)
with self._pool.acquire() as ctx:
return aead._decrypt(
backend, self, nonce, data, [associated_data], 16, ctx
)

def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if len(nonce) != 12:
raise ValueError("Nonce must be 12 bytes")


class AESCCM:
_MAX_SIZE = 2**31 - 1

Expand Down
Loading

0 comments on commit 035e0ab

Please sign in to comment.