update the proxy code to resolve a route by the client token and the client-side peer name

This commit is contained in:
2025-08-24 19:00:26 +09:00
parent e2a3180ec7
commit c8cb71cf95
4 changed files with 106 additions and 26 deletions

View File

@ -490,13 +490,13 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) {
err = r.cts.psc.Send(MakeRouteStartPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet))
if err != nil {
r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG,
"Failed to send route_start for route(%d,%s,%v,%s,%s) to %s - %s",
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p, err.Error())
"Failed to send route_start for route(%d,%s,%v,%s,%s,%s) to %s - %s",
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName, r.cts.remote_addr_p, err.Error())
goto done
} else {
r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG,
"Sent route_start for route(%d,%s,%v,%s,%s) to %s",
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p)
"Sent route_start for route(%d,%s,%v,%s,%s,%s) to %s",
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName, r.cts.remote_addr_p)
}
r.ptc_wg.Add(1) // increment counter here
@ -702,9 +702,9 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event
r.cts.C.FireRouteEvent(CLIENT_EVENT_ROUTE_UPDATED, r)
r.cts.C.log.Write(r.cts.Sid, LOG_INFO,
"Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s)",
"Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s,%s)",
rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr,
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName)
}
}
@ -721,9 +721,9 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event
r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid data in route_started event(%d)", r.Id)
} else {
r.cts.C.log.Write(r.cts.Sid, LOG_INFO,
"Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s)",
"Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s,%s)",
rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr,
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName)
}
r.ReqStop()

View File

