implementing websocket proxy

This commit is contained in:
hyung-hwan 2024-12-20 01:09:00 +09:00
parent 02e3e4561a
commit 91714b23ae
2 changed files with 86 additions and 36 deletions

View File

@ -4,6 +4,7 @@ import "context"
import "crypto/tls"
import _ "embed"
import "encoding/json"
import "errors"
import "fmt"
import "io"
import "net"
@ -94,6 +95,16 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref
hdr = req.Header
newhdr = newreq.Header
if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") {
newhdr.Set("Connection", "Upgrade")
newhdr.Set("Upgrade", hdr.Get("Upgrade"))
}
if httpguts.HeaderValuesContainsToken(hdr["Te"], "trailers") {
newhdr.Set("Te", "trailers")
}
remote_addr, _, err = net.SplitHostPort(req.RemoteAddr)
if err == nil {
oldv, ok = hdr["X-Forwarded-For"]
@ -149,7 +160,6 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref
func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var s *Server
var r *ServerRoute
var status_code int
var conn_id string
var route_id string
@ -168,14 +178,14 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re
route_id = req.PathValue("route_id")
path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id)
r, err = s.FindServerRouteByIdStr(conn_id, route_id)
_, err = s.FindServerRouteByIdStr(conn_id, route_id)
if err != nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code)
goto oops
}
hdr = w.Header()
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, r.cts.id, r.id)) // use numeric id
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) // use numeric id
hdr.Set("Location", strings.TrimPrefix(req.URL.Path, path_prefix)) // use the original ids as in the request
status_code = http.StatusFound; w.WriteHeader(status_code)
@ -194,6 +204,48 @@ func prevent_follow_redirect (req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
func (pxy *server_proxy_http_main) serve_websocket(w http.ResponseWriter, req *http.Request, ws_url string, target *net.TCPAddr) {
pxy.s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), ws_url)
websocket.Handler(func(wc *websocket.Conn) {
var ws *websocket.Conn
var err_chan chan error
var err error
defer wc.Close()
// TODO: timeout or cancellation
// TODO: use DialConfig??
ws, err = websocket.Dial(ws_url, "", req.Header.Get("Origin"))
if err != nil {
// TODO: logging
return
}
defer ws.Close()
err_chan = make(chan error, 2)
go func() {
// client to server
var err error
_, err = io.Copy(ws, wc)
err_chan <- err
}()
go func() {
// server to client
var err error
_, err = io.Copy(wc, ws)
err_chan <- err
}()
err = <-err_chan
if err != nil && errors.Is(err, io.EOF) {
// TODO: logging
}
}).ServeHTTP(w, req)
}
func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var s *Server
var r *ServerRoute
@ -211,7 +263,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
var proxy_url *url.URL
var proxy_url_path string
var proxy_proto string
var req_upgrade_type string
var err error
defer func() {
@ -226,7 +277,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
if ctx.Done() != nil {
}
*/
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
if conn_id == "" && route_id == "" {
@ -236,6 +286,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
id, err = req.Cookie(SERVER_PROXY_ID_COOKIE)
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
err = fmt.Errorf("%s cookie not found - %s", SERVER_PROXY_ID_COOKIE, err.Error())
goto oops
}
@ -265,6 +316,26 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
} else {
addr.IP = net.IPv6loopback // net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
}
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
var ws_url string
if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 {
proxy_proto = "wss"
} else {
proxy_proto = "ws"
}
proxy_url_path = req.URL.Path
if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) }
ws_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, proxy_url_path)
pxy.serve_websocket(w, req, ws_url, &addr)
return
}
tcp_conn, err = net.DialTCP("tcp", nil, &addr) // need to be specific between tcp4 and tcp6? maybe not
if err != nil {
status_code = http.StatusBadGateway; w.WriteHeader(status_code)
@ -303,7 +374,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
Fragment: req.URL.Fragment,
}
s.log.Write("", LOG_DEBUG, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url)
s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url)
// TODO: http.NewRequestWithContext().??
proxy_req, err = http.NewRequest(req.Method, proxy_url.String(), req.Body)
@ -312,20 +383,9 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
goto oops
}
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
req_upgrade_type = req.Header.Get("Upgrade")
}
mutate_proxy_req_headers(req, proxy_req, path_prefix)
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
proxy_req.Header.Set("Te", "trailers")
}
if req_upgrade_type != "" {
proxy_req.Header.Set("Connection", "Upgrade")
proxy_req.Header.Set("Upgrade", req_upgrade_type)
}
//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req)
//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header)
resp, err = client.Do(proxy_req)
if err != nil {
@ -352,7 +412,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
}*/
if path_prefix == "" {
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id))
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id))
}
status_code = resp.StatusCode; w.WriteHeader(status_code)

View File

@ -1393,28 +1393,22 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve
var port_nid uint64
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil {
return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error())
}
if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) }
r = s.FindServerRouteByPortId(PortId(port_nid))
if r == nil {
return nil, fmt.Errorf("port(%d) not found", port_nid)
}
if r == nil { return nil, fmt.Errorf("port(%d) not found", port_nid) }
} else {
var conn_nid uint64
var route_nid uint64
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) }
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) }
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) }
if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) }
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
}
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
}
return r, nil
@ -1426,14 +1420,10 @@ func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) {
var err error
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error());
}
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); }
cts = s.FindServerConnById(ConnId(conn_nid))
if cts == nil {
return nil, fmt.Errorf("non-existent connection id %d", conn_nid)
}
if cts == nil { return nil, fmt.Errorf("non-existent connection id %d", conn_nid) }
return cts, nil
}