added some code for control channel authentication

This commit is contained in:
hyung-hwan 2025-01-28 23:50:28 +09:00
parent a97be385ec
commit 2fa5817e88
7 changed files with 163 additions and 77 deletions

View File

@ -23,10 +23,6 @@ import "unsafe"
* GET get info
*/
type JsonErrmsg struct {
Text string `json:"error-text"`
}
type json_in_client_conn struct {
ServerAddrs []string `json:"server-addrs"`
}
@ -215,7 +211,8 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R
cts, err = c.start_service(&cc) // TODO: this can be blocking. do we have to resolve addresses before calling this? also not good because resolution succeed or fail at each attempt. however ok as ServeHTTP itself is in a goroutine?
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusInternalServerError)
if err = je.Encode(JsonErrmsg{Text: err.Error()}); err != nil { goto oops }
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
} else {
status_code = WriteJsonRespHeader(w, http.StatusCreated)
if err = je.Encode(json_out_client_conn_id{Id: cts.Id}); err != nil { goto oops }
@ -260,15 +257,15 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id})
goto oops
}
cts = c.FindClientConnById(ConnId(conn_nid))
if cts == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id})
goto oops
}
switch req.Method {
@ -313,7 +310,7 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -339,15 +336,15 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id }); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id })
goto oops
}
cts = c.FindClientConnById(ConnId(conn_nid))
if cts == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id})
goto oops
}
switch req.Method {
@ -421,7 +418,8 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
r, err = cts.AddNewClientRoute(rc)
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusInternalServerError)
if err = je.Encode(JsonErrmsg{Text: err.Error()}); err != nil { goto oops }
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
} else {
status_code = WriteJsonRespHeader(w, http.StatusCreated)
if err = je.Encode(json_out_client_route_id{Id: r.Id, CtsId: r.cts.Id}); err != nil { goto oops }
@ -436,7 +434,7 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -466,28 +464,28 @@ func (ctl *client_ctl_client_conns_id_routes_id) ServeHTTP(w http.ResponseWriter
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id})
goto oops
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong route id - " + route_id})
goto oops
}
cts = c.FindClientConnById(ConnId(conn_nid))
if cts == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id})
goto oops
}
r = cts.FindClientRouteById(RouteId(route_nid))
if r == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent route id - " + route_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent route id - " + route_id})
goto oops
}
switch req.Method {
@ -543,7 +541,7 @@ func (ctl *client_ctl_client_conns_id_routes_id) ServeHTTP(w http.ResponseWriter
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -573,28 +571,28 @@ func (ctl *client_ctl_client_conns_id_routes_spsp) ServeHTTP(w http.ResponseWrit
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id})
goto oops
}
port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong route id - " + port_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong route id - " + port_id})
goto oops
}
cts = c.FindClientConnById(ConnId(conn_nid))
if cts == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent connection id - " + conn_id})
goto oops
}
r = cts.FindClientRouteByServerPeerSvcPortId(PortId(port_nid))
if r == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent server peer port id - " + port_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent server peer port id - " + port_id})
goto oops
}
switch req.Method {
@ -648,7 +646,7 @@ func (ctl *client_ctl_client_conns_id_routes_spsp) ServeHTTP(w http.ResponseWrit
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -677,21 +675,21 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers) ServeHTTP(w http.Response
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id})
goto oops
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong route id - " + route_id})
goto oops
}
r = c.FindClientRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent connection/route id - " + conn_id + "/" + route_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent connection/route id - " + conn_id + "/" + route_id})
goto oops
}
switch req.Method {
@ -723,7 +721,7 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers) ServeHTTP(w http.Response
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -755,27 +753,27 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers_id) ServeHTTP(w http.Respo
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong connection id - " + conn_id})
goto oops
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong route id - " + route_id})
goto oops
}
peer_nid, err = strconv.ParseUint(peer_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusBadRequest)
if err = je.Encode(JsonErrmsg{Text: "wrong peer id - " + peer_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "wrong peer id - " + peer_id})
goto oops
}
p = c.FindClientPeerConnById(ConnId(conn_nid), RouteId(route_nid), PeerId(peer_nid))
if p == nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: "non-existent connection/route/peer id - " + conn_id + "/" + route_id + "/" + peer_id}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: "non-existent connection/route/peer id - " + conn_id + "/" + route_id + "/" + peer_id})
goto oops
}
switch req.Method {
@ -801,7 +799,7 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers_id) ServeHTTP(w http.Respo
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:

View File

@ -354,6 +354,7 @@ func make_server_basic_auth_config(cfg *BasicAuthConfig) (*hodu.ServerBasicAuth,
config.Enabled = cfg.Enabled
config.Realm = cfg.Realm
config.Creds = make(hodu.ServerBasicAuthCredMap)
for _, cred = range cfg.Creds {
b, err = base64.StdEncoding.DecodeString(cred)
@ -365,7 +366,7 @@ func make_server_basic_auth_config(cfg *BasicAuthConfig) (*hodu.ServerBasicAuth,
return nil, fmt.Errorf("invalid basic auth credential - %s", cred)
}
config.Creds = append(config.Creds, hodu.ServerBasicAuthCred{ Username: x[0], Password: x[1] })
config.Creds[x[0]] = x[1]
}
return &config, nil

View File

@ -48,6 +48,10 @@ type Service interface {
WriteLog(id string, level LogLevel, fmtstr string, args ...interface{})
}
type JsonErrmsg struct {
Text string `json:"error-text"`
}
type json_out_go_stats struct {
CPUs int `json:"cpus"`
Goroutines int `json:"goroutines"`

11
jwt.go
View File

@ -16,7 +16,7 @@ func Sign(data []byte, privkey *rsa.PrivateKey) ([]byte, error) {
h = crypto.SHA512.New()
h.Write(data)
fmt.Printf("%+v\n", h.Sum(nil))
//fmt.Printf("%+v\n", h.Sum(nil))
return rsa.SignPKCS1v15(rand.Reader, privkey, crypto.SHA512, h.Sum(nil))
}
@ -79,7 +79,7 @@ func (j *JWT) Sign(claims interface{}) (string, error) {
sb, err = SignHS512([]byte(ss), "hello")
if err != nil { return "", err }
fmt.Printf ("%+v %+v %s\n", string(hb), string(cb), (ss + "." + base64.RawURLEncoding.EncodeToString(sb)))
//fmt.Printf ("%+v %+v %s\n", string(hb), string(cb), (ss + "." + base64.RawURLEncoding.EncodeToString(sb)))
return ss + "." + base64.RawURLEncoding.EncodeToString(sb), nil
}
@ -100,22 +100,23 @@ func (j *JWT) Verify(tok string) error {
if err != nil { return fmt.Errorf("invalid header - %s", err.Error()) }
err = json.Unmarshal(hb, &jh)
if err != nil { return fmt.Errorf("invalid header - %s", err.Error()) }
fmt.Printf ("DECODED HEADER [%+v]\n", jh)
//fmt.Printf ("DECODED HEADER [%+v]\n", jh)
cb, err = base64.RawURLEncoding.DecodeString(segs[1])
if err != nil { return fmt.Errorf("invalid claims - %s", err.Error()) }
err = json.Unmarshal(cb, &jcm)
if err != nil { return fmt.Errorf("invalid header - %s", err.Error()) }
fmt.Printf ("DECODED CLAIMS [%+v]\n", jcm)
//fmt.Printf ("DECODED CLAIMS [%+v]\n", jcm)
x, err = j.Sign(jcm)
if err != nil { return err }
fmt.Printf ("VERIFICATION OK...\n")
if x != tok { return fmt.Errorf("signature mismatch") }
//fmt.Printf ("VERIFICATION OK...[%s] [%s]\n", x, tok)
// sb, err = base64.RawURLEncoding.DecodeString(segs[2])
// if err != nil { return fmt.Errorf("invalid signature - %s", err.Error()) }
// TODO: check expiry and others...
_ = sb

View File

@ -2,6 +2,8 @@ package hodu
import "encoding/json"
import "net/http"
import "strings"
import "time"
type json_out_server_conn struct {
Id ConnId `json:"id"`
@ -36,6 +38,10 @@ type server_ctl struct {
id string
}
type server_ctl_token struct {
server_ctl
}
type server_ctl_server_conns struct {
server_ctl
}
@ -62,15 +68,85 @@ func (ctl *server_ctl) Id() string {
return ctl.id
}
func (ctl *server_ctl) Authenticate(req *http.Request) bool {
func (ctl *server_ctl) Authenticate(req *http.Request) string {
var s *Server
s = ctl.s
if s.cfg.CtlBasicAuth != nil && s.cfg.CtlBasicAuth.Enabled {
// perform basic authentication
var auth_hdr string
var auth_parts []string
var username string
var password string
var credpass string
var ok bool
var err error
auth_hdr = req.Header.Get("Authorization")
if auth_hdr == "" { return s.cfg.CtlBasicAuth.Realm }
auth_parts = strings.Fields(auth_hdr)
if len(auth_parts) == 2 && strings.EqualFold(auth_parts[0], "Bearer") {
var jwt JWT
err = jwt.Verify(strings.TrimSpace(auth_parts[1]))
if err == nil { return "" }
}
return true
// fall back to basic authentication
username, password, ok = req.BasicAuth()
if !ok { return s.cfg.CtlBasicAuth.Realm }
credpass, ok = s.cfg.CtlBasicAuth.Creds[username]
if !ok || credpass != password { return s.cfg.CtlBasicAuth.Realm }
}
return ""
}
// ------------------------------------
type ServerTokenClaim struct {
IssuedAt int64 `json:"iat"`
}
type json_out_token struct {
AccessToken string `json:"access-token"`
RefreshToken string `json:"refresh-token"`
}
func (ctl *server_ctl_token) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
var status_code int
var je *json.Encoder
var err error
je = json.NewEncoder(w)
switch req.Method {
case http.MethodGet:
var jwt JWT
var jc ServerTokenClaim
var tok string
jc.IssuedAt = time.Now().Unix()
tok, err = jwt.Sign(&jc)
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusInternalServerError)
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
}
status_code = WriteJsonRespHeader(w, http.StatusOK)
err = je.Encode(json_out_token{ AccessToken: tok }) // TODO: refresh token
if err != nil { goto oops }
default:
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
//done:
return status_code, nil
oops:
return status_code, err
}
// ------------------------------------
@ -78,8 +154,8 @@ func (ctl *server_ctl) Authenticate(req *http.Request) bool {
func (ctl *server_ctl_server_conns) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
var s *Server
var status_code int
var err error
var je *json.Encoder
var err error
s = ctl.s
je = json.NewEncoder(w)
@ -152,8 +228,8 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
cts, err = s.FindServerConnByIdStr(conn_id)
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: err.Error()}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
}
switch req.Method {
@ -194,7 +270,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -218,8 +294,8 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
cts, err = s.FindServerConnByIdStr(conn_id)
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: err.Error()}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
}
switch req.Method {
@ -253,7 +329,7 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
}
done:
//done:
return status_code, nil
oops:
@ -308,8 +384,8 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter
if err != nil {
status_code = WriteJsonRespHeader(w, http.StatusNotFound)
if err = je.Encode(JsonErrmsg{Text: err.Error()}); err != nil { goto oops }
goto done
je.Encode(JsonErrmsg{Text: err.Error()})
goto oops
}
switch req.Method {

View File

@ -190,8 +190,8 @@ func (pxy *server_proxy) Id() string {
return pxy.id
}
func (pxy *server_proxy) Authenticate(req *http.Request) bool {
return true
func (pxy *server_proxy) Authenticate(req *http.Request) string {
return ""
}
// ------------------------------------

View File

@ -42,15 +42,12 @@ type ServerSvcPortMap = map[PortId]ConnRouteId
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
type ServerBasicAuthCred struct {
Username string
Password string
}
type ServerBasicAuthCredMap map[string]string
type ServerBasicAuth struct {
Enabled bool
Realm string
Creds []ServerBasicAuthCred
Creds ServerBasicAuthCredMap
}
type ServerConfig struct {
@ -953,7 +950,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
type ServerHttpHandler interface {
Id() string
Authenticate(req *http.Request) bool
Authenticate(req *http.Request) string
ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error)
}
@ -963,6 +960,7 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler {
var err error
var start_time time.Time
var time_taken time.Duration
var realm string
// this deferred function is to overcome the recovering implemenation
// from panic done in go's http server. in that implemenation, panic
@ -975,7 +973,13 @@ func (s *Server) wrap_http_handler(handler ServerHttpHandler) http.Handler {
}()
start_time = time.Now()
realm = handler.Authenticate(req)
if realm != "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic Realm=\"%s\"", realm))
status_code = http.StatusUnauthorized
} else {
status_code, err = handler.ServeHTTP(w, req)
}
time_taken = time.Now().Sub(start_time)
if status_code > 0 {
@ -1065,6 +1069,8 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.wrap_http_handler(&server_ctl_server_conns_id_routes_id{server_ctl{s: &s, id: HS_ID_CTL}}))
s.ctl_mux.Handle(s.cfg.CtlPrefix + "/_ctl/stats",
s.wrap_http_handler(&server_ctl_stats{server_ctl{s: &s, id: HS_ID_CTL}}))
s.ctl_mux.Handle(s.cfg.CtlPrefix + "/_ctl/token",
s.wrap_http_handler(&server_ctl_token{server_ctl{s: &s, id: HS_ID_CTL}}))
// TODO: make this optional. add this endpoint only if it's enabled...
s.promreg = prometheus.NewRegistry()