fixed the issue of no mutex lock/unlock when accessing some maps.

added a wpx service
This commit is contained in:
hyung-hwan 2024-12-26 00:20:44 +09:00
parent fcb7ae5ade
commit 6809cfdeb6
6 changed files with 158 additions and 121 deletions

View File

@ -240,11 +240,8 @@ func (r *ClientRoute) ReqStopAllClientPeerConns() {
var c *ClientPeerConn
r.ptc_mtx.Lock()
defer r.ptc_mtx.Unlock()
for _, c = range r.ptc_map {
c.ReqStop()
}
for _, c = range r.ptc_map { c.ReqStop() }
r.ptc_mtx.Unlock()
}
func (r *ClientRoute) FindClientPeerConnById(conn_id PeerId) *ClientPeerConn {
@ -310,9 +307,11 @@ done:
func (r *ClientRoute) ReqStop() {
if r.stop_req.CompareAndSwap(false, true) {
var ptc *ClientPeerConn
for _, ptc = range r.ptc_map {
ptc.ReqStop()
}
r.ptc_mtx.Lock()
for _, ptc = range r.ptc_map { ptc.ReqStop() }
r.ptc_mtx.Unlock()
r.stop_chan <- true
}
}
@ -634,13 +633,9 @@ func (cts *ClientConn) AddNewClientRoute(rc *ClientRouteConfig) (*ClientRoute, e
func (cts *ClientConn) ReqStopAllClientRoutes() {
var r *ClientRoute
cts.route_mtx.Lock()
defer cts.route_mtx.Unlock()
for _, r = range cts.route_map {
r.ReqStop()
}
for _, r = range cts.route_map { r.ReqStop() }
cts.route_mtx.Unlock()
}
/*
@ -741,9 +736,7 @@ func (cts *ClientConn) disconnect_from_server() {
var r *ClientRoute
cts.route_mtx.Lock()
for _, r = range cts.route_map {
r.ReqStop()
}
for _, r = range cts.route_map { r.ReqStop() }
cts.route_mtx.Unlock()
cts.conn.Close()
@ -1174,13 +1167,9 @@ func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) {
func (c *Client) ReqStopAllClientConns() {
var cts *ClientConn
c.cts_mtx.Lock()
defer c.cts_mtx.Unlock()
for _, cts = range c.cts_map {
cts.ReqStop()
}
for _, cts = range c.cts_map { cts.ReqStop() }
c.cts_mtx.Unlock()
}
/*
@ -1312,9 +1301,9 @@ func (c *Client) ReqStop() {
ctl.Shutdown(c.ctx) // to break c.ctl.ListenAndServe()
}
for _, cts = range c.cts_map {
cts.ReqStop()
}
c.cts_mtx.Lock()
for _, cts = range c.cts_map { cts.ReqStop() }
c.cts_mtx.Unlock()
c.stop_chan <- true
}

View File

@ -93,6 +93,11 @@ type ServerConfig struct {
TLS ServerTLSConfig `yaml:"tls"`
} `yaml:"pxy"`
WPX struct {
Service PXYServiceConfig `yaml:"service"`
TLS ServerTLSConfig `yaml:"tls"`
} `yaml:"wpx"`
RPC struct {
Service RPCServiceConfig `yaml:"service"`
TLS ServerTLSConfig `yaml:"tls"`

View File

@ -90,11 +90,12 @@ func (sh *signal_handler) WriteLog(id string, level hodu.LogLevel, fmt string, a
// --------------------------------------------------------------------
func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg *ServerConfig) error {
func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx_addrs []string, cfg *ServerConfig) error {
var s *hodu.Server
var ctltlscfg *tls.Config
var rpctlscfg *tls.Config
var pxytlscfg *tls.Config
var wpxtlscfg *tls.Config
var ctl_prefix string
var logger *AppLogger
var log_mask hodu.LogMask
@ -114,18 +115,13 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg
if err != nil { return err }
pxytlscfg, err = make_tls_server_config(&cfg.PXY.TLS)
if err != nil { return err }
wpxtlscfg, err = make_tls_server_config(&cfg.WPX.TLS)
if err != nil { return err }
if len(ctl_addrs) <= 0 {
ctl_addrs = cfg.CTL.Service.Addrs
}
if len(rpc_addrs) <= 0 {
rpc_addrs = cfg.RPC.Service.Addrs
}
if len(pxy_addrs) <= 0 {
pxy_addrs = cfg.PXY.Service.Addrs
}
if len(ctl_addrs) <= 0 { ctl_addrs = cfg.CTL.Service.Addrs }
if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Service.Addrs }
if len(pxy_addrs) <= 0 { pxy_addrs = cfg.PXY.Service.Addrs }
if len(wpx_addrs) <= 0 { wpx_addrs = cfg.WPX.Service.Addrs }
ctl_prefix = cfg.CTL.Service.Prefix
log_mask = log_strings_to_mask(cfg.APP.LogMask)
@ -155,10 +151,12 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg
ctl_addrs,
rpc_addrs,
pxy_addrs,
wpx_addrs,
ctl_prefix,
ctltlscfg,
rpctlscfg,
pxytlscfg,
wpxtlscfg,
max_rpc_conns,
max_peers)
if err != nil {
@ -168,6 +166,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg
s.StartService(nil)
s.StartCtlService()
s.StartPxyService()
s.StartWpxService()
s.StartExtService(&signal_handler{svc:s}, nil)
s.WaitForTermination()
logger.Close()
@ -331,6 +330,7 @@ func main() {
var rpc_addrs []string
var ctl_addrs []string
var pxy_addrs []string
var wpx_addrs []string
var cfgfile string
var logfile string
var cfg *ServerConfig
@ -338,6 +338,7 @@ func main() {
ctl_addrs = make([]string, 0)
rpc_addrs = make([]string, 0)
pxy_addrs = make([]string, 0)
wpx_addrs = make([]string, 0)
flgs = flag.NewFlagSet("", flag.ContinueOnError)
flgs.Func("ctl-on", "specify a listening address for control channel", func(v string) error {
@ -352,6 +353,10 @@ func main() {
pxy_addrs = append(pxy_addrs, v)
return nil
})
flgs.Func("wpx-on", "specify a wpx listening address", func(v string) error {
wpx_addrs = append(wpx_addrs, v)
return nil
})
flgs.Func("log-file", "specify a log file", func(v string) error {
logfile = v
return nil
@ -379,7 +384,7 @@ func main() {
}
if logfile != "" { cfg.APP.LogFile = logfile }
err = server_main(ctl_addrs, rpc_addrs, pxy_addrs, cfg)
err = server_main(ctl_addrs, rpc_addrs, pxy_addrs, wpx_addrs, cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: server error - %s\n", err.Error())
goto oops
@ -443,7 +448,7 @@ func main() {
os.Exit(0)
wrong_usage:
fmt.Fprintf(os.Stderr, "USAGE: %s server --rpc-on=addr:port --ctl-on=addr:port --pxy-on=addr:port [--config-file=file]\n", os.Args[0])
fmt.Fprintf(os.Stderr, "USAGE: %s server --rpc-on=addr:port --ctl-on=addr:port --pxy-on=addr:port --wpx-on=addr:port [--config-file=file]\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s client --rpc-server=addr:port --ctl-on=addr:port [--config-file=file] [peer-addr:peer-port ...]\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s version\n", os.Args[0])
os.Exit(1)

View File

@ -5,7 +5,6 @@ import "context"
import "crypto/tls"
import _ "embed"
import "encoding/json"
import "errors"
import "fmt"
import "io"
import "net"
@ -40,6 +39,7 @@ type server_proxy_http_init struct {
type server_proxy_http_main struct {
s *Server
prefix string
restore bool // restore URLs in response text
}
type server_proxy_ssh struct {
@ -212,48 +212,6 @@ func prevent_follow_redirect (req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
func (pxy *server_proxy_http_main) serve_websocket(w http.ResponseWriter, req *http.Request, ws_url string, target *net.TCPAddr) {
pxy.s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), ws_url)
websocket.Handler(func(wc *websocket.Conn) {
var ws *websocket.Conn
var err_chan chan error
var err error
defer wc.Close()
// TODO: timeout or cancellation
// TODO: use DialConfig??
ws, err = websocket.Dial(ws_url, "", req.Header.Get("Origin"))
if err != nil {
// TODO: logging
return
}
defer ws.Close()
err_chan = make(chan error, 2)
go func() {
// client to server
var err error
_, err = io.Copy(ws, wc)
err_chan <- err
}()
go func() {
// server to client
var err error
_, err = io.Copy(wc, ws)
err_chan <- err
}()
err = <-err_chan
if err != nil && errors.Is(err, io.EOF) {
// TODO: logging
}
}).ServeHTTP(w, req)
}
func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, string, string, string, error) {
var conn_id string
var route_id string
@ -261,8 +219,13 @@ func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, s
var path_prefix string
var err error
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
if pxy.prefix == PORT_ID_MARKER { // for wpx
conn_id = req.PathValue("port_id")
route_id = pxy.prefix
} else {
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
}
if conn_id == "" && route_id == "" {
// it's not via /_http/<<conn-id>>/<<route-id>>.
// get ids from the cookie.
@ -283,7 +246,11 @@ func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, s
route_id = ids[1]
path_prefix = ""
} else {
path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id)
if pxy.prefix == PORT_ID_MARKER { // for wpx
path_prefix = fmt.Sprintf("/%s", conn_id)
} else {
path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id)
}
}
r, err = pxy.s.FindServerRouteByIdStr(conn_id, route_id)
@ -292,11 +259,6 @@ func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, s
return r, path_prefix, conn_id, route_id, nil
}
func (pxy *server_proxy_http_main) get_upgrade_type(hdr http.Header) string {
if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") { return hdr.Get("Upgrade") }
return ""
}
func (pxy *server_proxy_http_main) serve_upgraded(w http.ResponseWriter, req *http.Request, proxy_res *http.Response) error {
var err_chan chan error
var proxy_res_body io.ReadWriteCloser
@ -431,6 +393,14 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
goto oops
}
/*
if r.svc_option & (RouteOption(ROUTE_OPTION_HTTP) | RouteOption(ROUTE_OPTION_HTTPS)) == 0 {
status_code = http.StatusForbidden; w.WriteHeader(status_code)
err = fmt.Errorf("target not http/https")
goto oops
}
*/
addr = svc_addr_to_dst_addr(r.svc_addr)
//transport, err = pxy.addr_to_transport(req.Context(), addr)
transport, err = pxy.addr_to_transport(s.ctx, addr)
@ -535,7 +505,6 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R
var tmpl *template.Template
var conn_id string
var route_id string
//var r *ServerRoute
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
@ -565,7 +534,6 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R
status_code = http.StatusNotFound; w.WriteHeader(status_code)
}
//done:
s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code)
return
@ -622,7 +590,7 @@ func (pxy *server_proxy_ssh_ws) connect_ssh (ctx context.Context, username strin
/* Is this protection needed?
if r.svc_option & RouteOption(ROUTE_OPTION_SSH) == 0 {
err = fmt.Errorf("peer not ssh")
err = fmt.Errorf("target not ssh")
goto oops
}
*/

122
server.go
View File

@ -25,6 +25,7 @@ const PTS_LIMIT int = 16384
const CTS_LIMIT int = 16384
type PortId uint16
const PORT_ID_MARKER string = "_"
type ServerConnMapByAddr = map[net.Addr]*ServerConn
type ServerConnMap = map[ConnId]*ServerConn
@ -36,6 +37,7 @@ type Server struct {
ctx context.Context
ctx_cancel context.CancelFunc
pxytlscfg *tls.Config
wpxtlscfg *tls.Config
ctltlscfg *tls.Config
rpctlscfg *tls.Config
@ -51,6 +53,10 @@ type Server struct {
pxy_mux *http.ServeMux
pxy []*http.Server // proxy server
wpx_addr []string
wpx_mux *http.ServeMux
wpx []*http.Server // proxy server than handles http/https only
ctl_addr []string
ctl_prefix string
ctl_mux *http.ServeMux
@ -296,9 +302,9 @@ func (r *ServerRoute) ReqStop() {
if r.stop_req.CompareAndSwap(false, true) {
var pts *ServerPeerConn
for _, pts = range r.pts_map {
pts.ReqStop()
}
r.pts_mtx.Lock()
for _, pts = range r.pts_map { pts.ReqStop() }
r.pts_mtx.Unlock()
r.svc_l.Close()
}
@ -475,11 +481,8 @@ func (cts *ServerConn) ReqStopAllServerRoutes() {
var r *ServerRoute
cts.route_mtx.Lock()
defer cts.route_mtx.Unlock()
for _, r = range cts.route_map {
r.ReqStop()
}
for _, r = range cts.route_map { r.ReqStop() }
cts.route_mtx.Unlock()
}
func (cts *ServerConn) ReportEvent(route_id RouteId, pts_id PeerId, event_type PACKET_KIND, event_data interface{}) error {
@ -725,9 +728,9 @@ func (cts *ServerConn) ReqStop() {
if cts.stop_req.CompareAndSwap(false, true) {
var r *ServerRoute
for _, r = range cts.route_map {
r.ReqStop()
}
cts.route_mtx.Lock()
for _, r = range cts.route_map { r.ReqStop() }
cts.route_mtx.Unlock()
// there is no good way to break a specific connection client to
// the grpc server. while the global grpc server is closed in
@ -902,7 +905,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
return len(p), nil
}
func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, pxytlscfg *tls.Config, rpc_max int, peer_max int) (*Server, error) {
func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, pxytlscfg *tls.Config, wpxtlscfg *tls.Config, rpc_max int, peer_max int) (*Server, error) {
var s Server
var l *net.TCPListener
var rpcaddr *net.TCPAddr
@ -938,6 +941,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs
s.ctltlscfg = ctltlscfg
s.rpctlscfg = rpctlscfg
s.pxytlscfg = pxytlscfg
s.wpxtlscfg = wpxtlscfg
s.ext_svcs = make([]Service, 0, 1)
s.pts_limit = peer_max
s.cts_limit = rpc_max
@ -995,7 +999,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs
// ---------------------------------------------------------
s.pxy_ws = &server_proxy_ssh_ws{s: &s}
s.pxy_mux = http.NewServeMux() // TODO: make /_init configurable...
s.pxy_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh_ws,_http configurable...
s.pxy_mux.Handle("/_ssh-ws/{conn_id}/{route_id}",
websocket.Handler(func(ws *websocket.Conn) { s.pxy_ws.ServeWebsocket(ws) }))
s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", &server_ctl_server_conns_id_routes_id{s: &s})
@ -1005,9 +1009,6 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs
s.pxy_mux.Handle("/_ssh/xterm.css", &server_proxy_xterm_file{s: &s, file: "xterm.css"})
s.pxy_mux.Handle("/_ssh/", &server_proxy_xterm_file{s: &s, file: "_forbidden"})
//cwd, _ = os.Getwd() // TODO:
//s.pxy_mux.Handle(s.ctl_prefix + "/ui/", http.StripPrefix(s.ctl_prefix, http.FileServer(http.Dir(cwd)))) // TODO: proper directory. it must not use the current working directory...
s.pxy_mux.Handle("/_http/{conn_id}/{route_id}/{trailer...}", &server_proxy_http_main{s: &s, prefix: "/_http"})
s.pxy_mux.Handle("/_init/{conn_id}/{route_id}/{trailer...}", &server_proxy_http_init{s: &s, prefix: "/_init"})
s.pxy_mux.Handle("/", &server_proxy_http_main{s: &s, prefix: ""})
@ -1025,6 +1026,28 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs
// TODO: more settings
}
}
// ---------------------------------------------------------
s.wpx_mux = http.NewServeMux()
s.wpx_mux.Handle("/{port_id}/{trailer...}", &server_proxy_http_main{s: &s, prefix: PORT_ID_MARKER})
s.wpx_mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusForbidden)
})
s.wpx_addr = make([]string, len(wpx_addrs))
s.wpx = make([]*http.Server, len(wpx_addrs))
copy(s.wpx_addr, wpx_addrs)
for i = 0; i < len(wpx_addrs); i++ {
s.wpx[i] = &http.Server{
Addr: wpx_addrs[i],
Handler: s.wpx_mux,
TLSConfig: s.wpxtlscfg,
ErrorLog: hs_log,
// TODO: more settings
}
}
// ---------------------------------------------------------
s.stats.conns.Store(0)
@ -1191,6 +1214,49 @@ func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
l_wg.Wait()
}
func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
var err error
var wpx *http.Server
var idx int
var l_wg sync.WaitGroup
defer wg.Done()
for idx, wpx = range s.wpx {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.wpx_addr[i])
if s.stop_req.Load() == false {
l, err = net.Listen(tcp_addr_str_class(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
if s.wpxtlscfg == nil { // TODO: change this
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.wpxtlscfg must provide a certificate and a key
}
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, wpx)
}
l_wg.Wait()
}
func (s *Server) ReqStop() {
if s.stop_req.CompareAndSwap(false, true) {
var l *net.TCPListener
@ -1210,16 +1276,19 @@ func (s *Server) ReqStop() {
hs.Shutdown(s.ctx) // to break s.pxy.Serve()
}
for _, hs = range s.wpx {
hs.Shutdown(s.ctx) // to break s.wpx.Serve()
}
//s.rpc_svr.GracefulStop()
//s.rpc_svr.Stop()
for _, l = range s.rpc {
l.Close()
}
// request to stop connections from/to peer held in the cts structure
s.cts_mtx.Lock()
for _, cts = range s.cts_map {
cts.ReqStop() // request to stop connections from/to peer held in the cts structure
}
for _, cts = range s.cts_map { cts.ReqStop() }
s.cts_mtx.Unlock()
s.stop_chan <- true
@ -1279,13 +1348,9 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
func (s *Server) ReqStopAllServerConns() {
var cts *ServerConn
s.cts_mtx.Lock()
defer s.cts_mtx.Unlock()
for _, cts = range s.cts_map {
cts.ReqStop()
}
for _, cts = range s.cts_map { cts.ReqStop() }
s.cts_mtx.Unlock()
}
func (s *Server) RemoveServerConn(cts *ServerConn) error {
@ -1393,7 +1458,7 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve
var r *ServerRoute
var err error
if route_id == "_" {
if route_id == PORT_ID_MARKER {
var port_nid uint64
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
@ -1455,6 +1520,11 @@ func (s *Server) StartPxyService() {
go s.RunPxyTask(&s.wg)
}
func (s *Server) StartWpxService() {
s.wg.Add(1)
go s.RunWpxTask(&s.wg)
}
func (s *Server) StopServices() {
var ext_svc Service
s.ReqStop()

View File

@ -24,7 +24,7 @@ func monotonic_time() uint64 {
var r uintptr
var sts syscall.Timespec
r, _, _/*errno*/ = syscall.Syscall(syscall.SYS_CLOCK_GETTIME, unix.CLOCK_MONOTONIC, uintptr(unsafe.Pointer(&sts)), 0)
if r == ^uintptr(0) { return uint64(n) } // may be negative cast to unsigned. no other fall-back
if r == ^uintptr(0) { return uint64(n) } // may be a negative number cast to unsigned. no other fall-back
return uint64(sts.Nano())
}