Files
hodu/rsa-aes.go
2026-04-17 00:58:05 +09:00

177 lines
4.6 KiB
Go

package hodu
import "crypto/aes"
import "crypto/cipher"
import "crypto/rand"
import "crypto/rsa"
import "crypto/sha256"
import "encoding/base64"
import "fmt"
import "strconv"
import "strings"
import "time"
// currently, it supports rsa-aes-128-gcm only.
type RSAAES struct {
key *rsa.PrivateKey
}
type RSAAESToken struct {
Token string
IssuedAt time.Time
ExpiresAt time.Time
}
func encode_url_base64(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
func decode_url_base64(text string) ([]byte, error) {
return base64.RawURLEncoding.DecodeString(text)
}
func NewRSAAES(key *rsa.PrivateKey) *RSAAES {
return &RSAAES{key: key}
}
func (e *RSAAES) Encipher(data []byte) (string, error) {
var aes_key []byte
var block cipher.Block
var gcm cipher.AEAD
var nonce []byte
var ciphertext []byte
var encrypted_key []byte
var err error
if e.key == nil {
return "", fmt.Errorf("missing rsa key")
}
aes_key = make([]byte, 32)
_, err = rand.Read(aes_key)
if err != nil { return "", err }
block, err = aes.NewCipher(aes_key)
if err != nil { return "", err }
gcm, err = cipher.NewGCM(block)
if err != nil { return "", err }
nonce = make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil { return "", err }
ciphertext = gcm.Seal(nil, nonce, data, nil)
encrypted_key, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &e.key.PublicKey, aes_key, nil)
if err != nil { return "", err }
return encode_url_base64(encrypted_key) +
"." + encode_url_base64(nonce) +
"." + encode_url_base64(ciphertext), nil
}
func (e *RSAAES) Decipher(doc string) ([]byte, error) {
var parts []string
var encrypted_key []byte
var nonce []byte
var ciphertext []byte
var aes_key []byte
var block cipher.Block
var gcm cipher.AEAD
var plaintext []byte
var err error
if e.key == nil {
return nil, fmt.Errorf("missing rsa key")
}
parts = strings.Split(doc, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid serialized token document")
}
encrypted_key, err = decode_url_base64(parts[0])
if err != nil { return nil, fmt.Errorf("invalid encrypted key - %s", err.Error()) }
nonce, err = decode_url_base64(parts[1])
if err != nil { return nil, fmt.Errorf("invalid nonce - %s", err.Error()) }
ciphertext, err = decode_url_base64(parts[2])
if err != nil { return nil, fmt.Errorf("invalid ciphertext - %s", err.Error()) }
aes_key, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, e.key, encrypted_key, nil)
if err != nil { return nil, fmt.Errorf("failed to decrypt aes key - %s", err.Error()) }
block, err = aes.NewCipher(aes_key)
if err != nil { return nil, fmt.Errorf("invalid aes key - %s", err.Error()) }
gcm, err = cipher.NewGCM(block)
if err != nil { return nil, err }
if len(nonce) != gcm.NonceSize() {
return nil, fmt.Errorf("invalid nonce size %d", len(nonce))
}
plaintext, err = gcm.Open(nil, nonce, ciphertext, nil)
if err != nil { return nil, fmt.Errorf("failed to decrypt ciphertext - %s", err.Error()) }
return plaintext, nil
}
func (e *RSAAES) EncipherToken(token string, issued_at time.Time, expires_at time.Time) (string, error) {
var plain string
plain = base64.RawURLEncoding.EncodeToString([]byte(token)) +
"|" + strconv.FormatInt(issued_at.Unix(), 10) +
"|" + strconv.FormatInt(expires_at.Unix(), 10)
return e.Encipher([]byte(plain))
}
func (e *RSAAES) DecipherToken(doc string, now time.Time) (*RSAAESToken, error) {
var data []byte
var parts []string
var token_data []byte
var issued_at_n int64
var expires_at_n int64
var token RSAAESToken
var err error
data, err = e.Decipher(doc)
if err != nil { return nil, err }
parts = strings.SplitN(string(data), "|", 3)
if len(parts) != 3 {
return nil, fmt.Errorf("invalid protected token payload")
}
token_data, err = base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid protected token text - %s", err.Error())
}
issued_at_n, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid protected token issued-at - %s", err.Error())
}
expires_at_n, err = strconv.ParseInt(parts[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid protected token expiry - %s", err.Error())
}
token.Token = string(token_data)
token.IssuedAt = time.Unix(issued_at_n, 0)
token.ExpiresAt = time.Unix(expires_at_n, 0)
const time_format string = "2006-01-02 15:04:05 -0700"
if now.Before(token.IssuedAt) {
return nil, fmt.Errorf("protected token not valid until %s", token.IssuedAt.Format(time_format))
}
if !now.Before(token.ExpiresAt) {
return nil, fmt.Errorf("protected token expired at %s", token.ExpiresAt.Format(time_format))
}
return &token, nil
}