diff --git a/Makefile b/Makefile index 2cdee92..680eed9 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,10 @@ +# make +# make GOARCH=386 +# make GOARCH=amd64 +# make GOOS=linux GOARCH=mips +# +# 'go tool dist list' for available os and architextures + NAME=hodu VERSION=1.0.0 diff --git a/client-ctl.go b/client-ctl.go index d7bc32e..7186a13 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -248,7 +248,7 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt conn_id = req.PathValue("conn_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8)) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } @@ -333,7 +333,7 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r conn_id = req.PathValue("conn_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8)) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id }); err != nil { goto oops } @@ -456,13 +456,13 @@ func (ctl *client_ctl_client_conns_id_routes_id) ServeHTTP(w http.ResponseWriter conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8)) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } goto done } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8)) + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops } @@ -538,13 +538,13 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers) ServeHTTP(w http.Response conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, 32) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } goto done } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8)) + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops } @@ -622,19 +622,19 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers_id) ServeHTTP(w http.Respo route_id = req.PathValue("route_id") peer_id = req.PathValue("peer_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, 32) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } goto done } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8)) + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops } goto done } - peer_nid, err = strconv.ParseUint(peer_id, 10, 32) + peer_nid, err = strconv.ParseUint(peer_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong peer id - " + peer_id}); err != nil { goto oops } diff --git a/packet.go b/packet.go index 9111a94..f962725 100644 --- a/packet.go +++ b/packet.go @@ -1,10 +1,15 @@ package hodu -type ConnId uint64 -type RouteId uint32 // keep this in sync with the type of RouteId in hodu.proto -type PeerId uint32 // keep this in sync with the type of RouteId in hodu.proto +type ConnId uint64 +type RouteId uint32 // keep this in sync with the type of RouteId in hodu.proto +type PeerId uint32 // keep this in sync with the type of RouteId in hodu.proto type RouteOption uint32 +type ConnRouteId struct { + conn_id ConnId + route_id RouteId +} + func MakeRouteStartPacket(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_addr string, svc_net string) *Packet { return &Packet{ Kind: PACKET_KIND_ROUTE_START, diff --git a/server-ctl.go b/server-ctl.go index 0b8ab92..d78b4cc 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -155,7 +155,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt conn_id = req.PathValue("conn_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, 32) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } @@ -237,7 +237,7 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r conn_id = req.PathValue("conn_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, 32) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } @@ -296,14 +296,15 @@ oops: func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter, req *http.Request) { var s *Server var status_code int - var err error + var port_id string var conn_id string - var conn_nid uint64 var route_id string + var port_nid uint64 + var conn_nid uint64 var route_nid uint64 var je *json.Encoder - var cts *ServerConn var r *ServerRoute + var err error defer func() { var err interface{} = recover() @@ -315,32 +316,54 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } - goto done - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops } - goto done + if route_id == "_" { + port_id = conn_id + route_id = "" } - cts = s.FindServerConnById(ConnId(conn_nid)) - if cts == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } - goto done - } + if port_id != "" { + // this condition is for ssh proxy server. + port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "wrong port id - " + port_id}); err != nil { goto oops } + goto done + } - r = cts.FindServerRouteById(RouteId(route_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent route id - " + conn_id}); err != nil { goto oops } - goto done + r = s.FindServerRouteByPortId(PortId(port_nid)) + if r == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "non-existent port id - " + port_id}); err != nil { goto oops } + } + } else { + var cts *ServerConn + + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } + goto done + } + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops } + goto done + } + + cts = s.FindServerConnById(ConnId(conn_nid)) + if cts == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } + goto done + } + + r = cts.FindServerRouteById(RouteId(route_nid)) + if r == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "non-existent route id - " + conn_id}); err != nil { goto oops } + goto done + } } switch req.Method { diff --git a/server-proxy.go b/server-proxy.go index e1c6d86..dc87bdf 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -81,7 +81,7 @@ func delete_hop_by_hop_headers(header http.Header) { } } -func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) { +func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path bool) { var hdr http.Header var newhdr http.Header var remote_addr string @@ -128,6 +128,13 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) { if !ok { newhdr.Set("X-Forwarded-Host", req.Host) } + + if add_path { + _, ok = newhdr["X-Forwarded-Path"] + if !ok { + newhdr.Set("X-Forwarded-Path", req.URL.RawPath) + } + } } // ------------------------------------ @@ -151,12 +158,12 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8)) + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) goto oops } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8)) + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) goto oops @@ -240,12 +247,12 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re goto oops } - conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(conn_nid) * 8)) + conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) goto oops } - route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(route_nid) * 8)) + route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(RouteId(0)) * 8)) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) goto oops @@ -309,7 +316,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { req_upgrade_type = req.Header.Get("Upgrade") } - mutate_proxy_req_headers(req, proxy_req) + mutate_proxy_req_headers(req, proxy_req, false) if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { proxy_req.Header.Set("Te", "trailers") @@ -374,50 +381,106 @@ type server_proxy_xterm_session_info struct { } func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Request) { + var s *Server + var status_code int + var err error + defer func() { var err interface{} = recover() if err != nil { dump_call_frame_and_exit(pxy.s.log, req, err) } }() - + + s = pxy.s + // TODO: logging switch pxy.file { case "xterm.js": w.Header().Set("Content-Type", "text/javascript") - w.WriteHeader(http.StatusOK) + status_code = http.StatusOK; w.WriteHeader(status_code) w.Write(xterm_js) case "xterm-addon-fit.js": w.Header().Set("Content-Type", "text/javascript") - w.WriteHeader(http.StatusOK) + status_code = http.StatusOK; w.WriteHeader(status_code) w.Write(xterm_addon_fit_js) case "xterm.css": w.Header().Set("Content-Type", "text/css") - w.WriteHeader(http.StatusOK) + status_code = http.StatusOK; w.WriteHeader(status_code) w.Write(xterm_css) case "xterm.html": var tmpl *template.Template - var err error + var port_id string + var conn_id string + var route_id string + var port_nid uint64 + var conn_nid uint64 + var route_nid uint64 + var r *ServerRoute + conn_id = req.PathValue("conn_id") + route_id = req.PathValue("route_id") + if route_id == "_" { port_id = conn_id } + + if port_id != "" { + port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + goto oops + } + + r = s.FindServerRouteByPortId(PortId(port_nid)) + if r == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + goto done + } + } else { + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + goto oops + } + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + goto oops + } + + r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) + if r == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + goto done + } + } + tmpl = template.New("") _, err = tmpl.Parse(string(xterm_html)) if err != nil { - w.WriteHeader(http.StatusInternalServerError) + status_code = http.StatusInternalServerError; w.WriteHeader(status_code) + goto oops } else { w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) tmpl.Execute(w, &server_proxy_xterm_session_info{ - ConnId: req.PathValue("conn_id"), - RouteId: req.PathValue("route_id"), + ConnId: conn_id, + RouteId: route_id, }) } case "_forbidden": - w.WriteHeader(http.StatusForbidden) + status_code = http.StatusForbidden; w.WriteHeader(status_code) default: - w.WriteHeader(http.StatusNotFound) + status_code = http.StatusNotFound; w.WriteHeader(status_code) } -// TODO: logging.. + +done: + s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) + return + +oops: + s.log.Write("", LOG_ERROR, "[%s] %s %s %d - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, err.Error()) + return } + // ------------------------------------ type server_proxy_ssh_ws struct { @@ -511,9 +574,11 @@ oops: func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) { var s *Server var req *http.Request + var port_id string var conn_id string - var conn_nid uint64 var route_id string + var port_nid uint64 + var conn_nid uint64 var route_nid uint64 var r *ServerRoute var username string @@ -539,24 +604,42 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) { conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") + if route_id == "_" { port_id = conn_id } - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8)) - if err != nil { - // TODO: - goto done - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8)) - if err != nil { - // TODO: - goto done - } + if port_id != "" { + port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) + if err != nil { + err = fmt.Errorf("invalid port id - %s", port_id) + pxy.send_ws_data(ws, "error", err.Error()) + goto done + } - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { - // TODO: enhance logging. original request, conn_nid, route_nid - pxy.send_ws_data(ws, "error", fmt.Sprintf("route(%d,%d) not found", conn_nid, route_nid)) - s.log.Write("", LOG_ERROR, "No server route(%d,%d) found", conn_nid, route_nid) - goto done + r = s.FindServerRouteByPortId(PortId(port_nid)) + if r == nil { + err = fmt.Errorf("port(%d) not found", conn_nid, route_nid) + pxy.send_ws_data(ws, "error", err.Error()) + goto done + } + } else { + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { + err = fmt.Errorf("invalid connection id - %s", conn_id) + pxy.send_ws_data(ws, "error", err.Error()) + goto done + } + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) + if err != nil { + err = fmt.Errorf("invalid route id - %s", route_id) + pxy.send_ws_data(ws, "error", err.Error()) + goto done + } + + r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) + if r == nil { + err = fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) + pxy.send_ws_data(ws, "error", err.Error()) + goto done + } } wg.Add(1) @@ -598,11 +681,8 @@ ws_recv_loop: for { var msg []byte err = websocket.Message.Receive(ws, &msg) - if err != nil { - // TODO: check if EOF - s.log.Write("", LOG_ERROR, "Failed to read from websocket - %s", err.Error()) - goto done - } + if err != nil { goto done } + if len(msg) > 0 { var ev json_ssh_ws_event err = json.Unmarshal(msg, &ev) @@ -669,9 +749,7 @@ ws_recv_loop: if sess != nil { err = pxy.send_ws_data(ws, "status", "closed") - if err != nil { - s.log.Write("", LOG_ERROR, "Failed to write closed event to websocket - %s", err.Error()) - } + if err != nil { goto done } } done: @@ -680,5 +758,9 @@ done: if sess != nil { sess.Close() } if c != nil { c.Close() } wg.Wait() - s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String()) + if err != nil { + s.log.Write("", LOG_ERROR, "[%s] %s %s - %s", req.RemoteAddr, req.Method, req.URL.String(), err.Error()) + } else { + s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String()) + } } diff --git a/server.go b/server.go index 4ae435d..71b9c7c 100644 --- a/server.go +++ b/server.go @@ -22,10 +22,13 @@ import "google.golang.org/grpc/stats" const PTS_LIMIT int = 16384 const CTS_LIMIT int = 16384 +type PortId uint16 + type ServerConnMapByAddr = map[net.Addr]*ServerConn type ServerConnMap = map[ConnId]*ServerConn type ServerRouteMap = map[RouteId]*ServerRoute type ServerPeerConnMap = map[PeerId]*ServerPeerConn +type ServerSvcPortMap = map[PortId]ConnRouteId type Server struct { ctx context.Context @@ -65,6 +68,9 @@ type Server struct { log Logger + svc_port_mtx sync.Mutex + svc_port_map ServerSvcPortMap + stats struct { conns atomic.Int64 routes atomic.Int64 @@ -316,6 +322,8 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r var l *net.TCPListener var svcaddr *net.TCPAddr var nw string + var prev_cri ConnRouteId + var ok bool var err error if svc_requested_addr != "" { @@ -354,7 +362,20 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r } svcaddr = l.Addr().(*net.TCPAddr) - cts.svr.log.Write(cts.sid, LOG_DEBUG, "Route(%d) listening on %s", id, svcaddr.String()) + + cts.svr.svc_port_mtx.Lock() + prev_cri, ok = cts.svr.svc_port_map[PortId(svcaddr.Port)] + if ok { + cts.svr.log.Write(cts.sid, LOG_ERROR, + "Route(%d,%d) on %s not unique by port number - existing route(%d,%d)", + cts.id, id, prev_cri.conn_id, prev_cri.route_id, svcaddr.String()) + l.Close() + return nil, nil, err + } + cts.svr.svc_port_map[PortId(svcaddr.Port)] = ConnRouteId{conn_id: cts.id, route_id: id} + cts.svr.svc_port_mtx.Unlock() + + cts.svr.log.Write(cts.sid, LOG_DEBUG, "Route(%d,%d) listening on %s", cts.id, id, svcaddr.String()) return l, svcaddr, nil } @@ -403,6 +424,10 @@ func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error { cts.svr.stats.routes.Add(-1) cts.route_mtx.Unlock() + cts.svr.svc_port_mtx.Lock() + delete(cts.svr.svc_port_map, PortId(route.svc_addr.Port)) + cts.svr.svc_port_mtx.Unlock() + r.ReqStop() return nil } @@ -421,6 +446,10 @@ func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, er cts.svr.stats.routes.Add(-1) cts.route_mtx.Unlock() + cts.svr.svc_port_mtx.Lock() + delete(cts.svr.svc_port_map, PortId(r.svc_addr.Port)) + cts.svr.svc_port_mtx.Unlock() + r.ReqStop() return r, nil } @@ -913,6 +942,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs s.cts_next_id = 0 s.cts_map = make(ServerConnMap) s.cts_map_by_addr = make(ServerConnMapByAddr) + s.svc_port_map = make(ServerSvcPortMap) s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) @@ -965,9 +995,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs s.pxy_ws = &server_proxy_ssh_ws{s: &s} s.pxy_mux = http.NewServeMux() // TODO: make /_init configurable... s.pxy_mux.Handle("/_ssh-ws/{conn_id}/{route_id}", - websocket.Handler(func(ws *websocket.Conn) { - s.pxy_ws.ServeWebsocket(ws) - })) + websocket.Handler(func(ws *websocket.Conn) { s.pxy_ws.ServeWebsocket(ws) })) s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", &server_ctl_server_conns_id_routes_id{s: &s}) s.pxy_mux.Handle("/_ssh/{conn_id}/{route_id}/", &server_proxy_xterm_file{s: &s, file: "xterm.html"}) s.pxy_mux.Handle("/_ssh/xterm.js", &server_proxy_xterm_file{s: &s, file: "xterm.js"}) @@ -1337,13 +1365,23 @@ func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute { defer s.cts_mtx.Unlock() cts, ok = s.cts_map[id] - if !ok { - return nil - } + if !ok { return nil } return cts.FindServerRouteById(route_id) } +func (s *Server) FindServerRouteByPortId(port_id PortId) *ServerRoute { + var cri ConnRouteId + var ok bool + + s.svc_port_mtx.Lock() + defer s.svc_port_mtx.Unlock() + + cri, ok = s.svc_port_map[port_id] + if !ok { return nil } + return s.FindServerRouteById(cri.conn_id, cri.route_id) +} + func (s *Server) StartService(cfg interface{}) { s.wg.Add(1) diff --git a/xterm.html b/xterm.html index 458a2e8..a8aa23e 100644 --- a/xterm.html +++ b/xterm.html @@ -90,6 +90,9 @@