updated to support cors in primitive manner

This commit is contained in:
hyung-hwan 2025-02-10 14:48:18 +09:00
parent ec51c101ec
commit 3dc5d9c91e
8 changed files with 122 additions and 72 deletions

View File

@ -135,6 +135,10 @@ func (ctl *client_ctl) Id() string {
return ctl.id
}
func (ctl *client_ctl) Cors(req *http.Request) bool {
return ctl.c.ctl_cors
}
func (ctl *client_ctl) Authenticate(req *http.Request) (int, string) {
if ctl.c.ctl_auth == nil { return http.StatusOK, "" }
return ctl.c.ctl_auth.Authenticate(req)
@ -257,7 +261,7 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R
// after hacing connected to the server. therefore, the json_in_client_conn
// type contains a server address field only.
var s json_in_client_conn
var cc ClientConfig
var cc ClientConnConfig
var cts *ClientConn
err = json.NewDecoder(req.Body).Decode(&s)

View File

@ -39,16 +39,29 @@ type ClientRouteConfig struct {
Static bool
}
type ClientConfig struct {
type ClientConnConfig struct {
ServerAddrs []string
Routes []ClientRouteConfig
ServerSeedTmout time.Duration
ServerAuthority string // http2 :authority header
}
type ClientConfigActive struct {
type ClientConnConfigActive struct {
Index int
ClientConfig
ClientConnConfig
}
type ClientConfig struct {
CtlAddrs []string
CtlTls *tls.Config
CtlPrefix string
CtlAuth *HttpAuthConfig
CtlCors bool
RpcTls *tls.Config
RpcConnMax int
PeerConnMax int
PeerConnTmout time.Duration
}
type Client struct {
@ -64,6 +77,7 @@ type Client struct {
ctl_tls *tls.Config
ctl_addr []string
ctl_prefix string
ctl_cors bool
ctl_auth *HttpAuthConfig
ctl_mux *http.ServeMux
ctl []*http.Server // control server
@ -102,7 +116,7 @@ const (
// client connection to server
type ClientConn struct {
cli *Client
cfg ClientConfigActive
cfg ClientConnConfigActive
Id ConnId
Sid string // id rendered in string
State ClientConnState
@ -660,14 +674,14 @@ func (r *ClientRoute) ReportEvent(pts_id PeerId, event_type PACKET_KIND, event_d
}
// --------------------------------------------------------------------
func NewClientConn(c *Client, cfg *ClientConfig) *ClientConn {
func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
var cts ClientConn
var i int
cts.cli = c
cts.route_map = make(ClientRouteMap)
cts.route_next_id = 1
cts.cfg.ClientConfig = *cfg
cts.cfg.ClientConnConfig = *cfg
cts.stop_req.Store(false)
cts.stop_chan = make(chan bool, 8)
@ -1225,6 +1239,7 @@ func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
type ClientHttpHandler interface {
Id() string
Cors(req *http.Request) bool
Authenticate(req *http.Request) (int, string)
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
}
@ -1249,6 +1264,13 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler {
start_time = time.Now()
if handler.Cors(req) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
}
if req.Method == http.MethodOptions {
status_code = WriteEmptyRespHeader(w, http.StatusOK)
} else {
status_code, realm = handler.Authenticate(req)
if status_code == http.StatusUnauthorized {
if realm != "" {
@ -1260,6 +1282,7 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler {
} else {
WriteEmptyRespHeader(w, status_code)
}
}
// TODO: statistics by status_code and end point types.
time_taken = time.Now().Sub(start_time)
@ -1274,7 +1297,7 @@ func (c *Client) wrap_http_handler(handler ClientHttpHandler) http.Handler {
})
}
func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []string, ctl_prefix string, ctl_tls *tls.Config, ctl_auth *HttpAuthConfig, rpc_tls *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client {
func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfig) *Client {
var c Client
var i int
var hs_log *log.Logger
@ -1282,19 +1305,20 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri
c.name = name
c.ctx, c.ctx_cancel = context.WithCancel(ctx)
c.ext_svcs = make([]Service, 0, 1)
c.ptc_tmout = peer_conn_tmout
c.ptc_limit = peer_max
c.cts_limit = rpc_max
c.ptc_tmout = cfg.PeerConnTmout
c.ptc_limit = cfg.PeerConnMax
c.cts_limit = cfg.RpcConnMax
c.cts_next_id = 1
c.cts_map = make(ClientConnMap)
c.stop_req.Store(false)
c.stop_chan = make(chan bool, 8)
c.log = logger
c.rpc_tls = rpc_tls
c.ctl_auth = ctl_auth
c.ctl_tls = ctl_tls
c.ctl_prefix = ctl_prefix
c.rpc_tls = cfg.RpcTls
c.ctl_auth = cfg.CtlAuth
c.ctl_tls = cfg.CtlTls
c.ctl_prefix = cfg.CtlPrefix
c.ctl_cors = cfg.CtlCors
c.ctl_mux = http.NewServeMux()
c.ctl_mux.Handle(c.ctl_prefix + "/_ctl/client-conns",
c.wrap_http_handler(&client_ctl_client_conns{client_ctl{c: &c, id: HS_ID_CTL}}))
@ -1323,15 +1347,15 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri
promhttp.HandlerFor(c.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true }))
c.ctl_addr = make([]string, len(ctl_addrs))
c.ctl = make([]*http.Server, len(ctl_addrs))
copy(c.ctl_addr, ctl_addrs)
c.ctl_addr = make([]string, len(cfg.CtlAddrs))
c.ctl = make([]*http.Server, len(cfg.CtlAddrs))
copy(c.ctl_addr, cfg.CtlAddrs)
hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0)
for i = 0; i < len(ctl_addrs); i++ {
for i = 0; i < len(cfg.CtlAddrs); i++ {
c.ctl[i] = &http.Server{
Addr: ctl_addrs[i],
Addr: cfg.CtlAddrs[i],
Handler: c.ctl_mux,
TLSConfig: c.ctl_tls,
ErrorLog: hs_log,
@ -1346,7 +1370,7 @@ func NewClient(ctx context.Context, name string, logger Logger, ctl_addrs []stri
return &c
}
func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) {
func (c *Client) AddNewClientConn(cfg *ClientConnConfig) (*ClientConn, error) {
var cts *ClientConn
var ok bool
var start_id ConnId
@ -1622,7 +1646,7 @@ func (c *Client) RunTask(wg *sync.WaitGroup) {
// so no call to wg.Done()
}
func (c *Client) start_service(cfg *ClientConfig) (*ClientConn, error) {
func (c *Client) start_service(cfg *ClientConnConfig) (*ClientConn, error) {
var cts *ClientConn
var err error
@ -1639,10 +1663,10 @@ func (c *Client) start_service(cfg *ClientConfig) (*ClientConn, error) {
}
func (c *Client) StartService(data interface{}) {
var cfg *ClientConfig
var cfg *ClientConnConfig
var ok bool
cfg, ok = data.(*ClientConfig)
cfg, ok = data.(*ClientConnConfig)
if !ok {
c.log.Write("", LOG_ERROR, "Failed to start service - invalid configuration - %v", data)
} else {

View File

@ -66,6 +66,7 @@ type HttpAuthConfig struct {
type CTLServiceConfig struct {
Prefix string `yaml:"prefix"` // url prefix for control channel endpoints
Addrs []string `yaml:"addresses"`
Cors bool `yaml:"cors"`
Auth HttpAuthConfig `yaml:"auth"`
}

View File

@ -1,7 +1,6 @@
package main
import "context"
import "crypto/tls"
import _ "embed"
import "flag"
import "fmt"
@ -13,7 +12,6 @@ import "os/signal"
import "strings"
import "sync"
import "syscall"
import "time"
// Don't change these items to 'const' as they can be overridden externally with a linker option
var HODU_NAME string = "hodu"
@ -126,6 +124,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx
if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs }
if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs }
config.CtlCors = cfg.CTL.Service.Cors
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
if err != nil { return err }
@ -246,33 +245,33 @@ func parse_client_route_config(v string) (*hodu.ClientRouteConfig, error) {
func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, logfile string, cfg *ClientConfig) error {
var c *hodu.Client
var rpctlscfg *tls.Config
var ctltlscfg *tls.Config
var ctl_auth *hodu.HttpAuthConfig
var ctl_prefix string
var cc hodu.ClientConfig
var config *hodu.ClientConfig
var cc hodu.ClientConnConfig
var logger *AppLogger
var logmask hodu.LogMask
var logfile_maxsize int64
var logfile_rotate int
var max_rpc_conns int
var max_peers int
var peer_conn_tmout time.Duration
var i int
var err error
logmask = hodu.LOG_ALL
config = &hodu.ClientConfig{
CtlAddrs: ctl_addrs,
}
if cfg != nil {
ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS)
config.CtlTls, err = make_tls_server_config(&cfg.CTL.TLS)
if err != nil { return err }
rpctlscfg, err = make_tls_client_config(&cfg.RPC.TLS)
config.RpcTls, err = make_tls_client_config(&cfg.RPC.TLS)
if err != nil { return err }
if len(ctl_addrs) <= 0 { ctl_addrs = cfg.CTL.Service.Addrs }
if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs }
ctl_prefix = cfg.CTL.Service.Prefix
if len(config.CtlAddrs) <= 0 { config.CtlAddrs = cfg.CTL.Service.Addrs }
ctl_auth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
config.CtlPrefix = cfg.CTL.Service.Prefix
config.CtlCors = cfg.CTL.Service.Cors
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
if err != nil { return err }
cc.ServerSeedTmout = cfg.RPC.Endpoint.SeedTmout
@ -281,9 +280,9 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string,
if logfile == "" { logfile = cfg.APP.LogFile }
logfile_maxsize = cfg.APP.LogMaxSize
logfile_rotate = cfg.APP.LogRotate
max_rpc_conns = cfg.APP.MaxRpcConns
max_peers = cfg.APP.MaxPeers
peer_conn_tmout = cfg.APP.PeerConnTmout
config.RpcConnMax = cfg.APP.MaxRpcConns
config.PeerConnMax = cfg.APP.MaxPeers
config.PeerConnTmout = cfg.APP.PeerConnTmout
}
// unlke the server, we allow the client to start with no rpc address.
@ -305,18 +304,7 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string,
return fmt.Errorf("failed to initialize logger - %s", err.Error())
}
}
c = hodu.NewClient(
context.Background(),
HODU_NAME,
logger,
ctl_addrs,
ctl_prefix,
ctltlscfg,
ctl_auth,
rpctlscfg,
max_rpc_conns,
max_peers,
peer_conn_tmout)
c = hodu.NewClient(context.Background(), HODU_NAME, logger, config)
c.StartService(&cc)
c.StartCtlService() // control channel

14
hodu.go
View File

@ -1,6 +1,7 @@
package hodu
import "crypto/rsa"
import "encoding/base64"
import "fmt"
import "net"
import "net/http"
@ -403,8 +404,21 @@ func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) {
}
}
// this application wants these two header values to be base64-encoded
username = req.Header.Get("X-Auth-Username")
password = req.Header.Get("X-Auth-Password")
if username != "" {
var tmp []byte
tmp, err = base64.StdEncoding.DecodeString(username)
if err != nil { return http.StatusBadRequest, "" }
username = string(tmp)
}
if password != "" {
var tmp []byte
tmp, err = base64.StdEncoding.DecodeString(password)
if err != nil { return http.StatusBadRequest, "" }
password = string(tmp)
}
// fall back to basic authentication
if username == "" && password == "" && auth.Realm != "" {

View File

@ -78,6 +78,10 @@ func (ctl *server_ctl) Id() string {
return ctl.id
}
func (ctl *server_ctl) Cors(req *http.Request) bool {
return ctl.s.cfg.CtlCors
}
func (ctl *server_ctl) Authenticate(req *http.Request) (int, string) {
if ctl.s.cfg.CtlAuth == nil { return http.StatusOK, "" }
return ctl.s.cfg.CtlAuth.Authenticate(req)

View File

@ -190,6 +190,10 @@ func (pxy *server_proxy) Id() string {
return pxy.id
}
func (pxy *server_proxy) Cors(req *http.Request) bool {
return false
}
func (pxy *server_proxy) Authenticate(req *http.Request) (int, string) {
return http.StatusOK, ""
}

View File

@ -52,6 +52,7 @@ type ServerConfig struct {
CtlTls *tls.Config
CtlPrefix string
CtlAuth *HttpAuthConfig
CtlCors bool
PxyAddrs []string
PxyTls *tls.Config
@ -942,6 +943,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
type ServerHttpHandler interface {
Id() string
Cors(req *http.Request) bool
Authenticate(req *http.Request) (int, string)
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
}
@ -965,6 +967,14 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler {
}()
start_time = time.Now()
if handler.Cors(req) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
}
if req.Method == http.MethodOptions {
status_code = WriteEmptyRespHeader(w, http.StatusOK)
} else {
status_code, realm = handler.Authenticate(req)
if status_code == http.StatusUnauthorized {
if realm != "" {
@ -976,6 +986,7 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler {
} else {
WriteEmptyRespHeader(w, status_code)
}
}
time_taken = time.Now().Sub(start_time)
if status_code > 0 {