From 7835696166a4db703a19557f89189fb6c08ae09e Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Sat, 21 Jun 2025 22:01:24 +0900 Subject: [PATCH] enhanced the websocket endpoints to return failure for normal https packets --- client.go | 27 +++++++++++++++++++++------ server.go | 22 ++++++++++++++++++---- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 88b7093..a061ec5 100644 --- a/client.go +++ b/client.go @@ -10,6 +10,7 @@ import "net" import "net/http" import "slices" import "strconv" +import "strings" import "sync" import "sync/atomic" import "time" @@ -1538,7 +1539,20 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler { }) } -func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.Handler { +func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + var status_code int + status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.URL.String(), status_code) + return + } + handler.ServeHTTP(w, req) + }) +} + +func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.Handler { return websocket.Handler(func(ws *websocket.Conn) { var status_code int var err error @@ -1547,7 +1561,7 @@ func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket. var req *http.Request req = ws.Request() - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String()) start_time = time.Now() status_code, err = handler.ServeWebsocket(ws) @@ -1555,9 +1569,9 @@ func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket. if status_code > 0 { if err != nil { - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error()) } else { - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds()) } } }) @@ -1630,10 +1644,11 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi promhttp.HandlerFor(c.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true })) c.ctl_mux.Handle("/_ctl/events", - c.WrapWebsocketHandler(&client_ctl_ws{client_ctl{c: &c, id: HS_ID_CTL}})) + c.SafeWrapWebsocketHandler(c.WrapWebsocketHandler(&client_ctl_ws{client_ctl{c: &c, id: HS_ID_CTL}}))) - c.ctl_mux.Handle("/_pts/ws", c.WrapWebsocketHandler(&client_pts_ws{C: &c, Id: HS_ID_CTL})) + c.ctl_mux.Handle("/_pts/ws", + c.SafeWrapWebsocketHandler(c.WrapWebsocketHandler(&client_pts_ws{C: &c, Id: HS_ID_CTL}))) c.ctl_mux.Handle("/_pts/xterm.js", c.WrapHttpHandler(&client_pts_xterm_file{client_ctl: client_ctl{c: &c, id: HS_ID_CTL}, file: "xterm.js"})) diff --git a/server.go b/server.go index 211df85..9bd8883 100644 --- a/server.go +++ b/server.go @@ -12,6 +12,7 @@ import "net/http" import "net/netip" import "slices" import "strconv" +import "strings" import "sync" import "sync/atomic" import "time" @@ -1233,6 +1234,19 @@ func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler { }) } +func (s *Server) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + var status_code int + status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.URL.String(), status_code) + return + } + handler.ServeHTTP(w, req) + }) +} + func (s *Server) WrapWebsocketHandler(handler ServerWebsocketHandler) websocket.Handler { return websocket.Handler(func(ws *websocket.Conn) { var status_code int @@ -1360,12 +1374,12 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi promhttp.HandlerFor(s.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true })) s.ctl_mux.Handle("/_ctl/events", - s.WrapWebsocketHandler(&server_ctl_ws{ServerCtl{S: &s, Id: HS_ID_CTL}})) + s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_ctl_ws{ServerCtl{S: &s, Id: HS_ID_CTL}}))) /* // this part is duplcate of pxy_mux. s.ctl_mux.Handle("/_ssh/ws/{conn_id}/{route_id}", - s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS})) + s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS}))) s.ctl_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{ServerCtl{S: &s, Id: HS_ID_CTL, NoAuth: true}})) s.ctl_mux.Handle("/_ssh/xterm.js", @@ -1408,7 +1422,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.pxy_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh/ws,_http configurable... s.pxy_mux.Handle("/_ssh/ws/{conn_id}/{route_id}", - s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS})) + s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS}))) s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{ServerCtl{S: &s, Id: HS_ID_PXY, NoAuth: true}})) s.pxy_mux.Handle("/_ssh/xterm.js", @@ -1453,7 +1467,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.wpx_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh/ws,_http configurable... s.wpx_mux.Handle("/_ssh/ws/{conn_id}/{route_id}", - s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: "wpx-ssh"})) + s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: "wpx-ssh"}))) s.wpx_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{ServerCtl{S: &s, Id: HS_ID_WPX, NoAuth: true}}))