Skip to content

Commit

Permalink
Migrate EC support to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Jun 23, 2023
1 parent 4ae49a4 commit f50ed6c
Show file tree
Hide file tree
Showing 12 changed files with 695 additions and 614 deletions.
261 changes: 17 additions & 244 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@
import contextlib
import itertools
import typing
from contextlib import contextmanager

from cryptography import utils, x509
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.ciphers import _CipherContext
from cryptography.hazmat.backends.openssl.cmac import _CMACContext
from cryptography.hazmat.backends.openssl.ec import (
_EllipticCurvePrivateKey,
_EllipticCurvePublicKey,
)
from cryptography.hazmat.backends.openssl.rsa import (
_RSAPrivateKey,
_RSAPublicKey,
Expand Down Expand Up @@ -542,10 +537,9 @@ def _evp_pkey_to_private_key(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_EC:
ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
self.openssl_assert(ec_cdata != self._ffi.NULL)
ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type in self._dh_types:
return rust_openssl.dh.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
Expand Down Expand Up @@ -603,12 +597,9 @@ def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes:
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_EC:
ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
if ec_cdata == self._ffi.NULL:
errors = self._consume_errors()
raise ValueError("Unable to load EC key", errors)
ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type in self._dh_types:
return rust_openssl.dh.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
Expand Down Expand Up @@ -944,20 +935,7 @@ def elliptic_curve_supported(self, curve: ec.EllipticCurve) -> bool:
):
return False

try:
curve_nid = self._elliptic_curve_to_nid(curve)
except UnsupportedAlgorithm:
curve_nid = self._lib.NID_undef

group = self._lib.EC_GROUP_new_by_curve_name(curve_nid)

if group == self._ffi.NULL:
self._consume_errors()
return False
else:
self.openssl_assert(curve_nid != self._lib.NID_undef)
self._lib.EC_GROUP_free(group)
return True
return rust_openssl.ec.curve_supported(curve)

def elliptic_curve_signature_algorithm_supported(
self,
Expand All @@ -979,158 +957,27 @@ def generate_elliptic_curve_private_key(
"""
Generate a new private key on the named curve.
"""

if self.elliptic_curve_supported(curve):
ec_cdata = self._ec_key_new_by_curve(curve)

res = self._lib.EC_KEY_generate_key(ec_cdata)
self.openssl_assert(res == 1)

evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
else:
raise UnsupportedAlgorithm(
f"Backend object does not support {curve.name}.",
_Reasons.UNSUPPORTED_ELLIPTIC_CURVE,
)
return rust_openssl.ec.generate_private_key(curve)

def load_elliptic_curve_private_numbers(
self, numbers: ec.EllipticCurvePrivateNumbers
) -> ec.EllipticCurvePrivateKey:
public = numbers.public_numbers

ec_cdata = self._ec_key_new_by_curve(public.curve)

private_value = self._ffi.gc(
self._int_to_bn(numbers.private_value), self._lib.BN_clear_free
)
res = self._lib.EC_KEY_set_private_key(ec_cdata, private_value)
if res != 1:
self._consume_errors()
raise ValueError("Invalid EC key.")

with self._tmp_bn_ctx() as bn_ctx:
self._ec_key_set_public_key_affine_coordinates(
ec_cdata, public.x, public.y, bn_ctx
)
# derive the expected public point and compare it to the one we
# just set based on the values we were given. If they don't match
# this isn't a valid key pair.
group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)
set_point = backend._lib.EC_KEY_get0_public_key(ec_cdata)
self.openssl_assert(set_point != self._ffi.NULL)
computed_point = self._lib.EC_POINT_new(group)
self.openssl_assert(computed_point != self._ffi.NULL)
computed_point = self._ffi.gc(
computed_point, self._lib.EC_POINT_free
)
res = self._lib.EC_POINT_mul(
group,
computed_point,
private_value,
self._ffi.NULL,
self._ffi.NULL,
bn_ctx,
)
self.openssl_assert(res == 1)
if (
self._lib.EC_POINT_cmp(
group, set_point, computed_point, bn_ctx
)
!= 0
):
raise ValueError("Invalid EC key.")

evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.from_private_numbers(numbers)

def load_elliptic_curve_public_numbers(
self, numbers: ec.EllipticCurvePublicNumbers
) -> ec.EllipticCurvePublicKey:
ec_cdata = self._ec_key_new_by_curve(numbers.curve)
with self._tmp_bn_ctx() as bn_ctx:
self._ec_key_set_public_key_affine_coordinates(
ec_cdata, numbers.x, numbers.y, bn_ctx
)
evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.from_public_numbers(numbers)

def load_elliptic_curve_public_bytes(
self, curve: ec.EllipticCurve, point_bytes: bytes
) -> ec.EllipticCurvePublicKey:
ec_cdata = self._ec_key_new_by_curve(curve)
group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)
point = self._lib.EC_POINT_new(group)
self.openssl_assert(point != self._ffi.NULL)
point = self._ffi.gc(point, self._lib.EC_POINT_free)
with self._tmp_bn_ctx() as bn_ctx:
res = self._lib.EC_POINT_oct2point(
group, point, point_bytes, len(point_bytes), bn_ctx
)
if res != 1:
self._consume_errors()
raise ValueError("Invalid public bytes for the given curve")

res = self._lib.EC_KEY_set_public_key(ec_cdata, point)
self.openssl_assert(res == 1)
evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)
return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.from_public_bytes(curve, point_bytes)

