177 lines
4.6 KiB
Go
177 lines
4.6 KiB
Go
package hodu
|
|
|
|
import "crypto/aes"
|
|
import "crypto/cipher"
|
|
import "crypto/rand"
|
|
import "crypto/rsa"
|
|
import "crypto/sha256"
|
|
import "encoding/base64"
|
|
import "fmt"
|
|
import "strconv"
|
|
import "strings"
|
|
import "time"
|
|
|
|
// currently, it supports rsa-aes-128-gcm only.
|
|
type RSAAES struct {
|
|
key *rsa.PrivateKey
|
|
}
|
|
|
|
type RSAAESToken struct {
|
|
Token string
|
|
IssuedAt time.Time
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
func encode_url_base64(data []byte) string {
|
|
return base64.RawURLEncoding.EncodeToString(data)
|
|
}
|
|
|
|
func decode_url_base64(text string) ([]byte, error) {
|
|
return base64.RawURLEncoding.DecodeString(text)
|
|
}
|
|
|
|
func NewRSAAES(key *rsa.PrivateKey) *RSAAES {
|
|
return &RSAAES{key: key}
|
|
}
|
|
|
|
func (e *RSAAES) Encipher(data []byte) (string, error) {
|
|
var aes_key []byte
|
|
var block cipher.Block
|
|
var gcm cipher.AEAD
|
|
var nonce []byte
|
|
var ciphertext []byte
|
|
var encrypted_key []byte
|
|
var err error
|
|
|
|
if e.key == nil {
|
|
return "", fmt.Errorf("missing rsa key")
|
|
}
|
|
|
|
aes_key = make([]byte, 32)
|
|
_, err = rand.Read(aes_key)
|
|
if err != nil { return "", err }
|
|
|
|
block, err = aes.NewCipher(aes_key)
|
|
if err != nil { return "", err }
|
|
|
|
gcm, err = cipher.NewGCM(block)
|
|
if err != nil { return "", err }
|
|
|
|
nonce = make([]byte, gcm.NonceSize())
|
|
_, err = rand.Read(nonce)
|
|
if err != nil { return "", err }
|
|
|
|
ciphertext = gcm.Seal(nil, nonce, data, nil)
|
|
|
|
encrypted_key, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &e.key.PublicKey, aes_key, nil)
|
|
if err != nil { return "", err }
|
|
|
|
return encode_url_base64(encrypted_key) +
|
|
"." + encode_url_base64(nonce) +
|
|
"." + encode_url_base64(ciphertext), nil
|
|
}
|
|
|
|
func (e *RSAAES) Decipher(doc string) ([]byte, error) {
|
|
var parts []string
|
|
var encrypted_key []byte
|
|
var nonce []byte
|
|
var ciphertext []byte
|
|
var aes_key []byte
|
|
var block cipher.Block
|
|
var gcm cipher.AEAD
|
|
var plaintext []byte
|
|
var err error
|
|
|
|
if e.key == nil {
|
|
return nil, fmt.Errorf("missing rsa key")
|
|
}
|
|
|
|
parts = strings.Split(doc, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid serialized token document")
|
|
}
|
|
|
|
encrypted_key, err = decode_url_base64(parts[0])
|
|
if err != nil { return nil, fmt.Errorf("invalid encrypted key - %s", err.Error()) }
|
|
|
|
nonce, err = decode_url_base64(parts[1])
|
|
if err != nil { return nil, fmt.Errorf("invalid nonce - %s", err.Error()) }
|
|
|
|
ciphertext, err = decode_url_base64(parts[2])
|
|
if err != nil { return nil, fmt.Errorf("invalid ciphertext - %s", err.Error()) }
|
|
|
|
aes_key, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, e.key, encrypted_key, nil)
|
|
if err != nil { return nil, fmt.Errorf("failed to decrypt aes key - %s", err.Error()) }
|
|
|
|
block, err = aes.NewCipher(aes_key)
|
|
if err != nil { return nil, fmt.Errorf("invalid aes key - %s", err.Error()) }
|
|
|
|
gcm, err = cipher.NewGCM(block)
|
|
if err != nil { return nil, err }
|
|
if len(nonce) != gcm.NonceSize() {
|
|
return nil, fmt.Errorf("invalid nonce size %d", len(nonce))
|
|
}
|
|
|
|
plaintext, err = gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil { return nil, fmt.Errorf("failed to decrypt ciphertext - %s", err.Error()) }
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
func (e *RSAAES) EncipherToken(token string, issued_at time.Time, expires_at time.Time) (string, error) {
|
|
var plain string
|
|
|
|
plain = base64.RawURLEncoding.EncodeToString([]byte(token)) +
|
|
"|" + strconv.FormatInt(issued_at.Unix(), 10) +
|
|
"|" + strconv.FormatInt(expires_at.Unix(), 10)
|
|
|
|
return e.Encipher([]byte(plain))
|
|
}
|
|
|
|
func (e *RSAAES) DecipherToken(doc string, now time.Time) (*RSAAESToken, error) {
|
|
var data []byte
|
|
var parts []string
|
|
var token_data []byte
|
|
var issued_at_n int64
|
|
var expires_at_n int64
|
|
var token RSAAESToken
|
|
var err error
|
|
|
|
data, err = e.Decipher(doc)
|
|
if err != nil { return nil, err }
|
|
|
|
parts = strings.SplitN(string(data), "|", 3)
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid protected token payload")
|
|
}
|
|
|
|
token_data, err = base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid protected token text - %s", err.Error())
|
|
}
|
|
|
|
issued_at_n, err = strconv.ParseInt(parts[1], 10, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid protected token issued-at - %s", err.Error())
|
|
}
|
|
|
|
expires_at_n, err = strconv.ParseInt(parts[2], 10, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid protected token expiry - %s", err.Error())
|
|
}
|
|
|
|
token.Token = string(token_data)
|
|
token.IssuedAt = time.Unix(issued_at_n, 0)
|
|
token.ExpiresAt = time.Unix(expires_at_n, 0)
|
|
|
|
const time_format string = "2006-01-02 15:04:05 -0700"
|
|
if now.Before(token.IssuedAt) {
|
|
return nil, fmt.Errorf("protected token not valid until %s", token.IssuedAt.Format(time_format))
|
|
}
|
|
if !now.Before(token.ExpiresAt) {
|
|
return nil, fmt.Errorf("protected token expired at %s", token.ExpiresAt.Format(time_format))
|
|
}
|
|
|
|
return &token, nil
|
|
}
|