updated to support ssh by port number
This commit is contained in:
168
server-proxy.go
168
server-proxy.go
@@ -81,7 +81,7 @@ func delete_hop_by_hop_headers(header http.Header) {
|
||||
}
|
||||
}
|
||||
|
||||
func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) {
|
||||
func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path bool) {
|
||||
var hdr http.Header
|
||||
var newhdr http.Header
|
||||
var remote_addr string
|
||||
@@ -128,6 +128,13 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) {
|
||||
if !ok {
|
||||
newhdr.Set("X-Forwarded-Host", req.Host)
|
||||
}
|
||||
|
||||
if add_path {
|
||||
_, ok = newhdr["X-Forwarded-Path"]
|
||||
if !ok {
|
||||
newhdr.Set("X-Forwarded-Path", req.URL.RawPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
// ------------------------------------
|
||||
|
||||
@@ -151,12 +158,12 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
conn_id = req.PathValue("conn_id")
|
||||
route_id = req.PathValue("route_id")
|
||||
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
}
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
@@ -240,12 +247,12 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
goto oops
|
||||
}
|
||||
|
||||
conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(conn_nid) * 8))
|
||||
conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
}
|
||||
route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(route_nid) * 8))
|
||||
route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
@@ -309,7 +316,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
|
||||
req_upgrade_type = req.Header.Get("Upgrade")
|
||||
}
|
||||
mutate_proxy_req_headers(req, proxy_req)
|
||||
mutate_proxy_req_headers(req, proxy_req, false)
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
proxy_req.Header.Set("Te", "trailers")
|
||||
@@ -374,50 +381,106 @@ type server_proxy_xterm_session_info struct {
|
||||
}
|
||||
|
||||
func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
var s *Server
|
||||
var status_code int
|
||||
var err error
|
||||
|
||||
defer func() {
|
||||
var err interface{} = recover()
|
||||
if err != nil { dump_call_frame_and_exit(pxy.s.log, req, err) }
|
||||
}()
|
||||
|
||||
|
||||
s = pxy.s
|
||||
|
||||
// TODO: logging
|
||||
switch pxy.file {
|
||||
case "xterm.js":
|
||||
w.Header().Set("Content-Type", "text/javascript")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
status_code = http.StatusOK; w.WriteHeader(status_code)
|
||||
w.Write(xterm_js)
|
||||
case "xterm-addon-fit.js":
|
||||
w.Header().Set("Content-Type", "text/javascript")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
status_code = http.StatusOK; w.WriteHeader(status_code)
|
||||
w.Write(xterm_addon_fit_js)
|
||||
case "xterm.css":
|
||||
w.Header().Set("Content-Type", "text/css")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
status_code = http.StatusOK; w.WriteHeader(status_code)
|
||||
w.Write(xterm_css)
|
||||
case "xterm.html":
|
||||
var tmpl *template.Template
|
||||
var err error
|
||||
var port_id string
|
||||
var conn_id string
|
||||
var route_id string
|
||||
var port_nid uint64
|
||||
var conn_nid uint64
|
||||
var route_nid uint64
|
||||
var r *ServerRoute
|
||||
|
||||
conn_id = req.PathValue("conn_id")
|
||||
route_id = req.PathValue("route_id")
|
||||
if route_id == "_" { port_id = conn_id }
|
||||
|
||||
if port_id != "" {
|
||||
port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
}
|
||||
|
||||
r = s.FindServerRouteByPortId(PortId(port_nid))
|
||||
if r == nil {
|
||||
status_code = http.StatusNotFound; w.WriteHeader(status_code)
|
||||
goto done
|
||||
}
|
||||
} else {
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
}
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
if err != nil {
|
||||
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
}
|
||||
|
||||
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
||||
if r == nil {
|
||||
status_code = http.StatusNotFound; w.WriteHeader(status_code)
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
tmpl = template.New("")
|
||||
_, err = tmpl.Parse(string(xterm_html))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
status_code = http.StatusInternalServerError; w.WriteHeader(status_code)
|
||||
goto oops
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
tmpl.Execute(w,
|
||||
&server_proxy_xterm_session_info{
|
||||
ConnId: req.PathValue("conn_id"),
|
||||
RouteId: req.PathValue("route_id"),
|
||||
ConnId: conn_id,
|
||||
RouteId: route_id,
|
||||
})
|
||||
}
|
||||
case "_forbidden":
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
status_code = http.StatusForbidden; w.WriteHeader(status_code)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
status_code = http.StatusNotFound; w.WriteHeader(status_code)
|
||||
}
|
||||
|
||||
// TODO: logging..
|
||||
|
||||
done:
|
||||
s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code)
|
||||
return
|
||||
|
||||
oops:
|
||||
s.log.Write("", LOG_ERROR, "[%s] %s %s %d - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
type server_proxy_ssh_ws struct {
|
||||
@@ -511,9 +574,11 @@ oops:
|
||||
func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
|
||||
var s *Server
|
||||
var req *http.Request
|
||||
var port_id string
|
||||
var conn_id string
|
||||
var conn_nid uint64
|
||||
var route_id string
|
||||
var port_nid uint64
|
||||
var conn_nid uint64
|
||||
var route_nid uint64
|
||||
var r *ServerRoute
|
||||
var username string
|
||||
@@ -539,24 +604,42 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
|
||||
|
||||
conn_id = req.PathValue("conn_id")
|
||||
route_id = req.PathValue("route_id")
|
||||
if route_id == "_" { port_id = conn_id }
|
||||
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
|
||||
if err != nil {
|
||||
// TODO:
|
||||
goto done
|
||||
}
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
|
||||
if err != nil {
|
||||
// TODO:
|
||||
goto done
|
||||
}
|
||||
if port_id != "" {
|
||||
port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid port id - %s", port_id)
|
||||
pxy.send_ws_data(ws, "error", err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
||||
if r == nil {
|
||||
// TODO: enhance logging. original request, conn_nid, route_nid
|
||||
pxy.send_ws_data(ws, "error", fmt.Sprintf("route(%d,%d) not found", conn_nid, route_nid))
|
||||
s.log.Write("", LOG_ERROR, "No server route(%d,%d) found", conn_nid, route_nid)
|
||||
goto done
|
||||
r = s.FindServerRouteByPortId(PortId(port_nid))
|
||||
if r == nil {
|
||||
err = fmt.Errorf("port(%d) not found", conn_nid, route_nid)
|
||||
pxy.send_ws_data(ws, "error", err.Error())
|
||||
goto done
|
||||
}
|
||||
} else {
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid connection id - %s", conn_id)
|
||||
pxy.send_ws_data(ws, "error", err.Error())
|
||||
goto done
|
||||
}
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid route id - %s", route_id)
|
||||
pxy.send_ws_data(ws, "error", err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
||||
if r == nil {
|
||||
err = fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
|
||||
pxy.send_ws_data(ws, "error", err.Error())
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -598,11 +681,8 @@ ws_recv_loop:
|
||||
for {
|
||||
var msg []byte
|
||||
err = websocket.Message.Receive(ws, &msg)
|
||||
if err != nil {
|
||||
// TODO: check if EOF
|
||||
s.log.Write("", LOG_ERROR, "Failed to read from websocket - %s", err.Error())
|
||||
goto done
|
||||
}
|
||||
if err != nil { goto done }
|
||||
|
||||
if len(msg) > 0 {
|
||||
var ev json_ssh_ws_event
|
||||
err = json.Unmarshal(msg, &ev)
|
||||
@@ -669,9 +749,7 @@ ws_recv_loop:
|
||||
|
||||
if sess != nil {
|
||||
err = pxy.send_ws_data(ws, "status", "closed")
|
||||
if err != nil {
|
||||
s.log.Write("", LOG_ERROR, "Failed to write closed event to websocket - %s", err.Error())
|
||||
}
|
||||
if err != nil { goto done }
|
||||
}
|
||||
|
||||
done:
|
||||
@@ -680,5 +758,9 @@ done:
|
||||
if sess != nil { sess.Close() }
|
||||
if c != nil { c.Close() }
|
||||
wg.Wait()
|
||||
s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String())
|
||||
if err != nil {
|
||||
s.log.Write("", LOG_ERROR, "[%s] %s %s - %s", req.RemoteAddr, req.Method, req.URL.String(), err.Error())
|
||||
} else {
|
||||
s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user