224 lines
4.9 KiB
Python
224 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import base64
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
|
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
|
|
|
|
|
def fail(msg: str) -> None:
|
|
sys.stderr.write(msg + "\n")
|
|
sys.exit(1)
|
|
|
|
|
|
def usage() -> None:
|
|
fail(
|
|
"USAGE: rsa-aes-256-gcm.py encipher public-key-file text-to-encipher [ttl-seconds]\n"
|
|
" rsa-aes-256-gcm.py decipher private-key-file [document]\n\n"
|
|
"If document is omitted, stdin is used. ttl-seconds defaults to 30."
|
|
)
|
|
|
|
|
|
def load_key_text(path: str) -> bytes:
|
|
try:
|
|
return open(path, "rb").read()
|
|
except OSError:
|
|
fail(f"unable to read {path}")
|
|
|
|
|
|
def load_private_key(path: str):
|
|
try:
|
|
return load_pem_private_key(load_key_text(path), password=None)
|
|
except ValueError:
|
|
fail(f"unable to load private key from {path}")
|
|
|
|
|
|
def load_public_key(path: str):
|
|
key_text = load_key_text(path)
|
|
|
|
try:
|
|
return load_pem_public_key(key_text)
|
|
except ValueError:
|
|
pass
|
|
|
|
try:
|
|
private_key = load_pem_private_key(key_text, password=None)
|
|
except ValueError:
|
|
fail(f"unable to load public key from {path}")
|
|
|
|
return private_key.public_key()
|
|
|
|
|
|
def rsa_oaep_sha256_encrypt(public_key, plaintext: bytes) -> bytes:
|
|
if not isinstance(public_key, rsa.RSAPublicKey):
|
|
fail("unable to get rsa public key details")
|
|
|
|
try:
|
|
return public_key.encrypt(
|
|
plaintext,
|
|
padding.OAEP(
|
|
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
|
algorithm=hashes.SHA256(),
|
|
label=None,
|
|
),
|
|
)
|
|
except ValueError:
|
|
fail("rsa encryption failed")
|
|
|
|
|
|
def rsa_oaep_sha256_decrypt(private_key, ciphertext: bytes) -> bytes:
|
|
if not isinstance(private_key, rsa.RSAPrivateKey):
|
|
fail("unable to get rsa private key details")
|
|
|
|
try:
|
|
return private_key.decrypt(
|
|
ciphertext,
|
|
padding.OAEP(
|
|
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
|
algorithm=hashes.SHA256(),
|
|
label=None,
|
|
),
|
|
)
|
|
except ValueError:
|
|
fail("rsa decryption failed")
|
|
|
|
|
|
def encode_url_base64(data: bytes) -> str:
|
|
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
|
|
|
|
|
|
def decode_url_base64(text: str) -> bytes:
|
|
padding_len = (-len(text)) % 4
|
|
|
|
try:
|
|
return base64.b64decode(text + ("=" * padding_len), altchars=b"-_", validate=True)
|
|
except Exception:
|
|
raise ValueError()
|
|
|
|
|
|
def encipher(public_key, text: str) -> str:
|
|
aes_key = os.urandom(32)
|
|
nonce = os.urandom(12)
|
|
ciphertext = AESGCM(aes_key).encrypt(nonce, text.encode("utf-8"), None)
|
|
encrypted_key = rsa_oaep_sha256_encrypt(public_key, aes_key)
|
|
|
|
return (
|
|
encode_url_base64(encrypted_key)
|
|
+ "."
|
|
+ encode_url_base64(nonce)
|
|
+ "."
|
|
+ encode_url_base64(ciphertext)
|
|
)
|
|
|
|
|
|
def decipher(private_key, doc: str) -> str:
|
|
parts = doc.split(".", 2)
|
|
|
|
if len(parts) != 3:
|
|
fail("invalid serialized token document")
|
|
|
|
try:
|
|
encrypted_key = decode_url_base64(parts[0])
|
|
except Exception:
|
|
fail("invalid encrypted key")
|
|
|
|
try:
|
|
nonce = decode_url_base64(parts[1])
|
|
except Exception:
|
|
fail("invalid nonce")
|
|
|
|
try:
|
|
ciphertext = decode_url_base64(parts[2])
|
|
except Exception:
|
|
fail("invalid ciphertext")
|
|
|
|
if len(nonce) != 12:
|
|
fail(f"invalid nonce size")
|
|
if len(ciphertext) < 16:
|
|
fail("invalid ciphertext size")
|
|
|
|
aes_key = rsa_oaep_sha256_decrypt(private_key, encrypted_key)
|
|
|
|
try:
|
|
plaintext = AESGCM(aes_key).decrypt(nonce, ciphertext, None)
|
|
except Exception:
|
|
fail("aes-gcm decryption failed")
|
|
|
|
return plaintext.decode("utf-8")
|
|
|
|
|
|
def make_token_payload(token: str, now: int, ttl: int) -> str:
|
|
return (
|
|
base64.urlsafe_b64encode(token.encode("utf-8")).decode("ascii").rstrip("=")
|
|
+ "|"
|
|
+ str(now)
|
|
+ "|"
|
|
+ str(now + ttl)
|
|
)
|
|
|
|
|
|
def parse_token_payload(payload: str) -> str:
|
|
parts = payload.split("|", 2)
|
|
padding_len = 0
|
|
|
|
if len(parts) != 3:
|
|
fail("invalid protected token payload")
|
|
|
|
padding_len = (-len(parts[0])) % 4
|
|
try:
|
|
token = base64.urlsafe_b64decode(parts[0] + ("=" * padding_len)).decode("utf-8")
|
|
except Exception:
|
|
fail("invalid protected token text")
|
|
|
|
return token + "|" + parts[1] + "|" + parts[2]
|
|
|
|
|
|
def main() -> None:
|
|
if len(sys.argv) < 3:
|
|
usage()
|
|
|
|
mode = sys.argv[1]
|
|
key_file = sys.argv[2]
|
|
|
|
if mode == "encipher":
|
|
now = 0
|
|
ttl = 30
|
|
input_text = ""
|
|
|
|
if len(sys.argv) >= 4:
|
|
input_text = sys.argv[3]
|
|
else:
|
|
fail("missing token")
|
|
if len(sys.argv) >= 5:
|
|
try:
|
|
ttl = int(sys.argv[4])
|
|
except ValueError:
|
|
fail("invalid ttl-seconds")
|
|
now = int(time.time())
|
|
public_key = load_public_key(key_file)
|
|
print(encipher(public_key, make_token_payload(input_text, now, ttl)))
|
|
return
|
|
|
|
if mode == "decipher":
|
|
input_text = ""
|
|
if len(sys.argv) >= 4:
|
|
input_text = sys.argv[3]
|
|
else:
|
|
input_text = sys.stdin.read()
|
|
private_key = load_private_key(key_file)
|
|
print(parse_token_payload(decipher(private_key, input_text)))
|
|
return
|
|
|
|
usage()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|