Skip to content

Commit

Permalink
Convert AESSIV AEAD to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Aug 5, 2023
1 parent f7cfcef commit a284784
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 127 deletions.
64 changes: 12 additions & 52 deletions src/cryptography/hazmat/backends/openssl/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
AESCCM,
AESGCM,
AESOCB3,
AESSIV,
ChaCha20Poly1305,
)

_AEADTypes = typing.Union[
AESCCM, AESGCM, AESOCB3, AESSIV, ChaCha20Poly1305
]
_AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3, ChaCha20Poly1305]


def _is_evp_aead_supported_cipher(
Expand All @@ -44,16 +41,9 @@ def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool:
cipher_name = _evp_cipher_cipher_name(cipher)
if backend._fips_enabled and cipher_name not in backend._fips_aead:
return False
# SIV isn't loaded through get_cipherbyname but instead a new fetch API
# only available in 3.0+. But if we know we're on 3.0+ then we know
# it's supported.
if cipher_name.endswith(b"-siv"):
return backend._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER == 1
else:
return (
backend._lib.EVP_get_cipherbyname(cipher_name)
!= backend._ffi.NULL
)
return (
backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL
)


def _aead_create_ctx(
Expand Down Expand Up @@ -231,7 +221,6 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
AESCCM,
AESGCM,
AESOCB3,
AESSIV,
ChaCha20Poly1305,
)

Expand All @@ -241,26 +230,14 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
elif isinstance(cipher, AESOCB3):
return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii")
elif isinstance(cipher, AESSIV):
return f"aes-{len(cipher._key) * 8 // 2}-siv".encode("ascii")
else:
assert isinstance(cipher, AESGCM)
return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii")


def _evp_cipher(cipher_name: bytes, backend: Backend):
if cipher_name.endswith(b"-siv"):
evp_cipher = backend._lib.EVP_CIPHER_fetch(
backend._ffi.NULL,
cipher_name,
backend._ffi.NULL,
)
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
evp_cipher = backend._ffi.gc(evp_cipher, backend._lib.EVP_CIPHER_free)
else:
evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
backend.openssl_assert(evp_cipher != backend._ffi.NULL)

evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
return evp_cipher


Expand Down Expand Up @@ -389,10 +366,7 @@ def _evp_cipher_process_data(backend: Backend, ctx, data: bytes) -> bytes:
buf = backend._ffi.new("unsigned char[]", len(data))
data_ptr = backend._ffi.from_buffer(data)
res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data_ptr, len(data))
if res == 0:
# AES SIV can error here if the data is invalid on decrypt
backend._consume_errors()
raise InvalidTag
backend.openssl_assert(res != 0)
return backend._ffi.buffer(buf, outlen[0])[:]


Expand All @@ -405,7 +379,7 @@ def _evp_cipher_encrypt(
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

if ctx is None:
cipher_name = _evp_cipher_cipher_name(cipher)
Expand Down Expand Up @@ -445,14 +419,7 @@ def _evp_cipher_encrypt(
backend.openssl_assert(res != 0)
tag = backend._ffi.buffer(tag_buf)[:]

if isinstance(cipher, AESSIV):
# RFC 5297 defines the output as IV || C, where the tag we generate
# is the "IV" and C is the ciphertext. This is the opposite of our
# other AEADs, which are Ciphertext || Tag
backend.openssl_assert(len(tag) == 16)
return tag + processed_data
else:
return processed_data + tag
return processed_data + tag


def _evp_cipher_decrypt(
Expand All @@ -464,20 +431,13 @@ def _evp_cipher_decrypt(
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

if len(data) < tag_length:
raise InvalidTag

if isinstance(cipher, AESSIV):
# RFC 5297 defines the output as IV || C, where the tag we generate
# is the "IV" and C is the ciphertext. This is the opposite of our
# other AEADs, which are Ciphertext || Tag
tag = data[:tag_length]
data = data[tag_length:]
else:
tag = data[-tag_length:]
data = data[:-tag_length]
tag = data[-tag_length:]
data = data[:-tag_length]
if ctx is None:
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
Expand Down
2 changes: 2 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing

from cryptography.hazmat.bindings._rust.openssl import (
aead,
dh,
dsa,
ec,
Expand All @@ -21,6 +22,7 @@ from cryptography.hazmat.bindings._rust.openssl import (
__all__ = [
"openssl_version",
"raise_openssl_error",
"aead",
"dh",
"dsa",
"ec",
Expand Down
20 changes: 20 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import typing

class AESSIV:
def __init__(self, key: bytes) -> None: ...
@staticmethod
def generate_key(key_size: int) -> bytes: ...
def encrypt(
self,
nonce: bytes,
associated_data: typing.Optional[typing.List[bytes]],
) -> bytes: ...
def decrypt(
self,
nonce: bytes,
associated_data: typing.Optional[typing.List[bytes]],
) -> bytes: ...
86 changes: 11 additions & 75 deletions src/cryptography/hazmat/primitives/ciphers/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.backend import backend
from cryptography.hazmat.bindings._rust import FixedPool
from cryptography.hazmat.bindings._rust import openssl as rust_openssl

__all__ = [
"ChaCha20Poly1305",
"AESCCM",
"AESGCM",
"AESOCB3",
"AESSIV",
]

AESSIV = rust_openssl.aead.AESSIV


class ChaCha20Poly1305:
Expand Down Expand Up @@ -301,78 +312,3 @@ def _check_params(
utils._check_byteslike("associated_data", associated_data)
if len(nonce) < 12 or len(nonce) > 15:
raise ValueError("Nonce must be between 12 and 15 bytes")


class AESSIV:
_MAX_SIZE = 2**31 - 1

def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (32, 48, 64):
raise ValueError("AESSIV key must be 256, 384, or 512 bits.")

self._key = key

if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"AES-SIV is not supported by this version of OpenSSL",
exceptions._Reasons.UNSUPPORTED_CIPHER,
)

@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")

if bit_length not in (256, 384, 512):
raise ValueError("bit_length must be 256, 384, or 512")

return os.urandom(bit_length // 8)

def encrypt(
self,
data: bytes,
associated_data: typing.Optional[typing.List[bytes]],
) -> bytes:
if associated_data is None:
associated_data = []

self._check_params(data, associated_data)

if len(data) > self._MAX_SIZE or any(
len(ad) > self._MAX_SIZE for ad in associated_data
):
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)

return aead._encrypt(backend, self, b"", data, associated_data, 16)

def decrypt(
self,
data: bytes,
associated_data: typing.Optional[typing.List[bytes]],
) -> bytes:
if associated_data is None:
associated_data = []

self._check_params(data, associated_data)

return aead._decrypt(backend, self, b"", data, associated_data, 16)

def _check_params(
self,
data: bytes,
associated_data: typing.List[bytes],
) -> None:
utils._check_byteslike("data", data)
if len(data) == 0:
raise ValueError("data must not be zero length")

if not isinstance(associated_data, list):
raise TypeError(
"associated_data must be a list of bytes-like objects or None"
)
for x in associated_data:
utils._check_byteslike("associated_data elements", x)
Loading

0 comments on commit a284784

Please sign in to comment.