diff --git a/server-ctl.go b/server-ctl.go index cf97dec..40b0852 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -151,7 +151,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt je = json.NewEncoder(w) conn_id = req.PathValue("conn_id") - cts, err = get_server_conn(s, conn_id) + cts, err = s.FindServerConnByIdStr(conn_id) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } @@ -224,7 +224,7 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r je = json.NewEncoder(w) conn_id = req.PathValue("conn_id") - cts, err = get_server_conn(s, conn_id) + cts, err = s.FindServerConnByIdStr(conn_id) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } @@ -292,7 +292,7 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - r, err = get_server_route(s, conn_id, route_id) + r, err = s.FindServerRouteByIdStr(conn_id, route_id) if err != nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } diff --git a/server-proxy.go b/server-proxy.go index e3db606..f7f7991 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -14,7 +14,6 @@ import "strings" import "sync" import "text/template" import "time" -import "unsafe" import "golang.org/x/crypto/ssh" import "golang.org/x/net/http/httpguts" @@ -134,73 +133,6 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path } } -func parse_conn_route_id(conn_id string, route_id string) (ConnId, RouteId, error) { - var conn_nid uint64 - var route_nid uint64 - var err error - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { return ConnId(0), RouteId(0), fmt.Errorf("invalid connection id - %s", err.Error()) } - - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { return ConnId(0), RouteId(0), fmt.Errorf("invalid route id - %s", err.Error()) } - - return ConnId(conn_nid), RouteId(route_nid), nil -} - -func get_server_conn(s *Server, conn_id string) (*ServerConn, error) { - var conn_nid uint64 - var cts *ServerConn - var err error - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { - return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); - } - - cts = s.FindServerConnById(ConnId(conn_nid)) - if cts == nil { - return nil, fmt.Errorf("non-existent connection id %d", conn_nid) - } - - return cts, nil -} - -func get_server_route(s *Server, conn_id string, route_id string) (*ServerRoute, error) { - var r *ServerRoute - var err error - - if route_id == "_" { - var port_nid uint64 - - port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) - if err != nil { - return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) - } - - r = s.FindServerRouteByPortId(PortId(port_nid)) - if r == nil { - return nil, fmt.Errorf("port(%d) not found", port_nid) - } - } else { - var conn_nid uint64 - var route_nid uint64 - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) } - - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) } - - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { - return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) - } - } - - return r, nil -} - // ------------------------------------ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -222,7 +154,7 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - r, err = get_server_route(s, conn_id, route_id) + r, err = s.FindServerRouteByIdStr(conn_id, route_id) if err != nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) goto oops @@ -306,7 +238,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re route_id = ids[1] } - r, err = get_server_route(s, conn_id, route_id) + r, err = s.FindServerRouteByIdStr(conn_id, route_id) if err != nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) goto oops @@ -410,8 +342,11 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) } status_code = resp.StatusCode; w.WriteHeader(status_code) + + // TODO" if prefixed { append prefix removed... io.Copy(w, resp.Body) + // TODO: handle trailers } @@ -470,7 +405,7 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - _, err = get_server_route(s, conn_id, route_id) + _, err = s.FindServerRouteByIdStr(conn_id, route_id) if err != nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) goto oops @@ -625,7 +560,7 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) { conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - r, err = get_server_route(s, conn_id, route_id) + r, err = s.FindServerRouteByIdStr(conn_id, route_id) if err != nil { pxy.send_ws_data(ws, "error", err.Error()) goto done diff --git a/server.go b/server.go index f061cad..a981b2a 100644 --- a/server.go +++ b/server.go @@ -9,8 +9,10 @@ import "log" import "net" import "net/http" import "net/netip" +import "strconv" import "sync" import "sync/atomic" +import "unsafe" import "golang.org/x/net/websocket" import "google.golang.org/grpc" @@ -1383,6 +1385,58 @@ func (s *Server) FindServerRouteByPortId(port_id PortId) *ServerRoute { return s.FindServerRouteById(cri.conn_id, cri.route_id) } +func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) { + var r *ServerRoute + var err error + + if route_id == "_" { + var port_nid uint64 + + port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) + if err != nil { + return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) + } + + r = s.FindServerRouteByPortId(PortId(port_nid)) + if r == nil { + return nil, fmt.Errorf("port(%d) not found", port_nid) + } + } else { + var conn_nid uint64 + var route_nid uint64 + + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) } + + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) + if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) } + + r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) + if r == nil { + return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) + } + } + + return r, nil +} + +func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) { + var conn_nid uint64 + var cts *ServerConn + var err error + + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { + return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); + } + + cts = s.FindServerConnById(ConnId(conn_nid)) + if cts == nil { + return nil, fmt.Errorf("non-existent connection id %d", conn_nid) + } + + return cts, nil +} func (s *Server) StartService(cfg interface{}) { s.wg.Add(1)