diff --git a/cmd/config.go b/cmd/config.go index e03567c..1bc4573 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -44,16 +44,17 @@ type ClientTLSConfig struct { ServerName string `yaml:"server-name"` } -type BasicAuthConfig struct { +type AuthConfig struct { Enabled bool `yaml:"enabled"` Realm string `yaml:"realm"` Creds []string `yaml:"credentials"` + TokenTtl string `yaml:"token-ttl"` } type CTLServiceConfig struct { Prefix string `yaml:"prefix"` // url prefix for control channel endpoints Addrs []string `yaml:"addresses"` - BasicAuth BasicAuthConfig `yaml:"basic-auth"` + Auth AuthConfig `yaml:"auth"` } type PXYServiceConfig struct { @@ -345,8 +346,8 @@ func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) { } // -------------------------------------------------------------------- -func make_server_basic_auth_config(cfg *BasicAuthConfig) (*hodu.ServerBasicAuth, error) { - var config hodu.ServerBasicAuth +func make_server_basic_auth_config(cfg *AuthConfig) (*hodu.ServerAuthConfig, error) { + var config hodu.ServerAuthConfig var cred string var b []byte var x []string @@ -354,7 +355,11 @@ func make_server_basic_auth_config(cfg *BasicAuthConfig) (*hodu.ServerBasicAuth, config.Enabled = cfg.Enabled config.Realm = cfg.Realm - config.Creds = make(hodu.ServerBasicAuthCredMap) + config.Creds = make(hodu.ServerAuthCredMap) + config.TokenTtl, err = hodu.ParseDurationString(cfg.TokenTtl) + if err != nil { + return nil, fmt.Errorf("invalid token ttl %s - %s", cred, err) + } for _, cred = range cfg.Creds { b, err = base64.StdEncoding.DecodeString(cred) diff --git a/cmd/main.go b/cmd/main.go index 96a8988..34634f3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -125,7 +125,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs } if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs } - config.CtlBasicAuth, err = make_server_basic_auth_config(&cfg.CTL.Service.BasicAuth) + config.CtlAuth, err = make_server_basic_auth_config(&cfg.CTL.Service.Auth) if err != nil { return err } config.CtlPrefix = cfg.CTL.Service.Prefix diff --git a/server-ctl.go b/server-ctl.go index b33c630..f49e417 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -72,7 +72,7 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { var s *Server s = ctl.s - if s.cfg.CtlBasicAuth != nil && s.cfg.CtlBasicAuth.Enabled { + if s.cfg.CtlAuth != nil && s.cfg.CtlAuth.Enabled { var auth_hdr string var auth_parts []string var username string @@ -82,7 +82,7 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { var err error auth_hdr = req.Header.Get("Authorization") - if auth_hdr == "" { return s.cfg.CtlBasicAuth.Realm } + if auth_hdr == "" { return s.cfg.CtlAuth.Realm } auth_parts = strings.Fields(auth_hdr) if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") { @@ -93,10 +93,10 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { // fall back to basic authentication username, password, ok = req.BasicAuth() - if !ok { return s.cfg.CtlBasicAuth.Realm } + if !ok { return s.cfg.CtlAuth.Realm } - credpass, ok = s.cfg.CtlBasicAuth.Creds[username] - if !ok || credpass != password { return s.cfg.CtlBasicAuth.Realm } + credpass, ok = s.cfg.CtlAuth.Creds[username] + if !ok || credpass != password { return s.cfg.CtlAuth.Realm } } return "" @@ -105,6 +105,7 @@ func (ctl *server_ctl) Authenticate(req *http.Request) string { // ------------------------------------ type ServerTokenClaim struct { + ExpiresAt int64 `json:"exp"` IssuedAt int64 `json:"iat"` } @@ -114,10 +115,12 @@ type json_out_token struct { } func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { + var s *Server var status_code int var je *json.Encoder var err error + s = ctl.s je = json.NewEncoder(w) switch req.Method { @@ -125,8 +128,16 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) var jwt JWT var jc ServerTokenClaim var tok string + var now time.Time - jc.IssuedAt = time.Now().Unix() + if s.cfg.CtlAuth == nil || !s.cfg.CtlAuth.Enabled { + status_code = WriteEmptyRespHeader(w, http.StatusNotFound) + goto done + } + + now = time.Now() + jc.IssuedAt = now.Unix() + jc.ExpiresAt = now.Add(s.cfg.CtlAuth.TokenTtl).Unix() tok, err = jwt.Sign(&jc) if err != nil { status_code = WriteJsonRespHeader(w, http.StatusInternalServerError) @@ -139,10 +150,10 @@ func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) if err != nil { goto oops } default: - status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } -//done: +done: return status_code, nil oops: @@ -201,7 +212,7 @@ func (ctl *server_ctl_server_conns) ServeHTTP(w http.ResponseWriter, req *http.R status_code = WriteEmptyRespHeader(w, http.StatusNoContent) default: - status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } //done: @@ -267,7 +278,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt status_code = WriteEmptyRespHeader(w, http.StatusNoContent) default: - status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } //done: @@ -326,7 +337,7 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r status_code = WriteEmptyRespHeader(w, http.StatusNoContent) default: - status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } //done: @@ -412,7 +423,7 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter status_code = WriteEmptyRespHeader(w, http.StatusNoContent) default: - status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } done: @@ -445,7 +456,7 @@ func (ctl *server_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request) if err = je.Encode(stats); err != nil { goto oops } default: - status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + status_code = WriteEmptyRespHeader(w, http.StatusMethodNotAllowed) } //done: diff --git a/server.go b/server.go index 0736ab3..930a708 100644 --- a/server.go +++ b/server.go @@ -42,12 +42,13 @@ type ServerSvcPortMap = map[PortId]ConnRouteId type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error) -type ServerBasicAuthCredMap map[string]string +type ServerAuthCredMap map[string]string -type ServerBasicAuth struct { +type ServerAuthConfig struct { Enabled bool Realm string - Creds ServerBasicAuthCredMap + Creds ServerAuthCredMap + TokenTtl time.Duration } type ServerConfig struct { @@ -59,7 +60,7 @@ type ServerConfig struct { CtlAddrs []string CtlTls *tls.Config CtlPrefix string - CtlBasicAuth *ServerBasicAuth + CtlAuth *ServerAuthConfig PxyAddrs []string PxyTls *tls.Config