updated to support ssh by port number

This commit is contained in:
2024-12-16 15:19:01 +09:00
parent de96a75af8
commit bf2c70fa2c
7 changed files with 258 additions and 97 deletions

View File

@@ -81,7 +81,7 @@ func delete_hop_by_hop_headers(header http.Header) {
}
}
func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) {
func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path bool) {
var hdr http.Header
var newhdr http.Header
var remote_addr string
@@ -128,6 +128,13 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) {
if !ok {
newhdr.Set("X-Forwarded-Host", req.Host)
}
if add_path {
_, ok = newhdr["X-Forwarded-Path"]
if !ok {
newhdr.Set("X-Forwarded-Path", req.URL.RawPath)
}
}
}
// ------------------------------------
@@ -151,12 +158,12 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
@@ -240,12 +247,12 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
goto oops
}
conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(conn_nid) * 8))
conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(route_nid) * 8))
route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
@@ -309,7 +316,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
req_upgrade_type = req.Header.Get("Upgrade")
}
mutate_proxy_req_headers(req, proxy_req)
mutate_proxy_req_headers(req, proxy_req, false)
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
proxy_req.Header.Set("Te", "trailers")
@@ -374,50 +381,106 @@ type server_proxy_xterm_session_info struct {
}
func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var s *Server
var status_code int
var err error
defer func() {
var err interface{} = recover()
if err != nil { dump_call_frame_and_exit(pxy.s.log, req, err) }
}()
s = pxy.s
// TODO: logging
switch pxy.file {
case "xterm.js":
w.Header().Set("Content-Type", "text/javascript")
w.WriteHeader(http.StatusOK)
status_code = http.StatusOK; w.WriteHeader(status_code)
w.Write(xterm_js)
case "xterm-addon-fit.js":
w.Header().Set("Content-Type", "text/javascript")
w.WriteHeader(http.StatusOK)
status_code = http.StatusOK; w.WriteHeader(status_code)
w.Write(xterm_addon_fit_js)
case "xterm.css":
w.Header().Set("Content-Type", "text/css")
w.WriteHeader(http.StatusOK)
status_code = http.StatusOK; w.WriteHeader(status_code)
w.Write(xterm_css)
case "xterm.html":
var tmpl *template.Template
var err error
var port_id string
var conn_id string
var route_id string
var port_nid uint64
var conn_nid uint64
var route_nid uint64
var r *ServerRoute
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
if route_id == "_" { port_id = conn_id }
if port_id != "" {
port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
r = s.FindServerRouteByPortId(PortId(port_nid))
if r == nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code)
goto done
}
} else {
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
status_code = http.StatusNotFound; w.WriteHeader(status_code)
goto done
}
}
tmpl = template.New("")
_, err = tmpl.Parse(string(xterm_html))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
status_code = http.StatusInternalServerError; w.WriteHeader(status_code)
goto oops
} else {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
tmpl.Execute(w,
&server_proxy_xterm_session_info{
ConnId: req.PathValue("conn_id"),
RouteId: req.PathValue("route_id"),
ConnId: conn_id,
RouteId: route_id,
})
}
case "_forbidden":
w.WriteHeader(http.StatusForbidden)
status_code = http.StatusForbidden; w.WriteHeader(status_code)
default:
w.WriteHeader(http.StatusNotFound)
status_code = http.StatusNotFound; w.WriteHeader(status_code)
}
// TODO: logging..
done:
s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code)
return
oops:
s.log.Write("", LOG_ERROR, "[%s] %s %s %d - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, err.Error())
return
}
// ------------------------------------
type server_proxy_ssh_ws struct {
@@ -511,9 +574,11 @@ oops:
func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
var s *Server
var req *http.Request
var port_id string
var conn_id string
var conn_nid uint64
var route_id string
var port_nid uint64
var conn_nid uint64
var route_nid uint64
var r *ServerRoute
var username string
@@ -539,24 +604,42 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
if route_id == "_" { port_id = conn_id }
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
if err != nil {
// TODO:
goto done
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
if err != nil {
// TODO:
goto done
}
if port_id != "" {
port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil {
err = fmt.Errorf("invalid port id - %s", port_id)
pxy.send_ws_data(ws, "error", err.Error())
goto done
}
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
// TODO: enhance logging. original request, conn_nid, route_nid
pxy.send_ws_data(ws, "error", fmt.Sprintf("route(%d,%d) not found", conn_nid, route_nid))
s.log.Write("", LOG_ERROR, "No server route(%d,%d) found", conn_nid, route_nid)
goto done
r = s.FindServerRouteByPortId(PortId(port_nid))
if r == nil {
err = fmt.Errorf("port(%d) not found", conn_nid, route_nid)
pxy.send_ws_data(ws, "error", err.Error())
goto done
}
} else {
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
err = fmt.Errorf("invalid connection id - %s", conn_id)
pxy.send_ws_data(ws, "error", err.Error())
goto done
}
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
err = fmt.Errorf("invalid route id - %s", route_id)
pxy.send_ws_data(ws, "error", err.Error())
goto done
}
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil {
err = fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
pxy.send_ws_data(ws, "error", err.Error())
goto done
}
}
wg.Add(1)
@@ -598,11 +681,8 @@ ws_recv_loop:
for {
var msg []byte
err = websocket.Message.Receive(ws, &msg)
if err != nil {
// TODO: check if EOF
s.log.Write("", LOG_ERROR, "Failed to read from websocket - %s", err.Error())
goto done
}
if err != nil { goto done }
if len(msg) > 0 {
var ev json_ssh_ws_event
err = json.Unmarshal(msg, &ev)
@@ -669,9 +749,7 @@ ws_recv_loop:
if sess != nil {
err = pxy.send_ws_data(ws, "status", "closed")
if err != nil {
s.log.Write("", LOG_ERROR, "Failed to write closed event to websocket - %s", err.Error())
}
if err != nil { goto done }
}
done:
@@ -680,5 +758,9 @@ done:
if sess != nil { sess.Close() }
if c != nil { c.Close() }
wg.Wait()
s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String())
if err != nil {
s.log.Write("", LOG_ERROR, "[%s] %s %s - %s", req.RemoteAddr, req.Method, req.URL.String(), err.Error())
} else {
s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String())
}
}