@ -51,14 +51,22 @@ func _is_ansi_tty(fd uintptr) bool {
}
func NewAppLogger (id string, w io.Writer, mask hodu.LogMask) *AppLogger {
func NewAppLogger(id string, w io.Writer, mask hodu.LogMask) *AppLogger {
var l *AppLogger
var f *os.File
var ok bool
var use_color bool
use_color = false
f, ok = w.(*os.File)
if ok { use_color = _is_ansi_tty(f.Fd()) }
l = &AppLogger{
id: id,
out: w,
mask: mask,
msg_chan: make(chan app_logger_msg_t, 256),
use_color: false,
use_color: use_color,
}
l.closed.Store(false)
l.wg.Add(1)
@ -66,7 +74,7 @@ func NewAppLogger (id string, w io.Writer, mask hodu.LogMask) *AppLogger {
return l
}
func NewAppLoggerToFile (id string, file_name string, max_size int64, rotate int, mask hodu.LogMask) (*AppLogger, error) {
func NewAppLoggerToFile(id string, file_name string, max_size int64, rotate int, mask hodu.LogMask) (*AppLogger, error) {
var l *AppLogger
var f *os.File
var matched bool

View File

@ -191,7 +191,7 @@ func (pxy *server_pxy) Authenticate(req *http.Request) (int, string) {
// ------------------------------------
func prevent_follow_redirect (req *http.Request, via []*http.Request) error {
func prevent_follow_redirect(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
@ -289,7 +289,7 @@ func (pxy *server_pxy_http_main) serve_upgraded(w http.ResponseWriter, req *http
return err
}
func (pxy *server_pxy_http_main) addr_to_transport (ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) {
func (pxy *server_pxy_http_main) addr_to_transport(ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) {
var dialer *net.Dialer
var waitctx context.Context
var cancel_wait context.CancelFunc
@ -318,7 +318,7 @@ func (pxy *server_pxy_http_main) addr_to_transport (ctx context.Context, addr *n
}, nil
}
func (pxy *server_pxy_http_main) req_to_proxy_url (req *http.Request, r *ServerRouteProxyInfo) *url.URL {
func (pxy *server_pxy_http_main) req_to_proxy_url(req *http.Request, r *ServerRouteProxyInfo) *url.URL {
var proxy_proto string
var proxy_url_path string
@ -568,7 +568,7 @@ func (pxy *server_pxy_ssh_ws) Identity() string {
// TODO: put this task to sync group.
// TODO: put the above proxy task to sync group too.
func (pxy *server_pxy_ssh_ws) connect_ssh (ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) {
func (pxy *server_pxy_ssh_ws) connect_ssh(ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) {
var cc *ssh.ClientConfig
var addr *net.TCPAddr
var dialer *net.Dialer

View File

@ -41,6 +41,7 @@ type ServerConnMapByAddr map[net.Addr]*ServerConn
type ServerConnMapByClientToken map[string]*ServerConn
type ServerConnMap map[ConnId]*ServerConn
type ServerRouteMap map[RouteId]*ServerRoute
type ServerRouteMapByPtcName map[string]*ServerRoute
type ServerPeerConnMap map[PeerId]*ServerPeerConn
type ServerSvcPortMap map[PortId]ConnRouteId
@ -199,6 +200,7 @@ type ServerConn struct {
route_mtx sync.Mutex
route_map ServerRouteMap
route_map_by_ptc_name ServerRouteMapByPtcName
route_wg sync.WaitGroup
pts_mtx sync.Mutex
@ -581,10 +583,12 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r
func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
var r *ServerRoute
var ok bool
var err error
cts.route_mtx.Lock()
if cts.route_map[route_id] != nil {
_, ok = cts.route_map[route_id]
if ok {
// If this happens, something must be wrong between the server and the client
// most likely, it must be a logic error. the state must not go out of sync
// as the route_id and the peer_id are supposed to be the same between the client
@ -592,12 +596,27 @@ func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, pt
cts.route_mtx.Unlock()
return nil, fmt.Errorf("existent route id - %d", route_id)
}
if ptc_name != "" {
// ptc name can be empty. but if not empty, it must be unique
_, ok = cts.route_map_by_ptc_name[ptc_name]
if ok {
cts.route_mtx.Unlock()
return nil, fmt.Errorf("existent ptc name %s for route %d", ptc_name, route_id)
}
}
r, err = NewServerRoute(cts, route_id, proto, ptc_addr, ptc_name, svc_requested_addr, svc_permitted_net)
if err != nil {
cts.route_mtx.Unlock()
return nil, err
}
cts.route_map[route_id] = r
if ptc_name != "" {
cts.route_map_by_ptc_name[ptc_name] = r
}
cts.S.stats.routes.Add(1)
cts.route_mtx.Unlock()
@ -625,6 +644,9 @@ func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error {
cts.route_mtx.Unlock()
return fmt.Errorf("non-existent route - %d", route.Id)
}
if route.PtcName != "" {
delete(cts.route_map_by_ptc_name, route.PtcName)
}
delete(cts.route_map, route.Id)
cts.S.stats.routes.Add(-1)
cts.route_mtx.Unlock()
@ -647,6 +669,9 @@ func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, er
cts.route_mtx.Unlock()
return nil, fmt.Errorf("non-existent route id - %d", route_id)
}
if r.PtcName != "" {
delete(cts.route_map_by_ptc_name, r.PtcName)
}
delete(cts.route_map, route_id)
cts.S.stats.routes.Add(-1)
cts.route_mtx.Unlock()
@ -674,6 +699,21 @@ func (cts *ServerConn) FindServerRouteById(route_id RouteId) *ServerRoute {
return r
}
func (cts *ServerConn) FindServerRouteByPtcName(ptc_name string) *ServerRoute {
var r *ServerRoute
var ok bool
cts.route_mtx.Lock()
r, ok = cts.route_map_by_ptc_name[ptc_name]
if !ok {
cts.route_mtx.Unlock()
return nil
}
cts.route_mtx.Unlock()
return r
}
func (cts *ServerConn) FindServerPeerConnById(route_id RouteId, peer_id PeerId) *ServerPeerConn {
var r *ServerRoute
var ok bool
@ -1009,8 +1049,8 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
} else {
cts.S.log.Write(cts.Sid, LOG_INFO,
"Added route(%d,%s,%s,%v,%v) for client %s to cts(%d)",
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, cts.Id)
"Added route(%d,%s,%s,%v,%v,%s) for client %s to cts(%d)",
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, r.PtcName, cts.RemoteAddr, cts.Id)
err = cts.pss.Send(MakeRouteStartedPacket(r.Id, r.SvcOption, r.SvcAddr.String(), r.PtcName, r.SvcReqAddr, r.SvcPermNet.String()))
if err != nil {
r.ReqStop()
@ -2288,6 +2328,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
cts.S = s
cts.Created = time.Now()
cts.route_map = make(ServerRouteMap)
cts.route_map_by_ptc_name = make(ServerRouteMapByPtcName)
cts.RemoteAddr = *remote_addr
cts.LocalAddr = *local_addr
cts.pss = &GuardedPacketStreamServer{Hodu_PacketStreamServer: pss}
@ -2480,6 +2521,32 @@ func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute {
return cts.FindServerRouteById(route_id)
}
func (s *Server) FindServerRouteByClientTokenAndRouteId(token string, route_id RouteId) *ServerRoute {
var cts *ServerConn
var ok bool
s.cts_mtx.Lock()
defer s.cts_mtx.Unlock()
cts, ok = s.cts_map_by_token[token]
if !ok { return nil }
return cts.FindServerRouteById(route_id)
}
func (s *Server) FindServerRouteByClientTokenAndPtcName(token string, ptc_name string) *ServerRoute {
var cts *ServerConn
var ok bool
s.cts_mtx.Lock()
defer s.cts_mtx.Unlock()
cts, ok = s.cts_map_by_token[token]
if !ok { return nil }
return cts.FindServerRouteByPtcName(ptc_name)
}
func (s *Server) FindServerPeerConnById(id ConnId, route_id RouteId, peer_id PeerId) *ServerPeerConn {
var cts *ServerConn
var ok bool
@ -2556,10 +2623,10 @@ func (s *Server) FindServerPeerConnByIdStr(conn_id string, route_id string, peer
func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) {
var r *ServerRoute
var err error
if route_id == PORT_ID_MARKER {
var port_nid uint64
var err error
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) }
@ -2569,15 +2636,20 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve
} else {
var conn_nid uint64
var route_nid uint64
var err1 error
var err2 error
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) }
conn_nid, err1 = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
route_nid, err2 = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) }
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
if err1 != nil || err2 != nil {
// if not numeric, attempt to use it as a token and ptc name
r = s.FindServerRouteByClientTokenAndPtcName(conn_id, route_id)
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_id, route_id) }
} else {
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
}
}
return r, nil