Skip to content

Commit

Permalink
feat: Support binary key and IV parameter (#14)
Browse files Browse the repository at this point in the history
* feat: Support binary key and IV parameter

* fix: Use Union typing for backward compatibility

* fix: Typing leftover
  • Loading branch information
Laerte authored Dec 13, 2023
1 parent 7197137 commit f3c6e09
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 20 deletions.
5 changes: 3 additions & 2 deletions aes_pkcs5/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from base64 import b64decode, b64encode
from binascii import unhexlify
from typing import Union

from cryptography.hazmat.primitives.ciphers import Cipher

Expand All @@ -10,8 +11,8 @@
class AESCommon(metaclass=ABCMeta):
"""Common AES interface"""

def __init__(self, key: str, output_format: str) -> None:
self._key = key.encode()
def __init__(self, key: Union[str, bytes], output_format: str) -> None:
self._key = key if isinstance(key, bytes) else key.encode()

if output_format not in OUTPUT_FORMATS:
raise NotImplementedError(
Expand Down
13 changes: 11 additions & 2 deletions aes_pkcs5/algorithms/aes_cbc_pkcs5_padding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import CBC
Expand All @@ -10,9 +12,16 @@ class AESCBCPKCS5Padding(AESCommon):
Implements AES algorithm with CBC mode of operation and padding scheme PKCS5.
"""

def __init__(self, key: str, output_format: str, iv_parameter: str):
def __init__(
self,
key: Union[str, bytes],
output_format: str,
iv_parameter: Union[str, bytes],
):
super(AESCBCPKCS5Padding, self).__init__(key=key, output_format=output_format)
self._iv_parameter = iv_parameter.encode()
self._iv_parameter = (
iv_parameter if isinstance(iv_parameter, bytes) else iv_parameter.encode()
)

def _get_cipher(self):
"""Return AES/CBC/PKCS5Padding Cipher"""
Expand Down
4 changes: 3 additions & 1 deletion aes_pkcs5/algorithms/aes_ecb_pkcs5_padding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import ECB
Expand All @@ -10,7 +12,7 @@ class AESECBPKCS5Padding(AESCommon):
Implements AES algorithm with ECB mode of operation and padding scheme PKCS5.
"""

def __init__(self, key: str, output_format: str):
def __init__(self, key: Union[str, bytes], output_format: str):
super(AESECBPKCS5Padding, self).__init__(key=key, output_format=output_format)

def _get_cipher(self):
Expand Down
17 changes: 10 additions & 7 deletions tests/algorithms/test_aes_cbc_pkcs5_padding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# isort: skip_file
from typing import Dict, List
from typing import Dict, List, Union

from pytest import mark

Expand Down Expand Up @@ -41,10 +41,13 @@
],
)
def test_encrypt_and_decrypt_and_output_formats(
key: str, expected_outputs: Dict[str, List[str]]
key: Union[str, bytes], expected_outputs: Dict[str, List[str]]
):
for output_format in OUTPUT_FORMATS:
cipher = AESCBCPKCS5Padding(key, output_format, IV)
encrypted_output = cipher.encrypt(INPUT_VALUE)
assert encrypted_output == expected_outputs.get(output_format)
assert cipher.decrypt(encrypted_output) == INPUT_VALUE
for key in [key, key.encode()]:
for output_format in OUTPUT_FORMATS:
cipher = AESCBCPKCS5Padding(
key, output_format, IV.encode() if isinstance(key, str) else IV
)
encrypted_output = cipher.encrypt(INPUT_VALUE)
assert encrypted_output == expected_outputs.get(output_format)
assert cipher.decrypt(encrypted_output) == INPUT_VALUE
3 changes: 2 additions & 1 deletion tests/algorithms/test_aes_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ def test_not_implemented_output_format():


def test_dummy_get_cipher():
assert DummyCipher("key", "b64")._get_cipher() == "dummy_cipher"
for key in ["key", "key".encode()]:
assert DummyCipher(key, "b64")._get_cipher() == "dummy_cipher"
15 changes: 8 additions & 7 deletions tests/algorithms/test_aes_ecb_pkcs5_padding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# isort: skip_file
from typing import Dict, List
from typing import Dict, List, Union

from pytest import mark

Expand Down Expand Up @@ -40,10 +40,11 @@
],
)
def test_encrypt_and_decrypt_and_output_formats(
key: str, expected_outputs: Dict[str, List[str]]
key: Union[str, bytes], expected_outputs: Dict[str, List[str]]
):
for output_format in OUTPUT_FORMATS:
cipher = AESECBPKCS5Padding(key, output_format)
encrypted_output = cipher.encrypt(INPUT_VALUE)
assert encrypted_output == expected_outputs.get(output_format)
assert cipher.decrypt(encrypted_output) == INPUT_VALUE
for key in [key, key.encode()]:
for output_format in OUTPUT_FORMATS:
cipher = AESECBPKCS5Padding(key, output_format)
encrypted_output = cipher.encrypt(INPUT_VALUE)
assert encrypted_output == expected_outputs.get(output_format)
assert cipher.decrypt(encrypted_output) == INPUT_VALUE

0 comments on commit f3c6e09

Please sign in to comment.