diff --git a/cmd/config.go b/cmd/config.go index 16df674..2da0981 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -10,6 +10,7 @@ import "fmt" import "hodu" import "io" import "io/ioutil" +import "net/netip" import "os" import "strings" import "time" @@ -46,19 +47,26 @@ type ClientTLSConfig struct { ServerName string `yaml:"server-name"` } -type AuthConfig struct { +type HttpAccessRule struct { + Prefix string `yaml:"prefix"` + OrgNets []string `yaml:"origin-networks"` + Action string `yaml:"action"` +} + +type HttpAuthConfig struct { Enabled bool `yaml:"enabled"` Realm string `yaml:"realm"` Creds []string `yaml:"credentials"` TokenTtl string `yaml:"token-ttl"` TokenRsaKeyText string `yaml:"token-rsa-key-text"` TokenRsaKeyFile string `yaml:"token-rsa-key-file"` + AccessRules []HttpAccessRule `yaml:"access-rules"` } type CTLServiceConfig struct { Prefix string `yaml:"prefix"` // url prefix for control channel endpoints Addrs []string `yaml:"addresses"` - Auth AuthConfig `yaml:"auth"` + Auth HttpAuthConfig `yaml:"auth"` } type PXYServiceConfig struct { @@ -350,19 +358,20 @@ func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) { } // -------------------------------------------------------------------- -func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) { - var config hodu.ServerAuthConfig +func make_server_auth_config(cfg *HttpAuthConfig) (*hodu.ServerHttpAuthConfig, error) { + var config hodu.ServerHttpAuthConfig var cred string var b []byte var x []string var rsa_key_text []byte var rk *rsa.PrivateKey var pb *pem.Block + var rule HttpAccessRule var err error config.Enabled = cfg.Enabled config.Realm = cfg.Realm - config.Creds = make(hodu.ServerAuthCredMap) + config.Creds = make(hodu.ServerHttpAuthCredMap) config.TokenTtl, err = hodu.ParseDurationString(cfg.TokenTtl) if err != nil { return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err) @@ -382,7 +391,6 @@ func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) { config.Creds[x[0]] = x[1] } - // load rsa key if cfg.TokenRsaKeyText == "" && cfg.TokenRsaKeyFile != "" { rsa_key_text, err = os.ReadFile(cfg.TokenRsaKeyFile) @@ -404,5 +412,41 @@ func make_server_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) { } config.TokenRsaKey = rk + + // load access rules + for _, rule = range cfg.AccessRules { + var action hodu.HttpAccessAction + var orgnet string + var lastidx int + + if rule.Prefix == "" { + return nil, fmt.Errorf("blank access rule prefix not allowed") + } + + switch strings.ToLower(rule.Action) { + case "accept": + action = hodu.HTTP_ACCESS_ACCEPT + case "reject": + action = hodu.HTTP_ACCESS_REJECT + case "auth-required": + action = hodu.HTTP_ACCESS_AUTH_REQUIRED + default: + return nil, fmt.Errorf("invalid access rule action %s", rule.Action) + } + + config.AccessRules = append(config.AccessRules, hodu.HttpAccessRule{ + Prefix: rule.Prefix, + Action: action, + }) + + lastidx = len(config.AccessRules) - 1 + for _, orgnet = range rule.OrgNets { + var netpfx netip.Prefix + netpfx, err = netip.ParsePrefix(orgnet) + if err != nil { return nil, fmt.Errorf("invalid network %s - %s", orgnet, err.Error()) } + config.AccessRules[lastidx].OrgNets = append(config.AccessRules[lastidx].OrgNets, netpfx) + } + } + return &config, nil } diff --git a/hodu.go b/hodu.go index 66ee158..533c81d 100644 --- a/hodu.go +++ b/hodu.go @@ -48,6 +48,19 @@ type Service interface { WriteLog(id string, level LogLevel, fmtstr string, args ...interface{}) } +type HttpAccessAction int +const ( + HTTP_ACCESS_ACCEPT HttpAccessAction = iota + HTTP_ACCESS_REJECT + HTTP_ACCESS_AUTH_REQUIRED +) + +type HttpAccessRule struct { + Prefix string + OrgNets []netip.Prefix + Action HttpAccessAction +} + type JsonErrmsg struct { Text string `json:"error-text"` } diff --git a/server-ctl.go b/server-ctl.go index 7e50918..f0685d5 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -1,6 +1,7 @@ package hodu import "encoding/json" +import "fmt" import "net/http" import "strings" import "time" @@ -12,7 +13,7 @@ type ServerTokenClaim struct { type json_out_token struct { AccessToken string `json:"access-token"` - RefreshToken string `json:"refresh-token"` + RefreshToken string `json:"refresh-token,omitempty"` } type json_out_server_conn struct { @@ -78,10 +79,22 @@ func (ctl *server_ctl) Id() string { return ctl.id } -func (ctl *server_ctl) Authenticate(req *http.Request) string { +func (ctl *server_ctl) Authenticate(req *http.Request) (int, string) { var s *Server + var rule HttpAccessRule s = ctl.s + + for _, rule = range s.cfg.CtlAuth.AccessRules { + if req.URL.Path == rule.Prefix || strings.HasPrefix(req.URL.Path, rule.Prefix + "/") { + if rule.Action == HTTP_ACCESS_ACCEPT { + return http.StatusOK, "" + } else if rule.Action == HTTP_ACCESS_REJECT { + return http.StatusForbidden, "" + } + } + } + if s.cfg.CtlAuth != nil && s.cfg.CtlAuth.Enabled { var auth_hdr string var auth_parts []string @@ -92,10 +105,10 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { var err error auth_hdr = req.Header.Get("Authorization") - if auth_hdr == "" { return s.cfg.CtlAuth.Realm } + if auth_hdr == "" { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm } auth_parts = strings.Fields(auth_hdr) - if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") { + if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") && s.cfg.CtlAuth.TokenRsaKey != nil { var jwt *JWT[ServerTokenClaim] var claim ServerTokenClaim jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim) @@ -104,19 +117,19 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { // verification ok. let's check the actual payload var now time.Time now = time.Now() - if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return "" } // not expired + if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return http.StatusOK, "" } // not expired } } // fall back to basic authentication username, password, ok = req.BasicAuth() - if !ok { return s.cfg.CtlAuth.Realm } + if !ok { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm } credpass, ok = s.cfg.CtlAuth.Creds[username] - if !ok || credpass != password { return s.cfg.CtlAuth.Realm } + if !ok || credpass != password { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm } } - return "" + return http.StatusOK, "" } // ------------------------------------ @@ -137,9 +150,11 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) var tok string var now time.Time - if s.cfg.CtlAuth == nil || !s.cfg.CtlAuth.Enabled { - status_code = WriteEmptyRespHeader(w, http.StatusNotFound) - goto done + if s.cfg.CtlAuth == nil || !s.cfg.CtlAuth.Enabled || s.cfg.CtlAuth.TokenRsaKey == nil { + status_code = WriteJsonRespHeader(w, http.StatusForbidden) + err = fmt.Errorf("auth not enabled or token rsa key not set") + je.Encode(JsonErrmsg{Text: err.Error()}) + goto oops } now = time.Now() @@ -161,7 +176,7 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } -done: +//done: return status_code, nil oops: diff --git a/server-proxy.go b/server-proxy.go index 360ce15..e4c5612 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -190,8 +190,8 @@ func (pxy *server_proxy) Id() string { return pxy.id } -func (pxy *server_proxy) Authenticate(req *http.Request) string { - return "" +func (pxy *server_proxy) Authenticate(req *http.Request) (int, string) { + return http.StatusOK, "" } // ------------------------------------ diff --git a/server.go b/server.go index 9bb510a..cff2486 100644 --- a/server.go +++ b/server.go @@ -43,14 +43,15 @@ type ServerSvcPortMap = map[PortId]ConnRouteId type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error) -type ServerAuthCredMap map[string]string +type ServerHttpAuthCredMap map[string]string -type ServerAuthConfig struct { +type ServerHttpAuthConfig struct { Enabled bool Realm string - Creds ServerAuthCredMap + Creds ServerHttpAuthCredMap TokenTtl time.Duration TokenRsaKey *rsa.PrivateKey + AccessRules []HttpAccessRule } type ServerConfig struct { @@ -62,7 +63,7 @@ type ServerConfig struct { CtlAddrs []string CtlTls *tls.Config CtlPrefix string - CtlAuth *ServerAuthConfig + CtlAuth *ServerHttpAuthConfig PxyAddrs []string PxyTls *tls.Config @@ -953,7 +954,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) { type ServerHttpHandler interface { Id() string - Authenticate(req *http.Request) string + Authenticate(req *http.Request) (int, string) ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error) } @@ -976,11 +977,10 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler { }() start_time = time.Now() - realm = handler.Authenticate(req) - if realm != "" { + status_code, realm = handler.Authenticate(req) + if status_code == http.StatusUnauthorized && realm != "" { w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm)) - status_code = http.StatusUnauthorized - } else { + } else if status_code == http.StatusOK { status_code, err = handler.ServeHTTP(w, req) } time_taken = time.Now().Sub(start_time)