From c1f7c5d4c2bf3fd7f841999f1b6fcf86e4c28fc5 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Mon, 23 Dec 2024 01:27:47 +0900 Subject: [PATCH] enhancing http proxy code --- hodu.go | 18 +++ server-proxy.go | 355 +++++++++++++++++++++++++++++------------------- 2 files changed, 237 insertions(+), 136 deletions(-) diff --git a/hodu.go b/hodu.go index a0defa2..0c57319 100644 --- a/hodu.go +++ b/hodu.go @@ -1,5 +1,6 @@ package hodu +import "net" import "net/http" import "net/netip" import "os" @@ -112,3 +113,20 @@ func dump_call_frame_and_exit(log Logger, req *http.Request, err interface{}) { log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, req.URL.String(), err, string(buf)) os.Exit(99) // fatal error. treat panic() as a fatal runtime error } + +func svc_addr_to_dst_addr (svc_addr *net.TCPAddr) *net.TCPAddr { + var addr net.TCPAddr + + addr = *svc_addr + if addr.IP.To4() != nil { + if addr.IP.IsUnspecified() { + addr.IP = net.IPv4(127, 0, 0, 1) // net.IPv4loopback is not defined. so use net.IPv4() + } + } else { + if addr.IP.IsUnspecified() { + addr.IP = net.IPv6loopback // net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + } + } + + return &addr +} diff --git a/server-proxy.go b/server-proxy.go index fe99ee7..f6c4f0e 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -1,5 +1,6 @@ package hodu +import "bufio" import "context" import "crypto/tls" import _ "embed" @@ -101,9 +102,12 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref newhdr.Set("Upgrade", hdr.Get("Upgrade")) } +/* if httpguts.HeaderValuesContainsToken(hdr["Te"], "trailers") { newhdr.Set("Te", "trailers") } +*/ + remote_addr, _, err = net.SplitHostPort(req.RemoteAddr) if err == nil { @@ -246,19 +250,128 @@ func (pxy *server_proxy_http_main) serve_websocket(w http.ResponseWriter, req *h }).ServeHTTP(w, req) } +func (pxy *server_proxy_http_main) get_route(req *http.Request) (*ServerRoute, string, string, string, error) { + var conn_id string + var route_id string + var r *ServerRoute + var path_prefix string + var err error + + conn_id = req.PathValue("conn_id") + route_id = req.PathValue("route_id") + if conn_id == "" && route_id == "" { + // it's not via /_http/<>/<>. + // get ids from the cookie. + var id *http.Cookie + var ids []string + + id, err = req.Cookie(SERVER_PROXY_ID_COOKIE) + if err != nil { + return nil, "", "", "", fmt.Errorf("%s cookie not found - %s", SERVER_PROXY_ID_COOKIE, err.Error()) + } + + ids = strings.Split(id.Value, "-") + if (len(ids) != 2) { + return nil, "", "", "", fmt.Errorf("invalid proxy id cookie value - %s", id.Value) + } + + conn_id = ids[0] + route_id = ids[1] + path_prefix = "" + } else { + path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id) + } + + r, err = pxy.s.FindServerRouteByIdStr(conn_id, route_id) + if err != nil { return nil, "", "", "", err } + + return r, path_prefix, conn_id, route_id, nil +} + +func (pxy *server_proxy_http_main) get_upgrade_type(hdr http.Header) string { + if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") { return hdr.Get("Upgrade") } + return "" +} + +func (pxy *server_proxy_http_main) serve_upgraded(w http.ResponseWriter, req *http.Request, proxy_res *http.Response) { + var err_chan chan error + var req_up_type string + var res_up_type string + var proxy_res_body io.ReadWriteCloser + var rc *http.ResponseController + var client_conn net.Conn + var buf_rw *bufio.ReadWriter + var ok bool + var err error + + req_up_type = pxy.get_upgrade_type(req.Header) + res_up_type = pxy.get_upgrade_type(proxy_res.Header) + if !strings.EqualFold(req_up_type, res_up_type) { + // TODO: error + return + } + + proxy_res_body, ok = proxy_res.Body.(io.ReadWriteCloser) + if !ok { + //p.getErrorHandler()(w, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer proxy_res_body.Close() + + rc = http.NewResponseController(w) + client_conn, buf_rw, err = rc.Hijack() + if errors.Is(err, http.ErrNotSupported) { + //p.getErrorHandler()(w, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", w)) + return + } + + //if hijack_err != nil { + //p.getErrorHandler()(w, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) + // return + //} + defer client_conn.Close() + + copy_headers(w.Header(), proxy_res.Header) + + proxy_res.Header = w.Header() + proxy_res.Body = nil // so res.Write only writes the headers; we have res.Body in proxy_res_body above + err = proxy_res.Write(buf_rw) + if err != nil { + //p.getErrorHandler()(w, req, fmt.Errorf("response write: %v", err)) + return + } + err = buf_rw.Flush() + if err != nil { + //p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + + err_chan = make(chan error, 2) + go func() { + var err error + _, err = io.Copy(client_conn, proxy_res_body) + err_chan <- err + }() + + go func() { + var err error + _, err = io.Copy(proxy_res_body, client_conn) + err_chan <- err + }() + <-err_chan +} + func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Request) { var s *Server var r *ServerRoute var status_code int - var id *http.Cookie var conn_id string var route_id string var path_prefix string var client *http.Client var resp *http.Response - var tcp_conn *net.TCPConn var transport *http.Transport - var addr net.TCPAddr + var addr *net.TCPAddr var proxy_req *http.Request var proxy_url *url.URL var proxy_url_path string @@ -277,150 +390,125 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re if ctx.Done() != nil { } */ - conn_id = req.PathValue("conn_id") - route_id = req.PathValue("route_id") - if conn_id == "" && route_id == "" { - // it's not via /_http/<>/<> - var ids []string - id, err = req.Cookie(SERVER_PROXY_ID_COOKIE) - if err != nil { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - err = fmt.Errorf("%s cookie not found - %s", SERVER_PROXY_ID_COOKIE, err.Error()) - goto oops - } - - ids = strings.Split(id.Value, "-") - if (len(ids) != 2) { - status_code = http.StatusBadRequest; w.WriteHeader(status_code) - err = fmt.Errorf("invalid proxy id cookie value - %s", id.Value) - goto oops - } - - conn_id = ids[0] - route_id = ids[1] - path_prefix = "" - } else { - path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id) - } - - r, err = s.FindServerRouteByIdStr(conn_id, route_id) + r, path_prefix, conn_id, route_id, err = pxy.get_route(req) if err != nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) goto oops } - addr = *r.svc_addr; - if addr.IP.To4() != nil { - addr.IP = net.IPv4(127, 0, 0, 1) // net.IPv4loopback is not defined. so use net.IPv4() - } else { - addr.IP = net.IPv6loopback // net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - } + addr = svc_addr_to_dst_addr(r.svc_addr) +/* + if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") && + httpguts.HeaderValuesContainsToken(req.Header["Upgrade"], "websocket") { + // websocket upgrade + var ws_url string -if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { - var ws_url string + if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { + proxy_proto = "wss" + } else { + proxy_proto = "ws" + } + proxy_url_path = req.URL.Path + if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } - if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { - proxy_proto = "wss" - } else { - proxy_proto = "ws" - } - proxy_url_path = req.URL.Path - if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } + ws_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, proxy_url_path) - ws_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, proxy_url_path) + pxy.serve_websocket(w, req, ws_url, addr) - pxy.serve_websocket(w, req, ws_url, &addr) - return -} + } else */{ + var dialer *net.Dialer + var conn net.Conn - - tcp_conn, err = net.DialTCP("tcp", nil, &addr) // need to be specific between tcp4 and tcp6? maybe not - if err != nil { - status_code = http.StatusBadGateway; w.WriteHeader(status_code) - goto oops - } - - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return tcp_conn, nil - }, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - - client = &http.Client{ - Transport: transport, - CheckRedirect: prevent_follow_redirect, - } - - // HTTP or HTTPS is actually a hint to the client-side peer - // Use the hint to compose the URL to the client via the server-side - // listening socket as if it connects to the client-side peer - if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { - proxy_proto = "https" - } else { - proxy_proto = "http" - } - - proxy_url_path = req.URL.Path - if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } - - proxy_url = &url.URL{ - Scheme: proxy_proto, - Host: r.ptc_addr, - Path: proxy_url_path, - RawQuery: req.URL.RawQuery, - Fragment: req.URL.Fragment, - } - - s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url) - -// TODO: http.NewRequestWithContext().?? - proxy_req, err = http.NewRequest(req.Method, proxy_url.String(), req.Body) - if err != nil { - status_code = http.StatusInternalServerError; w.WriteHeader(status_code) - goto oops - } - - mutate_proxy_req_headers(req, proxy_req, path_prefix) - -//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header) - - resp, err = client.Do(proxy_req) - if err != nil { - status_code = http.StatusInternalServerError; w.WriteHeader(status_code) - goto oops - } else { - var hdr http.Header - //var loc string - - defer resp.Body.Close() - - if resp.StatusCode == http.StatusSwitchingProtocols { - // TODO: + dialer = &net.Dialer{} + conn, err = dialer.DialContext(req.Context(), "tcp", addr.String()) + if err != nil { + status_code = http.StatusBadGateway; w.WriteHeader(status_code) + goto oops } - hdr = w.Header() - copy_headers(hdr, resp.Header) - delete_hop_by_hop_headers(hdr) - /* - loc = hdr.Get("Location") - if loc != "" { - strings.Replace(lv, r.ptc_addr, req.Host - hdr.Set("Location", xxx) - }*/ - - if path_prefix == "" { - hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return conn, nil + }, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - status_code = resp.StatusCode; w.WriteHeader(status_code) - // TODO" if prefixed { append prefix removed... - io.Copy(w, resp.Body) + client = &http.Client{ + Transport: transport, + CheckRedirect: prevent_follow_redirect, + } + // HTTP or HTTPS is actually a hint to the client-side peer + // Use the hint to compose the URL to the client via the server-side + // listening socket as if it connects to the client-side peer + if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { + proxy_proto = "https" + } else { + proxy_proto = "http" + } - // TODO: handle trailers + proxy_url_path = req.URL.Path + if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } + + proxy_url = &url.URL{ + Scheme: proxy_proto, + Host: r.ptc_addr, + Path: proxy_url_path, + RawQuery: req.URL.RawQuery, + Fragment: req.URL.Fragment, + } + + s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url) + + proxy_req, err = http.NewRequestWithContext(req.Context(), req.Method, proxy_url.String(), req.Body) + if err != nil { + status_code = http.StatusInternalServerError; w.WriteHeader(status_code) + goto oops + } + + mutate_proxy_req_headers(req, proxy_req, path_prefix) + + //fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header) + resp, err = client.Do(proxy_req) + if err != nil { + status_code = http.StatusInternalServerError; w.WriteHeader(status_code) + goto oops + } else { + var hdr http.Header + //var loc string + + defer resp.Body.Close() + + if resp.StatusCode == http.StatusSwitchingProtocols { + if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { + pxy.serve_upgraded(w, req, resp) + return + } + } + + hdr = w.Header() + copy_headers(hdr, resp.Header) + delete_hop_by_hop_headers(hdr) + /* + loc = hdr.Get("Location") + if loc != "" { + strings.Replace(lv, r.ptc_addr, req.Host + hdr.Set("Location", xxx) + }*/ + + if path_prefix == "" { + hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) +//fmt.Printf("<<<%s=%s-%s; Path=/; HttpOnly>>>\n", SERVER_PROXY_ID_COOKIE, conn_id, route_id) + } + status_code = resp.StatusCode; w.WriteHeader(status_code) + + // TODO" if prefixed { append prefix removed... + io.Copy(w, resp.Body) + + // TODO: handle trailers + } } //done: @@ -540,7 +628,7 @@ func (pxy *server_proxy_ssh_ws) send_ws_data(ws *websocket.Conn, type_val string func (pxy *server_proxy_ssh_ws) connect_ssh (ctx context.Context, username string, password string, r *ServerRoute) ( *ssh.Client, *ssh.Session, io.Writer, io.Reader, error) { var cc *ssh.ClientConfig - var addr net.TCPAddr + var addr *net.TCPAddr var dialer *net.Dialer var conn net.Conn var ssh_conn ssh.Conn @@ -565,12 +653,7 @@ func (pxy *server_proxy_ssh_ws) connect_ssh (ctx context.Context, username strin //} // TODO: timeout... - addr = *r.svc_addr; - if addr.IP.To4() != nil { - addr.IP = net.IPv4(127, 0, 0, 1) // net.IPv4loopback is not defined. so use net.IPv4() - } else { - addr.IP = net.IPv6loopback // net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - } + addr = svc_addr_to_dst_addr(r.svc_addr); dialer = &net.Dialer{} conn, err = dialer.DialContext(ctx, "tcp", addr.String())