enhanced the websocket endpoints to return failure for normal https packets

This commit is contained in:
hyung-hwan 2025-06-21 22:01:24 +09:00
parent c5bac71eaf
commit 7835696166
2 changed files with 39 additions and 10 deletions

View File

@ -10,6 +10,7 @@ import "net"
import "net/http"
import "slices"
import "strconv"
import "strings"
import "sync"
import "sync/atomic"
import "time"
@ -1538,7 +1539,20 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler {
})
}
func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.Handler {
func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
var status_code int
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.URL.String(), status_code)
return
}
handler.ServeHTTP(w, req)
})
}
func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.Handler {
return websocket.Handler(func(ws *websocket.Conn) {
var status_code int
var err error
@ -1547,7 +1561,7 @@ func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
var req *http.Request
req = ws.Request()
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String())
start_time = time.Now()
status_code, err = handler.ServeWebsocket(ws)
@ -1555,9 +1569,9 @@ func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
if status_code > 0 {
if err != nil {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error())
} else {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds())
}
}
})
@ -1630,10 +1644,11 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
promhttp.HandlerFor(c.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true }))
c.ctl_mux.Handle("/_ctl/events",
c.WrapWebsocketHandler(&client_ctl_ws{client_ctl{c: &c, id: HS_ID_CTL}}))
c.SafeWrapWebsocketHandler(c.WrapWebsocketHandler(&client_ctl_ws{client_ctl{c: &c, id: HS_ID_CTL}})))
c.ctl_mux.Handle("/_pts/ws", c.WrapWebsocketHandler(&client_pts_ws{C: &c, Id: HS_ID_CTL}))
c.ctl_mux.Handle("/_pts/ws",
c.SafeWrapWebsocketHandler(c.WrapWebsocketHandler(&client_pts_ws{C: &c, Id: HS_ID_CTL})))
c.ctl_mux.Handle("/_pts/xterm.js",
c.WrapHttpHandler(&client_pts_xterm_file{client_ctl: client_ctl{c: &c, id: HS_ID_CTL}, file: "xterm.js"}))

View File

@ -12,6 +12,7 @@ import "net/http"
import "net/netip"
import "slices"
import "strconv"
import "strings"
import "sync"
import "sync/atomic"
import "time"
@ -1233,6 +1234,19 @@ func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler {
})
}
func (s *Server) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
var status_code int
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.URL.String(), status_code)
return
}
handler.ServeHTTP(w, req)
})
}
func (s *Server) WrapWebsocketHandler(handler ServerWebsocketHandler) websocket.Handler {
return websocket.Handler(func(ws *websocket.Conn) {
var status_code int
@ -1360,12 +1374,12 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
promhttp.HandlerFor(s.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true }))
s.ctl_mux.Handle("/_ctl/events",
s.WrapWebsocketHandler(&server_ctl_ws{ServerCtl{S: &s, Id: HS_ID_CTL}}))
s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_ctl_ws{ServerCtl{S: &s, Id: HS_ID_CTL}})))
/*
// this part is duplcate of pxy_mux.
s.ctl_mux.Handle("/_ssh/ws/{conn_id}/{route_id}",
s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS}))
s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS})))
s.ctl_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}",
s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{ServerCtl{S: &s, Id: HS_ID_CTL, NoAuth: true}}))
s.ctl_mux.Handle("/_ssh/xterm.js",
@ -1408,7 +1422,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.pxy_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh/ws,_http configurable...
s.pxy_mux.Handle("/_ssh/ws/{conn_id}/{route_id}",
s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS}))
s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: HS_ID_PXY_WS})))
s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}",
s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{ServerCtl{S: &s, Id: HS_ID_PXY, NoAuth: true}}))
s.pxy_mux.Handle("/_ssh/xterm.js",
@ -1453,7 +1467,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.wpx_mux = http.NewServeMux() // TODO: make /_init,_ssh,_ssh/ws,_http configurable...
s.wpx_mux.Handle("/_ssh/ws/{conn_id}/{route_id}",
s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: "wpx-ssh"}))
s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pxy_ssh_ws{S: &s, Id: "wpx-ssh"})))
s.wpx_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}",
s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{ServerCtl{S: &s, Id: HS_ID_WPX, NoAuth: true}}))