def derive_elliptic_curve_private_key(
self, private_value: int, curve: ec.EllipticCurve
) -> ec.EllipticCurvePrivateKey:
ec_cdata = self._ec_key_new_by_curve(curve)

group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)

point = self._lib.EC_POINT_new(group)
self.openssl_assert(point != self._ffi.NULL)
point = self._ffi.gc(point, self._lib.EC_POINT_free)

value = self._int_to_bn(private_value)
value = self._ffi.gc(value, self._lib.BN_clear_free)

with self._tmp_bn_ctx() as bn_ctx:
res = self._lib.EC_POINT_mul(
group, point, value, self._ffi.NULL, self._ffi.NULL, bn_ctx
)
self.openssl_assert(res == 1)

bn_x = self._lib.BN_CTX_get(bn_ctx)
bn_y = self._lib.BN_CTX_get(bn_ctx)

res = self._lib.EC_POINT_get_affine_coordinates(
group, point, bn_x, bn_y, bn_ctx
)
if res != 1:
self._consume_errors()
raise ValueError("Unable to derive key from private_value")

res = self._lib.EC_KEY_set_public_key(ec_cdata, point)
self.openssl_assert(res == 1)
private = self._int_to_bn(private_value)
private = self._ffi.gc(private, self._lib.BN_clear_free)
res = self._lib.EC_KEY_set_private_key(ec_cdata, private)
self.openssl_assert(res == 1)

evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)

def _ec_key_new_by_curve(self, curve: ec.EllipticCurve):
curve_nid = self._elliptic_curve_to_nid(curve)
return self._ec_key_new_by_curve_nid(curve_nid)

def _ec_key_new_by_curve_nid(self, curve_nid: int):
ec_cdata = self._lib.EC_KEY_new_by_curve_name(curve_nid)
self.openssl_assert(ec_cdata != self._ffi.NULL)
return self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
return rust_openssl.ec.derive_private_key(private_value, curve)

