diff --git a/client.go b/client.go index 148678e..22dfe0b 100644 --- a/client.go +++ b/client.go @@ -240,11 +240,8 @@ func (r *ClientRoute) ReqStopAllClientPeerConns() { var c *ClientPeerConn r.ptc_mtx.Lock() - defer r.ptc_mtx.Unlock() - - for _, c = range r.ptc_map { - c.ReqStop() - } + for _, c = range r.ptc_map { c.ReqStop() } + r.ptc_mtx.Unlock() } func (r *ClientRoute) FindClientPeerConnById(conn_id PeerId) *ClientPeerConn { @@ -310,9 +307,11 @@ done: func (r *ClientRoute) ReqStop() { if r.stop_req.CompareAndSwap(false, true) { var ptc *ClientPeerConn - for _, ptc = range r.ptc_map { - ptc.ReqStop() - } + + r.ptc_mtx.Lock() + for _, ptc = range r.ptc_map { ptc.ReqStop() } + r.ptc_mtx.Unlock() + r.stop_chan <- true } } @@ -634,13 +633,9 @@ func (cts *ClientConn) AddNewClientRoute(rc *ClientRouteConfig) (*ClientRoute, e func (cts *ClientConn) ReqStopAllClientRoutes() { var r *ClientRoute - cts.route_mtx.Lock() - defer cts.route_mtx.Unlock() - - for _, r = range cts.route_map { - r.ReqStop() - } + for _, r = range cts.route_map { r.ReqStop() } + cts.route_mtx.Unlock() } /* @@ -741,9 +736,7 @@ func (cts *ClientConn) disconnect_from_server() { var r *ClientRoute cts.route_mtx.Lock() - for _, r = range cts.route_map { - r.ReqStop() - } + for _, r = range cts.route_map { r.ReqStop() } cts.route_mtx.Unlock() cts.conn.Close() @@ -1174,13 +1167,9 @@ func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) { func (c *Client) ReqStopAllClientConns() { var cts *ClientConn - c.cts_mtx.Lock() - defer c.cts_mtx.Unlock() - - for _, cts = range c.cts_map { - cts.ReqStop() - } + for _, cts = range c.cts_map { cts.ReqStop() } + c.cts_mtx.Unlock() } /* @@ -1312,9 +1301,9 @@ func (c *Client) ReqStop() { ctl.Shutdown(c.ctx) // to break c.ctl.ListenAndServe() } - for _, cts = range c.cts_map { - cts.ReqStop() - } + c.cts_mtx.Lock() + for _, cts = range c.cts_map { cts.ReqStop() } + c.cts_mtx.Unlock() c.stop_chan <- true } diff --git a/cmd/config.go b/cmd/config.go index c50a9bc..33532d6 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -93,6 +93,11 @@ type ServerConfig struct { TLS ServerTLSConfig `yaml:"tls"` } `yaml:"pxy"` + WPX struct { + Service PXYServiceConfig `yaml:"service"` + TLS ServerTLSConfig `yaml:"tls"` + } `yaml:"wpx"` + RPC struct { Service RPCServiceConfig `yaml:"service"` TLS ServerTLSConfig `yaml:"tls"` diff --git a/cmd/main.go b/cmd/main.go index 51ce13d..5e3ce6f 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -90,11 +90,12 @@ func (sh *signal_handler) WriteLog(id string, level hodu.LogLevel, fmt string, a // -------------------------------------------------------------------- -func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg *ServerConfig) error { +func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx_addrs []string, cfg *ServerConfig) error { var s *hodu.Server var ctltlscfg *tls.Config var rpctlscfg *tls.Config var pxytlscfg *tls.Config + var wpxtlscfg *tls.Config var ctl_prefix string var logger *AppLogger var log_mask hodu.LogMask @@ -114,18 +115,13 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg if err != nil { return err } pxytlscfg, err = make_tls_server_config(&cfg.PXY.TLS) if err != nil { return err } + wpxtlscfg, err = make_tls_server_config(&cfg.WPX.TLS) + if err != nil { return err } - if len(ctl_addrs) <= 0 { - ctl_addrs = cfg.CTL.Service.Addrs - } - - if len(rpc_addrs) <= 0 { - rpc_addrs = cfg.RPC.Service.Addrs - } - - if len(pxy_addrs) <= 0 { - pxy_addrs = cfg.PXY.Service.Addrs - } + if len(ctl_addrs) <= 0 { ctl_addrs = cfg.CTL.Service.Addrs } + if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Service.Addrs } + if len(pxy_addrs) <= 0 { pxy_addrs = cfg.PXY.Service.Addrs } + if len(wpx_addrs) <= 0 { wpx_addrs = cfg.WPX.Service.Addrs } ctl_prefix = cfg.CTL.Service.Prefix log_mask = log_strings_to_mask(cfg.APP.LogMask) @@ -155,10 +151,12 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg ctl_addrs, rpc_addrs, pxy_addrs, + wpx_addrs, ctl_prefix, ctltlscfg, rpctlscfg, pxytlscfg, + wpxtlscfg, max_rpc_conns, max_peers) if err != nil { @@ -168,6 +166,7 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, cfg s.StartService(nil) s.StartCtlService() s.StartPxyService() + s.StartWpxService() s.StartExtService(&signal_handler{svc:s}, nil) s.WaitForTermination() logger.Close() @@ -331,6 +330,7 @@ func main() { var rpc_addrs []string var ctl_addrs []string var pxy_addrs []string + var wpx_addrs []string var cfgfile string var logfile string var cfg *ServerConfig @@ -338,6 +338,7 @@ func main() { ctl_addrs = make([]string, 0) rpc_addrs = make([]string, 0) pxy_addrs = make([]string, 0) + wpx_addrs = make([]string, 0) flgs = flag.NewFlagSet("", flag.ContinueOnError) flgs.Func("ctl-on", "specify a listening address for control channel", func(v string) error { @@ -352,6 +353,10 @@ func main() { pxy_addrs = append(pxy_addrs, v) return nil }) + flgs.Func("wpx-on", "specify a wpx listening address", func(v string) error { + wpx_addrs = append(wpx_addrs, v) + return nil + }) flgs.Func("log-file", "specify a log file", func(v string) error { logfile = v return nil @@ -379,7 +384,7 @@ func main() { } if logfile != "" { cfg.APP.LogFile = logfile } - err = server_main(ctl_addrs, rpc_addrs, pxy_addrs, cfg) + err = server_main(ctl_addrs, rpc_addrs, pxy_addrs, wpx_addrs, cfg) if err != nil { fmt.Fprintf(os.Stderr, "ERROR: server error - %s\n", err.Error()) goto oops @@ -443,7 +448,7 @@ func main() { os.Exit(0) wrong_usage: - fmt.Fprintf(os.Stderr, "USAGE: %s server --rpc-on=addr:port --ctl-on=addr:port --pxy-on=addr:port [--config-file=file]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "USAGE: %s server --rpc-on=addr:port --ctl-on=addr:port --pxy-on=addr:port --wpx-on=addr:port [--config-file=file]\n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s client --rpc-server=addr:port --ctl-on=addr:port [--config-file=file] [peer-addr:peer-port ...]\n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s version\n", os.Args[0]) os.Exit(1) diff --git a/server-proxy.go b/server-proxy.go index 7a9e9bc..54287f9 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -5,7 +5,6 @@ import "context" import "crypto/tls" import _ "embed" import "encoding/json" -import "errors" import "fmt" import "io" import "net" @@ -40,6 +39,7 @@ type server_proxy_http_init struct { type server_proxy_http_main struct { s *Server prefix string + restore bool // restore URLs in response text } type server_proxy_ssh struct { @@ -212,48 +212,6 @@ func prevent_follow_redirect (req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } -func (pxy *server_proxy_http_main) serve_websocket(w http.ResponseWriter, req *http.Request, ws_url string, target *net.TCPAddr) { - pxy.s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), ws_url) - - websocket.Handler(func(wc *websocket.Conn) { - var ws *websocket.Conn - var err_chan chan error - var err error - - defer wc.Close() - -// TODO: timeout or cancellation -// TODO: use DialConfig?? - ws, err = websocket.Dial(ws_url, "", req.Header.Get("Origin")) - if err != nil { - // TODO: logging - return - } - defer ws.Close() - - err_chan = make(chan error, 2) - - go func() { - // client to server - var err error - _, err = io.Copy(ws, wc) - err_chan <- err - }() - - go func() { - // server to client - var err error - _, err = io.Copy(wc, ws) - err_chan <- err - }() - - err = <-err_chan - if err != nil && errors.Is(err, io.EOF) { - // TODO: logging - } - }).ServeHTTP(w, req) -} - func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, string, string, string, error) { var conn_id string var route_id string @@ -261,8 +219,13 @@ func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, s var path_prefix string var err error - conn_id = req.PathValue("conn_id") - route_id = req.PathValue("route_id") + if pxy.prefix == PORT_ID_MARKER { // for wpx + conn_id = req.PathValue("port_id") + route_id = pxy.prefix + } else { + conn_id = req.PathValue("conn_id") + route_id = req.PathValue("route_id") + } if conn_id == "" && route_id == "" { // it's not via /_http/<>/<>. // get ids from the cookie. @@ -283,7 +246,11 @@ func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, s route_id = ids[1] path_prefix = "" } else { - path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id) + if pxy.prefix == PORT_ID_MARKER { // for wpx + path_prefix = fmt.Sprintf("/%s", conn_id) + } else { + path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id) + } } r, err = pxy.s.FindServerRouteByIdStr(conn_id, route_id) @@ -292,11 +259,6 @@ func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, s return r, path_prefix, conn_id, route_id, nil } -func (pxy *server_proxy_http_main) get_upgrade_type(hdr http.Header) string { - if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") { return hdr.Get("Upgrade") } - return "" -} - func (pxy *server_proxy_http_main) serve_upgraded(w http.ResponseWriter, req *http.Request, proxy_res *http.Response) error { var err_chan chan error var proxy_res_body io.ReadWriteCloser @@ -431,6 +393,14 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re goto oops } +/* + if r.svc_option & (RouteOption(ROUTE_OPTION_HTTP) | RouteOption(ROUTE_OPTION_HTTPS)) == 0 { + status_code = http.StatusForbidden; w.WriteHeader(status_code) + err = fmt.Errorf("target not http/https") + goto oops + } +*/ + addr = svc_addr_to_dst_addr(r.svc_addr) //transport, err = pxy.addr_to_transport(req.Context(), addr) transport, err = pxy.addr_to_transport(s.ctx, addr) @@ -535,7 +505,6 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R var tmpl *template.Template var conn_id string var route_id string - //var r *ServerRoute conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") @@ -565,7 +534,6 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R status_code = http.StatusNotFound; w.WriteHeader(status_code) } - //done: s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) return @@ -622,7 +590,7 @@ func (pxy *server_proxy_ssh_ws) connect_ssh (ctx context.Context, username strin /* Is this protection needed? if r.svc_option & RouteOption(ROUTE_OPTION_SSH) == 0 { - err = fmt.Errorf("peer not ssh") + err = fmt.Errorf("target not ssh") goto oops } */ diff --git a/server.go b/server.go index c8a7bd9..8351cb5 100644 --- a/server.go +++ b/server.go @@ -25,6 +25,7 @@ const PTS_LIMIT int = 16384 const CTS_LIMIT int = 16384 type PortId uint16 +const PORT_ID_MARKER string = "_" type ServerConnMapByAddr = map[net.Addr]*ServerConn type ServerConnMap = map[ConnId]*ServerConn @@ -36,6 +37,7 @@ type Server struct { ctx context.Context ctx_cancel context.CancelFunc pxytlscfg *tls.Config + wpxtlscfg *tls.Config ctltlscfg *tls.Config rpctlscfg *tls.Config @@ -51,6 +53,10 @@ type Server struct { pxy_mux *http.ServeMux pxy []*http.Server // proxy server + wpx_addr []string + wpx_mux *http.ServeMux + wpx []*http.Server // proxy server than handles http/https only + ctl_addr []string ctl_prefix string ctl_mux *http.ServeMux @@ -296,9 +302,9 @@ func (r *ServerRoute) ReqStop() { if r.stop_req.CompareAndSwap(false, true) { var pts *ServerPeerConn - for _, pts = range r.pts_map { - pts.ReqStop() - } + r.pts_mtx.Lock() + for _, pts = range r.pts_map { pts.ReqStop() } + r.pts_mtx.Unlock() r.svc_l.Close() } @@ -475,11 +481,8 @@ func (cts *ServerConn) ReqStopAllServerRoutes() { var r *ServerRoute cts.route_mtx.Lock() - defer cts.route_mtx.Unlock() - - for _, r = range cts.route_map { - r.ReqStop() - } + for _, r = range cts.route_map { r.ReqStop() } + cts.route_mtx.Unlock() } func (cts *ServerConn) ReportEvent(route_id RouteId, pts_id PeerId, event_type PACKET_KIND, event_data interface{}) error { @@ -725,9 +728,9 @@ func (cts *ServerConn) ReqStop() { if cts.stop_req.CompareAndSwap(false, true) { var r *ServerRoute - for _, r = range cts.route_map { - r.ReqStop() - } + cts.route_mtx.Lock() + for _, r = range cts.route_map { r.ReqStop() } + cts.route_mtx.Unlock() // there is no good way to break a specific connection client to // the grpc server. while the global grpc server is closed in @@ -902,7 +905,7 @@ func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) { return len(p), nil } -func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, pxytlscfg *tls.Config, rpc_max int, peer_max int) (*Server, error) { +func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, pxytlscfg *tls.Config, wpxtlscfg *tls.Config, rpc_max int, peer_max int) (*Server, error) { var s Server var l *net.TCPListener var rpcaddr *net.TCPAddr @@ -938,6 +941,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs s.ctltlscfg = ctltlscfg s.rpctlscfg = rpctlscfg s.pxytlscfg = pxytlscfg + s.wpxtlscfg = wpxtlscfg s.ext_svcs = make([]Service, 0, 1) s.pts_limit = peer_max s.cts_limit = rpc_max @@ -995,7 +999,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs // --------------------------------------------------------- s.pxy_ws = &server_proxy_ssh_ws{s: &s} - s.pxy_mux = http.NewServeMux() // TODO: make /_init configurable... + s.pxy_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh_ws,_http configurable... s.pxy_mux.Handle("/_ssh-ws/{conn_id}/{route_id}", websocket.Handler(func(ws *websocket.Conn) { s.pxy_ws.ServeWebsocket(ws) })) s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", &server_ctl_server_conns_id_routes_id{s: &s}) @@ -1005,9 +1009,6 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs s.pxy_mux.Handle("/_ssh/xterm.css", &server_proxy_xterm_file{s: &s, file: "xterm.css"}) s.pxy_mux.Handle("/_ssh/", &server_proxy_xterm_file{s: &s, file: "_forbidden"}) - //cwd, _ = os.Getwd() // TODO: - //s.pxy_mux.Handle(s.ctl_prefix + "/ui/", http.StripPrefix(s.ctl_prefix, http.FileServer(http.Dir(cwd)))) // TODO: proper directory. it must not use the current working directory... - s.pxy_mux.Handle("/_http/{conn_id}/{route_id}/{trailer...}", &server_proxy_http_main{s: &s, prefix: "/_http"}) s.pxy_mux.Handle("/_init/{conn_id}/{route_id}/{trailer...}", &server_proxy_http_init{s: &s, prefix: "/_init"}) s.pxy_mux.Handle("/", &server_proxy_http_main{s: &s, prefix: ""}) @@ -1025,6 +1026,28 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs // TODO: more settings } } + + // --------------------------------------------------------- + + s.wpx_mux = http.NewServeMux() + s.wpx_mux.Handle("/{port_id}/{trailer...}", &server_proxy_http_main{s: &s, prefix: PORT_ID_MARKER}) + s.wpx_mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusForbidden) + }) + + s.wpx_addr = make([]string, len(wpx_addrs)) + s.wpx = make([]*http.Server, len(wpx_addrs)) + copy(s.wpx_addr, wpx_addrs) + + for i = 0; i < len(wpx_addrs); i++ { + s.wpx[i] = &http.Server{ + Addr: wpx_addrs[i], + Handler: s.wpx_mux, + TLSConfig: s.wpxtlscfg, + ErrorLog: hs_log, + // TODO: more settings + } + } // --------------------------------------------------------- s.stats.conns.Store(0) @@ -1191,6 +1214,49 @@ func (s *Server) RunPxyTask(wg *sync.WaitGroup) { l_wg.Wait() } +func (s *Server) RunWpxTask(wg *sync.WaitGroup) { + var err error + var wpx *http.Server + var idx int + var l_wg sync.WaitGroup + + defer wg.Done() + + for idx, wpx = range s.wpx { + l_wg.Add(1) + go func(i int, cs *http.Server) { + var l net.Listener + + s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.wpx_addr[i]) + + if s.stop_req.Load() == false { + l, err = net.Listen(tcp_addr_str_class(cs.Addr), cs.Addr) + if err == nil { + if s.stop_req.Load() == false { + if s.wpxtlscfg == nil { // TODO: change this + err = cs.Serve(l) + } else { + err = cs.ServeTLS(l, "", "") // s.wpxtlscfg must provide a certificate and a key + } + } else { + err = fmt.Errorf("stop requested") + } + l.Close() + } + } else { + err = fmt.Errorf("stop requested") + } + if errors.Is(err, http.ErrServerClosed) { + s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i) + } else { + s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error()) + } + l_wg.Done() + }(idx, wpx) + } + l_wg.Wait() +} + func (s *Server) ReqStop() { if s.stop_req.CompareAndSwap(false, true) { var l *net.TCPListener @@ -1210,16 +1276,19 @@ func (s *Server) ReqStop() { hs.Shutdown(s.ctx) // to break s.pxy.Serve() } + for _, hs = range s.wpx { + hs.Shutdown(s.ctx) // to break s.wpx.Serve() + } + //s.rpc_svr.GracefulStop() //s.rpc_svr.Stop() for _, l = range s.rpc { l.Close() } + // request to stop connections from/to peer held in the cts structure s.cts_mtx.Lock() - for _, cts = range s.cts_map { - cts.ReqStop() // request to stop connections from/to peer held in the cts structure - } + for _, cts = range s.cts_map { cts.ReqStop() } s.cts_mtx.Unlock() s.stop_chan <- true @@ -1279,13 +1348,9 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p func (s *Server) ReqStopAllServerConns() { var cts *ServerConn - s.cts_mtx.Lock() - defer s.cts_mtx.Unlock() - - for _, cts = range s.cts_map { - cts.ReqStop() - } + for _, cts = range s.cts_map { cts.ReqStop() } + s.cts_mtx.Unlock() } func (s *Server) RemoveServerConn(cts *ServerConn) error { @@ -1393,7 +1458,7 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve var r *ServerRoute var err error - if route_id == "_" { + if route_id == PORT_ID_MARKER { var port_nid uint64 port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) @@ -1455,6 +1520,11 @@ func (s *Server) StartPxyService() { go s.RunPxyTask(&s.wg) } +func (s *Server) StartWpxService() { + s.wg.Add(1) + go s.RunWpxTask(&s.wg) +} + func (s *Server) StopServices() { var ext_svc Service s.ReqStop() diff --git a/system.go b/system.go index 5f624fd..420c92d 100644 --- a/system.go +++ b/system.go @@ -24,7 +24,7 @@ func monotonic_time() uint64 { var r uintptr var sts syscall.Timespec r, _, _/*errno*/ = syscall.Syscall(syscall.SYS_CLOCK_GETTIME, unix.CLOCK_MONOTONIC, uintptr(unsafe.Pointer(&sts)), 0) - if r == ^uintptr(0) { return uint64(n) } // may be negative cast to unsigned. no other fall-back + if r == ^uintptr(0) { return uint64(n) } // may be a negative number cast to unsigned. no other fall-back return uint64(sts.Nano()) }