added http auth config to the client-side control channel

This commit is contained in:
hyung-hwan 2025-02-01 00:06:05 +09:00
parent 16327fc576
commit 0fb57cb77b
8 changed files with 206 additions and 127 deletions

View File

@ -91,6 +91,10 @@ type client_ctl struct {
id string
}
type client_ctl_token struct {
client_ctl
}
type client_ctl_client_conns struct {
client_ctl
//c *Client
@ -131,6 +135,62 @@ func (ctl *client_ctl) Id() string {
return ctl.id
}
func (ctl *client_ctl) Authenticate(req *http.Request) (int, string) {
if ctl.c.ctl_auth == nil { return http.StatusOK, "" }
return ctl.c.ctl_auth.Authenticate(req)
}
// ------------------------------------
func (ctl *client_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
var c *Client
var status_code int
var je *json.Encoder
var err error
c = ctl.c
je = json.NewEncoder(w)
switch req.Method {
case http.MethodGet:
var jwt *JWT[ServerTokenClaim]
var claim ServerTokenClaim
var tok string
var now time.Time
if c.ctl_auth == nil || !c.ctl_auth.Enabled || c.ctl_auth.TokenRsaKey == nil {
status_code = WriteJsonRespHeader(w, http.StatusForbidden)
err = fmt.Errorf("auth not enabled or token rsa key not set")
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
}
now = time.Now()
claim.IssuedAt = now.Unix()
claim.ExpiresAt = now.Add(c.ctl_auth.TokenTtl).Unix()
jwt = NewJWT(c.ctl_auth.TokenRsaKey, &claim)
tok, err = jwt.SignRS512()
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusInternalServerError)
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
}
status_code = WriteJsonRespHeader(w, http.StatusOK)
err = je.Encode(json_out_token{ AccessToken: tok }) // TODO: refresh token
if err != nil { goto oops }
default:
status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed)
}
//done:
return status_code, nil
oops:
return status_code, err
}
// ------------------------------------
func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {

View File

@ -55,16 +55,17 @@ type ClientConfigActive struct {
type Client struct {
Named
ctx context.Context
ctx_cancel context.CancelFunc
ctltlscfg *tls.Config
rpctlscfg *tls.Config
ctx context.Context
ctx_cancel context.CancelFunc
ext_mtx sync.Mutex
ext_mtx sync.Mutex
ext_svcs []Service
rpc_tls *tls.Config
ctl_tls *tls.Config
ctl_addr []string
ctl_prefix string
ctl_prefix string
ctl_auth *HttpAuthConfig
ctl_mux *http.ServeMux
ctl []*http.Server // control server
@ -930,16 +931,16 @@ start_over:
cts.State = CLIENT_CONN_CONNECTING
cts.cfg.Index = (cts.cfg.Index + 1) % len(cts.cfg.ServerAddrs)
cts.cli.log.Write(cts.Sid, LOG_INFO, "Connecting to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
if cts.cli.rpctlscfg == nil {
if cts.cli.rpc_tls == nil {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
if cts.cfg.ServerAuthority != "" { opts = append(opts, grpc.WithAuthority(cts.cfg.ServerAuthority)) }
} else {
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(cts.cli.rpctlscfg)))
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(cts.cli.rpc_tls)))
// set the http2 :authority header with tls server name defined.
if cts.cfg.ServerAuthority != "" {
opts = append(opts, grpc.WithAuthority(cts.cfg.ServerAuthority))
} else if cts.cli.rpctlscfg.ServerName != "" {
opts = append(opts, grpc.WithAuthority(cts.cli.rpctlscfg.ServerName))
} else if cts.cli.rpc_tls.ServerName != "" {
opts = append(opts, grpc.WithAuthority(cts.cli.rpc_tls.ServerName))
}
}
if cts.cfg.ServerSeedTmout > 0 {
@ -1225,6 +1226,7 @@ func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
type ClientHttpHandler interface {
Id() string
Authenticate(req *http.Request) (int, string)
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
}
@ -1234,6 +1236,7 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler {
var err error
var start_time time.Time
var time_taken time.Duration
var realm string
// this deferred function is to overcome the recovering implemenation
// from panic done in go's http server. in that implemenation, panic
@ -1247,11 +1250,15 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler {
start_time = time.Now()
// TODO: some kind of authorization, especially for ctl
//req.BasicAuth()
//req.Header.Get("Authorization")
status_code, err = handler.ServeHTTP(w, req)
status_code, realm = handler.Authenticate(req)
if status_code == http.StatusUnauthorized && realm != "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm))
WriteEmptyRespHeader(w, status_code)
} else if status_code == http.StatusOK {
status_code, err = handler.ServeHTTP(w, req)
} else {
WriteEmptyRespHeader(w, status_code)
}
// TODO: statistics by status_code and end point types.
time_taken = time.Now().Sub(start_time)
@ -1266,15 +1273,13 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler {
})
}
func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client {
func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []string, ctl_prefix string, ctl_tls *tls.Config, ctl_auth *HttpAuthConfig, rpc_tls *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client {
var c Client
var i int
var hs_log *log.Logger
c.name = name
c.ctx, c.ctx_cancel = context.WithCancel(ctx)
c.ctltlscfg = ctltlscfg
c.rpctlscfg = rpctlscfg
c.ext_svcs = make([]Service, 0, 1)
c.ptc_tmout = peer_conn_tmout
c.ptc_limit = peer_max
@ -1284,8 +1289,11 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri
c.stop_req.Store(false)
c.stop_chan = make(chan bool, 8)
c.log = logger
c.ctl_prefix = ctl_prefix
c.rpc_tls = rpc_tls
c.ctl_auth = ctl_auth
c.ctl_tls = ctl_tls
c.ctl_prefix = ctl_prefix
c.ctl_mux = http.NewServeMux()
c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/client-conns",
c.wrap_http_handler(&client_ctl_client_conns{client_ctl{c: &c, id: HS_ID_CTL}}))
@ -1303,6 +1311,8 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri
c.wrap_http_handler(&client_ctl_client_conns_id_routes_id_peers_id{client_ctl{c: &c, id: HS_ID_CTL}}))
c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/stats",
c.wrap_http_handler(&client_ctl_stats{client_ctl{c: &c, id: HS_ID_CTL}}))
c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/token",
c.wrap_http_handler(&client_ctl_token{client_ctl{c: &c, id: HS_ID_CTL}}))
// TODO: make this optional. add this endpoint only if it's enabled...
c.promreg = prometheus.NewRegistry()
@ -1322,7 +1332,7 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri
c.ctl[i] = &http.Server{
Addr: ctl_addrs[i],
Handler: c.ctl_mux,
TLSConfig: c.ctltlscfg,
TLSConfig: c.ctl_tls,
ErrorLog: hs_log,
// TODO: more settings
}
@ -1575,10 +1585,10 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
// check it again to make the guard slightly more stable
// although it's still possible that the stop request is made
// after Listen()
if c.ctltlscfg == nil {
if c.ctl_tls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // c.ctltlscfg must provide a certificate and a key
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
}
} else {
err = fmt.Errorf("stop requested")

View File

@ -358,8 +358,8 @@ func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) {
}
// --------------------------------------------------------------------
func make_server_auth_config(cfg *HttpAuthConfig) (*hodu.ServerHttpAuthConfig, error) {
var config hodu.ServerHttpAuthConfig
func make_http_auth_config(cfg *HttpAuthConfig) (*hodu.HttpAuthConfig, error) {
var config hodu.HttpAuthConfig
var cred string
var b []byte
var x []string
@ -371,7 +371,7 @@ func make_server_auth_config(cfg *HttpAuthConfig) (*hodu.ServerHttpAuthConfig, e
config.Enabled = cfg.Enabled
config.Realm = cfg.Realm
config.Creds = make(hodu.ServerHttpAuthCredMap)
config.Creds = make(hodu.HttpAuthCredMap)
config.TokenTtl, err = hodu.ParseDurationString(cfg.TokenTtl)
if err != nil {
return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err)

View File

@ -127,7 +127,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx
if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs }
if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs }
config.CtlAuth, err = make_server_auth_config(&cfg.CTL.Service.Auth)
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
if err != nil { return err }
config.CtlPrefix = cfg.CTL.Service.Prefix
@ -247,8 +247,9 @@ func parse_client_route_config(v string) (*hodu.ClientRouteConfig, error) {
func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, cfg *ClientConfig) error {
var c *hodu.Client
var ctltlscfg *tls.Config
var rpctlscfg *tls.Config
var ctltlscfg *tls.Config
var ctl_auth *hodu.HttpAuthConfig
var ctl_prefix string
var cc hodu.ClientConfig
var logger *AppLogger
@ -265,18 +266,17 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string,
logmask = hodu.LOG_ALL
if cfg != nil {
ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS)
if err != nil {
return err
}
if err != nil { return err }
rpctlscfg, err = make_tls_client_config(&cfg.RPC.TLS)
if err != nil {
return err
}
if err != nil { return err }
if len(ctl_addrs) <= 0 { ctl_addrs = cfg.CTL.Service.Addrs }
if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs }
ctl_prefix = cfg.CTL.Service.Prefix
ctl_auth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
if err != nil { return err }
cc.ServerSeedTmout = cfg.RPC.Endpoint.SeedTmout
cc.ServerAuthority = cfg.RPC.Endpoint.Authority
logmask = log_strings_to_mask(cfg.APP.LogMask)
@ -314,6 +314,7 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string,
ctl_addrs,
ctl_prefix,
ctltlscfg,
ctl_auth,
rpctlscfg,
max_rpc_conns,
max_peers,

98
hodu.go
View File

@ -1,10 +1,12 @@
package hodu
import "crypto/rsa"
import "fmt"
import "net"
import "net/http"
import "net/netip"
import "os"
import "path/filepath"
import "runtime"
import "strings"
import "sync"
@ -61,6 +63,17 @@ type HttpAccessRule struct {
Action HttpAccessAction
}
type HttpAuthCredMap map[string]string
type HttpAuthConfig struct {
Enabled bool
Realm string
Creds HttpAuthCredMap
TokenTtl time.Duration
TokenRsaKey *rsa.PrivateKey
AccessRules []HttpAccessRule
}
type JsonErrmsg struct {
Text string `json:"error-text"`
}
@ -230,13 +243,15 @@ func DurationToSecString(d time.Duration) string {
return fmt.Sprintf("%.09f", d.Seconds())
}
// ------------------------------------
func WriteJsonRespHeader(w http.ResponseWriter, status_code int) int {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status_code)
return status_code
}
func write_js_resp_header(w http.ResponseWriter, status_code int) int {
func WriteJsRespHeader(w http.ResponseWriter, status_code int) int {
w.Header().Set("Content-Type", "application/javascript")
w.WriteHeader(status_code)
return status_code
@ -259,6 +274,8 @@ func WriteEmptyRespHeader(w http.ResponseWriter, status_code int) int {
return status_code
}
// ------------------------------------
func server_route_to_proxy_info(r *ServerRoute) *ServerRouteProxyInfo {
return &ServerRouteProxyInfo{
SvcOption: r.SvcOption,
@ -279,6 +296,8 @@ func proxy_info_to_server_route(pi *ServerRouteProxyInfo) *ServerRoute {
}
}
// ------------------------------------
func (stats *json_out_go_stats) from_runtime_stats() {
var mstat runtime.MemStats
@ -312,3 +331,80 @@ func (stats *json_out_go_stats) from_runtime_stats() {
stats.GCSysBytes = mstat.GCSys
stats.OtherSysBytes = mstat.OtherSys
}
// ------------------------------------
func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) {
var rule HttpAccessRule
var raddrport netip.AddrPort
var raddr netip.Addr
var err error
raddrport, err = netip.ParseAddrPort(req.RemoteAddr)
if err == nil { raddr = raddrport.Addr() }
for _, rule = range auth.AccessRules {
// i don't take into account X-Forwarded-For and similar headers
if req.URL.Path == rule.Prefix || strings.HasPrefix(req.URL.Path, filepath.Clean(rule.Prefix + "/")) {
var org_net_ok bool
if len(rule.OrgNets) > 0 && raddr.IsValid() {
var netpfx netip.Prefix
org_net_ok = false
for _, netpfx = range rule.OrgNets {
if err == nil && netpfx.Contains(raddr) {
org_net_ok = true
break
}
}
} else {
org_net_ok = true
}
if org_net_ok {
if rule.Action == HTTP_ACCESS_ACCEPT {
return http.StatusOK, ""
} else if rule.Action == HTTP_ACCESS_REJECT {
return http.StatusForbidden, ""
}
}
}
}
if auth != nil && auth.Enabled {
var auth_hdr string
var auth_parts []string
var username string
var password string
var credpass string
var ok bool
var err error
auth_hdr = req.Header.Get("Authorization")
if auth_hdr == "" { return http.StatusUnauthorized, auth.Realm }
auth_parts = strings.Fields(auth_hdr)
if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") && auth.TokenRsaKey != nil {
var jwt *JWT[ServerTokenClaim]
var claim ServerTokenClaim
jwt = NewJWT(auth.TokenRsaKey, &claim)
err = jwt.VerifyRS512(strings.TrimSpace(auth_parts[1]))
if err == nil {
// verification ok. let's check the actual payload
var now time.Time
now = time.Now()
if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return http.StatusOK, "" } // not expired
}
}
// fall back to basic authentication
username, password, ok = req.BasicAuth()
if !ok { return http.StatusUnauthorized, auth.Realm }
credpass, ok = auth.Creds[username]
if !ok || credpass != password { return http.StatusUnauthorized, auth.Realm }
}
return http.StatusOK, ""
}

View File

@ -3,9 +3,6 @@ package hodu
import "encoding/json"
import "fmt"
import "net/http"
import "net/netip"
import "path/filepath"
import "strings"
import "time"
type ServerTokenClaim struct {
@ -82,81 +79,8 @@ func (ctl *server_ctl) Id() string {
}
func (ctl *server_ctl) Authenticate(req *http.Request) (int, string) {
var s *Server
var rule HttpAccessRule
var raddrport netip.AddrPort
var raddr netip.Addr
var err error
s = ctl.s
raddrport, err = netip.ParseAddrPort(req.RemoteAddr)
if err == nil { raddr = raddrport.Addr() }
for _, rule = range s.cfg.CtlAuth.AccessRules {
// i don't take into account X-Forwarded-For and similar headers
if req.URL.Path == rule.Prefix || strings.HasPrefix(req.URL.Path, filepath.Clean(rule.Prefix + "/")) {
var org_net_ok bool
if len(rule.OrgNets) > 0 && raddr.IsValid() {
var netpfx netip.Prefix
org_net_ok = false
for _, netpfx = range rule.OrgNets {
if err == nil && netpfx.Contains(raddr) {
org_net_ok = true
break
}
}
} else {
org_net_ok = true
}
if org_net_ok {
if rule.Action == HTTP_ACCESS_ACCEPT {
return http.StatusOK, ""
} else if rule.Action == HTTP_ACCESS_REJECT {
return http.StatusForbidden, ""
}
}
}
}
if s.cfg.CtlAuth != nil && s.cfg.CtlAuth.Enabled {
var auth_hdr string
var auth_parts []string
var username string
var password string
var credpass string
var ok bool
var err error
auth_hdr = req.Header.Get("Authorization")
if auth_hdr == "" { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm }
auth_parts = strings.Fields(auth_hdr)
if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") && s.cfg.CtlAuth.TokenRsaKey != nil {
var jwt *JWT[ServerTokenClaim]
var claim ServerTokenClaim
jwt = NewJWT(s.cfg.CtlAuth.TokenRsaKey, &claim)
err = jwt.VerifyRS512(strings.TrimSpace(auth_parts[1]))
if err == nil {
// verification ok. let's check the actual payload
var now time.Time
now = time.Now()
if now.After(time.Unix(claim.IssuedAt, 0)) && now.Before(time.Unix(claim.ExpiresAt, 0)) { return http.StatusOK, "" } // not expired
}
}
// fall back to basic authentication
username, password, ok = req.BasicAuth()
if !ok { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm }
credpass, ok = s.cfg.CtlAuth.Creds[username]
if !ok || credpass != password { return http.StatusUnauthorized, s.cfg.CtlAuth.Realm }
}
return http.StatusOK, ""
if ctl.s.cfg.CtlAuth == nil { return http.StatusOK, "" }
return ctl.s.cfg.CtlAuth.Authenticate(req)
}
// ------------------------------------

View File

@ -474,10 +474,10 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R
switch pxy.file {
case "xterm.js":
status_code = write_js_resp_header(w, http.StatusOK)
status_code = WriteJsRespHeader(w, http.StatusOK)
w.Write(xterm_js)
case "xterm-addon-fit.js":
status_code = write_js_resp_header(w, http.StatusOK)
status_code = WriteJsRespHeader(w, http.StatusOK)
w.Write(xterm_addon_fit_js)
case "xterm.css":
status_code = WriteCssRespHeader(w, http.StatusOK)

View File

@ -1,7 +1,6 @@
package hodu
import "context"
import "crypto/rsa"
import "crypto/tls"
import "errors"
import "fmt"
@ -43,17 +42,6 @@ type ServerSvcPortMap = map[PortId]ConnRouteId
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
type ServerHttpAuthCredMap map[string]string
type ServerHttpAuthConfig struct {
Enabled bool
Realm string
Creds ServerHttpAuthCredMap
TokenTtl time.Duration
TokenRsaKey *rsa.PrivateKey
AccessRules []HttpAccessRule
}
type ServerConfig struct {
RpcAddrs []string
RpcTls *tls.Config
@ -63,7 +51,7 @@ type ServerConfig struct {
CtlAddrs []string
CtlTls *tls.Config
CtlPrefix string
CtlAuth *ServerHttpAuthConfig
CtlAuth *HttpAuthConfig
PxyAddrs []string
PxyTls *tls.Config