From 2424a63db96b89217d3f32263ea5adb2215dde3d Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Wed, 25 Dec 2024 02:14:47 +0900 Subject: [PATCH] enhanced conneciton upgrade handling in http proxy code --- client.go | 2 +- server-peer.go | 2 +- server-proxy.go | 261 +++++++++++++++++++++--------------------------- server.go | 8 +- 4 files changed, 123 insertions(+), 150 deletions(-) diff --git a/client.go b/client.go index 38e0769..fdd2766 100644 --- a/client.go +++ b/client.go @@ -1299,6 +1299,7 @@ func (c *Client) ReqStop() { var cts *ClientConn var ctl *http.Server + c.ctx_cancel() for _, ctl = range c.ctl { ctl.Shutdown(c.ctx) // to break c.ctl.ListenAndServe() } @@ -1308,7 +1309,6 @@ func (c *Client) ReqStop() { } c.stop_chan <- true - c.ctx_cancel() } } diff --git a/server-peer.go b/server-peer.go index 223f6be..64a9ab4 100644 --- a/server-peer.go +++ b/server-peer.go @@ -64,7 +64,7 @@ func (spc *ServerPeerConn) RunTask(wg *sync.WaitGroup) { goto done_without_stop } - tmr = time.NewTimer(2 * time.Second) // TODO: make this configurable... + tmr = time.NewTimer(4 * time.Second) // TODO: make this configurable... wait_for_started: for { select { diff --git a/server-proxy.go b/server-proxy.go index f6c4f0e..cf0c546 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -80,7 +80,7 @@ func delete_hop_by_hop_headers(header http.Header) { } } -func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_prefix string) { +func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_prefix string) bool { var hdr http.Header var newhdr http.Header var remote_addr string @@ -89,17 +89,20 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref var ok bool var err error var conn_addr net.Addr + var upgrade_required bool //newreq.Header = req.Header.Clone() - copy_headers(newreq.Header, req.Header) - delete_hop_by_hop_headers(newreq.Header) - hdr = req.Header newhdr = newreq.Header + copy_headers(newhdr, hdr) + delete_hop_by_hop_headers(newhdr) + + // put back the upgrade header removed by delete_hop_by_hop_headers if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") { newhdr.Set("Connection", "Upgrade") newhdr.Set("Upgrade", hdr.Get("Upgrade")) + upgrade_required = true } /* @@ -108,7 +111,6 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref } */ - remote_addr, _, err = net.SplitHostPort(req.RemoteAddr) if err == nil { oldv, ok = hdr["X-Forwarded-For"] @@ -158,6 +160,8 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref newhdr.Set("X-Forwarded-Prefix", v[0] + path_prefix) } } + + return upgrade_required } // ------------------------------------ @@ -293,10 +297,8 @@ func (pxy *server_proxy_http_main) get_upgrade_type(hdr http.Header) string { return "" } -func (pxy *server_proxy_http_main) serve_upgraded(w http.ResponseWriter, req *http.Request, proxy_res *http.Response) { +func (pxy *server_proxy_http_main) serve_upgraded(w http.ResponseWriter, req *http.Request, proxy_res *http.Response) error { var err_chan chan error - var req_up_type string - var res_up_type string var proxy_res_body io.ReadWriteCloser var rc *http.ResponseController var client_conn net.Conn @@ -304,61 +306,88 @@ func (pxy *server_proxy_http_main) serve_upgraded(w http.ResponseWriter, req *ht var ok bool var err error - req_up_type = pxy.get_upgrade_type(req.Header) - res_up_type = pxy.get_upgrade_type(proxy_res.Header) - if !strings.EqualFold(req_up_type, res_up_type) { - // TODO: error - return - } - proxy_res_body, ok = proxy_res.Body.(io.ReadWriteCloser) if !ok { - //p.getErrorHandler()(w, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) - return + return fmt.Errorf("internal error - unable to cast upgraded response body") } defer proxy_res_body.Close() rc = http.NewResponseController(w) - client_conn, buf_rw, err = rc.Hijack() - if errors.Is(err, http.ErrNotSupported) { - //p.getErrorHandler()(w, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", w)) - return - } + client_conn, buf_rw, err = rc.Hijack() // take over the connection. + if err != nil { return err } - //if hijack_err != nil { - //p.getErrorHandler()(w, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) - // return - //} - defer client_conn.Close() + defer client_conn.Close() - copy_headers(w.Header(), proxy_res.Header) + copy_headers(w.Header(), proxy_res.Header) + proxy_res.Header = w.Header() - proxy_res.Header = w.Header() - proxy_res.Body = nil // so res.Write only writes the headers; we have res.Body in proxy_res_body above - err = proxy_res.Write(buf_rw) - if err != nil { - //p.getErrorHandler()(w, req, fmt.Errorf("response write: %v", err)) - return - } - err = buf_rw.Flush() - if err != nil { - //p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) - return - } + // reset it to make Write() and Flush() to handle the headers only. + // the goroutines below will use the saved proxy_res_body. + proxy_res.Body = nil - err_chan = make(chan error, 2) - go func() { - var err error + err = proxy_res.Write(buf_rw) + if err != nil { return fmt.Errorf("unable to write upgraded response header - %s", err.Error()) } + + err = buf_rw.Flush() + if err != nil { return fmt.Errorf("unable to flush upgraded response header - %s", err.Error()) } + + err_chan = make(chan error, 2) + go func() { + var err error _, err = io.Copy(client_conn, proxy_res_body) err_chan <- err - }() + }() - go func() { - var err error + go func() { + var err error _, err = io.Copy(proxy_res_body, client_conn) err_chan <- err - }() - <-err_chan + }() + err =<-err_chan + + return err +} + +func (pxy *server_proxy_http_main) addr_to_transport (ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) { + var err error + var dialer *net.Dialer + var conn net.Conn + + dialer = &net.Dialer{} + conn, err = dialer.DialContext(ctx, "tcp", addr.String()) + if err != nil { return nil, err } + + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return conn, nil + }, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // TODO: make this part configurable? + }, nil +} + +func (pxy *server_proxy_http_main) req_to_proxy_url (req *http.Request, r *ServerRoute, path_prefix string) *url.URL { + var proxy_proto string + var proxy_url_path string + + // HTTP or HTTPS is actually a hint to the client-side peer + // Use the hint to compose the URL to the client via the server-side + // listening socket as if it connects to the client-side peer + if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { + proxy_proto = "https" + } else { + proxy_proto = "http" + } + + proxy_url_path = req.URL.Path + if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } + + return &url.URL{ + Scheme: proxy_proto, + Host: r.ptc_addr, + Path: proxy_url_path, + RawQuery: req.URL.RawQuery, + Fragment: req.URL.Fragment, + } } func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -368,14 +397,13 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re var conn_id string var route_id string var path_prefix string - var client *http.Client var resp *http.Response var transport *http.Transport + var client *http.Client var addr *net.TCPAddr var proxy_req *http.Request var proxy_url *url.URL - var proxy_url_path string - var proxy_proto string + var upgrade_required bool var err error defer func() { @@ -398,115 +426,57 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re } addr = svc_addr_to_dst_addr(r.svc_addr) + //transport, err = pxy.addr_to_transport(req.Context(), addr) + transport, err = pxy.addr_to_transport(s.ctx, addr) + if err != nil { + status_code = http.StatusBadGateway; w.WriteHeader(status_code) + goto oops + } + proxy_url = pxy.req_to_proxy_url(req, r, path_prefix) -/* - if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") && - httpguts.HeaderValuesContainsToken(req.Header["Upgrade"], "websocket") { - // websocket upgrade - var ws_url string + s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url) - if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { - proxy_proto = "wss" - } else { - proxy_proto = "ws" - } - proxy_url_path = req.URL.Path - if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } + //proxy_req, err = http.NewRequestWithContext(req.Context(), req.Method, proxy_url.String(), req.Body) + proxy_req, err = http.NewRequestWithContext(s.ctx, req.Method, proxy_url.String(), req.Body) + if err != nil { + status_code = http.StatusInternalServerError; w.WriteHeader(status_code) + goto oops + } - ws_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, proxy_url_path) - - pxy.serve_websocket(w, req, ws_url, addr) - - } else */{ - var dialer *net.Dialer - var conn net.Conn - - dialer = &net.Dialer{} - conn, err = dialer.DialContext(req.Context(), "tcp", addr.String()) - if err != nil { - status_code = http.StatusBadGateway; w.WriteHeader(status_code) - goto oops - } - - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return conn, nil - }, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - - client = &http.Client{ - Transport: transport, - CheckRedirect: prevent_follow_redirect, - } - - // HTTP or HTTPS is actually a hint to the client-side peer - // Use the hint to compose the URL to the client via the server-side - // listening socket as if it connects to the client-side peer - if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { - proxy_proto = "https" - } else { - proxy_proto = "http" - } - - proxy_url_path = req.URL.Path - if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } - - proxy_url = &url.URL{ - Scheme: proxy_proto, - Host: r.ptc_addr, - Path: proxy_url_path, - RawQuery: req.URL.RawQuery, - Fragment: req.URL.Fragment, - } - - s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url) - - proxy_req, err = http.NewRequestWithContext(req.Context(), req.Method, proxy_url.String(), req.Body) - if err != nil { - status_code = http.StatusInternalServerError; w.WriteHeader(status_code) - goto oops - } - - mutate_proxy_req_headers(req, proxy_req, path_prefix) - - //fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header) - resp, err = client.Do(proxy_req) - if err != nil { - status_code = http.StatusInternalServerError; w.WriteHeader(status_code) - goto oops + upgrade_required = mutate_proxy_req_headers(req, proxy_req, path_prefix) +fmt.Printf ("AAAAAAAAAAAAAAAAAAAAa\n") +//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header) + client = &http.Client{ + Transport: transport, + CheckRedirect: prevent_follow_redirect, + Timeout: 5 * time.Second, + } + resp, err = client.Do(proxy_req) +fmt.Printf ("BBBBBBBBBBBBBBBBBBBBBBBB\n") + //resp, err = transport.RoundTrip(proxy_req) + if err != nil { + status_code = http.StatusInternalServerError; w.WriteHeader(status_code) + goto oops + } else { + status_code = resp.StatusCode + if upgrade_required && resp.StatusCode == http.StatusSwitchingProtocols { + s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) + err = pxy.serve_upgraded(w, req, resp) + if err != nil { goto oops } + return // print the log mesage before calling serve_upgraded() and exit here } else { var hdr http.Header - //var loc string - defer resp.Body.Close() - if resp.StatusCode == http.StatusSwitchingProtocols { - if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { - pxy.serve_upgraded(w, req, resp) - return - } - } - hdr = w.Header() copy_headers(hdr, resp.Header) delete_hop_by_hop_headers(hdr) - /* - loc = hdr.Get("Location") - if loc != "" { - strings.Replace(lv, r.ptc_addr, req.Host - hdr.Set("Location", xxx) - }*/ if path_prefix == "" { hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) -//fmt.Printf("<<<%s=%s-%s; Path=/; HttpOnly>>>\n", SERVER_PROXY_ID_COOKIE, conn_id, route_id) } - status_code = resp.StatusCode; w.WriteHeader(status_code) - - // TODO" if prefixed { append prefix removed... + w.WriteHeader(status_code) io.Copy(w, resp.Body) - // TODO: handle trailers } } @@ -520,7 +490,6 @@ oops: return } - // ------------------------------------ type server_proxy_xterm_file struct { s *Server diff --git a/server.go b/server.go index 8c2da24..c8a7bd9 100644 --- a/server.go +++ b/server.go @@ -1010,7 +1010,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs s.pxy_mux.Handle("/_http/{conn_id}/{route_id}/{trailer...}", &server_proxy_http_main{s: &s, prefix: "/_http"}) s.pxy_mux.Handle("/_init/{conn_id}/{route_id}/{trailer...}", &server_proxy_http_init{s: &s, prefix: "/_init"}) - s.pxy_mux.Handle("/", &server_proxy_http_main{s: &s}) + s.pxy_mux.Handle("/", &server_proxy_http_main{s: &s, prefix: ""}) s.pxy_addr = make([]string, len(pxy_addrs)) s.pxy = make([]*http.Server, len(pxy_addrs)) @@ -1197,6 +1197,11 @@ func (s *Server) ReqStop() { var cts *ServerConn var hs *http.Server + // call cancellation function before anything else + // to break sub-tasks relying on this server context. + // for example, http.Client in server_proxy_http_main + s.ctx_cancel() + for _, hs = range s.ctl { hs.Shutdown(s.ctx) // to break s.ctl.Serve() } @@ -1218,7 +1223,6 @@ func (s *Server) ReqStop() { s.cts_mtx.Unlock() s.stop_chan <- true - s.ctx_cancel() } }