diff --git a/client-ctl.go b/client-ctl.go index 2978f56..ddcc403 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -91,6 +91,10 @@ type client_ctl struct { id string } +type client_ctl_token struct { + client_ctl +} + type client_ctl_client_conns struct { client_ctl //c *Client @@ -131,6 +135,62 @@ func (ctl *client_ctl) Id() string { return ctl.id } +func (ctl *client_ctl) Authenticate(req *http.Request) (int, string) { + if ctl.c.ctl_auth == nil { return http.StatusOK, "" } + return ctl.c.ctl_auth.Authenticate(req) +} + +// ------------------------------------ + +func (ctl *client_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { + var c *Client + var status_code int + var je *json.Encoder + var err error + + c = ctl.c + je = json.NewEncoder(w) + + switch req.Method { + case http.MethodGet: + var jwt *JWT[ServerTokenClaim] + var claim ServerTokenClaim + var tok string + var now time.Time + + if c.ctl_auth == nil || !c.ctl_auth.Enabled || c.ctl_auth.TokenRsaKey == nil { + status_code = WriteJsonRespHeader(w, http.StatusForbidden) + err = fmt.Errorf("auth not enabled or token rsa key not set") + je.Encode(JsonErrmsg{Text: err.Error()}) + goto oops + } + + now = time.Now() + claim.IssuedAt = now.Unix() + claim.ExpiresAt = now.Add(c.ctl_auth.TokenTtl).Unix() + jwt = NewJWT(c.ctl_auth.TokenRsaKey, &claim) + tok, err = jwt.SignRS512() + if err != nil { + status_code = WriteJsonRespHeader(w, http.StatusInternalServerError) + je.Encode(JsonErrmsg{Text: err.Error()}) + goto oops + } + + status_code = WriteJsonRespHeader(w, http.StatusOK) + err = je.Encode(json_out_token{ AccessToken: tok }) // TODO: refresh token + if err != nil { goto oops } + + default: + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) + } + +//done: + return status_code, nil + +oops: + return status_code, err +} + // ------------------------------------ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { diff --git a/client.go b/client.go index 8969f05..a138328 100644 --- a/client.go +++ b/client.go @@ -55,16 +55,17 @@ type ClientConfigActive struct { type Client struct { Named - ctx context.Context - ctx_cancel context.CancelFunc - ctltlscfg *tls.Config - rpctlscfg *tls.Config + ctx context.Context + ctx_cancel context.CancelFunc - ext_mtx sync.Mutex + ext_mtx sync.Mutex ext_svcs []Service + rpc_tls *tls.Config + ctl_tls *tls.Config ctl_addr []string - ctl_prefix string + ctl_prefix string + ctl_auth *HttpAuthConfig ctl_mux *http.ServeMux ctl []*http.Server // control server @@ -930,16 +931,16 @@ start_over: cts.State = CLIENT_CONN_CONNECTING cts.cfg.Index = (cts.cfg.Index + 1) % len(cts.cfg.ServerAddrs) cts.cli.log.Write(cts.Sid, LOG_INFO, "Connecting to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) - if cts.cli.rpctlscfg == nil { + if cts.cli.rpc_tls == nil { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) if cts.cfg.ServerAuthority != "" { opts = append(opts, grpc.WithAuthority(cts.cfg.ServerAuthority)) } } else { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(cts.cli.rpctlscfg))) + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(cts.cli.rpc_tls))) // set the http2 :authority header with tls server name defined. if cts.cfg.ServerAuthority != "" { opts = append(opts, grpc.WithAuthority(cts.cfg.ServerAuthority)) - } else if cts.cli.rpctlscfg.ServerName != "" { - opts = append(opts, grpc.WithAuthority(cts.cli.rpctlscfg.ServerName)) + } else if cts.cli.rpc_tls.ServerName != "" { + opts = append(opts, grpc.WithAuthority(cts.cli.rpc_tls.ServerName)) } } if cts.cfg.ServerSeedTmout > 0 { @@ -1225,6 +1226,7 @@ func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) { type ClientHttpHandler interface { Id() string + Authenticate(req *http.Request) (int, string) ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error) } @@ -1234,6 +1236,7 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler { var err error var start_time time.Time var time_taken time.Duration + var realm string // this deferred function is to overcome the recovering implemenation // from panic done in go's http server. in that implemenation, panic @@ -1247,11 +1250,15 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler { start_time = time.Now() - // TODO: some kind of authorization, especially for ctl - //req.BasicAuth() - //req.Header.Get("Authorization") - - status_code, err = handler.ServeHTTP(w, req) + status_code, realm = handler.Authenticate(req) + if status_code == http.StatusUnauthorized && realm != "" { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm)) + WriteEmptyRespHeader(w, status_code) + } else if status_code == http.StatusOK { + status_code, err = handler.ServeHTTP(w, req) + } else { + WriteEmptyRespHeader(w, status_code) + } // TODO: statistics by status_code and end point types. time_taken = time.Now().Sub(start_time) @@ -1266,15 +1273,13 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler { }) } -func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client { +func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []string, ctl_prefix string, ctl_tls *tls.Config, ctl_auth *HttpAuthConfig, rpc_tls *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client { var c Client var i int var hs_log *log.Logger c.name = name c.ctx, c.ctx_cancel = context.WithCancel(ctx) - c.ctltlscfg = ctltlscfg - c.rpctlscfg = rpctlscfg c.ext_svcs = make([]Service, 0, 1) c.ptc_tmout = peer_conn_tmout c.ptc_limit = peer_max @@ -1284,8 +1289,11 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri c.stop_req.Store(false) c.stop_chan = make(chan bool, 8) c.log = logger - c.ctl_prefix = ctl_prefix + c.rpc_tls = rpc_tls + c.ctl_auth = ctl_auth + c.ctl_tls = ctl_tls + c.ctl_prefix = ctl_prefix c.ctl_mux = http.NewServeMux() c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/client-conns", c.wrap_http_handler(&client_ctl_client_conns{client_ctl{c: &c, id: HS_ID_CTL}})) @@ -1303,6 +1311,8 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri c.wrap_http_handler(&client_ctl_client_conns_id_routes_id_peers_id{client_ctl{c: &c, id: HS_ID_CTL}})) c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/stats", c.wrap_http_handler(&client_ctl_stats{client_ctl{c: &c, id: HS_ID_CTL}})) + c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/token", + c.wrap_http_handler(&client_ctl_token{client_ctl{c: &c, id: HS_ID_CTL}})) // TODO: make this optional. add this endpoint only if it's enabled... c.promreg = prometheus.NewRegistry() @@ -1322,7 +1332,7 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri c.ctl[i] = &http.Server{ Addr: ctl_addrs[i], Handler: c.ctl_mux, - TLSConfig: c.ctltlscfg, + TLSConfig: c.ctl_tls, ErrorLog: hs_log, // TODO: more settings } @@ -1575,10 +1585,10 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) { // check it again to make the guard slightly more stable // although it's still possible that the stop request is made // after Listen() - if c.ctltlscfg == nil { + if c.ctl_tls == nil { err = cs.Serve(l) } else { - err = cs.ServeTLS(l, "", "") // c.ctltlscfg must provide a certificate and a key + err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key } } else { err = fmt.Errorf("stop requested") diff --git a/cmd/config.go b/cmd/config.go index 2da0981..d17e80e 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -358,8 +358,8 @@ func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) { } // -------------------------------------------------------------------- -func make_server_auth_config(cfg *HttpAuthConfig) (*hodu.ServerHttpAuthConfig, error) { - var config hodu.ServerHttpAuthConfig +func make_http_auth_config(cfg *HttpAuthConfig) (*hodu.HttpAuthConfig, error) { + var config hodu.HttpAuthConfig var cred string var b []byte var x []string @@ -371,7 +371,7 @@ func make_server_auth_config(cfg *HttpAuthConfig) (*hodu.ServerHttpAuthConfig, e config.Enabled = cfg.Enabled config.Realm = cfg.Realm - config.Creds = make(hodu.ServerHttpAuthCredMap) + config.Creds = make(hodu.HttpAuthCredMap) config.TokenTtl, err = hodu.ParseDurationString(cfg.TokenTtl) if err != nil { return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err) diff --git a/cmd/main.go b/cmd/main.go index 48752e1..9535b37 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -127,7 +127,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs } if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs } - config.CtlAuth, err = make_server_auth_config(&cfg.CTL.Service.Auth) + config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth) if err != nil { return err } config.CtlPrefix = cfg.CTL.Service.Prefix @@ -247,8 +247,9 @@ func parse_client_route_config(v string) (*hodu.ClientRouteConfig, error) { func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, cfg *ClientConfig) error { var c *hodu.Client - var ctltlscfg *tls.Config var rpctlscfg *tls.Config + var ctltlscfg *tls.Config + var ctl_auth *hodu.HttpAuthConfig var ctl_prefix string var cc hodu.ClientConfig var logger *AppLogger @@ -265,18 +266,17 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, logmask = hodu.LOG_ALL if cfg != nil { ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS) - if err != nil { - return err - } + if err != nil { return err } rpctlscfg, err = make_tls_client_config(&cfg.RPC.TLS) - if err != nil { - return err - } + if err != nil { return err } if len(ctl_addrs) <= 0 { ctl_addrs = cfg.CTL.Service.Addrs } if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs } ctl_prefix = cfg.CTL.Service.Prefix + ctl_auth, err = make_http_auth_config(&cfg.CTL.Service.Auth) + if err != nil { return err } + cc.ServerSeedTmout = cfg.RPC.Endpoint.SeedTmout cc.ServerAuthority = cfg.RPC.Endpoint.Authority logmask = log_strings_to_mask(cfg.APP.LogMask) @@ -314,6 +314,7 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, ctl_addrs, ctl_prefix, ctltlscfg, + ctl_auth, rpctlscfg, max_rpc_conns, max_peers, diff --git a/hodu.go b/hodu.go index 533c81d..29b178f 100644 --- a/hodu.go +++ b/hodu.go @@ -1,10 +1,12 @@ package hodu +import "crypto/rsa" import "fmt" import "net" import "net/http" import "net/netip" import "os" +import "path/filepath" import "runtime" import "strings" import "sync" @@ -61,6 +63,17 @@ type HttpAccessRule struct { Action HttpAccessAction } +type HttpAuthCredMap map[string]string + +type HttpAuthConfig struct { + Enabled bool + Realm string + Creds HttpAuthCredMap + TokenTtl time.Duration + TokenRsaKey *rsa.PrivateKey + AccessRules []HttpAccessRule +} + type JsonErrmsg struct { Text string `json:"error-text"` } @@ -230,13 +243,15 @@ func DurationToSecString(d time.Duration) string { return fmt.Sprintf("%.09f", d.Seconds()) } +// ------------------------------------ + func WriteJsonRespHeader(w http.ResponseWriter, status_code int) int { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status_code) return status_code } -func write_js_resp_header(w http.ResponseWriter, status_code int) int { +func WriteJsRespHeader(w http.ResponseWriter, status_code int) int { w.Header().Set("Content-Type", "application/javascript") w.WriteHeader(status_code) return status_code @@ -259,6 +274,8 @@ func WriteEmptyRespHeader(w http.ResponseWriter, status_code int) int { return status_code } +// ------------------------------------ + func server_route_to_proxy_info(r *ServerRoute) *ServerRouteProxyInfo { return &ServerRouteProxyInfo{ SvcOption: r.SvcOption, @@ -279,6 +296,8 @@ func proxy_info_to_server_route(pi *ServerRouteProxyInfo) *ServerRoute { } } +// ------------------------------------ + func (stats *json_out_go_stats) from_runtime_stats() { var mstat runtime.MemStats @@ -312,3 +331,80 @@ func (stats *json_out_go_stats) from_runtime_stats() { stats.GCSysBytes = mstat.GCSys stats.OtherSysBytes = mstat.OtherSys } + +// ------------------------------------ + +func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) { + var rule HttpAccessRule + var raddrport netip.AddrPort + var raddr netip.Addr + var err error + + raddrport, err = netip.ParseAddrPort(req.RemoteAddr) + if err == nil { raddr = raddrport.Addr() } + + for _, rule = range auth.AccessRules { + // i don't take into account X-Forwarded-For and similar headers + if req.URL.Path == rule.Prefix || strings.HasPrefix(req.URL.Path, filepath.Clean(rule.Prefix + "/")) { + var org_net_ok bool + + if len(rule.OrgNets) > 0 && raddr.IsValid() { + var netpfx netip.Prefix + + org_net_ok = false + for _, netpfx = range rule.OrgNets { + if err == nil && netpfx.Contains(raddr) { + org_net_ok = true + break + } + } + } else { + org_net_ok = true + } + + if org_net_ok { + if rule.Action == HTTP_ACCESS_ACCEPT { + return http.StatusOK, "" + } else if rule.Action == HTTP_ACCESS_REJECT { + return http.StatusForbidden, "" + } + } + } + } + + if auth != nil && auth.Enabled { + var auth_hdr string + var auth_parts []string + var username string + var password string + var credpass string + var ok bool + var err error + + auth_hdr = req.Header.Get("Authorization") + if auth_hdr == "" { return http.StatusUnauthorized, auth.Realm } + + auth_parts = strings.Fields(auth_hdr) + if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") && auth.TokenRsaKey != nil { + var jwt *JWT[ServerTokenClaim] + var claim ServerTokenClaim + jwt = NewJWT(auth.TokenRsaKey, &claim) + err = jwt.VerifyRS512(strings.TrimSpace(auth_parts[1])) + if err == nil { + // verification ok. let's check the actual payload + var now time.Time + now = time.Now() + if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return http.StatusOK, "" } // not expired + } + } + + // fall back to basic authentication + username, password, ok = req.BasicAuth() + if !ok { return http.StatusUnauthorized, auth.Realm } + + credpass, ok = auth.Creds[username] + if !ok || credpass != password { return http.StatusUnauthorized, auth.Realm } + } + + return http.StatusOK, "" +} diff --git a/server-ctl.go b/server-ctl.go index 6f73428..21c43a7 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -3,9 +3,6 @@ package hodu import "encoding/json" import "fmt" import "net/http" -import "net/netip" -import "path/filepath" -import "strings" import "time" type ServerTokenClaim struct { @@ -82,81 +79,8 @@ func (ctl *server_ctl) Id() string { } func (ctl *server_ctl) Authenticate(req *http.Request) (int, string) { - var s *Server - var rule HttpAccessRule - var raddrport netip.AddrPort - var raddr netip.Addr - var err error - - s = ctl.s - - raddrport, err = netip.ParseAddrPort(req.RemoteAddr) - if err == nil { raddr = raddrport.Addr() } - - for _, rule = range s.cfg.CtlAuth.AccessRules { - // i don't take into account X-Forwarded-For and similar headers - if req.URL.Path == rule.Prefix || strings.HasPrefix(req.URL.Path, filepath.Clean(rule.Prefix + "/")) { - var org_net_ok bool - - if len(rule.OrgNets) > 0 && raddr.IsValid() { - var netpfx netip.Prefix - - org_net_ok = false - for _, netpfx = range rule.OrgNets { - if err == nil && netpfx.Contains(raddr) { - org_net_ok = true - break - } - } - } else { - org_net_ok = true - } - - if org_net_ok { - if rule.Action == HTTP_ACCESS_ACCEPT { - return http.StatusOK, "" - } else if rule.Action == HTTP_ACCESS_REJECT { - return http.StatusForbidden, "" - } - } - } - } - - if s.cfg.CtlAuth != nil && s.cfg.CtlAuth.Enabled { - var auth_hdr string - var auth_parts []string - var username string - var password string - var credpass string - var ok bool - var err error - - auth_hdr = req.Header.Get("Authorization") - if auth_hdr == "" { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm } - - auth_parts = strings.Fields(auth_hdr) - if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") && s.cfg.CtlAuth.TokenRsaKey != nil { - var jwt *JWT[ServerTokenClaim] - var claim ServerTokenClaim - jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim) - err = jwt.VerifyRS512(strings.TrimSpace(auth_parts[1])) - if err == nil { - // verification ok. let's check the actual payload - var now time.Time - now = time.Now() - if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return http.StatusOK, "" } // not expired - } - } - - // fall back to basic authentication - username, password, ok = req.BasicAuth() - if !ok { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm } - - credpass, ok = s.cfg.CtlAuth.Creds[username] - if !ok || credpass != password { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm } - } - - return http.StatusOK, "" + if ctl.s.cfg.CtlAuth == nil { return http.StatusOK, "" } + return ctl.s.cfg.CtlAuth.Authenticate(req) } // ------------------------------------ diff --git a/server-proxy.go b/server-proxy.go index e4c5612..e0e97e7 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -474,10 +474,10 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R switch pxy.file { case "xterm.js": - status_code = write_js_resp_header(w, http.StatusOK) + status_code = WriteJsRespHeader(w, http.StatusOK) w.Write(xterm_js) case "xterm-addon-fit.js": - status_code = write_js_resp_header(w, http.StatusOK) + status_code = WriteJsRespHeader(w, http.StatusOK) w.Write(xterm_addon_fit_js) case "xterm.css": status_code = WriteCssRespHeader(w, http.StatusOK) diff --git a/server.go b/server.go index 4bea8cd..45de1de 100644 --- a/server.go +++ b/server.go @@ -1,7 +1,6 @@ package hodu import "context" -import "crypto/rsa" import "crypto/tls" import "errors" import "fmt" @@ -43,17 +42,6 @@ type ServerSvcPortMap = map[PortId]ConnRouteId type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error) -type ServerHttpAuthCredMap map[string]string - -type ServerHttpAuthConfig struct { - Enabled bool - Realm string - Creds ServerHttpAuthCredMap - TokenTtl time.Duration - TokenRsaKey *rsa.PrivateKey - AccessRules []HttpAccessRule -} - type ServerConfig struct { RpcAddrs []string RpcTls *tls.Config @@ -63,7 +51,7 @@ type ServerConfig struct { CtlAddrs []string CtlTls *tls.Config CtlPrefix string - CtlAuth *ServerHttpAuthConfig + CtlAuth *HttpAuthConfig PxyAddrs []string PxyTls *tls.Config