updated Authenticate to return status code
This commit is contained in:
parent
8bee855aa8
commit
148dfbcfe1
@ -10,6 +10,7 @@ import "fmt"
|
|||||||
import "hodu"
|
import "hodu"
|
||||||
import "io"
|
import "io"
|
||||||
import "io/ioutil"
|
import "io/ioutil"
|
||||||
|
import "net/netip"
|
||||||
import "os"
|
import "os"
|
||||||
import "strings"
|
import "strings"
|
||||||
import "time"
|
import "time"
|
||||||
@ -46,19 +47,26 @@ type ClientTLSConfig struct {
|
|||||||
ServerName string `yaml:"server-name"`
|
ServerName string `yaml:"server-name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type HttpAccessRule struct {
|
||||||
|
Prefix string `yaml:"prefix"`
|
||||||
|
OrgNets []string `yaml:"origin-networks"`
|
||||||
|
Action string `yaml:"action"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HttpAuthConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
Realm string `yaml:"realm"`
|
Realm string `yaml:"realm"`
|
||||||
Creds []string `yaml:"credentials"`
|
Creds []string `yaml:"credentials"`
|
||||||
TokenTtl string `yaml:"token-ttl"`
|
TokenTtl string `yaml:"token-ttl"`
|
||||||
TokenRsaKeyText string `yaml:"token-rsa-key-text"`
|
TokenRsaKeyText string `yaml:"token-rsa-key-text"`
|
||||||
TokenRsaKeyFile string `yaml:"token-rsa-key-file"`
|
TokenRsaKeyFile string `yaml:"token-rsa-key-file"`
|
||||||
|
AccessRules []HttpAccessRule `yaml:"access-rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CTLServiceConfig struct {
|
type CTLServiceConfig struct {
|
||||||
Prefix string `yaml:"prefix"` // url prefix for control channel endpoints
|
Prefix string `yaml:"prefix"` // url prefix for control channel endpoints
|
||||||
Addrs []string `yaml:"addresses"`
|
Addrs []string `yaml:"addresses"`
|
||||||
Auth AuthConfig `yaml:"auth"`
|
Auth HttpAuthConfig `yaml:"auth"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PXYServiceConfig struct {
|
type PXYServiceConfig struct {
|
||||||
@ -350,19 +358,20 @@ func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) {
|
func make_server_auth_config(cfg *HttpAuthConfig) (*hodu.ServerHttpAuthConfig, error) {
|
||||||
var config hodu.ServerAuthConfig
|
var config hodu.ServerHttpAuthConfig
|
||||||
var cred string
|
var cred string
|
||||||
var b []byte
|
var b []byte
|
||||||
var x []string
|
var x []string
|
||||||
var rsa_key_text []byte
|
var rsa_key_text []byte
|
||||||
var rk *rsa.PrivateKey
|
var rk *rsa.PrivateKey
|
||||||
var pb *pem.Block
|
var pb *pem.Block
|
||||||
|
var rule HttpAccessRule
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
config.Enabled = cfg.Enabled
|
config.Enabled = cfg.Enabled
|
||||||
config.Realm = cfg.Realm
|
config.Realm = cfg.Realm
|
||||||
config.Creds = make(hodu.ServerAuthCredMap)
|
config.Creds = make(hodu.ServerHttpAuthCredMap)
|
||||||
config.TokenTtl, err = hodu.ParseDurationString(cfg.TokenTtl)
|
config.TokenTtl, err = hodu.ParseDurationString(cfg.TokenTtl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err)
|
return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err)
|
||||||
@ -382,7 +391,6 @@ func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) {
|
|||||||
config.Creds[x[0]] = x[1]
|
config.Creds[x[0]] = x[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// load rsa key
|
// load rsa key
|
||||||
if cfg.TokenRsaKeyText == "" && cfg.TokenRsaKeyFile != "" {
|
if cfg.TokenRsaKeyText == "" && cfg.TokenRsaKeyFile != "" {
|
||||||
rsa_key_text, err = os.ReadFile(cfg.TokenRsaKeyFile)
|
rsa_key_text, err = os.ReadFile(cfg.TokenRsaKeyFile)
|
||||||
@ -404,5 +412,41 @@ func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
config.TokenRsaKey = rk
|
config.TokenRsaKey = rk
|
||||||
|
|
||||||
|
// load access rules
|
||||||
|
for _, rule = range cfg.AccessRules {
|
||||||
|
var action hodu.HttpAccessAction
|
||||||
|
var orgnet string
|
||||||
|
var lastidx int
|
||||||
|
|
||||||
|
if rule.Prefix == "" {
|
||||||
|
return nil, fmt.Errorf("blank access rule prefix not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(rule.Action) {
|
||||||
|
case "accept":
|
||||||
|
action = hodu.HTTP_ACCESS_ACCEPT
|
||||||
|
case "reject":
|
||||||
|
action = hodu.HTTP_ACCESS_REJECT
|
||||||
|
case "auth-required":
|
||||||
|
action = hodu.HTTP_ACCESS_AUTH_REQUIRED
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid access rule action %s", rule.Action)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.AccessRules = append(config.AccessRules, hodu.HttpAccessRule{
|
||||||
|
Prefix: rule.Prefix,
|
||||||
|
Action: action,
|
||||||
|
})
|
||||||
|
|
||||||
|
lastidx = len(config.AccessRules) - 1
|
||||||
|
for _, orgnet = range rule.OrgNets {
|
||||||
|
var netpfx netip.Prefix
|
||||||
|
netpfx, err = netip.ParsePrefix(orgnet)
|
||||||
|
if err != nil { return nil, fmt.Errorf("invalid network %s - %s", orgnet, err.Error()) }
|
||||||
|
config.AccessRules[lastidx].OrgNets = append(config.AccessRules[lastidx].OrgNets, netpfx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
13
hodu.go
13
hodu.go
@ -48,6 +48,19 @@ type Service interface {
|
|||||||
WriteLog(id string, level LogLevel, fmtstr string, args ...interface{})
|
WriteLog(id string, level LogLevel, fmtstr string, args ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HttpAccessAction int
|
||||||
|
const (
|
||||||
|
HTTP_ACCESS_ACCEPT HttpAccessAction = iota
|
||||||
|
HTTP_ACCESS_REJECT
|
||||||
|
HTTP_ACCESS_AUTH_REQUIRED
|
||||||
|
)
|
||||||
|
|
||||||
|
type HttpAccessRule struct {
|
||||||
|
Prefix string
|
||||||
|
OrgNets []netip.Prefix
|
||||||
|
Action HttpAccessAction
|
||||||
|
}
|
||||||
|
|
||||||
type JsonErrmsg struct {
|
type JsonErrmsg struct {
|
||||||
Text string `json:"error-text"`
|
Text string `json:"error-text"`
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package hodu
|
package hodu
|
||||||
|
|
||||||
import "encoding/json"
|
import "encoding/json"
|
||||||
|
import "fmt"
|
||||||
import "net/http"
|
import "net/http"
|
||||||
import "strings"
|
import "strings"
|
||||||
import "time"
|
import "time"
|
||||||
@ -12,7 +13,7 @@ type ServerTokenClaim struct {
|
|||||||
|
|
||||||
type json_out_token struct {
|
type json_out_token struct {
|
||||||
AccessToken string `json:"access-token"`
|
AccessToken string `json:"access-token"`
|
||||||
RefreshToken string `json:"refresh-token"`
|
RefreshToken string `json:"refresh-token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type json_out_server_conn struct {
|
type json_out_server_conn struct {
|
||||||
@ -78,10 +79,22 @@ func (ctl *server_ctl) Id() string {
|
|||||||
return ctl.id
|
return ctl.id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctl *server_ctl) Authenticate(req *http.Request) string {
|
func (ctl *server_ctl) Authenticate(req *http.Request) (int, string) {
|
||||||
var s *Server
|
var s *Server
|
||||||
|
var rule HttpAccessRule
|
||||||
|
|
||||||
s = ctl.s
|
s = ctl.s
|
||||||
|
|
||||||
|
for _, rule = range s.cfg.CtlAuth.AccessRules {
|
||||||
|
if req.URL.Path == rule.Prefix || strings.HasPrefix(req.URL.Path, rule.Prefix + "/") {
|
||||||
|
if rule.Action == HTTP_ACCESS_ACCEPT {
|
||||||
|
return http.StatusOK, ""
|
||||||
|
} else if rule.Action == HTTP_ACCESS_REJECT {
|
||||||
|
return http.StatusForbidden, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.cfg.CtlAuth != nil && s.cfg.CtlAuth.Enabled {
|
if s.cfg.CtlAuth != nil && s.cfg.CtlAuth.Enabled {
|
||||||
var auth_hdr string
|
var auth_hdr string
|
||||||
var auth_parts []string
|
var auth_parts []string
|
||||||
@ -92,10 +105,10 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
auth_hdr = req.Header.Get("Authorization")
|
auth_hdr = req.Header.Get("Authorization")
|
||||||
if auth_hdr == "" { return s.cfg.CtlAuth.Realm }
|
if auth_hdr == "" { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm }
|
||||||
|
|
||||||
auth_parts = strings.Fields(auth_hdr)
|
auth_parts = strings.Fields(auth_hdr)
|
||||||
if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") {
|
if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") && s.cfg.CtlAuth.TokenRsaKey != nil {
|
||||||
var jwt *JWT[ServerTokenClaim]
|
var jwt *JWT[ServerTokenClaim]
|
||||||
var claim ServerTokenClaim
|
var claim ServerTokenClaim
|
||||||
jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim)
|
jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim)
|
||||||
@ -104,19 +117,19 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string {
|
|||||||
// verification ok. let's check the actual payload
|
// verification ok. let's check the actual payload
|
||||||
var now time.Time
|
var now time.Time
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return "" } // not expired
|
if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return http.StatusOK, "" } // not expired
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fall back to basic authentication
|
// fall back to basic authentication
|
||||||
username, password, ok = req.BasicAuth()
|
username, password, ok = req.BasicAuth()
|
||||||
if !ok { return s.cfg.CtlAuth.Realm }
|
if !ok { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm }
|
||||||
|
|
||||||
credpass, ok = s.cfg.CtlAuth.Creds[username]
|
credpass, ok = s.cfg.CtlAuth.Creds[username]
|
||||||
if !ok || credpass != password { return s.cfg.CtlAuth.Realm }
|
if !ok || credpass != password { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm }
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return http.StatusOK, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------------
|
// ------------------------------------
|
||||||
@ -137,9 +150,11 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
|||||||
var tok string
|
var tok string
|
||||||
var now time.Time
|
var now time.Time
|
||||||
|
|
||||||
if s.cfg.CtlAuth == nil || !s.cfg.CtlAuth.Enabled {
|
if s.cfg.CtlAuth == nil || !s.cfg.CtlAuth.Enabled || s.cfg.CtlAuth.TokenRsaKey == nil {
|
||||||
status_code = WriteEmptyRespHeader(w, http.StatusNotFound)
|
status_code = WriteJsonRespHeader(w, http.StatusForbidden)
|
||||||
goto done
|
err = fmt.Errorf("auth not enabled or token rsa key not set")
|
||||||
|
je.Encode(JsonErrmsg{Text: err.Error()})
|
||||||
|
goto oops
|
||||||
}
|
}
|
||||||
|
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
@ -161,7 +176,7 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
|||||||
status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed)
|
status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
done:
|
//done:
|
||||||
return status_code, nil
|
return status_code, nil
|
||||||
|
|
||||||
oops:
|
oops:
|
||||||
|
@ -190,8 +190,8 @@ func (pxy *server_proxy) Id() string {
|
|||||||
return pxy.id
|
return pxy.id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pxy *server_proxy) Authenticate(req *http.Request) string {
|
func (pxy *server_proxy) Authenticate(req *http.Request) (int, string) {
|
||||||
return ""
|
return http.StatusOK, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------------
|
// ------------------------------------
|
||||||
|
18
server.go
18
server.go
@ -43,14 +43,15 @@ type ServerSvcPortMap = map[PortId]ConnRouteId
|
|||||||
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
|
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
|
||||||
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
|
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
|
||||||
|
|
||||||
type ServerAuthCredMap map[string]string
|
type ServerHttpAuthCredMap map[string]string
|
||||||
|
|
||||||
type ServerAuthConfig struct {
|
type ServerHttpAuthConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Realm string
|
Realm string
|
||||||
Creds ServerAuthCredMap
|
Creds ServerHttpAuthCredMap
|
||||||
TokenTtl time.Duration
|
TokenTtl time.Duration
|
||||||
TokenRsaKey *rsa.PrivateKey
|
TokenRsaKey *rsa.PrivateKey
|
||||||
|
AccessRules []HttpAccessRule
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
@ -62,7 +63,7 @@ type ServerConfig struct {
|
|||||||
CtlAddrs []string
|
CtlAddrs []string
|
||||||
CtlTls *tls.Config
|
CtlTls *tls.Config
|
||||||
CtlPrefix string
|
CtlPrefix string
|
||||||
CtlAuth *ServerAuthConfig
|
CtlAuth *ServerHttpAuthConfig
|
||||||
|
|
||||||
PxyAddrs []string
|
PxyAddrs []string
|
||||||
PxyTls *tls.Config
|
PxyTls *tls.Config
|
||||||
@ -953,7 +954,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
|
|||||||
|
|
||||||
type ServerHttpHandler interface {
|
type ServerHttpHandler interface {
|
||||||
Id() string
|
Id() string
|
||||||
Authenticate(req *http.Request) string
|
Authenticate(req *http.Request) (int, string)
|
||||||
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
|
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -976,11 +977,10 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
start_time = time.Now()
|
start_time = time.Now()
|
||||||
realm = handler.Authenticate(req)
|
status_code, realm = handler.Authenticate(req)
|
||||||
if realm != "" {
|
if status_code == http.StatusUnauthorized && realm != "" {
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm))
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm))
|
||||||
status_code = http.StatusUnauthorized
|
} else if status_code == http.StatusOK {
|
||||||
} else {
|
|
||||||
status_code, err = handler.ServeHTTP(w, req)
|
status_code, err = handler.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
time_taken = time.Now().Sub(start_time)
|
time_taken = time.Now().Sub(start_time)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user