implementing websocket proxy
This commit is contained in:
parent
02e3e4561a
commit
91714b23ae
@ -4,6 +4,7 @@ import "context"
|
||||
import "crypto/tls"
|
||||
import _ "embed"
|
||||
import "encoding/json"
|
||||
import "errors"
|
||||
import "fmt"
|
||||
import "io"
|
||||
import "net"
|
||||
@ -94,6 +95,16 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref
|
||||
|
||||
hdr = req.Header
|
||||
newhdr = newreq.Header
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(hdr["Connection"], "Upgrade") {
|
||||
newhdr.Set("Connection", "Upgrade")
|
||||
newhdr.Set("Upgrade", hdr.Get("Upgrade"))
|
||||
}
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(hdr["Te"], "trailers") {
|
||||
newhdr.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
remote_addr, _, err = net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
oldv, ok = hdr["X-Forwarded-For"]
|
||||
@ -149,7 +160,6 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, path_pref
|
||||
|
||||
func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
var s *Server
|
||||
var r *ServerRoute
|
||||
var status_code int
|
||||
var conn_id string
|
||||
var route_id string
|
||||
@ -168,14 +178,14 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
route_id = req.PathValue("route_id")
|
||||
path_prefix = fmt.Sprintf("%s/%s/%s", pxy.prefix, conn_id, route_id)
|
||||
|
||||
r, err = s.FindServerRouteByIdStr(conn_id, route_id)
|
||||
_, err = s.FindServerRouteByIdStr(conn_id, route_id)
|
||||
if err != nil {
|
||||
status_code = http.StatusNotFound; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
}
|
||||
|
||||
hdr = w.Header()
|
||||
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, r.cts.id, r.id)) // use numeric id
|
||||
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id)) // use numeric id
|
||||
hdr.Set("Location", strings.TrimPrefix(req.URL.Path, path_prefix)) // use the original ids as in the request
|
||||
status_code = http.StatusFound; w.WriteHeader(status_code)
|
||||
|
||||
@ -194,6 +204,48 @@ func prevent_follow_redirect (req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
func (pxy *server_proxy_http_main) serve_websocket(w http.ResponseWriter, req *http.Request, ws_url string, target *net.TCPAddr) {
|
||||
pxy.s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), ws_url)
|
||||
|
||||
websocket.Handler(func(wc *websocket.Conn) {
|
||||
var ws *websocket.Conn
|
||||
var err_chan chan error
|
||||
var err error
|
||||
|
||||
defer wc.Close()
|
||||
|
||||
// TODO: timeout or cancellation
|
||||
// TODO: use DialConfig??
|
||||
ws, err = websocket.Dial(ws_url, "", req.Header.Get("Origin"))
|
||||
if err != nil {
|
||||
// TODO: logging
|
||||
return
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
err_chan = make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
// client to server
|
||||
var err error
|
||||
_, err = io.Copy(ws, wc)
|
||||
err_chan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// server to client
|
||||
var err error
|
||||
_, err = io.Copy(wc, ws)
|
||||
err_chan <- err
|
||||
}()
|
||||
|
||||
err = <-err_chan
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
// TODO: logging
|
||||
}
|
||||
}).ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
var s *Server
|
||||
var r *ServerRoute
|
||||
@ -211,7 +263,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
var proxy_url *url.URL
|
||||
var proxy_url_path string
|
||||
var proxy_proto string
|
||||
var req_upgrade_type string
|
||||
var err error
|
||||
|
||||
defer func() {
|
||||
@ -226,7 +277,6 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
if ctx.Done() != nil {
|
||||
}
|
||||
*/
|
||||
|
||||
conn_id = req.PathValue("conn_id")
|
||||
route_id = req.PathValue("route_id")
|
||||
if conn_id == "" && route_id == "" {
|
||||
@ -236,6 +286,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
id, err = req.Cookie(SERVER_PROXY_ID_COOKIE)
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
err = fmt.Errorf("%s cookie not found - %s", SERVER_PROXY_ID_COOKIE, err.Error())
|
||||
goto oops
|
||||
}
|
||||
|
||||
@ -265,6 +316,26 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
} else {
|
||||
addr.IP = net.IPv6loopback // net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
}
|
||||
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
|
||||
var ws_url string
|
||||
|
||||
if r.svc_option & RouteOption(ROUTE_OPTION_HTTPS) != 0 {
|
||||
proxy_proto = "wss"
|
||||
} else {
|
||||
proxy_proto = "ws"
|
||||
}
|
||||
proxy_url_path = req.URL.Path
|
||||
if path_prefix != "" { proxy_url_path = strings.TrimPrefix(proxy_url_path, path_prefix) }
|
||||
|
||||
ws_url = fmt.Sprintf("%s://%s%s", proxy_proto, r.ptc_addr, proxy_url_path)
|
||||
|
||||
pxy.serve_websocket(w, req, ws_url, &addr)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
tcp_conn, err = net.DialTCP("tcp", nil, &addr) // need to be specific between tcp4 and tcp6? maybe not
|
||||
if err != nil {
|
||||
status_code = http.StatusBadGateway; w.WriteHeader(status_code)
|
||||
@ -303,7 +374,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
Fragment: req.URL.Fragment,
|
||||
}
|
||||
|
||||
s.log.Write("", LOG_DEBUG, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url)
|
||||
s.log.Write("", LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.URL.String(), proxy_url)
|
||||
|
||||
// TODO: http.NewRequestWithContext().??
|
||||
proxy_req, err = http.NewRequest(req.Method, proxy_url.String(), req.Body)
|
||||
@ -312,20 +383,9 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
goto oops
|
||||
}
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
|
||||
req_upgrade_type = req.Header.Get("Upgrade")
|
||||
}
|
||||
mutate_proxy_req_headers(req, proxy_req, path_prefix)
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
proxy_req.Header.Set("Te", "trailers")
|
||||
}
|
||||
if req_upgrade_type != "" {
|
||||
proxy_req.Header.Set("Connection", "Upgrade")
|
||||
proxy_req.Header.Set("Upgrade", req_upgrade_type)
|
||||
}
|
||||
|
||||
//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req)
|
||||
//fmt.Printf ("proxy NEW req [%+v]\n", proxy_req.Header)
|
||||
|
||||
resp, err = client.Do(proxy_req)
|
||||
if err != nil {
|
||||
@ -352,7 +412,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
}*/
|
||||
|
||||
if path_prefix == "" {
|
||||
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%d-%d; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id))
|
||||
hdr.Add("Set-Cookie", fmt.Sprintf("%s=%s-%s; Path=/; HttpOnly", SERVER_PROXY_ID_COOKIE, conn_id, route_id))
|
||||
}
|
||||
status_code = resp.StatusCode; w.WriteHeader(status_code)
|
||||
|
||||
|
24
server.go
24
server.go
@ -1393,28 +1393,22 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve
|
||||
var port_nid uint64
|
||||
|
||||
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error())
|
||||
}
|
||||
if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) }
|
||||
|
||||
r = s.FindServerRouteByPortId(PortId(port_nid))
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("port(%d) not found", port_nid)
|
||||
}
|
||||
if r == nil { return nil, fmt.Errorf("port(%d) not found", port_nid) }
|
||||
} else {
|
||||
var conn_nid uint64
|
||||
var route_nid uint64
|
||||
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil { return nil, fmt.Errorf("invalid connection id - %s", err.Error()) }
|
||||
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) }
|
||||
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
if err != nil { return nil, fmt.Errorf("invalid route id - %s", err.Error()) }
|
||||
if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) }
|
||||
|
||||
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
|
||||
}
|
||||
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
|
||||
}
|
||||
|
||||
return r, nil
|
||||
@ -1426,14 +1420,10 @@ func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) {
|
||||
var err error
|
||||
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error());
|
||||
}
|
||||
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()); }
|
||||
|
||||
cts = s.FindServerConnById(ConnId(conn_nid))
|
||||
if cts == nil {
|
||||
return nil, fmt.Errorf("non-existent connection id %d", conn_nid)
|
||||
}
|
||||
if cts == nil { return nil, fmt.Errorf("non-existent connection id %d", conn_nid) }
|
||||
|
||||
return cts, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user