diff --git a/Makefile b/Makefile
index 2cdee92..680eed9 100644
--- a/Makefile
+++ b/Makefile
@@ -1,3 +1,10 @@
+# make
+# make GOARCH=386
+# make GOARCH=amd64
+# make GOOS=linux GOARCH=mips
+#
+# 'go tool dist list' for available os and architextures
+
NAME=hodu
VERSION=1.0.0
diff --git a/client-ctl.go b/client-ctl.go
index d7bc32e..7186a13 100644
--- a/client-ctl.go
+++ b/client-ctl.go
@@ -248,7 +248,7 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
conn_id = req.PathValue("conn_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
@@ -333,7 +333,7 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
conn_id = req.PathValue("conn_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id }); err != nil { goto oops }
@@ -456,13 +456,13 @@ func (ctl *client_ctl_client_conns_id_routes_id) ServeHTTP(w http.ResponseWriter
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
}
- route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
@@ -538,13 +538,13 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers) ServeHTTP(w http.Response
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, 32)
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
}
- route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
@@ -622,19 +622,19 @@ func (ctl *client_ctl_client_conns_id_routes_id_peers_id) ServeHTTP(w http.Respo
route_id = req.PathValue("route_id")
peer_id = req.PathValue("peer_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, 32)
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
goto done
}
- route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
goto done
}
- peer_nid, err = strconv.ParseUint(peer_id, 10, 32)
+ peer_nid, err = strconv.ParseUint(peer_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong peer id - " + peer_id}); err != nil { goto oops }
diff --git a/packet.go b/packet.go
index 9111a94..f962725 100644
--- a/packet.go
+++ b/packet.go
@@ -1,10 +1,15 @@
package hodu
-type ConnId uint64
-type RouteId uint32 // keep this in sync with the type of RouteId in hodu.proto
-type PeerId uint32 // keep this in sync with the type of RouteId in hodu.proto
+type ConnId uint64
+type RouteId uint32 // keep this in sync with the type of RouteId in hodu.proto
+type PeerId uint32 // keep this in sync with the type of RouteId in hodu.proto
type RouteOption uint32
+type ConnRouteId struct {
+ conn_id ConnId
+ route_id RouteId
+}
+
func MakeRouteStartPacket(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_addr string, svc_net string) *Packet {
return &Packet{
Kind: PACKET_KIND_ROUTE_START,
diff --git a/server-ctl.go b/server-ctl.go
index 0b8ab92..d78b4cc 100644
--- a/server-ctl.go
+++ b/server-ctl.go
@@ -155,7 +155,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt
conn_id = req.PathValue("conn_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, 32)
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
@@ -237,7 +237,7 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r
conn_id = req.PathValue("conn_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, 32)
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
@@ -296,14 +296,15 @@ oops:
func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var s *Server
var status_code int
- var err error
+ var port_id string
var conn_id string
- var conn_nid uint64
var route_id string
+ var port_nid uint64
+ var conn_nid uint64
var route_nid uint64
var je *json.Encoder
- var cts *ServerConn
var r *ServerRoute
+ var err error
defer func() {
var err interface{} = recover()
@@ -315,32 +316,54 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
-
- conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
- if err != nil {
- status_code = http.StatusBadRequest; w.WriteHeader(status_code)
- if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
- goto done
- }
- route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
- if err != nil {
- status_code = http.StatusBadRequest; w.WriteHeader(status_code)
- if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
- goto done
+ if route_id == "_" {
+ port_id = conn_id
+ route_id = ""
}
- cts = s.FindServerConnById(ConnId(conn_nid))
- if cts == nil {
- status_code = http.StatusNotFound; w.WriteHeader(status_code)
- if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops }
- goto done
- }
+ if port_id != "" {
+ // this condition is for ssh proxy server.
+ port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
+ if err != nil {
+ status_code = http.StatusBadRequest; w.WriteHeader(status_code)
+ if err = je.Encode(json_errmsg{Text: "wrong port id - " + port_id}); err != nil { goto oops }
+ goto done
+ }
- r = cts.FindServerRouteById(RouteId(route_nid))
- if r == nil {
- status_code = http.StatusNotFound; w.WriteHeader(status_code)
- if err = je.Encode(json_errmsg{Text: "non-existent route id - " + conn_id}); err != nil { goto oops }
- goto done
+ r = s.FindServerRouteByPortId(PortId(port_nid))
+ if r == nil {
+ status_code = http.StatusNotFound; w.WriteHeader(status_code)
+ if err = je.Encode(json_errmsg{Text: "non-existent port id - " + port_id}); err != nil { goto oops }
+ }
+ } else {
+ var cts *ServerConn
+
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
+ if err != nil {
+ status_code = http.StatusBadRequest; w.WriteHeader(status_code)
+ if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops }
+ goto done
+ }
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
+ if err != nil {
+ status_code = http.StatusBadRequest; w.WriteHeader(status_code)
+ if err = je.Encode(json_errmsg{Text: "wrong route id - " + route_id}); err != nil { goto oops }
+ goto done
+ }
+
+ cts = s.FindServerConnById(ConnId(conn_nid))
+ if cts == nil {
+ status_code = http.StatusNotFound; w.WriteHeader(status_code)
+ if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops }
+ goto done
+ }
+
+ r = cts.FindServerRouteById(RouteId(route_nid))
+ if r == nil {
+ status_code = http.StatusNotFound; w.WriteHeader(status_code)
+ if err = je.Encode(json_errmsg{Text: "non-existent route id - " + conn_id}); err != nil { goto oops }
+ goto done
+ }
}
switch req.Method {
diff --git a/server-proxy.go b/server-proxy.go
index e1c6d86..dc87bdf 100644
--- a/server-proxy.go
+++ b/server-proxy.go
@@ -81,7 +81,7 @@ func delete_hop_by_hop_headers(header http.Header) {
}
}
-func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) {
+func mutate_proxy_req_headers(req *http.Request, newreq *http.Request, add_path bool) {
var hdr http.Header
var newhdr http.Header
var remote_addr string
@@ -128,6 +128,13 @@ func mutate_proxy_req_headers(req *http.Request, newreq *http.Request) {
if !ok {
newhdr.Set("X-Forwarded-Host", req.Host)
}
+
+ if add_path {
+ _, ok = newhdr["X-Forwarded-Path"]
+ if !ok {
+ newhdr.Set("X-Forwarded-Path", req.URL.RawPath)
+ }
+ }
}
// ------------------------------------
@@ -151,12 +158,12 @@ func (pxy *server_proxy_http_init) ServeHTTP(w http.ResponseWriter, req *http.Re
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
- conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
- route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
@@ -240,12 +247,12 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
goto oops
}
- conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(conn_nid) * 8))
+ conn_nid, err = strconv.ParseUint(ids[0], 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
}
- route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(route_nid) * 8))
+ route_nid, err = strconv.ParseUint(ids[1], 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil {
status_code = http.StatusBadRequest; w.WriteHeader(status_code)
goto oops
@@ -309,7 +316,7 @@ func (pxy *server_proxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Re
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
req_upgrade_type = req.Header.Get("Upgrade")
}
- mutate_proxy_req_headers(req, proxy_req)
+ mutate_proxy_req_headers(req, proxy_req, false)
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
proxy_req.Header.Set("Te", "trailers")
@@ -374,50 +381,106 @@ type server_proxy_xterm_session_info struct {
}
func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ var s *Server
+ var status_code int
+ var err error
+
defer func() {
var err interface{} = recover()
if err != nil { dump_call_frame_and_exit(pxy.s.log, req, err) }
}()
-
+
+ s = pxy.s
+
// TODO: logging
switch pxy.file {
case "xterm.js":
w.Header().Set("Content-Type", "text/javascript")
- w.WriteHeader(http.StatusOK)
+ status_code = http.StatusOK; w.WriteHeader(status_code)
w.Write(xterm_js)
case "xterm-addon-fit.js":
w.Header().Set("Content-Type", "text/javascript")
- w.WriteHeader(http.StatusOK)
+ status_code = http.StatusOK; w.WriteHeader(status_code)
w.Write(xterm_addon_fit_js)
case "xterm.css":
w.Header().Set("Content-Type", "text/css")
- w.WriteHeader(http.StatusOK)
+ status_code = http.StatusOK; w.WriteHeader(status_code)
w.Write(xterm_css)
case "xterm.html":
var tmpl *template.Template
- var err error
+ var port_id string
+ var conn_id string
+ var route_id string
+ var port_nid uint64
+ var conn_nid uint64
+ var route_nid uint64
+ var r *ServerRoute
+ conn_id = req.PathValue("conn_id")
+ route_id = req.PathValue("route_id")
+ if route_id == "_" { port_id = conn_id }
+
+ if port_id != "" {
+ port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
+ if err != nil {
+ status_code = http.StatusBadRequest; w.WriteHeader(status_code)
+ goto oops
+ }
+
+ r = s.FindServerRouteByPortId(PortId(port_nid))
+ if r == nil {
+ status_code = http.StatusNotFound; w.WriteHeader(status_code)
+ goto done
+ }
+ } else {
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
+ if err != nil {
+ status_code = http.StatusBadRequest; w.WriteHeader(status_code)
+ goto oops
+ }
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
+ if err != nil {
+ status_code = http.StatusBadRequest; w.WriteHeader(status_code)
+ goto oops
+ }
+
+ r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
+ if r == nil {
+ status_code = http.StatusNotFound; w.WriteHeader(status_code)
+ goto done
+ }
+ }
+
tmpl = template.New("")
_, err = tmpl.Parse(string(xterm_html))
if err != nil {
- w.WriteHeader(http.StatusInternalServerError)
+ status_code = http.StatusInternalServerError; w.WriteHeader(status_code)
+ goto oops
} else {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
tmpl.Execute(w,
&server_proxy_xterm_session_info{
- ConnId: req.PathValue("conn_id"),
- RouteId: req.PathValue("route_id"),
+ ConnId: conn_id,
+ RouteId: route_id,
})
}
case "_forbidden":
- w.WriteHeader(http.StatusForbidden)
+ status_code = http.StatusForbidden; w.WriteHeader(status_code)
default:
- w.WriteHeader(http.StatusNotFound)
+ status_code = http.StatusNotFound; w.WriteHeader(status_code)
}
-// TODO: logging..
+
+done:
+ s.log.Write("", LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code)
+ return
+
+oops:
+ s.log.Write("", LOG_ERROR, "[%s] %s %s %d - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, err.Error())
+ return
}
+
// ------------------------------------
type server_proxy_ssh_ws struct {
@@ -511,9 +574,11 @@ oops:
func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
var s *Server
var req *http.Request
+ var port_id string
var conn_id string
- var conn_nid uint64
var route_id string
+ var port_nid uint64
+ var conn_nid uint64
var route_nid uint64
var r *ServerRoute
var username string
@@ -539,24 +604,42 @@ func (pxy *server_proxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) {
conn_id = req.PathValue("conn_id")
route_id = req.PathValue("route_id")
+ if route_id == "_" { port_id = conn_id }
- conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(conn_nid) * 8))
- if err != nil {
- // TODO:
- goto done
- }
- route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(route_nid) * 8))
- if err != nil {
- // TODO:
- goto done
- }
+ if port_id != "" {
+ port_nid, err = strconv.ParseUint(port_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
+ if err != nil {
+ err = fmt.Errorf("invalid port id - %s", port_id)
+ pxy.send_ws_data(ws, "error", err.Error())
+ goto done
+ }
- r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
- if r == nil {
- // TODO: enhance logging. original request, conn_nid, route_nid
- pxy.send_ws_data(ws, "error", fmt.Sprintf("route(%d,%d) not found", conn_nid, route_nid))
- s.log.Write("", LOG_ERROR, "No server route(%d,%d) found", conn_nid, route_nid)
- goto done
+ r = s.FindServerRouteByPortId(PortId(port_nid))
+ if r == nil {
+ err = fmt.Errorf("port(%d) not found", conn_nid, route_nid)
+ pxy.send_ws_data(ws, "error", err.Error())
+ goto done
+ }
+ } else {
+ conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
+ if err != nil {
+ err = fmt.Errorf("invalid connection id - %s", conn_id)
+ pxy.send_ws_data(ws, "error", err.Error())
+ goto done
+ }
+ route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
+ if err != nil {
+ err = fmt.Errorf("invalid route id - %s", route_id)
+ pxy.send_ws_data(ws, "error", err.Error())
+ goto done
+ }
+
+ r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
+ if r == nil {
+ err = fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid)
+ pxy.send_ws_data(ws, "error", err.Error())
+ goto done
+ }
}
wg.Add(1)
@@ -598,11 +681,8 @@ ws_recv_loop:
for {
var msg []byte
err = websocket.Message.Receive(ws, &msg)
- if err != nil {
- // TODO: check if EOF
- s.log.Write("", LOG_ERROR, "Failed to read from websocket - %s", err.Error())
- goto done
- }
+ if err != nil { goto done }
+
if len(msg) > 0 {
var ev json_ssh_ws_event
err = json.Unmarshal(msg, &ev)
@@ -669,9 +749,7 @@ ws_recv_loop:
if sess != nil {
err = pxy.send_ws_data(ws, "status", "closed")
- if err != nil {
- s.log.Write("", LOG_ERROR, "Failed to write closed event to websocket - %s", err.Error())
- }
+ if err != nil { goto done }
}
done:
@@ -680,5 +758,9 @@ done:
if sess != nil { sess.Close() }
if c != nil { c.Close() }
wg.Wait()
- s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String())
+ if err != nil {
+ s.log.Write("", LOG_ERROR, "[%s] %s %s - %s", req.RemoteAddr, req.Method, req.URL.String(), err.Error())
+ } else {
+ s.log.Write("", LOG_DEBUG, "[%s] %s %s - ended", req.RemoteAddr, req.Method, req.URL.String())
+ }
}
diff --git a/server.go b/server.go
index 4ae435d..71b9c7c 100644
--- a/server.go
+++ b/server.go
@@ -22,10 +22,13 @@ import "google.golang.org/grpc/stats"
const PTS_LIMIT int = 16384
const CTS_LIMIT int = 16384
+type PortId uint16
+
type ServerConnMapByAddr = map[net.Addr]*ServerConn
type ServerConnMap = map[ConnId]*ServerConn
type ServerRouteMap = map[RouteId]*ServerRoute
type ServerPeerConnMap = map[PeerId]*ServerPeerConn
+type ServerSvcPortMap = map[PortId]ConnRouteId
type Server struct {
ctx context.Context
@@ -65,6 +68,9 @@ type Server struct {
log Logger
+ svc_port_mtx sync.Mutex
+ svc_port_map ServerSvcPortMap
+
stats struct {
conns atomic.Int64
routes atomic.Int64
@@ -316,6 +322,8 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r
var l *net.TCPListener
var svcaddr *net.TCPAddr
var nw string
+ var prev_cri ConnRouteId
+ var ok bool
var err error
if svc_requested_addr != "" {
@@ -354,7 +362,20 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r
}
svcaddr = l.Addr().(*net.TCPAddr)
- cts.svr.log.Write(cts.sid, LOG_DEBUG, "Route(%d) listening on %s", id, svcaddr.String())
+
+ cts.svr.svc_port_mtx.Lock()
+ prev_cri, ok = cts.svr.svc_port_map[PortId(svcaddr.Port)]
+ if ok {
+ cts.svr.log.Write(cts.sid, LOG_ERROR,
+ "Route(%d,%d) on %s not unique by port number - existing route(%d,%d)",
+ cts.id, id, prev_cri.conn_id, prev_cri.route_id, svcaddr.String())
+ l.Close()
+ return nil, nil, err
+ }
+ cts.svr.svc_port_map[PortId(svcaddr.Port)] = ConnRouteId{conn_id: cts.id, route_id: id}
+ cts.svr.svc_port_mtx.Unlock()
+
+ cts.svr.log.Write(cts.sid, LOG_DEBUG, "Route(%d,%d) listening on %s", cts.id, id, svcaddr.String())
return l, svcaddr, nil
}
@@ -403,6 +424,10 @@ func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error {
cts.svr.stats.routes.Add(-1)
cts.route_mtx.Unlock()
+ cts.svr.svc_port_mtx.Lock()
+ delete(cts.svr.svc_port_map, PortId(route.svc_addr.Port))
+ cts.svr.svc_port_mtx.Unlock()
+
r.ReqStop()
return nil
}
@@ -421,6 +446,10 @@ func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, er
cts.svr.stats.routes.Add(-1)
cts.route_mtx.Unlock()
+ cts.svr.svc_port_mtx.Lock()
+ delete(cts.svr.svc_port_map, PortId(r.svc_addr.Port))
+ cts.svr.svc_port_mtx.Unlock()
+
r.ReqStop()
return r, nil
}
@@ -913,6 +942,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs
s.cts_next_id = 0
s.cts_map = make(ServerConnMap)
s.cts_map_by_addr = make(ServerConnMapByAddr)
+ s.svc_port_map = make(ServerSvcPortMap)
s.stop_chan = make(chan bool, 8)
s.stop_req.Store(false)
@@ -965,9 +995,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs
s.pxy_ws = &server_proxy_ssh_ws{s: &s}
s.pxy_mux = http.NewServeMux() // TODO: make /_init configurable...
s.pxy_mux.Handle("/_ssh-ws/{conn_id}/{route_id}",
- websocket.Handler(func(ws *websocket.Conn) {
- s.pxy_ws.ServeWebsocket(ws)
- }))
+ websocket.Handler(func(ws *websocket.Conn) { s.pxy_ws.ServeWebsocket(ws) }))
s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", &server_ctl_server_conns_id_routes_id{s: &s})
s.pxy_mux.Handle("/_ssh/{conn_id}/{route_id}/", &server_proxy_xterm_file{s: &s, file: "xterm.html"})
s.pxy_mux.Handle("/_ssh/xterm.js", &server_proxy_xterm_file{s: &s, file: "xterm.js"})
@@ -1337,13 +1365,23 @@ func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute {
defer s.cts_mtx.Unlock()
cts, ok = s.cts_map[id]
- if !ok {
- return nil
- }
+ if !ok { return nil }
return cts.FindServerRouteById(route_id)
}
+func (s *Server) FindServerRouteByPortId(port_id PortId) *ServerRoute {
+ var cri ConnRouteId
+ var ok bool
+
+ s.svc_port_mtx.Lock()
+ defer s.svc_port_mtx.Unlock()
+
+ cri, ok = s.svc_port_map[port_id]
+ if !ok { return nil }
+ return s.FindServerRouteById(cri.conn_id, cri.route_id)
+}
+
func (s *Server) StartService(cfg interface{}) {
s.wg.Add(1)
diff --git a/xterm.html b/xterm.html
index 458a2e8..a8aa23e 100644
--- a/xterm.html
+++ b/xterm.html
@@ -90,6 +90,9 @@