#!/usr/bin/env python3 import base64 import os import sys import time from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import load_pem_public_key def fail(msg: str) -> None: sys.stderr.write(msg + "\n") sys.exit(1) def usage() -> None: fail( "USAGE: rsa-aes-256-gcm.py encipher public-key-file text-to-encipher [ttl-seconds]\n" " rsa-aes-256-gcm.py decipher private-key-file [document]\n\n" "If document is omitted, stdin is used. ttl-seconds defaults to 30." ) def load_key_text(path: str) -> bytes: try: return open(path, "rb").read() except OSError: fail(f"unable to read {path}") def load_private_key(path: str): try: return load_pem_private_key(load_key_text(path), password=None) except ValueError: fail(f"unable to load private key from {path}") def load_public_key(path: str): key_text = load_key_text(path) try: return load_pem_public_key(key_text) except ValueError: pass try: private_key = load_pem_private_key(key_text, password=None) except ValueError: fail(f"unable to load public key from {path}") return private_key.public_key() def rsa_oaep_sha256_encrypt(public_key, plaintext: bytes) -> bytes: if not isinstance(public_key, rsa.RSAPublicKey): fail("unable to get rsa public key details") try: return public_key.encrypt( plaintext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) except ValueError: fail("rsa encryption failed") def rsa_oaep_sha256_decrypt(private_key, ciphertext: bytes) -> bytes: if not isinstance(private_key, rsa.RSAPrivateKey): fail("unable to get rsa private key details") try: return private_key.decrypt( ciphertext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ), ) except ValueError: fail("rsa decryption failed") def encode_url_base64(data: bytes) -> str: return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=") def decode_url_base64(text: str) -> bytes: padding_len = (-len(text)) % 4 try: return base64.b64decode(text + ("=" * padding_len), altchars=b"-_", validate=True) except Exception: raise ValueError() def encipher(public_key, text: str) -> str: aes_key = os.urandom(32) nonce = os.urandom(12) ciphertext = AESGCM(aes_key).encrypt(nonce, text.encode("utf-8"), None) encrypted_key = rsa_oaep_sha256_encrypt(public_key, aes_key) return ( encode_url_base64(encrypted_key) + "." + encode_url_base64(nonce) + "." + encode_url_base64(ciphertext) ) def decipher(private_key, doc: str) -> str: parts = doc.split(".", 2) if len(parts) != 3: fail("invalid serialized token document") try: encrypted_key = decode_url_base64(parts[0]) except Exception: fail("invalid encrypted key") try: nonce = decode_url_base64(parts[1]) except Exception: fail("invalid nonce") try: ciphertext = decode_url_base64(parts[2]) except Exception: fail("invalid ciphertext") if len(nonce) != 12: fail(f"invalid nonce size") if len(ciphertext) < 16: fail("invalid ciphertext size") aes_key = rsa_oaep_sha256_decrypt(private_key, encrypted_key) try: plaintext = AESGCM(aes_key).decrypt(nonce, ciphertext, None) except Exception: fail("aes-gcm decryption failed") return plaintext.decode("utf-8") def make_token_payload(token: str, now: int, ttl: int) -> str: return ( base64.urlsafe_b64encode(token.encode("utf-8")).decode("ascii").rstrip("=") + "|" + str(now) + "|" + str(now + ttl) ) def parse_token_payload(payload: str) -> str: parts = payload.split("|", 2) padding_len = 0 if len(parts) != 3: fail("invalid protected token payload") padding_len = (-len(parts[0])) % 4 try: token = base64.urlsafe_b64decode(parts[0] + ("=" * padding_len)).decode("utf-8") except Exception: fail("invalid protected token text") return token + "|" + parts[1] + "|" + parts[2] def main() -> None: if len(sys.argv) < 3: usage() mode = sys.argv[1] key_file = sys.argv[2] if mode == "encipher": now = 0 ttl = 30 input_text = "" if len(sys.argv) >= 4: input_text = sys.argv[3] else: fail("missing token") if len(sys.argv) >= 5: try: ttl = int(sys.argv[4]) except ValueError: fail("invalid ttl-seconds") now = int(time.time()) public_key = load_public_key(key_file) print(encipher(public_key, make_token_payload(input_text, now, ttl))) return if mode == "decipher": input_text = "" if len(sys.argv) >= 4: input_text = sys.argv[3] else: input_text = sys.stdin.read() private_key = load_private_key(key_file) print(parse_token_payload(decipher(private_key, input_text))) return usage() if __name__ == "__main__": main()