def elliptic_curve_exchange_algorithm_supported(
self, algorithm: ec.ECDH, curve: ec.EllipticCurve
Expand All @@ -1139,73 +986,6 @@ def elliptic_curve_exchange_algorithm_supported(
algorithm, ec.ECDH
)

def _ec_cdata_to_evp_pkey(self, ec_cdata):
evp_pkey = self._create_evp_pkey_gc()
res = self._lib.EVP_PKEY_set1_EC_KEY(evp_pkey, ec_cdata)
self.openssl_assert(res == 1)
return evp_pkey

def _elliptic_curve_to_nid(self, curve: ec.EllipticCurve) -> int:
"""
Get the NID for a curve name.
"""

curve_aliases = {"secp192r1": "prime192v1", "secp256r1": "prime256v1"}

curve_name = curve_aliases.get(curve.name, curve.name)

curve_nid = self._lib.OBJ_sn2nid(curve_name.encode())
if curve_nid == self._lib.NID_undef:
raise UnsupportedAlgorithm(
f"{curve.name} is not a supported elliptic curve",
_Reasons.UNSUPPORTED_ELLIPTIC_CURVE,
)
return curve_nid

@contextmanager
def _tmp_bn_ctx(self):
bn_ctx = self._lib.BN_CTX_new()
self.openssl_assert(bn_ctx != self._ffi.NULL)
bn_ctx = self._ffi.gc(bn_ctx, self._lib.BN_CTX_free)
self._lib.BN_CTX_start(bn_ctx)
try:
yield bn_ctx
finally:
self._lib.BN_CTX_end(bn_ctx)

def _ec_key_set_public_key_affine_coordinates(
self,
ec_cdata,
x: int,
y: int,
bn_ctx,
) -> None:
"""
Sets the public key point in the EC_KEY context to the affine x and y
values.
"""

if x < 0 or y < 0:
raise ValueError(
"Invalid EC key. Both x and y must be non-negative."
)

x = self._ffi.gc(self._int_to_bn(x), self._lib.BN_free)
y = self._ffi.gc(self._int_to_bn(y), self._lib.BN_free)
group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)
point = self._lib.EC_POINT_new(group)
self.openssl_assert(point != self._ffi.NULL)
point = self._ffi.gc(point, self._lib.EC_POINT_free)
res = self._lib.EC_POINT_set_affine_coordinates(
group, point, x, y, bn_ctx
)
if res != 1:
self._consume_errors()
raise ValueError("Invalid EC key.")
res = self._lib.EC_KEY_set_public_key(ec_cdata, point)
self.openssl_assert(res == 1)

def _private_key_bytes(
self,
encoding: serialization.Encoding,
Expand Down Expand Up @@ -1278,11 +1058,8 @@ def _private_key_bytes(
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if encoding is serialization.Encoding.PEM:
if key_type == self._lib.EVP_PKEY_RSA:
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
else:
assert key_type == self._lib.EVP_PKEY_EC
write_bio = self._lib.PEM_write_bio_ECPrivateKey
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
return self._private_key_bytes_via_bio(
write_bio, cdata, password
)
Expand All @@ -1293,11 +1070,8 @@ def _private_key_bytes(
"Encryption is not supported for DER encoded "
"traditional OpenSSL keys"
)
if key_type == self._lib.EVP_PKEY_RSA:
write_bio = self._lib.i2d_RSAPrivateKey_bio
else:
assert key_type == self._lib.EVP_PKEY_EC
write_bio = self._lib.i2d_ECPrivateKey_bio
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.i2d_RSAPrivateKey_bio
return self._bio_func_output(write_bio, cdata)

raise ValueError("Unsupported encoding for TraditionalOpenSSL")
Expand Down Expand Up @@ -1374,8 +1148,7 @@ def _public_key_bytes(
if format is serialization.PublicFormat.PKCS1:
# Only RSA is supported here.
key_type = self._lib.EVP_PKEY_id(evp_pkey)
if key_type != self._lib.EVP_PKEY_RSA:
raise ValueError("PKCS1 format is supported only for RSA keys")
self.openssl_assert(key_type == self._lib.EVP_PKEY_RSA)

if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_RSAPublicKey
Expand Down
Loading

0 comments on commit f50ed6c

Please sign in to comment.