From c8cb71cf9536f696a8269f6face999a29a6108fe Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Sun, 24 Aug 2025 19:00:26 +0900 Subject: [PATCH] update the proxy code to resolve a route by the client token and the client-side peer name --- client.go | 16 ++++----- cmd/logger.go | 14 ++++++-- server-pxy.go | 8 ++--- server.go | 94 +++++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 106 insertions(+), 26 deletions(-) diff --git a/client.go b/client.go index 8de9ba5..29bd159 100644 --- a/client.go +++ b/client.go @@ -490,13 +490,13 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { err = r.cts.psc.Send(MakeRouteStartPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)) if err != nil { r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG, - "Failed to send route_start for route(%d,%s,%v,%s,%s) to %s - %s", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p, err.Error()) + "Failed to send route_start for route(%d,%s,%v,%s,%s,%s) to %s - %s", + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName, r.cts.remote_addr_p, err.Error()) goto done } else { r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG, - "Sent route_start for route(%d,%s,%v,%s,%s) to %s", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p) + "Sent route_start for route(%d,%s,%v,%s,%s,%s) to %s", + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName, r.cts.remote_addr_p) } r.ptc_wg.Add(1) // increment counter here @@ -702,9 +702,9 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event r.cts.C.FireRouteEvent(CLIENT_EVENT_ROUTE_UPDATED, r) r.cts.C.log.Write(r.cts.Sid, LOG_INFO, - "Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s)", + "Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s,%s)", rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr, - r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet) + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName) } } @@ -721,9 +721,9 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid data in route_started event(%d)", r.Id) } else { r.cts.C.log.Write(r.cts.Sid, LOG_INFO, - "Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s)", + "Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s,%s)", rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr, - r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet) + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName) } r.ReqStop() diff --git a/cmd/logger.go b/cmd/logger.go index 2edf8cc..4fe5bf3 100644 --- a/cmd/logger.go +++ b/cmd/logger.go @@ -51,14 +51,22 @@ func _is_ansi_tty(fd uintptr) bool { } -func NewAppLogger (id string, w io.Writer, mask hodu.LogMask) *AppLogger { +func NewAppLogger(id string, w io.Writer, mask hodu.LogMask) *AppLogger { var l *AppLogger + var f *os.File + var ok bool + var use_color bool + + use_color = false + f, ok = w.(*os.File) + if ok { use_color = _is_ansi_tty(f.Fd()) } + l = &AppLogger{ id: id, out: w, mask: mask, msg_chan: make(chan app_logger_msg_t, 256), - use_color: false, + use_color: use_color, } l.closed.Store(false) l.wg.Add(1) @@ -66,7 +74,7 @@ func NewAppLogger (id string, w io.Writer, mask hodu.LogMask) *AppLogger { return l } -func NewAppLoggerToFile (id string, file_name string, max_size int64, rotate int, mask hodu.LogMask) (*AppLogger, error) { +func NewAppLoggerToFile(id string, file_name string, max_size int64, rotate int, mask hodu.LogMask) (*AppLogger, error) { var l *AppLogger var f *os.File var matched bool diff --git a/server-pxy.go b/server-pxy.go index 0185b4f..fd653e1 100644 --- a/server-pxy.go +++ b/server-pxy.go @@ -191,7 +191,7 @@ func (pxy *server_pxy) Authenticate(req *http.Request) (int, string) { // ------------------------------------ -func prevent_follow_redirect (req *http.Request, via []*http.Request) error { +func prevent_follow_redirect(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } @@ -289,7 +289,7 @@ func (pxy *server_pxy_http_main) serve_upgraded(w http.ResponseWriter, req *http return err } -func (pxy *server_pxy_http_main) addr_to_transport (ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) { +func (pxy *server_pxy_http_main) addr_to_transport(ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) { var dialer *net.Dialer var waitctx context.Context var cancel_wait context.CancelFunc @@ -318,7 +318,7 @@ func (pxy *server_pxy_http_main) addr_to_transport (ctx context.Context, addr *n }, nil } -func (pxy *server_pxy_http_main) req_to_proxy_url (req *http.Request, r *ServerRouteProxyInfo) *url.URL { +func (pxy *server_pxy_http_main) req_to_proxy_url(req *http.Request, r *ServerRouteProxyInfo) *url.URL { var proxy_proto string var proxy_url_path string @@ -568,7 +568,7 @@ func (pxy *server_pxy_ssh_ws) Identity() string { // TODO: put this task to sync group. // TODO: put the above proxy task to sync group too. -func (pxy *server_pxy_ssh_ws) connect_ssh (ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) { +func (pxy *server_pxy_ssh_ws) connect_ssh(ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) { var cc *ssh.ClientConfig var addr *net.TCPAddr var dialer *net.Dialer diff --git a/server.go b/server.go index d27cfa9..c168065 100644 --- a/server.go +++ b/server.go @@ -41,6 +41,7 @@ type ServerConnMapByAddr map[net.Addr]*ServerConn type ServerConnMapByClientToken map[string]*ServerConn type ServerConnMap map[ConnId]*ServerConn type ServerRouteMap map[RouteId]*ServerRoute +type ServerRouteMapByPtcName map[string]*ServerRoute type ServerPeerConnMap map[PeerId]*ServerPeerConn type ServerSvcPortMap map[PortId]ConnRouteId @@ -199,6 +200,7 @@ type ServerConn struct { route_mtx sync.Mutex route_map ServerRouteMap + route_map_by_ptc_name ServerRouteMapByPtcName route_wg sync.WaitGroup pts_mtx sync.Mutex @@ -581,10 +583,12 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) { var r *ServerRoute + var ok bool var err error cts.route_mtx.Lock() - if cts.route_map[route_id] != nil { + _, ok = cts.route_map[route_id] + if ok { // If this happens, something must be wrong between the server and the client // most likely, it must be a logic error. the state must not go out of sync // as the route_id and the peer_id are supposed to be the same between the client @@ -592,12 +596,27 @@ func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, pt cts.route_mtx.Unlock() return nil, fmt.Errorf("existent route id - %d", route_id) } + + if ptc_name != "" { + // ptc name can be empty. but if not empty, it must be unique + _, ok = cts.route_map_by_ptc_name[ptc_name] + if ok { + cts.route_mtx.Unlock() + return nil, fmt.Errorf("existent ptc name %s for route %d", ptc_name, route_id) + } + } + r, err = NewServerRoute(cts, route_id, proto, ptc_addr, ptc_name, svc_requested_addr, svc_permitted_net) if err != nil { cts.route_mtx.Unlock() return nil, err } + cts.route_map[route_id] = r + if ptc_name != "" { + cts.route_map_by_ptc_name[ptc_name] = r + } + cts.S.stats.routes.Add(1) cts.route_mtx.Unlock() @@ -625,6 +644,9 @@ func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error { cts.route_mtx.Unlock() return fmt.Errorf("non-existent route - %d", route.Id) } + if route.PtcName != "" { + delete(cts.route_map_by_ptc_name, route.PtcName) + } delete(cts.route_map, route.Id) cts.S.stats.routes.Add(-1) cts.route_mtx.Unlock() @@ -647,6 +669,9 @@ func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, er cts.route_mtx.Unlock() return nil, fmt.Errorf("non-existent route id - %d", route_id) } + if r.PtcName != "" { + delete(cts.route_map_by_ptc_name, r.PtcName) + } delete(cts.route_map, route_id) cts.S.stats.routes.Add(-1) cts.route_mtx.Unlock() @@ -674,6 +699,21 @@ func (cts *ServerConn) FindServerRouteById(route_id RouteId) *ServerRoute { return r } +func (cts *ServerConn) FindServerRouteByPtcName(ptc_name string) *ServerRoute { + var r *ServerRoute + var ok bool + + cts.route_mtx.Lock() + r, ok = cts.route_map_by_ptc_name[ptc_name] + if !ok { + cts.route_mtx.Unlock() + return nil + } + cts.route_mtx.Unlock() + + return r +} + func (cts *ServerConn) FindServerPeerConnById(route_id RouteId, peer_id PeerId) *ServerPeerConn { var r *ServerRoute var ok bool @@ -1009,8 +1049,8 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { } else { cts.S.log.Write(cts.Sid, LOG_INFO, - "Added route(%d,%s,%s,%v,%v) for client %s to cts(%d)", - r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, cts.Id) + "Added route(%d,%s,%s,%v,%v,%s) for client %s to cts(%d)", + r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, r.PtcName, cts.RemoteAddr, cts.Id) err = cts.pss.Send(MakeRouteStartedPacket(r.Id, r.SvcOption, r.SvcAddr.String(), r.PtcName, r.SvcReqAddr, r.SvcPermNet.String())) if err != nil { r.ReqStop() @@ -2288,6 +2328,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p cts.S = s cts.Created = time.Now() cts.route_map = make(ServerRouteMap) + cts.route_map_by_ptc_name = make(ServerRouteMapByPtcName) cts.RemoteAddr = *remote_addr cts.LocalAddr = *local_addr cts.pss = &GuardedPacketStreamServer{Hodu_PacketStreamServer: pss} @@ -2480,6 +2521,32 @@ func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute { return cts.FindServerRouteById(route_id) } +func (s *Server) FindServerRouteByClientTokenAndRouteId(token string, route_id RouteId) *ServerRoute { + var cts *ServerConn + var ok bool + + s.cts_mtx.Lock() + defer s.cts_mtx.Unlock() + + cts, ok = s.cts_map_by_token[token] + if !ok { return nil } + + return cts.FindServerRouteById(route_id) +} + +func (s *Server) FindServerRouteByClientTokenAndPtcName(token string, ptc_name string) *ServerRoute { + var cts *ServerConn + var ok bool + + s.cts_mtx.Lock() + defer s.cts_mtx.Unlock() + + cts, ok = s.cts_map_by_token[token] + if !ok { return nil } + + return cts.FindServerRouteByPtcName(ptc_name) +} + func (s *Server) FindServerPeerConnById(id ConnId, route_id RouteId, peer_id PeerId) *ServerPeerConn { var cts *ServerConn var ok bool @@ -2556,10 +2623,10 @@ func (s *Server) FindServerPeerConnByIdStr(conn_id string, route_id string, peer func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) { var r *ServerRoute - var err error if route_id == PORT_ID_MARKER { var port_nid uint64 + var err error port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8)) if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) } @@ -2569,15 +2636,20 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve } else { var conn_nid uint64 var route_nid uint64 + var err1 error + var err2 error - conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) - if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) } + conn_nid, err1 = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) + route_nid, err2 = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8)) - if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) } - - r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) - if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) } + if err1 != nil || err2 != nil { + // if not numeric, attempt to use it as a token and ptc name + r = s.FindServerRouteByClientTokenAndPtcName(conn_id, route_id) + if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_id, route_id) } + } else { + r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid)) + if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) } + } } return r, nil