diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index 5df500b..e01aa1d 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -82,7 +82,7 @@ class JWE: def __init__(self, plaintext=None, protected=None, unprotected=None, aad=None, algs=None, recipient=None, header=None, - header_registry=None): + header_registry=None, flattened=True): """Creates a JWE token. :param plaintext(bytes): An arbitrary plaintext to be encrypted. @@ -93,11 +93,13 @@ def __init__(self, plaintext=None, protected=None, unprotected=None, :param recipient: An optional, default recipient key :param header: An optional header for the default recipient :param header_registry: Optional additions to the header registry + :param flattened: Use flattened serialization syntax (default True) """ self._allowed_algs = None self.objects = {} self.plaintext = None self.header_registry = JWSEHeaderRegistry(JWEHeaderRegistry) + self.flattened = flattened if header_registry: self.header_registry.update(header_registry) if plaintext is not None: @@ -253,17 +255,20 @@ def add_recipient(self, key, header=None): if 'recipients' in self.objects: self.objects['recipients'].append(rec) - elif 'encrypted_key' in self.objects or 'header' in self.objects: - self.objects['recipients'] = [] - n = {} - if 'encrypted_key' in self.objects: - n['encrypted_key'] = self.objects.pop('encrypted_key') - if 'header' in self.objects: - n['header'] = self.objects.pop('header') - self.objects['recipients'].append(n) - self.objects['recipients'].append(rec) + elif self.flattened: + if 'encrypted_key' in self.objects or 'header' in self.objects: + self.objects['recipients'] = [] + n = {} + if 'encrypted_key' in self.objects: + n['encrypted_key'] = self.objects.pop('encrypted_key') + if 'header' in self.objects: + n['header'] = self.objects.pop('header') + self.objects['recipients'].append(n) + self.objects['recipients'].append(rec) + else: + self.objects.update(rec) else: - self.objects.update(rec) + self.objects['recipients'] = [rec] def serialize(self, compact=False): """Serializes the object into a JWE token. diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index 59049f8..d7b3b6b 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -1509,6 +1509,18 @@ def test_decrypt_keyset(self): with self.assertRaises(JWKeyNotFound): e4.deserialize(e3.serialize(), ks) + def test_serialize_not_flattened(self): + # JWE with flattened=False adds recipients in objects and in serialized + e = jwe.JWE(E_A1_ex['plaintext'], flattened=False) + e.add_recipient(E_A1_ex['key'], E_A1_ex['protected']) + self.assertIn('recipients', e.objects) + self.assertIn('recipients', e.serialize()) + + e = jwe.JWE(E_A1_ex['plaintext']) + e.add_recipient(E_A1_ex['key'], E_A1_ex['protected']) + self.assertNotIn('recipients', e.objects) + self.assertNotIn('recipients', e.serialize()) + MMA_vector_key = jwk.JWK(**E_A2_key) MMA_vector_ok_cek = \