Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update charm libraries #559

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 114 additions & 65 deletions lib/charms/tls_certificates_interface/v2/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 20
LIBPATCH = 21

PYDEPS = ["cryptography", "jsonschema"]

Expand Down Expand Up @@ -693,6 +693,105 @@ def generate_ca(
return cert.public_bytes(serialization.Encoding.PEM)


def get_certificate_extensions(
authority_key_identifier: bytes,
csr: x509.CertificateSigningRequest,
alt_names: Optional[List[str]],
is_ca: bool,
) -> List[x509.Extension]:
"""Generates a list of certificate extensions from a CSR and other known information.

Args:
authority_key_identifier (bytes): Authority key identifier
csr (x509.CertificateSigningRequest): CSR
alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR
is_ca (bool): Whether the certificate is a CA certificate

Returns:
List[x509.Extension]: List of extensions
"""
cert_extensions_list: List[x509.Extension] = [
x509.Extension(
oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER,
value=x509.AuthorityKeyIdentifier(
key_identifier=authority_key_identifier,
authority_cert_issuer=None,
authority_cert_serial_number=None,
),
critical=False,
),
x509.Extension(
oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER,
value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()),
critical=False,
),
x509.Extension(
oid=ExtensionOID.BASIC_CONSTRAINTS,
critical=True,
value=x509.BasicConstraints(ca=is_ca, path_length=None),
),
]

sans: List[x509.GeneralName] = []
san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else []
sans.extend(san_alt_names)
try:
loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
sans.extend(
[x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)]
)
sans.extend(
[x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)]
)
sans.extend(
[
x509.RegisteredID(oid)
for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID)
]
)
except x509.ExtensionNotFound:
pass

if sans:
cert_extensions_list.append(
x509.Extension(
oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
critical=False,
value=x509.SubjectAlternativeName(sans),
)
)

if is_ca:
cert_extensions_list.append(
x509.Extension(
ExtensionOID.KEY_USAGE,
critical=True,
value=x509.KeyUsage(
digital_signature=False,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
)
)

existing_oids = {ext.oid for ext in cert_extensions_list}
for extension in csr.extensions:
if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
continue
if extension.oid in existing_oids:
logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid)
continue
cert_extensions_list.append(extension)

return cert_extensions_list


def generate_certificate(
csr: bytes,
ca: bytes,
Expand Down Expand Up @@ -730,74 +829,24 @@ def generate_certificate(
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.utcnow())
.not_valid_after(datetime.utcnow() + timedelta(days=validity))
.add_extension(
x509.AuthorityKeyIdentifier(
key_identifier=ca_pem.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value.key_identifier,
authority_cert_issuer=None,
authority_cert_serial_number=None,
),
critical=False,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False
)
)

extensions_list = csr_object.extensions
san_ext: Optional[x509.Extension] = None
if alt_names:
full_sans_dns = alt_names.copy()
extensions = get_certificate_extensions(
authority_key_identifier=ca_pem.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value.key_identifier,
csr=csr_object,
alt_names=alt_names,
is_ca=is_ca,
)
for extension in extensions:
try:
loaded_san_ext = csr_object.extensions.get_extension_for_class(
x509.SubjectAlternativeName
certificate_builder = certificate_builder.add_extension(
extval=extension.value,
critical=extension.critical,
)
full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName))
except ExtensionNotFound:
pass
finally:
san_ext = Extension(
ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
False,
x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]),
)
if not extensions_list:
extensions_list = x509.Extensions([san_ext])

for extension in extensions_list:
if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext:
extension = san_ext

certificate_builder = certificate_builder.add_extension(
extension.value,
critical=extension.critical,
)

if is_ca:
certificate_builder = certificate_builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)
certificate_builder = certificate_builder.add_extension(
x509.KeyUsage(
digital_signature=False,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
else:
certificate_builder = certificate_builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=False
)
except ValueError as e:
logger.warning("Failed to add extension %s: %s", extension.oid, e)

certificate_builder._version = x509.Version.v3
cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type]
return cert.public_bytes(serialization.Encoding.PEM)

Expand Down
Loading