Source code for pycrypt.asymmetric.rsa.core

from typing import Literal, Self

from egcd import egcd

from pycrypt.asymmetric.rsa.keyformat import (
    pem_to_priv_key,
    pem_to_pub_key,
    priv_key_to_pem,
    pub_key_to_pem,
)
from pycrypt.asymmetric.rsa.utils import (
    generate_large_prime,
    oaep_decode,
    oaep_encode,
    pss_encode,
    pss_verify,
)
from pycrypt.hash import SHA256


[docs] class RSAKey: """Represents an RSA key pair supporting encryption, decryption, signing, and verification. This class provides both low-level RSA primitives and high-level operations using OAEP (Optimal Asymmetric Encryption Padding) for encryption and PSS (Probabilistic Signature Scheme) for signing, along with key generation and PEM serialization. Example: >>> key = RSAKey.generate(2048) >>> ciphertext = key.oaep_encrypt(b"hello") >>> plaintext = key.oaep_decrypt(ciphertext) >>> assert plaintext == b"hello" Attributes: n (int): RSA modulus (product of two large primes, `p` and `q`). e (int): RSA public exponent. d (int | None): RSA private exponent (None for public-only keys). p (int | None): First prime factor of `n` (None for public-only keys). q (int | None): Second prime factor of `n` (None for public-only keys). dP (int | None): d mod (p−1), used for CRT optimization. dQ (int | None): d mod (q−1), used for CRT optimization. qInv (int | None): Multiplicative inverse of q mod p. k (int): Key length in bytes. """ def __init__( self, n: int, e: int, d: int | None = None, p: int | None = None, q: int | None = None, ) -> None: """Initializes an RSAKey instance. Args: n (int): RSA modulus. e (int): RSA public exponent. d (int | None, optional): Private exponent. Defaults to None. p (int | None, optional): Prime factor p. Defaults to None. q (int | None, optional): Prime factor q. Defaults to None. Raises: ValueError: If invalid parameters are provided. """ self.n: int = n self.e: int = e self.d: int | None = d self.p: int | None = p self.q: int | None = q if all(param is not None for param in (d, p, q)): self.qInv, self.dP, self.dQ = self._precompute_crt(self.d, self.p, self.q) else: self.qInv = self.dP = self.dQ = None self.k: int = (self.n.bit_length() + 7) // 8 @property def PUBLIC_KEY(self) -> tuple[int, int]: """tuple[int, int]: Returns the public key (n, e).""" return self.n, self.e @property def PRIVATE_KEY(self) -> tuple[int, int, int | None, ...]: """tuple[int, int, int | None, ...]: Returns the full private key components (n, e, d, p, q, dP, dQ, qInv). """ return self.n, self.e, self.d, self.p, self.q, self.dP, self.dQ, self.qInv
[docs] def primitive_encrypt(self, message: int) -> int: """Performs raw RSA encryption (modular exponentiation). Args: message (int): The message as an integer. Returns: int: The ciphertext integer. """ return pow(message, self.e, self.n)
[docs] def primitive_decrypt(self, ciphertext: int) -> int: """Performs raw RSA decryption, optionally using CRT optimization. Args: ciphertext (int): The ciphertext as an integer. Returns: int: The decrypted message as an integer. Raises: TypeError: If the private exponent is missing. """ if self.p and self.q: m1 = pow(ciphertext % self.p, self.dP, self.p) m2 = pow(ciphertext % self.q, self.dQ, self.q) h = (m1 - m2) * self.qInv % self.p m = m2 + h * self.q return m % self.n else: if self.d is None: raise TypeError( "Private exponent missing: cannot decrypt/sign with public-only key" ) return pow(ciphertext, self.d, self.n)
[docs] def primitive_sign(self, message: int) -> int: """Performs raw RSA signing (equivalent to decryption). Args: message (int): The message as an integer. Returns: int: The signature as an integer. """ return self.primitive_decrypt(message)
[docs] def primitive_verify(self, signature: int) -> int: """Performs raw RSA signature verification (equivalent to encryption). Args: signature (int): The signature as an integer. Returns: int: The verified message integer. """ return self.primitive_encrypt(signature)
[docs] def oaep_encrypt( self, message: bytes, label: bytes = b"", hash: type = SHA256 ) -> bytes: """Encrypts a message using RSA with OAEP padding. Args: message (bytes): The plaintext message. label (bytes, optional): Optional label for OAEP encoding. Defaults to b"". hash (type, optional): Hash function class used in OAEP. Defaults to SHA256. Returns: bytes: The ciphertext. Raises: ValueError: If message length exceeds maximum allowed size. """ em = oaep_encode(message, self.k, label, hash) m = self._os2ip(em) c = self.primitive_encrypt(m) ciphertext = self._i2osp(c, self.k) return ciphertext
[docs] def oaep_decrypt( self, ciphertext: bytes, label: bytes = b"", hash: type = SHA256 ) -> bytes: """Decrypts a message encrypted with RSA-OAEP. Args: ciphertext (bytes): The ciphertext to decrypt. label (bytes, optional): Label used during encryption. Defaults to b"". hash (type, optional): Hash function used during encryption. Defaults to SHA256. Returns: bytes: The decrypted plaintext. Raises: ValueError: If ciphertext length is invalid or decryption fails. TypeError: If private key is missing. """ if len(ciphertext) != self.k: raise ValueError("Decryption Error: ciphertext length mismatch") c = self._os2ip(ciphertext) m = self.primitive_decrypt(c) em = self._i2osp(m, self.k) plaintext = oaep_decode(em, self.k, label, hash) return plaintext
[docs] def pss_sign( self, message: bytes, slen: int | None = None, hash: type = SHA256 ) -> bytes: """Generates a digital signature using RSA-PSS. Args: message (bytes): The message to sign. slen (int | None, optional): Salt length. Defaults to hash length. hash (type, optional): Hash function used for PSS. Defaults to SHA256. Returns: bytes: The signature. """ em = pss_encode(message, self.k - 1, slen, hash) m = self._os2ip(em) s = self.primitive_sign(m) signature = self._i2osp(s, self.k) return signature
[docs] def pss_verify( self, message: bytes, signature: bytes, slen: int | None = None, hash: type = SHA256, ) -> bool: """Verifies a signature created with RSA-PSS. Args: message (bytes): The original message. signature (bytes): The RSA-PSS signature. slen (int | None, optional): Salt length. Defaults to hash length. hash (type, optional): Hash function used for PSS. Defaults to SHA256. Returns: bool: True if the signature is valid, False otherwise. """ if len(signature) != self.k: return False s = self._os2ip(signature) m = self.primitive_verify(s) em = self._i2osp(m, self.k) return pss_verify(message, em[1:], slen, hash)
[docs] def export_key(self, type: Literal["public", "private"] = "public") -> str: """Exports the RSA key in PEM format. Args: type (Literal["public", "private"], optional): Key type to export. Must be either "public" or "private". Defaults to "public". Returns: str: The PEM-encoded RSA key. Raises: ValueError: If type is invalid. TypeError: If private key components are missing for private export. """ if type not in ("public", "private"): raise ValueError("type must be either 'public' or 'private'") if type == "public": pem = pub_key_to_pem(*self.PUBLIC_KEY) elif type == "private": if self.d is None: raise TypeError( "Private exponent missing: cannot export private key with public-only key" ) pem = priv_key_to_pem(*self.PRIVATE_KEY) return pem
[docs] @classmethod def import_key(cls, pem: str) -> Self: """Imports an RSA key from a PEM-formatted string. Args: pem (str): The PEM-encoded RSA key. Returns: RSAKey: An RSAKey instance initialized with the imported key. Raises: ValueError: If the PEM cannot be parsed as a valid RSA key. """ try: key = pem_to_pub_key(pem) return cls(key["n"], key["e"]) except Exception: try: key = pem_to_priv_key(pem) return cls(key["n"], key["e"], key["d"], key["p"], key["q"]) except Exception: raise ValueError( "Could not parse PEM as a valid RSA public or private key" )
[docs] @classmethod def generate(cls, bits: int = 2048, e: int = 65537) -> Self: """Generates a new RSA key pair. Args: bits (int, optional): Key size in bits. Defaults to 2048. e (int, optional): Public exponent. Defaults to 65537. Returns: RSAKey: A new RSA key pair. """ half = bits // 2 while True: p = generate_large_prime(half) q = generate_large_prime(bits - half) if p == q: continue n = p * q phi = (p - 1) * (q - 1) gcd, d, _ = egcd(e, phi) assert abs(gcd) == 1 if d < 0: d += phi return cls(n, e, d, p, q)
@staticmethod def _precompute_crt(d: int, p: int, q: int) -> tuple[int, int, int]: dP = d % (p - 1) dQ = d % (q - 1) _, qInv, _ = egcd(q, p) qInv %= p return qInv, dP, dQ @staticmethod def _os2ip(b: bytes) -> int: return int.from_bytes(b, "big") @staticmethod def _i2osp(x: int, x_len: int) -> bytes: return x.to_bytes(x_len, "big")