diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index 5df500b..cda6bf2 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -82,7 +82,7 @@ class JWE: def __init__(self, plaintext=None, protected=None, unprotected=None, aad=None, algs=None, recipient=None, header=None, - header_registry=None): + header_registry=None, flattened=True): """Creates a JWE token. :param plaintext(bytes): An arbitrary plaintext to be encrypted. @@ -93,11 +93,13 @@ def __init__(self, plaintext=None, protected=None, unprotected=None, :param recipient: An optional, default recipient key :param header: An optional header for the default recipient :param header_registry: Optional additions to the header registry + :param flattened: Use flattened serialization syntax (default True) """ self._allowed_algs = None self.objects = {} self.plaintext = None self.header_registry = JWSEHeaderRegistry(JWEHeaderRegistry) + self.flattened = flattened if header_registry: self.header_registry.update(header_registry) if plaintext is not None: @@ -251,19 +253,26 @@ def add_recipient(self, key, header=None): if 'ciphertext' not in self.objects: self._encrypt(alg, enc, jh) - if 'recipients' in self.objects: - self.objects['recipients'].append(rec) - elif 'encrypted_key' in self.objects or 'header' in self.objects: - self.objects['recipients'] = [] - n = {} - if 'encrypted_key' in self.objects: - n['encrypted_key'] = self.objects.pop('encrypted_key') - if 'header' in self.objects: - n['header'] = self.objects.pop('header') - self.objects['recipients'].append(n) - self.objects['recipients'].append(rec) + if self.flattened: + if 'recipients' in self.objects: + self.objects['recipients'].append(rec) + elif 'encrypted_key' in self.objects or 'header' in self.objects: + self.objects['recipients'] = [] + n = {} + if 'encrypted_key' in self.objects: + n['encrypted_key'] = self.objects.pop('encrypted_key') + if 'header' in self.objects: + n['header'] = self.objects.pop('header') + self.objects['recipients'].append(n) + self.objects['recipients'].append(rec) + else: + self.objects.update(rec) else: - self.objects.update(rec) + if 'recipients' in self.objects: + self.objects['recipients'].append(rec) + else: + self.objects['recipients'] = [rec] + def serialize(self, compact=False): """Serializes the object into a JWE token.