diff --git a/Makefile b/Makefile index 7ebf45a..900ea62 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,7 @@ DATA = \ xterm.html CMD_DATA=\ + cmd/rsa.key \ cmd/tls.crt \ cmd/tls.key @@ -81,4 +82,7 @@ cmd/tls.crt: cmd/tls.key: openssl req -x509 -newkey rsa:4096 -keyout cmd/tls.key -out cmd/tls.crt -sha256 -days 36500 -nodes -subj "/CN=$(NAME)" --addext "subjectAltName=DNS:$(NAME),IP:10.0.0.1,IP:::1" +cmd/rsa.key: + openssl genrsa -traditional -out cmd/rsa.key 2048 + .PHONY: clean test diff --git a/cmd/config.go b/cmd/config.go index 1bc4573..16df674 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -1,8 +1,10 @@ package main +import "crypto/rsa" import "crypto/tls" import "crypto/x509" import "encoding/base64" +import "encoding/pem" import "errors" import "fmt" import "hodu" @@ -49,6 +51,8 @@ type AuthConfig struct { Realm string `yaml:"realm"` Creds []string `yaml:"credentials"` TokenTtl string `yaml:"token-ttl"` + TokenRsaKeyText string `yaml:"token-rsa-key-text"` + TokenRsaKeyFile string `yaml:"token-rsa-key-file"` } type CTLServiceConfig struct { @@ -346,11 +350,14 @@ func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) { } // -------------------------------------------------------------------- -func make_server_basic_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) { +func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) { var config hodu.ServerAuthConfig var cred string var b []byte var x []string + var rsa_key_text []byte + var rk *rsa.PrivateKey + var pb *pem.Block var err error config.Enabled = cfg.Enabled @@ -361,6 +368,7 @@ func make_server_basic_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, err return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err) } + // convert user credentials for _, cred = range cfg.Creds { b, err = base64.StdEncoding.DecodeString(cred) if err == nil { cred = string(b) } @@ -368,11 +376,33 @@ func make_server_basic_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, err // each entry must be of the form username:password x = strings.Split(cred, ":") if len(x) != 2 { - return nil, fmt.Errorf("invalid basic auth credential - %s", cred) + return nil, fmt.Errorf("invalid auth credential - %s", cred) } config.Creds[x[0]] = x[1] } + + // load rsa key + if cfg.TokenRsaKeyText == "" && cfg.TokenRsaKeyFile != "" { + rsa_key_text, err = os.ReadFile(cfg.TokenRsaKeyFile) + if err != nil { + return nil, fmt.Errorf("unable to read %s - %s", cfg.TokenRsaKeyFile, err.Error()) + } + } + if len(rsa_key_text) == 0 { rsa_key_text = []byte(cfg.TokenRsaKeyText) } + if len(rsa_key_text) == 0 { rsa_key_text = hodu_rsa_key_text } + + pb, b = pem.Decode(rsa_key_text) + if pb == nil || len(b) > 0 { + return nil, fmt.Errorf("invalid token rsa key text %s - no block or too many blocks", string(rsa_key_text)) + } + + rk, err = x509.ParsePKCS1PrivateKey(pb.Bytes) + if err != nil { + return nil, fmt.Errorf("invalid token rsa key text %s - %s", string(rsa_key_text), err.Error()) + } + + config.TokenRsaKey = rk return &config, nil } diff --git a/cmd/main.go b/cmd/main.go index 34634f3..48752e1 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -23,6 +23,8 @@ var HODU_VERSION string = "0.0.0" var hodu_tls_cert_text []byte //go:embed tls.key var hodu_tls_key_text []byte +//go:embed rsa.key +var hodu_rsa_key_text []byte // -------------------------------------------------------------------- type signal_handler struct { @@ -125,7 +127,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs } if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs } - config.CtlAuth, err = make_server_basic_auth_config(&cfg.CTL.Service.Auth) + config.CtlAuth, err = make_server_auth_config(&cfg.CTL.Service.Auth) if err != nil { return err } config.CtlPrefix = cfg.CTL.Service.Prefix diff --git a/cmd/rsa.key b/cmd/rsa.key new file mode 100644 index 0000000..74e879e --- /dev/null +++ b/cmd/rsa.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAsTC9roInjDzu12tjv1CsOM4jvuB6/5vv+cmOMF5GLMVTnJCW +6U9onsOi6iN2rzlf5glkjdtijXCPL6QEX3YLYPD4NFCiOGIPhCHjWC4nBjI7LEEm +0SqrArMhPiyYLmnkA961a7mDw9dcr5JQBDq2ZyTe917N229Jr4PCZbHLboOxNlp3 +QLSyxE5tfKZea53qm8SUF8maBvnOH8igvuYOek3iRMg3T+GoxCqy2gE1qznvwsaK +PdmTTzbIbc7XNU7t5yT6fZTvjUqs4WBuHqud4unE//KAT5vfxDdQFGcb45oMwxcK +bf03w4ZsBNvAcgCkWW+ophEOZRPkKrluHjVdNwIDAQABAoIBAARZ/5aNEL6TcoQs +2X7F0uz0NxGFfs/POxYF2q2aaxvHXtXOAT7KmfWoNVSNuWj1PkMugN8w/5scpA+V +9huIESB42oeiYVGEKwBiOqycOY4f5q8gDH1/kEKZNpxJyRT+ucBUlF0IadGB9P9E +1x07eeZPlAA8Pk8AzSz3zerkcmwM2lYYG851QyuiiTReSec3LLDcJvG5xAXZrIY0 +Zwm7qv8uvjJGqMVYlywMnRngeNywP9ZaOJ38vdmWMu4bBF+QwydOAB9A7O9zluDZ +wK7OBedAZkRT15luZ1lkuTrKVZEaugD8dbt6BBLuhbPRRGuFb4WoNaVI3CRu9RSX +72gYkRECgYEA9x0IAFGc8DmCHOP/S+uy0VjvLGYh4QN3/0UOLRvoREzF0FtAxqci +bPASGmSCJEDL93JNjlxhITDUUawyOGRgAAXyAkE9MWmv18+pfTNTDeoaeXsBqcLz +f9LCNc3mCx93tvCK7gfIYs8Ef0QKfdrsQwMGlutgXmjE+pexNXPFWEcCgYEAt4/8 +gsXi7tsCQp1YiP7VFZjoXSLejq+7pQrGV58PzlZKiOH/M5S6YS8wgm5oIEMLq2UP +nUn+FBCJ/I2b6HIdVq/Jr77XHcBFSZZEQbXe2gxTTucj6BTja1kSEilOquaaPvbR +WEs0+50rsgH0nLqSbMZZRkxOAUu9nObFvHA6O5ECgYEAzzd8+id13suam/dkoZlo +PbzB8w1B45oxCdIybQk13/AxAONEklCcwZUe2RrnNtdPMpSbDIHSwS5dHI+1HSyu +g9Z4dgOW+NSTK/lrOx3Ky6Q/xxaq8lwULF/jk5KxESq2DKXxGmFUW+cU8lNwKNFn +xVnIMM335bMdWrXRV+1Y0wkCgYBbXYOl47Esij35wi+LIKwW7+DYWr7D7pxLba2D +d1x6q2C1+Sb5GZIbRU2z3hhd1oE8cjTvaSDaA9Fqr2FmtUX9G8obe7W+zTCvi+e1 +fTzK80+T+mBY5+y6Rb9E4uKRFe64YEma1PQuOPDCzU5fpE21bpSI9PnukzBxpDvP +q1yQwQKBgQCXiW8UghuwIp3INFzBTedBHNKBwRd82ZIhBWLcgWxC/EyWsRRFpJj4 +HlVRYOvi2Q3DV6+Yn8zg3OeBhudGfCRCTkENbzAalcWqr9qb3Q4y26tZZQ9yNKk1 +jJ2OfVw4K/6L49iVNF/2kLdbRebQXwngQUmiZSai5MlrHOFYkkiwaA== +-----END RSA PRIVATE KEY----- diff --git a/jwt.go b/jwt.go index 6433c51..afb6d97 100644 --- a/jwt.go +++ b/jwt.go @@ -1,7 +1,7 @@ package hodu import "crypto" -import "crypto/hmac" +//import "crypto/hmac" import "crypto/rand" import "crypto/rsa" import "encoding/base64" @@ -10,6 +10,7 @@ import "fmt" import "hash" import "strings" +/* func Sign(data []byte, privkey *rsa.PrivateKey) ([]byte, error) { var h hash.Hash @@ -47,8 +48,11 @@ func VerifyHS512(data []byte, key string, sig []byte) error { if !hmac.Equal(h.Sum(nil), sig) { return fmt.Errorf("invalid signature") } return nil } +*/ -type JWT struct { +type JWT[T any] struct { + key *rsa.PrivateKey + claims *T } type JWTHeader struct { @@ -58,39 +62,47 @@ type JWTHeader struct { type JWTClaimMap map[string]interface{} -func (j *JWT) Sign(claims interface{}) (string, error) { +func NewJWT[T any](key *rsa.PrivateKey, claims *T) *JWT[T] { + return &JWT[T]{key: key, claims: claims} +} + +func (j *JWT[T]) SignRS512() (string, error) { var h JWTHeader var hb []byte var cb []byte var ss string var sb []byte + var hs hash.Hash var err error - h.Algo = "HS512" + h.Algo = "RS512" h.Type = "JWT" hb, err = json.Marshal(h) if err != nil { return "", err } - cb, err = json.Marshal(claims) + cb, err = json.Marshal(j.claims) if err != nil { return "", err } ss = base64.RawURLEncoding.EncodeToString(hb) + "." + base64.RawURLEncoding.EncodeToString(cb) - sb, err = SignHS512([]byte(ss), "hello") + + hs = crypto.SHA512.New() + hs.Write([]byte(ss)) + + sb, err = rsa.SignPKCS1v15(rand.Reader, j.key, crypto.SHA512, hs.Sum(nil)) if err != nil { return "", err } //fmt.Printf ("%+v %+v %s\n", string(hb), string(cb), (ss + "." + base64.RawURLEncoding.EncodeToString(sb))) return ss + "." + base64.RawURLEncoding.EncodeToString(sb), nil } -func (j *JWT) Verify(tok string) error { +func (j *JWT[T]) VerifyRS512(tok string) error { var segs []string var hb []byte var cb []byte - var sb []byte + var ss []byte var jh JWTHeader - var jcm JWTClaimMap - var x string + var hs hash.Hash var err error segs = strings.Split(tok, ".") @@ -100,25 +112,23 @@ func (j *JWT) Verify(tok string) error { if err != nil { return fmt.Errorf("invalid header - %s", err.Error()) } err = json.Unmarshal(hb, &jh) if err != nil { return fmt.Errorf("invalid header - %s", err.Error()) } -//fmt.Printf ("DECODED HEADER [%+v]\n", jh) + + if jh.Algo != "RS512" || jh.Type != "JWT" { return fmt.Errorf("invalid header content %+v", jh) } cb, err = base64.RawURLEncoding.DecodeString(segs[1]) if err != nil { return fmt.Errorf("invalid claims - %s", err.Error()) } - err = json.Unmarshal(cb, &jcm) - if err != nil { return fmt.Errorf("invalid header - %s", err.Error()) } -//fmt.Printf ("DECODED CLAIMS [%+v]\n", jcm) + err = json.Unmarshal(cb, j.claims) + if err != nil { return fmt.Errorf("invalid claims - %s", err.Error()) } - x, err = j.Sign(jcm) - if err != nil { return err } + ss, err = base64.RawURLEncoding.DecodeString(segs[2]) + if err != nil { return fmt.Errorf("invalid signature - %s", err.Error()) } - if x != tok { return fmt.Errorf("signature mismatch") } -//fmt.Printf ("VERIFICATION OK...[%s] [%s]\n", x, tok) - -// sb, err = base64.RawURLEncoding.DecodeString(segs[2]) -// if err != nil { return fmt.Errorf("invalid signature - %s", err.Error()) } -// TODO: check expiry and others... - - _ = sb + hs = crypto.SHA512.New() + hs.Write([]byte(segs[0])) + hs.Write([]byte(".")) + hs.Write([]byte(segs[1])) + err = rsa.VerifyPKCS1v15(&j.key.PublicKey, crypto.SHA512, hs.Sum(nil), ss) + if err != nil { return fmt.Errorf("unverifiable signature - %s", err.Error()) } return nil } diff --git a/jwt_test.go b/jwt_test.go index cf9e004..b86303e 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -1,10 +1,11 @@ package hodu_test +import "crypto/rand" +import "crypto/rsa" import "hodu" import "testing" func TestJwt(t *testing.T) { - var j hodu.JWT var tok string var err error @@ -18,9 +19,21 @@ func TestJwt(t *testing.T) { jc.Abc = "def" jc.Donkey = "kong" jc.IssuedAt = 111 - tok, err = j.Sign(&jc) + + var key *rsa.PrivateKey + key, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { t.Fatalf("keygen failure - %s", err.Error()) } + + var j *hodu.JWT[JWTClaim] + j = hodu.NewJWT(key, &jc) + tok, err = j.SignRS512() if err != nil { t.Fatalf("signing failure - %s", err.Error()) } - err = j.Verify(tok) + jc = JWTClaim{} + err = j.VerifyRS512(tok) if err != nil { t.Fatalf("verification failure - %s", err.Error()) } + + if jc.Abc != "def" { t.Fatal("decoding failure of Abc field") } + if jc.Donkey != "kong" { t.Fatal("decoding failure of Donkey field") } + if jc.IssuedAt != 111 { t.Fatal("decoding failure of Issued field") } } diff --git a/server-ctl.go b/server-ctl.go index f49e417..7e50918 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -5,6 +5,16 @@ import "net/http" import "strings" import "time" +type ServerTokenClaim struct { + ExpiresAt int64 `json:"exp"` + IssuedAt int64 `json:"iat"` +} + +type json_out_token struct { + AccessToken string `json:"access-token"` + RefreshToken string `json:"refresh-token"` +} + type json_out_server_conn struct { Id ConnId `json:"id"` ServerAddr string `json:"server-addr"` @@ -86,9 +96,16 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { auth_parts = strings.Fields(auth_hdr) if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") { - var jwt JWT - err = jwt.Verify(strings.TrimSpace(auth_parts[1])) - if err == nil { return "" } + var jwt *JWT[ServerTokenClaim] + var claim ServerTokenClaim + jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim) + err = jwt.VerifyRS512(strings.TrimSpace(auth_parts[1])) + if err == nil { + // verification ok. let's check the actual payload + var now time.Time + now = time.Now() + if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return "" } // not expired + } } // fall back to basic authentication @@ -104,16 +121,6 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { // ------------------------------------ -type ServerTokenClaim struct { - ExpiresAt int64 `json:"exp"` - IssuedAt int64 `json:"iat"` -} - -type json_out_token struct { - AccessToken string `json:"access-token"` - RefreshToken string `json:"refresh-token"` -} - func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { var s *Server var status_code int @@ -125,8 +132,8 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) switch req.Method { case http.MethodGet: - var jwt JWT - var jc ServerTokenClaim + var jwt *JWT[ServerTokenClaim] + var claim ServerTokenClaim var tok string var now time.Time @@ -136,9 +143,10 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) } now = time.Now() - jc.IssuedAt = now.Unix() - jc.ExpiresAt = now.Add(s.cfg.CtlAuth.TokenTtl).Unix() - tok, err = jwt.Sign(&jc) + claim.IssuedAt = now.Unix() + claim.ExpiresAt = now.Add(s.cfg.CtlAuth.TokenTtl).Unix() + jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim) + tok, err = jwt.SignRS512() if err != nil { status_code = WriteJsonRespHeader(w, http.StatusInternalServerError) je.Encode(JsonErrmsg{Text: err.Error()}) diff --git a/server.go b/server.go index 930a708..9bb510a 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package hodu import "context" +import "crypto/rsa" import "crypto/tls" import "errors" import "fmt" @@ -49,6 +50,7 @@ type ServerAuthConfig struct { Realm string Creds ServerAuthCredMap TokenTtl time.Duration + TokenRsaKey *rsa.PrivateKey } type ServerConfig struct {