From e177b7dc6b99c410d62a7b80c155738375112ab4 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Tue, 17 Dec 2024 00:32:22 +0900 Subject: [PATCH] simplified some code of parsing connection ids and route ids --- server-ctl.go | 87 ++---------------- server-proxy.go | 235 ++++++++++++++++++++---------------------------- 2 files changed, 107 insertions(+), 215 deletions(-) diff --git a/server-ctl.go b/server-ctl.go index d78b4cc..cf97dec 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -3,8 +3,6 @@ package hodu import "encoding/json" import "net/http" import "runtime" -import "strconv" -import "unsafe" type json_out_server_conn struct { Id ConnId `json:"id"` @@ -142,7 +140,6 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt var err error var je *json.Encoder var conn_id string - var conn_nid uint64 var cts *ServerConn defer func() { @@ -154,18 +151,10 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt je = json.NewEncoder(w) conn_id = req.PathValue("conn_id") - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + cts, err = get_server_conn(s, conn_id) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } - goto done - } - - cts = s.FindServerConnById(ConnId(conn_nid)) - if cts == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } + if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } goto done } @@ -223,7 +212,6 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r var status_code int var err error var conn_id string - var conn_nid uint64 var je *json.Encoder var cts *ServerConn @@ -236,18 +224,10 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r je = json.NewEncoder(w) conn_id = req.PathValue("conn_id") - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + cts, err = get_server_conn(s, conn_id) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } - goto done - } - - cts = s.FindServerConnById(ConnId(conn_nid)) - if cts == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } + if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } goto done } @@ -296,12 +276,8 @@ oops: func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter, req *http.Request) { var s *Server var status_code int - var port_id string var conn_id string var route_id string - var port_nid uint64 - var conn_nid uint64 - var route_nid uint64 var je *json.Encoder var r *ServerRoute var err error @@ -316,54 +292,10 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - if route_id == "_" { - port_id = conn_id - route_id = "" - } - - if port_id != "" { - // this condition is for ssh proxy server. - port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong port id - " + port_id}); err != nil { goto oops } - goto done - } - - r = s.FindServerRouteByPortId(PortId(port_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent port id - " + port_id}); err != nil { goto oops } - } - } else { - var cts *ServerConn - - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } - goto done - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops } - goto done - } - - cts = s.FindServerConnById(ConnId(conn_nid)) - if cts == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } - goto done - } - - r = cts.FindServerRouteById(RouteId(route_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent route id - " + conn_id}); err != nil { goto oops } - goto done - } + r, err = get_server_route(s, conn_id, route_id) + if err != nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } } switch req.Method { @@ -387,8 +319,7 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter status_code = http.StatusBadRequest; w.WriteHeader(status_code) } -done: - // TODO: need to handle x-forwarded-for and other stuff? this is not a real web service, though +//done: s.log.Write("", LOG_DEBUG, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) // TODO: time taken return diff --git a/server-proxy.go b/server-proxy.go index aecd9f5..e3db606 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -58,7 +58,7 @@ var hopHeaders = []string{ "Upgrade", } -func copy_headers(src http.Header, dst http.Header) { +func copy_headers(dst http.Header, src http.Header) { var key string var val string var vals []string @@ -89,7 +89,7 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path var conn_addr net.Addr //newreq.Header = req.Header.Clone() - copy_headers(req.Header, newreq.Header) + copy_headers(newreq.Header, req.Header) delete_hop_by_hop_headers(newreq.Header) hdr = req.Header @@ -133,6 +133,74 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path } } } + +func parse_conn_route_id(conn_id string, route_id string) (ConnId, RouteId, error) { + var conn_nid uint64 + var route_nid uint64 + var err error + + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { return ConnId(0), RouteId(0), fmt.Errorf("invalid connection id - %s", err.Error()) } + + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) + if err != nil { return ConnId(0), RouteId(0), fmt.Errorf("invalid route id - %s", err.Error()) } + + return ConnId(conn_nid), RouteId(route_nid), nil +} + +func get_server_conn(s *Server, conn_id string) (*ServerConn, error) { + var conn_nid uint64 + var cts *ServerConn + var err error + + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { + return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); + } + + cts = s.FindServerConnById(ConnId(conn_nid)) + if cts == nil { + return nil, fmt.Errorf("non-existent connection id %d", conn_nid) + } + + return cts, nil +} + +func get_server_route(s *Server, conn_id string, route_id string) (*ServerRoute, error) { + var r *ServerRoute + var err error + + if route_id == "_" { + var port_nid uint64 + + port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) + if err != nil { + return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) + } + + r = s.FindServerRouteByPortId(PortId(port_nid)) + if r == nil { + return nil, fmt.Errorf("port(%d) not found", port_nid) + } + } else { + var conn_nid uint64 + var route_nid uint64 + + conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) } + + route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) + if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) } + + r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) + if r == nil { + return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) + } + } + + return r, nil +} + // ------------------------------------ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -140,9 +208,8 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re var r *ServerRoute var status_code int var conn_id string - var conn_nid uint64 var route_id string - var route_nid uint64 + var hdr http.Header var err error defer func() { @@ -155,28 +222,18 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + r, err = get_server_route(s, conn_id, route_id) if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) - goto done + goto oops } - w.Header().Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_nid, route_nid)) // use the interpreted ids. - w.Header().Set("Location", strings.TrimPrefix(req.URL.Path, fmt.Sprintf("/_init/%s/%s", conn_id, route_id))) // use the orignal id srings + hdr = w.Header() + hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, r.cts.id, r.id)) // use numeric id + hdr.Set("Location", strings.TrimPrefix(req.URL.Path, fmt.Sprintf("/_init/%s/%s", conn_id, route_id))) // use the original ids as in the request status_code = http.StatusFound; w.WriteHeader(status_code) -done: +//done: s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) return @@ -195,13 +252,9 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re var s *Server var r *ServerRoute var status_code int - var mode *http.Cookie var id *http.Cookie - var ids []string var conn_id string var route_id string - var conn_nid uint64 - var route_nid uint64 var prefixed bool var client *http.Client var resp *http.Response @@ -233,6 +286,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re route_id = req.PathValue("route_id") if conn_id == "" && route_id == "" { // it's not via /_http/<>/<> + var ids []string id, err = req.Cookie(SERVER_PROXY_ID_COOKIE) if err != nil { @@ -252,37 +306,10 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re route_id = ids[1] } - if route_id == "_" { - var port_nid uint64 - - port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - - r = s.FindServerRouteByPortId(PortId(port_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - goto done - } - } else { - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - goto done - } + r, err = get_server_route(s, conn_id, route_id) + if err != nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + goto oops } addr = *r.svc_addr; @@ -323,7 +350,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re proxy_url_path = strings.TrimPrefix(proxy_url_path, fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id)) } - //proxy_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, req.URL.Path) proxy_url = &url.URL{ Scheme: proxy_proto, Host: r.ptc_addr, @@ -371,7 +397,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re } hdr = w.Header() - copy_headers(resp.Header, hdr) + copy_headers(hdr, resp.Header) delete_hop_by_hop_headers(hdr) /* loc = hdr.Get("Location") @@ -380,14 +406,16 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re hdr.Set("Location", xxx) }*/ - w.Header().Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_nid, route_nid)) + if !prefixed { + hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) + } status_code = resp.StatusCode; w.WriteHeader(status_code) io.Copy(w, resp.Body) // TODO: handle trailers } -done: +//done: s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) return @@ -436,49 +464,18 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R w.Write(xterm_css) case "xterm.html": var tmpl *template.Template - var port_id string var conn_id string var route_id string - var port_nid uint64 - var conn_nid uint64 - var route_nid uint64 - var r *ServerRoute + //var r *ServerRoute conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - if route_id == "_" { port_id = conn_id } - - if port_id != "" { - port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - - r = s.FindServerRouteByPortId(PortId(port_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - goto done - } - } else { - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - goto oops - } - - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - goto done - } + _, err = get_server_route(s, conn_id, route_id) + if err != nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + goto oops } - + tmpl = template.New("") _, err = tmpl.Parse(string(xterm_html)) if err != nil { @@ -500,7 +497,7 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R } -done: +//done: s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) return @@ -602,12 +599,8 @@ oops: func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) { var s *Server var req *http.Request - var port_id string var conn_id string var route_id string - var port_nid uint64 - var conn_nid uint64 - var route_nid uint64 var r *ServerRoute var username string var password string @@ -632,42 +625,10 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) { conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") - if route_id == "_" { port_id = conn_id } - - if port_id != "" { - port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) - if err != nil { - err = fmt.Errorf("invalid port id - %s", port_id) - pxy.send_ws_data(ws, "error", err.Error()) - goto done - } - - r = s.FindServerRouteByPortId(PortId(port_nid)) - if r == nil { - err = fmt.Errorf("port(%d) not found", conn_nid, route_nid) - pxy.send_ws_data(ws, "error", err.Error()) - goto done - } - } else { - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { - err = fmt.Errorf("invalid connection id - %s", conn_id) - pxy.send_ws_data(ws, "error", err.Error()) - goto done - } - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { - err = fmt.Errorf("invalid route id - %s", route_id) - pxy.send_ws_data(ws, "error", err.Error()) - goto done - } - - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { - err = fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) - pxy.send_ws_data(ws, "error", err.Error()) - goto done - } + r, err = get_server_route(s, conn_id, route_id) + if err != nil { + pxy.send_ws_data(ws, "error", err.Error()) + goto done } wg.Add(1)