update the proxy code to resolve a route by the client token and the client-side peer name
This commit is contained in:
16
client.go
16
client.go
@ -490,13 +490,13 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) {
|
||||
err = r.cts.psc.Send(MakeRouteStartPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet))
|
||||
if err != nil {
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG,
|
||||
"Failed to send route_start for route(%d,%s,%v,%s,%s) to %s - %s",
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p, err.Error())
|
||||
"Failed to send route_start for route(%d,%s,%v,%s,%s,%s) to %s - %s",
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName, r.cts.remote_addr_p, err.Error())
|
||||
goto done
|
||||
} else {
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG,
|
||||
"Sent route_start for route(%d,%s,%v,%s,%s) to %s",
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p)
|
||||
"Sent route_start for route(%d,%s,%v,%s,%s,%s) to %s",
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName, r.cts.remote_addr_p)
|
||||
}
|
||||
|
||||
r.ptc_wg.Add(1) // increment counter here
|
||||
@ -702,9 +702,9 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event
|
||||
r.cts.C.FireRouteEvent(CLIENT_EVENT_ROUTE_UPDATED, r)
|
||||
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_INFO,
|
||||
"Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s)",
|
||||
"Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s,%s)",
|
||||
rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr,
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -721,9 +721,9 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid data in route_started event(%d)", r.Id)
|
||||
} else {
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_INFO,
|
||||
"Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s)",
|
||||
"Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s,%s)",
|
||||
rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr,
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)
|
||||
r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.PeerName)
|
||||
}
|
||||
r.ReqStop()
|
||||
|
||||
|
@ -51,14 +51,22 @@ func _is_ansi_tty(fd uintptr) bool {
|
||||
}
|
||||
|
||||
|
||||
func NewAppLogger (id string, w io.Writer, mask hodu.LogMask) *AppLogger {
|
||||
func NewAppLogger(id string, w io.Writer, mask hodu.LogMask) *AppLogger {
|
||||
var l *AppLogger
|
||||
var f *os.File
|
||||
var ok bool
|
||||
var use_color bool
|
||||
|
||||
use_color = false
|
||||
f, ok = w.(*os.File)
|
||||
if ok { use_color = _is_ansi_tty(f.Fd()) }
|
||||
|
||||
l = &AppLogger{
|
||||
id: id,
|
||||
out: w,
|
||||
mask: mask,
|
||||
msg_chan: make(chan app_logger_msg_t, 256),
|
||||
use_color: false,
|
||||
use_color: use_color,
|
||||
}
|
||||
l.closed.Store(false)
|
||||
l.wg.Add(1)
|
||||
@ -66,7 +74,7 @@ func NewAppLogger (id string, w io.Writer, mask hodu.LogMask) *AppLogger {
|
||||
return l
|
||||
}
|
||||
|
||||
func NewAppLoggerToFile (id string, file_name string, max_size int64, rotate int, mask hodu.LogMask) (*AppLogger, error) {
|
||||
func NewAppLoggerToFile(id string, file_name string, max_size int64, rotate int, mask hodu.LogMask) (*AppLogger, error) {
|
||||
var l *AppLogger
|
||||
var f *os.File
|
||||
var matched bool
|
||||
|
@ -191,7 +191,7 @@ func (pxy *server_pxy) Authenticate(req *http.Request) (int, string) {
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
func prevent_follow_redirect (req *http.Request, via []*http.Request) error {
|
||||
func prevent_follow_redirect(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
@ -289,7 +289,7 @@ func (pxy *server_pxy_http_main) serve_upgraded(w http.ResponseWriter, req *http
|
||||
return err
|
||||
}
|
||||
|
||||
func (pxy *server_pxy_http_main) addr_to_transport (ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) {
|
||||
func (pxy *server_pxy_http_main) addr_to_transport(ctx context.Context, addr *net.TCPAddr) (*http.Transport, error) {
|
||||
var dialer *net.Dialer
|
||||
var waitctx context.Context
|
||||
var cancel_wait context.CancelFunc
|
||||
@ -318,7 +318,7 @@ func (pxy *server_pxy_http_main) addr_to_transport (ctx context.Context, addr *n
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pxy *server_pxy_http_main) req_to_proxy_url (req *http.Request, r *ServerRouteProxyInfo) *url.URL {
|
||||
func (pxy *server_pxy_http_main) req_to_proxy_url(req *http.Request, r *ServerRouteProxyInfo) *url.URL {
|
||||
var proxy_proto string
|
||||
var proxy_url_path string
|
||||
|
||||
@ -568,7 +568,7 @@ func (pxy *server_pxy_ssh_ws) Identity() string {
|
||||
// TODO: put this task to sync group.
|
||||
// TODO: put the above proxy task to sync group too.
|
||||
|
||||
func (pxy *server_pxy_ssh_ws) connect_ssh (ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) {
|
||||
func (pxy *server_pxy_ssh_ws) connect_ssh(ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) {
|
||||
var cc *ssh.ClientConfig
|
||||
var addr *net.TCPAddr
|
||||
var dialer *net.Dialer
|
||||
|
94
server.go
94
server.go
@ -41,6 +41,7 @@ type ServerConnMapByAddr map[net.Addr]*ServerConn
|
||||
type ServerConnMapByClientToken map[string]*ServerConn
|
||||
type ServerConnMap map[ConnId]*ServerConn
|
||||
type ServerRouteMap map[RouteId]*ServerRoute
|
||||
type ServerRouteMapByPtcName map[string]*ServerRoute
|
||||
type ServerPeerConnMap map[PeerId]*ServerPeerConn
|
||||
type ServerSvcPortMap map[PortId]ConnRouteId
|
||||
|
||||
@ -199,6 +200,7 @@ type ServerConn struct {
|
||||
|
||||
route_mtx sync.Mutex
|
||||
route_map ServerRouteMap
|
||||
route_map_by_ptc_name ServerRouteMapByPtcName
|
||||
route_wg sync.WaitGroup
|
||||
|
||||
pts_mtx sync.Mutex
|
||||
@ -581,10 +583,12 @@ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_r
|
||||
|
||||
func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
var err error
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
if cts.route_map[route_id] != nil {
|
||||
_, ok = cts.route_map[route_id]
|
||||
if ok {
|
||||
// If this happens, something must be wrong between the server and the client
|
||||
// most likely, it must be a logic error. the state must not go out of sync
|
||||
// as the route_id and the peer_id are supposed to be the same between the client
|
||||
@ -592,12 +596,27 @@ func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, pt
|
||||
cts.route_mtx.Unlock()
|
||||
return nil, fmt.Errorf("existent route id - %d", route_id)
|
||||
}
|
||||
|
||||
if ptc_name != "" {
|
||||
// ptc name can be empty. but if not empty, it must be unique
|
||||
_, ok = cts.route_map_by_ptc_name[ptc_name]
|
||||
if ok {
|
||||
cts.route_mtx.Unlock()
|
||||
return nil, fmt.Errorf("existent ptc name %s for route %d", ptc_name, route_id)
|
||||
}
|
||||
}
|
||||
|
||||
r, err = NewServerRoute(cts, route_id, proto, ptc_addr, ptc_name, svc_requested_addr, svc_permitted_net)
|
||||
if err != nil {
|
||||
cts.route_mtx.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cts.route_map[route_id] = r
|
||||
if ptc_name != "" {
|
||||
cts.route_map_by_ptc_name[ptc_name] = r
|
||||
}
|
||||
|
||||
cts.S.stats.routes.Add(1)
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
@ -625,6 +644,9 @@ func (cts *ServerConn) RemoveServerRoute(route *ServerRoute) error {
|
||||
cts.route_mtx.Unlock()
|
||||
return fmt.Errorf("non-existent route - %d", route.Id)
|
||||
}
|
||||
if route.PtcName != "" {
|
||||
delete(cts.route_map_by_ptc_name, route.PtcName)
|
||||
}
|
||||
delete(cts.route_map, route.Id)
|
||||
cts.S.stats.routes.Add(-1)
|
||||
cts.route_mtx.Unlock()
|
||||
@ -647,6 +669,9 @@ func (cts *ServerConn) RemoveServerRouteById(route_id RouteId) (*ServerRoute, er
|
||||
cts.route_mtx.Unlock()
|
||||
return nil, fmt.Errorf("non-existent route id - %d", route_id)
|
||||
}
|
||||
if r.PtcName != "" {
|
||||
delete(cts.route_map_by_ptc_name, r.PtcName)
|
||||
}
|
||||
delete(cts.route_map, route_id)
|
||||
cts.S.stats.routes.Add(-1)
|
||||
cts.route_mtx.Unlock()
|
||||
@ -674,6 +699,21 @@ func (cts *ServerConn) FindServerRouteById(route_id RouteId) *ServerRoute {
|
||||
return r
|
||||
}
|
||||
|
||||
func (cts *ServerConn) FindServerRouteByPtcName(ptc_name string) *ServerRoute {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
r, ok = cts.route_map_by_ptc_name[ptc_name]
|
||||
if !ok {
|
||||
cts.route_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (cts *ServerConn) FindServerPeerConnById(route_id RouteId, peer_id PeerId) *ServerPeerConn {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
@ -1009,8 +1049,8 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
||||
|
||||
} else {
|
||||
cts.S.log.Write(cts.Sid, LOG_INFO,
|
||||
"Added route(%d,%s,%s,%v,%v) for client %s to cts(%d)",
|
||||
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, cts.Id)
|
||||
"Added route(%d,%s,%s,%v,%v,%s) for client %s to cts(%d)",
|
||||
r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, r.PtcName, cts.RemoteAddr, cts.Id)
|
||||
err = cts.pss.Send(MakeRouteStartedPacket(r.Id, r.SvcOption, r.SvcAddr.String(), r.PtcName, r.SvcReqAddr, r.SvcPermNet.String()))
|
||||
if err != nil {
|
||||
r.ReqStop()
|
||||
@ -2288,6 +2328,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
|
||||
cts.S = s
|
||||
cts.Created = time.Now()
|
||||
cts.route_map = make(ServerRouteMap)
|
||||
cts.route_map_by_ptc_name = make(ServerRouteMapByPtcName)
|
||||
cts.RemoteAddr = *remote_addr
|
||||
cts.LocalAddr = *local_addr
|
||||
cts.pss = &GuardedPacketStreamServer{Hodu_PacketStreamServer: pss}
|
||||
@ -2480,6 +2521,32 @@ func (s *Server) FindServerRouteById(id ConnId, route_id RouteId) *ServerRoute {
|
||||
return cts.FindServerRouteById(route_id)
|
||||
}
|
||||
|
||||
func (s *Server) FindServerRouteByClientTokenAndRouteId(token string, route_id RouteId) *ServerRoute {
|
||||
var cts *ServerConn
|
||||
var ok bool
|
||||
|
||||
s.cts_mtx.Lock()
|
||||
defer s.cts_mtx.Unlock()
|
||||
|
||||
cts, ok = s.cts_map_by_token[token]
|
||||
if !ok { return nil }
|
||||
|
||||
return cts.FindServerRouteById(route_id)
|
||||
}
|
||||
|
||||
func (s *Server) FindServerRouteByClientTokenAndPtcName(token string, ptc_name string) *ServerRoute {
|
||||
var cts *ServerConn
|
||||
var ok bool
|
||||
|
||||
s.cts_mtx.Lock()
|
||||
defer s.cts_mtx.Unlock()
|
||||
|
||||
cts, ok = s.cts_map_by_token[token]
|
||||
if !ok { return nil }
|
||||
|
||||
return cts.FindServerRouteByPtcName(ptc_name)
|
||||
}
|
||||
|
||||
func (s *Server) FindServerPeerConnById(id ConnId, route_id RouteId, peer_id PeerId) *ServerPeerConn {
|
||||
var cts *ServerConn
|
||||
var ok bool
|
||||
@ -2556,10 +2623,10 @@ func (s *Server) FindServerPeerConnByIdStr(conn_id string, route_id string, peer
|
||||
|
||||
func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*ServerRoute, error) {
|
||||
var r *ServerRoute
|
||||
var err error
|
||||
|
||||
if route_id == PORT_ID_MARKER {
|
||||
var port_nid uint64
|
||||
var err error
|
||||
|
||||
port_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(PortId(0)) * 8))
|
||||
if err != nil { return nil, fmt.Errorf("invalid port id %s - %s", conn_id, err.Error()) }
|
||||
@ -2569,15 +2636,20 @@ func (s *Server) FindServerRouteByIdStr(conn_id string, route_id string) (*Serve
|
||||
} else {
|
||||
var conn_nid uint64
|
||||
var route_nid uint64
|
||||
var err1 error
|
||||
var err2 error
|
||||
|
||||
conn_nid, err = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
if err != nil { return nil, fmt.Errorf("invalid connection id %s - %s", conn_id, err.Error()) }
|
||||
conn_nid, err1 = strconv.ParseUint(conn_id, 10, int(unsafe.Sizeof(ConnId(0)) * 8))
|
||||
route_nid, err2 = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
|
||||
route_nid, err = strconv.ParseUint(route_id, 10, int(unsafe.Sizeof(RouteId(0)) * 8))
|
||||
if err != nil { return nil, fmt.Errorf("invalid route id %s - %s", route_id, err.Error()) }
|
||||
|
||||
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
||||
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
|
||||
if err1 != nil || err2 != nil {
|
||||
// if not numeric, attempt to use it as a token and ptc name
|
||||
r = s.FindServerRouteByClientTokenAndPtcName(conn_id, route_id)
|
||||
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_id, route_id) }
|
||||
} else {
|
||||
r = s.FindServerRouteById(ConnId(conn_nid), RouteId(route_nid))
|
||||
if r == nil { return nil, fmt.Errorf("route(%d,%d) not found", conn_nid, route_nid) }
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
|
Reference in New Issue
Block a user