from abc import ABC, abstractmethod
from secrets import compare_digest
from typing import Final, Literal, override
from pycrypt.symmetric.aes.core import AESCore
from pycrypt.symmetric.aes.utils import (
inc_counter,
pad16,
validate_len,
validate_len_multiple,
)
from pycrypt.utils import PKCS7, xor_bytes
class _AESMode(ABC):
"""Abstract base class for AES block cipher modes.
This class provides the foundation for implementing different AES modes
(ECB, CBC, CTR, GCM). It handles low-level AES operations and provides
utility functions for block handling, counter management, and XOR operations.
Attributes:
_aes (AESCore): The underlying AES block cipher instance.
Example:
>>> from pycrypt.symmetric.aes.modes import AES_ECB
>>> key = b"0123456789abcdef"
>>> aes = AES_ECB(key)
>>> ciphertext = aes.encrypt(b"hello world")
>>> aes.decrypt(ciphertext)
b'hello world'
"""
def __init__(self, key: bytes):
"""Initialize an AES mode with a secret key.
Args:
key (bytes): The AES key (16, 24, or 32 bytes for AES-128/192/256).
"""
self._aes: AESCore = AESCore(key)
# --- Encryption / Decryption ---
@abstractmethod
def encrypt(self, *args, **kwargs) -> bytes: ... # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
@abstractmethod
def decrypt(self, *args, **kwargs) -> bytes: ... # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
# --- PRIVATE: Counter Logic for CTR/GCM ---
def _ctr(self, data: bytes, initial_counter: bytes) -> bytes:
validate_len("initial counter", initial_counter, 16)
cipher = self._aes.cipher
encrypted = bytearray()
for idx, block in enumerate(self._chunk_blocks(data, fixed_length=False)):
keystream = cipher(self._add_to_counter(initial_counter, idx))
encrypted.extend(xor_bytes(block, keystream[: len(block)]))
return bytes(encrypted)
# --- PRIVATE: Helper Functions ---
@staticmethod
def _chunk_blocks(data: bytes, block_size: int = 16, fixed_length: bool = True):
if fixed_length:
validate_len_multiple("Data length", data, block_size)
for i in range(0, len(data), block_size):
yield data[i : i + block_size]
@staticmethod
def _add_to_counter(counter: bytes, num: int) -> bytes:
counter_int = int.from_bytes(counter, "big") + num
return counter_int.to_bytes(len(counter), "big")
@override
def __repr__(self):
attrs: list[str] = []
for name in ("iv", "nonce", "aad"):
if hasattr(self, name):
attrs.append(f"{name}={getattr(self, name)!r}")
return f"{self.__class__.__name__}(key_len={len(self._aes._KEY)}, {', '.join(attrs)})"
[docs]
class AES_ECB(_AESMode):
"""AES in ECB (Electronic Codebook) mode.
ECB mode encrypts each 16-byte block independently. For messages that are
not multiples of 16 bytes, padding (e.g., PKCS7) is required. This mode
does not provide integrity/authentication, and identical plaintext blocks
produce identical ciphertext blocks.
Example:
>>> key = b"0123456789abcdef"
>>> aes = AES_ECB(key)
>>> plaintext = b"Secret Message"
>>> ct = aes.encrypt(plaintext)
>>> aes.decrypt(ct)
b'Secret Message'
"""
def __init__(self, key: bytes):
super().__init__(key)
# --- Encryption / Decryption ---
[docs]
@override
def encrypt(self, plaintext: bytes, *, pad: bool = True) -> bytes:
"""Encrypt plaintext using AES-ECB.
Args:
plaintext (bytes): The data to encrypt.
pad (bool, optional): Whether to apply PKCS7 padding (default True).
Returns:
bytes: The ciphertext.
"""
if pad:
plaintext = PKCS7.pad(plaintext)
else:
validate_len_multiple("Plaintext length", plaintext)
cipher = self._aes.cipher
return b"".join(cipher(block) for block in self._chunk_blocks(plaintext))
[docs]
@override
def decrypt(self, ciphertext: bytes, *, unpad: bool = True) -> bytes:
"""Decrypt ciphertext using AES-ECB.
Args:
ciphertext (bytes): The ciphertext to decrypt.
unpad (bool, optional): Whether to remove PKCS7 padding (default True).
Returns:
bytes: The decrypted plaintext.
"""
validate_len_multiple("Ciphertext length", ciphertext)
inv = self._aes.inv_cipher
out = b"".join(inv(block) for block in self._chunk_blocks(ciphertext))
return PKCS7.unpad(out) if unpad else out
[docs]
class AES_CBC(_AESMode):
"""AES in CBC (Cipher Block Chaining) mode.
CBC mode XORs each plaintext block with the previous ciphertext block before
encryption. Requires a 16-byte IV (initialization vector). Padding is required
for non-multiple-of-block-length messages.
Example:
>>> key = b"0123456789abcdef"
>>> iv = b"abcdef0123456789"
>>> aes = AES_CBC(key)
>>> ct = aes.encrypt(b"Secret Message", iv=iv)
>>> aes.decrypt(ct, iv=iv)
b'Secret Message'
"""
def __init__(self, key: bytes):
super().__init__(key)
# --- Encryption / Decryption ---
[docs]
@override
def encrypt(self, plaintext: bytes, *, iv: bytes, pad: bool = True) -> bytes:
"""Encrypt plaintext using AES-CBC.
Args:
plaintext (bytes): The data to encrypt.
iv (bytes): 16-byte initialization vector.
pad (bool, optional): Whether to apply PKCS7 padding (default True).
Returns:
bytes: The ciphertext.
"""
if pad:
plaintext = PKCS7.pad(plaintext)
else:
validate_len_multiple("Plaintext length", plaintext)
validate_len("iv length", iv, 16)
cipher = self._aes.cipher
encrypted_blocks = bytearray()
prev = iv
for block in self._chunk_blocks(plaintext):
x = xor_bytes(block, prev)
ct = cipher(x)
encrypted_blocks.extend(ct)
prev = ct
return bytes(encrypted_blocks)
[docs]
@override
def decrypt(self, ciphertext: bytes, *, iv: bytes, unpad: bool = True) -> bytes:
"""Decrypt ciphertext using AES-CBC.
Args:
ciphertext (bytes): The ciphertext to decrypt.
iv (bytes): 16-byte initialization vector used during encryption.
unpad (bool, optional): Whether to remove PKCS7 padding (default True).
Returns:
bytes: The decrypted plaintext.
"""
validate_len_multiple("Ciphertext length", ciphertext)
validate_len("iv length", iv, 16)
inv = self._aes.inv_cipher
decrypted_blocks = bytearray()
prev = iv
for block in self._chunk_blocks(ciphertext):
pt = xor_bytes(inv(block), prev)
decrypted_blocks.extend(pt)
prev = block
plaintext = bytes(decrypted_blocks)
if unpad:
return PKCS7.unpad(plaintext)
return plaintext
[docs]
class AES_CTR(_AESMode):
"""AES in CTR (Counter) mode.
CTR mode turns AES into a stream cipher. It combines a nonce with a counter
to produce a keystream. Encryption and decryption are symmetric operations.
Does not require padding. Requires an 8-byte nonce.
Example:
>>> key = b"0123456789abcdef"
>>> nonce = b"12345678"
>>> aes = AES_CTR(key)
>>> ct = aes.encrypt(b"Secret Message", nonce=nonce)
>>> aes.decrypt(ct, nonce=nonce)
b'Secret Message'
"""
def __init__(self, key: bytes):
super().__init__(key)
# --- Encryption / Decryption ---
[docs]
@override
def encrypt(self, plaintext: bytes, *, nonce: bytes) -> bytes:
"""Encrypt plaintext using AES-CTR.
Args:
plaintext (bytes): Data to encrypt.
nonce (bytes): 8-byte nonce for the counter block.
Returns:
bytes: Ciphertext.
"""
return self._operate(plaintext, nonce)
[docs]
@override
def decrypt(self, ciphertext: bytes, *, nonce: bytes) -> bytes:
"""Decrypt ciphertext using AES-CTR.
Args:
ciphertext (bytes): Data to decrypt.
nonce (bytes): 8-byte nonce used during encryption.
Returns:
bytes: Decrypted plaintext.
"""
return self._operate(ciphertext, nonce)
# --- PRIVATE: Helper Function ---
def _operate(self, data: bytes, nonce: bytes) -> bytes:
validate_len("nonce", nonce, 8)
counter = nonce + (b"\x00" * 8)
return self._ctr(data, counter)
[docs]
class AES_GCM(_AESMode):
"""AES in GCM (Galois/Counter Mode) with authentication.
Provides both confidentiality and integrity. Requires a 12-byte nonce.
Optional additional authenticated data (AAD) can be provided. Raises
`AES_GCM.GCMAuthenticationError` if authentication fails.
Example:
>>> key = b"0123456789abcdef"
>>> nonce = b"123456789012"
>>> aes = AES_GCM(key)
>>> ct, tag = aes.encrypt(b"Secret Message", nonce=nonce)
>>> aes.decrypt(ct, nonce=nonce, tag=tag)
b'Secret Message'
"""
[docs]
class GCMAuthenticationError(Exception):
"""Raised when GCM authentication fails."""
pass
_R: Final[int] = 0xE1000000000000000000000000000000
_MASK128: Final[int] = (1 << 128) - 1
_TAG_LENGTH: Final[int] = 16
def __init__(self, key: bytes):
super().__init__(key)
self._H: Final[int] = int.from_bytes(self._aes.cipher(b"\x00" * 16), "big")
# --- Encryption / Decryption ---
[docs]
@override
def encrypt( # pyright: ignore[reportIncompatibleMethodOverride]
self, plaintext: bytes, *, nonce: bytes, aad: bytes = b""
) -> tuple[bytes, bytes]:
"""Encrypt and authenticate data using AES-GCM.
Args:
plaintext (bytes): Data to encrypt.
nonce (bytes): 12-byte nonce.
aad (bytes, optional): Additional authenticated data.
Returns:
tuple[bytes, bytes]: Ciphertext and 16-byte authentication tag.
"""
return self._operate(plaintext, nonce, aad)
[docs]
@override
def decrypt(
self, ciphertext: bytes, *, nonce: bytes, tag: bytes, aad: bytes = b""
) -> bytes:
"""Decrypt and verify data using AES-GCM.
Args:
ciphertext (bytes): Ciphertext to decrypt.
nonce (bytes): 12-byte nonce used during encryption.
tag (bytes): 16-byte authentication tag from encryption.
aad (bytes, optional): Additional authenticated data.
Returns:
bytes: Decrypted plaintext.
Raises:
AES_GCM.GCMAuthenticationError: If authentication tag verification fails.
"""
validate_len("tag", tag, self._TAG_LENGTH)
plaintext, computed_tag = self._operate(ciphertext, nonce, aad, mode="decrypt")
if not compare_digest(tag, computed_tag):
raise AES_GCM.GCMAuthenticationError("GCM Authentication tag mismatch")
return plaintext
# --- PRIVATE: Helper Functions ---
def _operate(
self,
data: bytes,
nonce: bytes,
aad: bytes = b"",
mode: Literal["encrypt", "decrypt"] = "encrypt",
) -> tuple[bytes, bytes]:
validate_len("nonce", nonce, 12)
precounter = nonce + b"\x00\x00\x00\x01"
operated = self._ctr(data, inc_counter(precounter, 32))
if mode == "encrypt":
cipher = operated
else:
cipher = data
hashed_data = self._ghash(
pad16(aad)
+ pad16(cipher)
+ len(aad).to_bytes(8, "big")
+ len(cipher).to_bytes(8, "big")
)
tag = self._ctr(hashed_data, precounter)[: self._TAG_LENGTH]
return operated, tag
def _ghash(self, data: bytes) -> bytes:
validate_len_multiple("Data length", data)
y = 0
for block in self._chunk_blocks(data):
b = int.from_bytes(block, "big")
y = self._gf_mul(y ^ b, self._H)
return y.to_bytes(16, "big")
@staticmethod
def _gf_mul(x: int, y: int) -> int:
if x >> 128 or y >> 128:
raise ValueError("Inputs must be 128-bit integers (0 <= value < 2**128)")
z = 0
v = x
for i in range(128):
if (y >> (127 - i)) & 1:
z ^= v
lsb = v & 1
v >>= 1
if lsb:
v ^= AES_GCM._R
return z & AES_GCM._MASK128