enhanced the websocket endpoints to return failure for normal https packets

This commit is contained in:
2025-06-21 22:01:24 +09:00
parent c5bac71eaf
commit 7835696166
2 changed files with 39 additions and 10 deletions

View File

@ -10,6 +10,7 @@ import "net"
import "net/http"
import "slices"
import "strconv"
import "strings"
import "sync"
import "sync/atomic"
import "time"
@ -1538,7 +1539,20 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler {
})
}
func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.Handler {
func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
var status_code int
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.URL.String(), status_code)
return
}
handler.ServeHTTP(w, req)
})
}
func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.Handler {
return websocket.Handler(func(ws *websocket.Conn) {
var status_code int
var err error
@ -1547,7 +1561,7 @@ func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
var req *http.Request
req = ws.Request()
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String())
start_time = time.Now()
status_code, err = handler.ServeWebsocket(ws)
@ -1555,9 +1569,9 @@ func (s *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
if status_code > 0 {
if err != nil {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error())
} else {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds())
}
}
})
@ -1630,10 +1644,11 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
promhttp.HandlerFor(c.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true }))
c.ctl_mux.Handle("/_ctl/events",
c.WrapWebsocketHandler(&client_ctl_ws{client_ctl{c: &c, id: HS_ID_CTL}}))
c.SafeWrapWebsocketHandler(c.WrapWebsocketHandler(&client_ctl_ws{client_ctl{c: &c, id: HS_ID_CTL}})))
c.ctl_mux.Handle("/_pts/ws", c.WrapWebsocketHandler(&client_pts_ws{C: &c, Id: HS_ID_CTL}))
c.ctl_mux.Handle("/_pts/ws",
c.SafeWrapWebsocketHandler(c.WrapWebsocketHandler(&client_pts_ws{C: &c, Id: HS_ID_CTL})))
c.ctl_mux.Handle("/_pts/xterm.js",
c.WrapHttpHandler(&client_pts_xterm_file{client_ctl: client_ctl{c: &c, id: HS_ID_CTL}, file: "xterm.js"}))