Files
hodu/rsa-aes-256-gcm.py

224 lines
4.9 KiB
Python

#!/usr/bin/env python3
import base64
import os
import sys
import time
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.serialization import load_pem_public_key
def fail(msg: str) -> None:
sys.stderr.write(msg + "\n")
sys.exit(1)
def usage() -> None:
fail(
"USAGE: rsa-aes-256-gcm.py encipher public-key-file text-to-encipher [ttl-seconds]\n"
" rsa-aes-256-gcm.py decipher private-key-file [document]\n\n"
"If document is omitted, stdin is used. ttl-seconds defaults to 30."
)
def load_key_text(path: str) -> bytes:
try:
return open(path, "rb").read()
except OSError:
fail(f"unable to read {path}")
def load_private_key(path: str):
try:
return load_pem_private_key(load_key_text(path), password=None)
except ValueError:
fail(f"unable to load private key from {path}")
def load_public_key(path: str):
key_text = load_key_text(path)
try:
return load_pem_public_key(key_text)
except ValueError:
pass
try:
private_key = load_pem_private_key(key_text, password=None)
except ValueError:
fail(f"unable to load public key from {path}")
return private_key.public_key()
def rsa_oaep_sha256_encrypt(public_key, plaintext: bytes) -> bytes:
if not isinstance(public_key, rsa.RSAPublicKey):
fail("unable to get rsa public key details")
try:
return public_key.encrypt(
plaintext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
except ValueError:
fail("rsa encryption failed")
def rsa_oaep_sha256_decrypt(private_key, ciphertext: bytes) -> bytes:
if not isinstance(private_key, rsa.RSAPrivateKey):
fail("unable to get rsa private key details")
try:
return private_key.decrypt(
ciphertext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
except ValueError:
fail("rsa decryption failed")
def encode_url_base64(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
def decode_url_base64(text: str) -> bytes:
padding_len = (-len(text)) % 4
try:
return base64.b64decode(text + ("=" * padding_len), altchars=b"-_", validate=True)
except Exception:
raise ValueError()
def encipher(public_key, text: str) -> str:
aes_key = os.urandom(32)
nonce = os.urandom(12)
ciphertext = AESGCM(aes_key).encrypt(nonce, text.encode("utf-8"), None)
encrypted_key = rsa_oaep_sha256_encrypt(public_key, aes_key)
return (
encode_url_base64(encrypted_key)
+ "."
+ encode_url_base64(nonce)
+ "."
+ encode_url_base64(ciphertext)
)
def decipher(private_key, doc: str) -> str:
parts = doc.split(".", 2)
if len(parts) != 3:
fail("invalid serialized token document")
try:
encrypted_key = decode_url_base64(parts[0])
except Exception:
fail("invalid encrypted key")
try:
nonce = decode_url_base64(parts[1])
except Exception:
fail("invalid nonce")
try:
ciphertext = decode_url_base64(parts[2])
except Exception:
fail("invalid ciphertext")
if len(nonce) != 12:
fail(f"invalid nonce size")
if len(ciphertext) < 16:
fail("invalid ciphertext size")
aes_key = rsa_oaep_sha256_decrypt(private_key, encrypted_key)
try:
plaintext = AESGCM(aes_key).decrypt(nonce, ciphertext, None)
except Exception:
fail("aes-gcm decryption failed")
return plaintext.decode("utf-8")
def make_token_payload(token: str, now: int, ttl: int) -> str:
return (
base64.urlsafe_b64encode(token.encode("utf-8")).decode("ascii").rstrip("=")
+ "|"
+ str(now)
+ "|"
+ str(now + ttl)
)
def parse_token_payload(payload: str) -> str:
parts = payload.split("|", 2)
padding_len = 0
if len(parts) != 3:
fail("invalid protected token payload")
padding_len = (-len(parts[0])) % 4
try:
token = base64.urlsafe_b64decode(parts[0] + ("=" * padding_len)).decode("utf-8")
except Exception:
fail("invalid protected token text")
return token + "|" + parts[1] + "|" + parts[2]
def main() -> None:
if len(sys.argv) < 3:
usage()
mode = sys.argv[1]
key_file = sys.argv[2]
if mode == "encipher":
now = 0
ttl = 30
input_text = ""
if len(sys.argv) >= 4:
input_text = sys.argv[3]
else:
fail("missing token")
if len(sys.argv) >= 5:
try:
ttl = int(sys.argv[4])
except ValueError:
fail("invalid ttl-seconds")
now = int(time.time())
public_key = load_public_key(key_file)
print(encipher(public_key, make_token_payload(input_text, now, ttl)))
return
if mode == "decipher":
input_text = ""
if len(sys.argv) >= 4:
input_text = sys.argv[3]
else:
input_text = sys.stdin.read()
private_key = load_private_key(key_file)
print(parse_token_payload(decipher(private_key, input_text)))
return
usage()
if __name__ == "__main__":
main()