update the proxy code to resolve a route by the client token and the client-side peer name

This commit is contained in:
2025-08-24 19:00:26 +09:00
parent e2a3180ec7
commit c8cb71cf95
4 changed files with 106 additions and 26 deletions

View File

@ -41,6 +41,7 @@ type ServerConnMapByAddr map[net.Addr]*ServerConn
type ServerConnMapByClientToken map[string]*ServerConn
type ServerConnMap map[ConnId]*ServerConn
type ServerRouteMap map[RouteId]*ServerRoute
type ServerRouteMapByPtcName map[string]*ServerRoute
type ServerPeerConnMap map[PeerId]*ServerPeerConn
type ServerSvcPortMap map[PortId]ConnRouteId
@ -199,6 +200,7 @@ type ServerConn struct {
route_mtx sync.Mutex
route_map ServerRouteMap
route_map_by_ptc_name ServerRouteMapByPtcName
route_wg sync.WaitGroup
pts_mtx sync.Mutex
@ -581,10 +583,12 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r
func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
var r *ServerRoute
var ok bool
var err error
cts.route_mtx.Lock()
if cts.route_map[route_id] != nil {
_, ok = cts.route_map[route_id]
if ok {
// If this happens, something must be wrong between the server and the client
// most likely, it must be a logic error. the state must not go out of sync
// as the route_id and the peer_id are supposed to be the same between the client
@ -592,12 +596,27 @@ func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, pt
cts.route_mtx.Unlock()
return nil, fmt.Errorf("existent route id - %d", route_id)
}
if ptc_name != "" {
// ptc name can be empty. but if not empty, it must be unique
_, ok = cts.route_map_by_ptc_name[ptc_name]
if ok {
cts.route_mtx.Unlock()
return nil, fmt.Errorf("existent ptc name %s for route %d", ptc_name, route_id)
}
}
r, err = NewServerRoute(cts, route_id, proto, ptc_addr, ptc_name, svc_requested_addr, svc_permitted_net)
if err != nil {
cts.route_mtx.Unlock()
return nil, err
}
cts.route_map[route_id] = r
if ptc_name != "" {
cts.route_map_by_ptc_name[ptc_name] = r
}
cts.S.stats.routes.Add(1)
cts.route_mtx.Unlock()
@ -625,6 +644,9 @@ func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error {
cts.route_mtx.Unlock()
return fmt.Errorf("non-existent route - %d", route.Id)
}
if route.PtcName != "" {
delete(cts.route_map_by_ptc_name, route.PtcName)
}
delete(cts.route_map, route.Id)
cts.S.stats.routes.Add(-1)
cts.route_mtx.Unlock()
@ -647,6 +669,9 @@ func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, er
cts.route_mtx.Unlock()
return nil, fmt.Errorf("non-existent route id - %d", route_id)
}
if r.PtcName != "" {
delete(cts.route_map_by_ptc_name, r.PtcName)
}
delete(cts.route_map, route_id)
cts.S.stats.routes.Add(-1)
cts.route_mtx.Unlock()
@ -674,6 +699,21 @@ func (cts *ServerConn) FindServerRouteById(route_id RouteId) *ServerRoute {
return r
}
func (cts *ServerConn) FindServerRouteByPtcName(ptc_name string) *ServerRoute {
var r *ServerRoute
var ok bool
cts.route_mtx.Lock()
r, ok = cts.route_map_by_ptc_name[ptc_name]
if !ok {
cts.route_mtx.Unlock()
return nil
}
cts.route_mtx.Unlock()
return r
}
func (cts *ServerConn) FindServerPeerConnById(route_id RouteId, peer_id PeerId) *ServerPeerConn {
var r *ServerRoute
var ok bool
@ -1009,8 +1049,8 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
} else {
cts.S.log.Write(cts.Sid, LOG_INFO,
"Added route(%d,%s,%s,%v,%v) for client %s to cts(%d)",
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, cts.Id)
"Added route(%d,%s,%s,%v,%v,%s) for client %s to cts(%d)",
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, r.PtcName, cts.RemoteAddr, cts.Id)
err = cts.pss.Send(MakeRouteStartedPacket(r.Id, r.SvcOption, r.SvcAddr.String(), r.PtcName, r.SvcReqAddr, r.SvcPermNet.String()))
if err != nil {
r.ReqStop()
@ -2288,6 +2328,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
cts.S = s
cts.Created = time.Now()
cts.route_map = make(ServerRouteMap)
cts.route_map_by_ptc_name = make(ServerRouteMapByPtcName)
cts.RemoteAddr = *remote_addr
cts.LocalAddr = *local_addr
cts.pss = &GuardedPacketStreamServer{Hodu_PacketStreamServer: pss}
@ -2480,6 +2521,32 @@ func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute {
return cts.FindServerRouteById(route_id)
}
func (s *Server) FindServerRouteByClientTokenAndRouteId(token string, route_id RouteId) *ServerRoute {
var cts *ServerConn
var ok bool
s.cts_mtx.Lock()
defer s.cts_mtx.Unlock()
cts, ok = s.cts_map_by_token[token]
if !ok { return nil }
return cts.FindServerRouteById(route_id)
}
func (s *Server) FindServerRouteByClientTokenAndPtcName(token string, ptc_name string) *ServerRoute {
var cts *ServerConn
var ok bool
s.cts_mtx.Lock()
defer s.cts_mtx.Unlock()
cts, ok = s.cts_map_by_token[token]
if !ok { return nil }
return cts.FindServerRouteByPtcName(ptc_name)
}
func (s *Server) FindServerPeerConnById(id ConnId, route_id RouteId, peer_id PeerId) *ServerPeerConn {
var cts *ServerConn
var ok bool
@ -2556,10 +2623,10 @@ func (s *Server) FindServerPeerConnByIdStr(conn_id string, route_id string, peer
func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) {
var r *ServerRoute
var err error
if route_id == PORT_ID_MARKER {
var port_nid uint64
var err error
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) }
@ -2569,15 +2636,20 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve
} else {
var conn_nid uint64
var route_nid uint64
var err1 error
var err2 error
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) }
conn_nid, err1 = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
route_nid, err2 = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) }
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
if err1 != nil || err2 != nil {
// if not numeric, attempt to use it as a token and ptc name
r = s.FindServerRouteByClientTokenAndPtcName(conn_id, route_id)
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_id, route_id) }
} else {
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
}
}
return r, nil