From 91714b23aeec0ea451fa5ae7a738623bd98e98ec Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Fri, 20 Dec 2024 01:09:00 +0900 Subject: [PATCH] implementing websocket proxy --- server-proxy.go | 98 +++++++++++++++++++++++++++++++++++++++---------- server.go | 24 ++++-------- 2 files changed, 86 insertions(+), 36 deletions(-) diff --git a/server-proxy.go b/server-proxy.go index 42f4431..fe99ee7 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -4,6 +4,7 @@ import "context" import "crypto/tls" import _ "embed" import "encoding/json" +import "errors" import "fmt" import "io" import "net" @@ -94,6 +95,16 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref hdr = req.Header newhdr = newreq.Header + + if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") { + newhdr.Set("Connection", "Upgrade") + newhdr.Set("Upgrade", hdr.Get("Upgrade")) + } + + if httpguts.HeaderValuesContainsToken(hdr["Te"], "trailers") { + newhdr.Set("Te", "trailers") + } + remote_addr, _, err = net.SplitHostPort(req.RemoteAddr) if err == nil { oldv, ok = hdr["X-Forwarded-For"] @@ -149,7 +160,6 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) { var s *Server - var r *ServerRoute var status_code int var conn_id string var route_id string @@ -168,14 +178,14 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re route_id = req.PathValue("route_id") path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id) - r, err = s.FindServerRouteByIdStr(conn_id, route_id) + _, err = s.FindServerRouteByIdStr(conn_id, route_id) if err != nil { status_code = http.StatusNotFound; w.WriteHeader(status_code) goto oops } hdr = w.Header() - hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, r.cts.id, r.id)) // use numeric id + hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) // use numeric id hdr.Set("Location", strings.TrimPrefix(req.URL.Path, path_prefix)) // use the original ids as in the request status_code = http.StatusFound; w.WriteHeader(status_code) @@ -194,6 +204,48 @@ func prevent_follow_redirect (req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } +func (pxy *server_proxy_http_main) serve_websocket(w http.ResponseWriter, req *http.Request, ws_url string, target *net.TCPAddr) { + pxy.s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), ws_url) + + websocket.Handler(func(wc *websocket.Conn) { + var ws *websocket.Conn + var err_chan chan error + var err error + + defer wc.Close() + +// TODO: timeout or cancellation +// TODO: use DialConfig?? + ws, err = websocket.Dial(ws_url, "", req.Header.Get("Origin")) + if err != nil { + // TODO: logging + return + } + defer ws.Close() + + err_chan = make(chan error, 2) + + go func() { + // client to server + var err error + _, err = io.Copy(ws, wc) + err_chan <- err + }() + + go func() { + // server to client + var err error + _, err = io.Copy(wc, ws) + err_chan <- err + }() + + err = <-err_chan + if err != nil && errors.Is(err, io.EOF) { + // TODO: logging + } + }).ServeHTTP(w, req) +} + func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Request) { var s *Server var r *ServerRoute @@ -211,7 +263,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re var proxy_url *url.URL var proxy_url_path string var proxy_proto string - var req_upgrade_type string var err error defer func() { @@ -226,7 +277,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re if ctx.Done() != nil { } */ - conn_id = req.PathValue("conn_id") route_id = req.PathValue("route_id") if conn_id == "" && route_id == "" { @@ -236,6 +286,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re id, err = req.Cookie(SERVER_PROXY_ID_COOKIE) if err != nil { status_code = http.StatusBadRequest; w.WriteHeader(status_code) + err = fmt.Errorf("%s cookie not found - %s", SERVER_PROXY_ID_COOKIE, err.Error()) goto oops } @@ -265,6 +316,26 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re } else { addr.IP = net.IPv6loopback // net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} } + + +if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { + var ws_url string + + if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { + proxy_proto = "wss" + } else { + proxy_proto = "ws" + } + proxy_url_path = req.URL.Path + if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) } + + ws_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, proxy_url_path) + + pxy.serve_websocket(w, req, ws_url, &addr) + return +} + + tcp_conn, err = net.DialTCP("tcp", nil, &addr) // need to be specific between tcp4 and tcp6? maybe not if err != nil { status_code = http.StatusBadGateway; w.WriteHeader(status_code) @@ -303,7 +374,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re Fragment: req.URL.Fragment, } - s.log.Write("", LOG_DEBUG, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url) + s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url) // TODO: http.NewRequestWithContext().?? proxy_req, err = http.NewRequest(req.Method, proxy_url.String(), req.Body) @@ -312,20 +383,9 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re goto oops } - if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") { - req_upgrade_type = req.Header.Get("Upgrade") - } mutate_proxy_req_headers(req, proxy_req, path_prefix) - if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { - proxy_req.Header.Set("Te", "trailers") - } - if req_upgrade_type != "" { - proxy_req.Header.Set("Connection", "Upgrade") - proxy_req.Header.Set("Upgrade", req_upgrade_type) - } - -//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req) +//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header) resp, err = client.Do(proxy_req) if err != nil { @@ -352,7 +412,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re }*/ if path_prefix == "" { - hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) + hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) } status_code = resp.StatusCode; w.WriteHeader(status_code) diff --git a/server.go b/server.go index dc55fe7..8c2da24 100644 --- a/server.go +++ b/server.go @@ -1393,28 +1393,22 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve var port_nid uint64 port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) - if err != nil { - return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) - } + if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) } r = s.FindServerRouteByPortId(PortId(port_nid)) - if r == nil { - return nil, fmt.Errorf("port(%d) not found", port_nid) - } + if r == nil { return nil, fmt.Errorf("port(%d) not found", port_nid) } } else { var conn_nid uint64 var route_nid uint64 conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) } + if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) } route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) } + if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) } r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { - return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) - } + if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) } } return r, nil @@ -1426,14 +1420,10 @@ func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) { var err error conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { - return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); - } + if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); } cts = s.FindServerConnById(ConnId(conn_nid)) - if cts == nil { - return nil, fmt.Errorf("non-existent connection id %d", conn_nid) - } + if cts == nil { return nil, fmt.Errorf("non-existent connection id %d", conn_nid) } return cts, nil }