diff --git a/client-ctl.go b/client-ctl.go index ddcc403..0d213fe 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -135,6 +135,10 @@ func (ctl *client_ctl) Id() string { return ctl.id } +func (ctl *client_ctl) Cors(req *http.Request) bool { + return ctl.c.ctl_cors +} + func (ctl *client_ctl) Authenticate(req *http.Request) (int, string) { if ctl.c.ctl_auth == nil { return http.StatusOK, "" } return ctl.c.ctl_auth.Authenticate(req) @@ -257,7 +261,7 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R // after hacing connected to the server. therefore, the json_in_client_conn // type contains a server address field only. var s json_in_client_conn - var cc ClientConfig + var cc ClientConnConfig var cts *ClientConn err = json.NewDecoder(req.Body).Decode(&s) diff --git a/client.go b/client.go index 3a078f1..ecc2716 100644 --- a/client.go +++ b/client.go @@ -39,16 +39,29 @@ type ClientRouteConfig struct { Static bool } -type ClientConfig struct { +type ClientConnConfig struct { ServerAddrs []string Routes []ClientRouteConfig ServerSeedTmout time.Duration ServerAuthority string // http2 :authority header } -type ClientConfigActive struct { +type ClientConnConfigActive struct { Index int - ClientConfig + ClientConnConfig +} + +type ClientConfig struct { + CtlAddrs []string + CtlTls *tls.Config + CtlPrefix string + CtlAuth *HttpAuthConfig + CtlCors bool + + RpcTls *tls.Config + RpcConnMax int + PeerConnMax int + PeerConnTmout time.Duration } type Client struct { @@ -64,6 +77,7 @@ type Client struct { ctl_tls *tls.Config ctl_addr []string ctl_prefix string + ctl_cors bool ctl_auth *HttpAuthConfig ctl_mux *http.ServeMux ctl []*http.Server // control server @@ -102,7 +116,7 @@ const ( // client connection to server type ClientConn struct { cli *Client - cfg ClientConfigActive + cfg ClientConnConfigActive Id ConnId Sid string // id rendered in string State ClientConnState @@ -660,14 +674,14 @@ func (r *ClientRoute) ReportEvent(pts_id PeerId, event_type PACKET_KIND, event_d } // -------------------------------------------------------------------- -func NewClientConn(c *Client, cfg *ClientConfig) *ClientConn { +func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn { var cts ClientConn var i int cts.cli = c cts.route_map = make(ClientRouteMap) cts.route_next_id = 1 - cts.cfg.ClientConfig = *cfg + cts.cfg.ClientConnConfig = *cfg cts.stop_req.Store(false) cts.stop_chan = make(chan bool, 8) @@ -1225,6 +1239,7 @@ func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) { type ClientHttpHandler interface { Id() string + Cors(req *http.Request) bool Authenticate(req *http.Request) (int, string) ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error) } @@ -1249,16 +1264,24 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler { start_time = time.Now() - status_code, realm = handler.Authenticate(req) - if status_code == http.StatusUnauthorized { - if realm != "" { - w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm)) - } - WriteEmptyRespHeader(w, status_code) - } else if status_code == http.StatusOK { - status_code, err = handler.ServeHTTP(w, req) + if handler.Cors(req) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + } + if req.Method == http.MethodOptions { + status_code = WriteEmptyRespHeader(w, http.StatusOK) } else { - WriteEmptyRespHeader(w, status_code) + status_code, realm = handler.Authenticate(req) + if status_code == http.StatusUnauthorized { + if realm != "" { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm)) + } + WriteEmptyRespHeader(w, status_code) + } else if status_code == http.StatusOK { + status_code, err = handler.ServeHTTP(w, req) + } else { + WriteEmptyRespHeader(w, status_code) + } } // TODO: statistics by status_code and end point types. @@ -1274,7 +1297,7 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler { }) } -func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []string, ctl_prefix string, ctl_tls *tls.Config, ctl_auth *HttpAuthConfig, rpc_tls *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client { +func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfig) *Client { var c Client var i int var hs_log *log.Logger @@ -1282,19 +1305,20 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri c.name = name c.ctx, c.ctx_cancel = context.WithCancel(ctx) c.ext_svcs = make([]Service, 0, 1) - c.ptc_tmout = peer_conn_tmout - c.ptc_limit = peer_max - c.cts_limit = rpc_max + c.ptc_tmout = cfg.PeerConnTmout + c.ptc_limit = cfg.PeerConnMax + c.cts_limit = cfg.RpcConnMax c.cts_next_id = 1 c.cts_map = make(ClientConnMap) c.stop_req.Store(false) c.stop_chan = make(chan bool, 8) c.log = logger - c.rpc_tls = rpc_tls - c.ctl_auth = ctl_auth - c.ctl_tls = ctl_tls - c.ctl_prefix = ctl_prefix + c.rpc_tls = cfg.RpcTls + c.ctl_auth = cfg.CtlAuth + c.ctl_tls = cfg.CtlTls + c.ctl_prefix = cfg.CtlPrefix + c.ctl_cors = cfg.CtlCors c.ctl_mux = http.NewServeMux() c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/client-conns", c.wrap_http_handler(&client_ctl_client_conns{client_ctl{c: &c, id: HS_ID_CTL}})) @@ -1323,15 +1347,15 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri promhttp.HandlerFor(c.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true })) - c.ctl_addr = make([]string, len(ctl_addrs)) - c.ctl = make([]*http.Server, len(ctl_addrs)) - copy(c.ctl_addr, ctl_addrs) + c.ctl_addr = make([]string, len(cfg.CtlAddrs)) + c.ctl = make([]*http.Server, len(cfg.CtlAddrs)) + copy(c.ctl_addr, cfg.CtlAddrs) hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0) - for i = 0; i < len(ctl_addrs); i++ { + for i = 0; i < len(cfg.CtlAddrs); i++ { c.ctl[i] = &http.Server{ - Addr: ctl_addrs[i], + Addr: cfg.CtlAddrs[i], Handler: c.ctl_mux, TLSConfig: c.ctl_tls, ErrorLog: hs_log, @@ -1346,7 +1370,7 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri return &c } -func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) { +func (c *Client) AddNewClientConn(cfg *ClientConnConfig) (*ClientConn, error) { var cts *ClientConn var ok bool var start_id ConnId @@ -1622,7 +1646,7 @@ func (c *Client) RunTask(wg *sync.WaitGroup) { // so no call to wg.Done() } -func (c *Client) start_service(cfg *ClientConfig) (*ClientConn, error) { +func (c *Client) start_service(cfg *ClientConnConfig) (*ClientConn, error) { var cts *ClientConn var err error @@ -1639,10 +1663,10 @@ func (c *Client) start_service(cfg *ClientConfig) (*ClientConn, error) { } func (c *Client) StartService(data interface{}) { - var cfg *ClientConfig + var cfg *ClientConnConfig var ok bool - cfg, ok = data.(*ClientConfig) + cfg, ok = data.(*ClientConnConfig) if !ok { c.log.Write("", LOG_ERROR, "Failed to start service - invalid configuration - %v", data) } else { diff --git a/cmd/config.go b/cmd/config.go index b7ed371..e8452aa 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -66,6 +66,7 @@ type HttpAuthConfig struct { type CTLServiceConfig struct { Prefix string `yaml:"prefix"` // url prefix for control channel endpoints Addrs []string `yaml:"addresses"` + Cors bool `yaml:"cors"` Auth HttpAuthConfig `yaml:"auth"` } diff --git a/cmd/main.go b/cmd/main.go index dc6dea5..82ec653 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,7 +1,6 @@ package main import "context" -import "crypto/tls" import _ "embed" import "flag" import "fmt" @@ -13,7 +12,6 @@ import "os/signal" import "strings" import "sync" import "syscall" -import "time" // Don't change these items to 'const' as they can be overridden externally with a linker option var HODU_NAME string = "hodu" @@ -126,6 +124,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs } if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs } + config.CtlCors = cfg.CTL.Service.Cors config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth) if err != nil { return err } @@ -246,33 +245,33 @@ func parse_client_route_config(v string) (*hodu.ClientRouteConfig, error) { func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, logfile string, cfg *ClientConfig) error { var c *hodu.Client - var rpctlscfg *tls.Config - var ctltlscfg *tls.Config - var ctl_auth *hodu.HttpAuthConfig - var ctl_prefix string - var cc hodu.ClientConfig + var config *hodu.ClientConfig + var cc hodu.ClientConnConfig var logger *AppLogger var logmask hodu.LogMask var logfile_maxsize int64 var logfile_rotate int - var max_rpc_conns int - var max_peers int - var peer_conn_tmout time.Duration var i int var err error logmask = hodu.LOG_ALL + + config = &hodu.ClientConfig{ + CtlAddrs: ctl_addrs, + } + if cfg != nil { - ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS) + config.CtlTls, err = make_tls_server_config(&cfg.CTL.TLS) if err != nil { return err } - rpctlscfg, err = make_tls_client_config(&cfg.RPC.TLS) + config.RpcTls, err = make_tls_client_config(&cfg.RPC.TLS) if err != nil { return err } - if len(ctl_addrs) <= 0 { ctl_addrs = cfg.CTL.Service.Addrs } if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs } - ctl_prefix = cfg.CTL.Service.Prefix + if len(config.CtlAddrs) <= 0 { config.CtlAddrs = cfg.CTL.Service.Addrs } - ctl_auth, err = make_http_auth_config(&cfg.CTL.Service.Auth) + config.CtlPrefix = cfg.CTL.Service.Prefix + config.CtlCors = cfg.CTL.Service.Cors + config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth) if err != nil { return err } cc.ServerSeedTmout = cfg.RPC.Endpoint.SeedTmout @@ -281,9 +280,9 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, if logfile == "" { logfile = cfg.APP.LogFile } logfile_maxsize = cfg.APP.LogMaxSize logfile_rotate = cfg.APP.LogRotate - max_rpc_conns = cfg.APP.MaxRpcConns - max_peers = cfg.APP.MaxPeers - peer_conn_tmout = cfg.APP.PeerConnTmout + config.RpcConnMax = cfg.APP.MaxRpcConns + config.PeerConnMax = cfg.APP.MaxPeers + config.PeerConnTmout = cfg.APP.PeerConnTmout } // unlke the server, we allow the client to start with no rpc address. @@ -305,18 +304,7 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, return fmt.Errorf("failed to initialize logger - %s", err.Error()) } } - c = hodu.NewClient( - context.Background(), - HODU_NAME, - logger, - ctl_addrs, - ctl_prefix, - ctltlscfg, - ctl_auth, - rpctlscfg, - max_rpc_conns, - max_peers, - peer_conn_tmout) + c = hodu.NewClient(context.Background(), HODU_NAME, logger, config) c.StartService(&cc) c.StartCtlService() // control channel diff --git a/hodu.go b/hodu.go index 7c4f008..5b29cb3 100644 --- a/hodu.go +++ b/hodu.go @@ -1,6 +1,7 @@ package hodu import "crypto/rsa" +import "encoding/base64" import "fmt" import "net" import "net/http" @@ -403,8 +404,21 @@ func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) { } } + // this application wants these two header values to be base64-encoded username = req.Header.Get("X-Auth-Username") password = req.Header.Get("X-Auth-Password") + if username != "" { + var tmp []byte + tmp, err = base64.StdEncoding.DecodeString(username) + if err != nil { return http.StatusBadRequest, "" } + username = string(tmp) + } + if password != "" { + var tmp []byte + tmp, err = base64.StdEncoding.DecodeString(password) + if err != nil { return http.StatusBadRequest, "" } + password = string(tmp) + } // fall back to basic authentication if username == "" && password == "" && auth.Realm != "" { diff --git a/server-ctl.go b/server-ctl.go index 21c43a7..b32957c 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -78,6 +78,10 @@ func (ctl *server_ctl) Id() string { return ctl.id } +func (ctl *server_ctl) Cors(req *http.Request) bool { + return ctl.s.cfg.CtlCors +} + func (ctl *server_ctl) Authenticate(req *http.Request) (int, string) { if ctl.s.cfg.CtlAuth == nil { return http.StatusOK, "" } return ctl.s.cfg.CtlAuth.Authenticate(req) diff --git a/server-proxy.go b/server-proxy.go index e0e97e7..f10e02a 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -190,6 +190,10 @@ func (pxy *server_proxy) Id() string { return pxy.id } +func (pxy *server_proxy) Cors(req *http.Request) bool { + return false +} + func (pxy *server_proxy) Authenticate(req *http.Request) (int, string) { return http.StatusOK, "" } diff --git a/server.go b/server.go index 3c94038..4903b89 100644 --- a/server.go +++ b/server.go @@ -52,6 +52,7 @@ type ServerConfig struct { CtlTls *tls.Config CtlPrefix string CtlAuth *HttpAuthConfig + CtlCors bool PxyAddrs []string PxyTls *tls.Config @@ -942,6 +943,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) { type ServerHttpHandler interface { Id() string + Cors(req *http.Request) bool Authenticate(req *http.Request) (int, string) ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error) } @@ -965,16 +967,25 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler { }() start_time = time.Now() - status_code, realm = handler.Authenticate(req) - if status_code == http.StatusUnauthorized { - if realm != "" { - w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm)) - } - WriteEmptyRespHeader(w, status_code) - } else if status_code == http.StatusOK { - status_code, err = handler.ServeHTTP(w, req) + + if handler.Cors(req) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + } + if req.Method == http.MethodOptions { + status_code = WriteEmptyRespHeader(w, http.StatusOK) } else { - WriteEmptyRespHeader(w, status_code) + status_code, realm = handler.Authenticate(req) + if status_code == http.StatusUnauthorized { + if realm != "" { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm)) + } + WriteEmptyRespHeader(w, status_code) + } else if status_code == http.StatusOK { + status_code, err = handler.ServeHTTP(w, req) + } else { + WriteEmptyRespHeader(w, status_code) + } } time_taken = time.Now().Sub(start_time)