diff --git a/client.go b/client.go index cb621f5..bb3058c 100644 --- a/client.go +++ b/client.go @@ -5,7 +5,7 @@ import "crypto/tls" import "errors" import "fmt" import "log" -import "math/rand" +//import "math/rand" import "net" import "net/http" import "strings" @@ -59,6 +59,7 @@ type Client struct { ptc_limit int // global maximum number of peers cts_limit int cts_mtx sync.Mutex + cts_next_id ConnId cts_map ClientConnMap wg sync.WaitGroup @@ -76,26 +77,27 @@ type Client struct { // client connection to server type ClientConn struct { - cli *Client - cfg ClientConfigActive - id ConnId - sid string // id rendered in string + cli *Client + cfg ClientConfigActive + id ConnId + sid string // id rendered in string - local_addr string - remote_addr string - conn *grpc.ClientConn // grpc connection to the server - hdc HoduClient - psc *GuardedPacketStreamClient // guarded grpc stream + local_addr string + remote_addr string + conn *grpc.ClientConn // grpc connection to the server + hdc HoduClient + psc *GuardedPacketStreamClient // guarded grpc stream - s_seed Seed - c_seed Seed + s_seed Seed + c_seed Seed - route_mtx sync.Mutex - route_map ClientRouteMap - route_wg sync.WaitGroup + route_mtx sync.Mutex + route_next_id RouteId + route_map ClientRouteMap + route_wg sync.WaitGroup - stop_req atomic.Bool - stop_chan chan bool + stop_req atomic.Bool + stop_chan chan bool } type ClientRoute struct { @@ -579,6 +581,7 @@ func NewClientConn(c *Client, cfg *ClientConfig) *ClientConn { cts.cli = c cts.route_map = make(ClientRouteMap) + cts.route_next_id = 0 cts.cfg.ClientConfig = *cfg cts.stop_req.Store(false) cts.stop_chan = make(chan bool, 8) @@ -591,32 +594,30 @@ func NewClientConn(c *Client, cfg *ClientConfig) *ClientConn { func (cts *ClientConn) AddNewClientRoute(addr string, server_peer_svc_addr string, server_peer_svc_net string, option RouteOption) (*ClientRoute, error) { var r *ClientRoute - var id RouteId - var nattempts RouteId + var start_id RouteId - nattempts = 0 - id = RouteId(rand.Uint32()) + //start_id = RouteId(rand.Uint64()) + start_id = cts.route_next_id cts.route_mtx.Lock() for { var ok bool - - _, ok = cts.route_map[id] + _, ok = cts.route_map[cts.route_next_id] if !ok { break } - id++ - nattempts++ - if nattempts == ^RouteId(0) { + cts.route_next_id++ + if cts.route_next_id == start_id { cts.route_mtx.Unlock() - return nil, fmt.Errorf("route map full") + return nil, fmt.Errorf("unable to assign id") } } - r = NewClientRoute(cts, id, addr, server_peer_svc_addr, server_peer_svc_net, option) - cts.route_map[id] = r + r = NewClientRoute(cts, cts.route_next_id, addr, server_peer_svc_addr, server_peer_svc_net, option) + cts.route_map[r.id] = r + cts.route_next_id++ cts.cli.stats.routes.Add(1) cts.route_mtx.Unlock() - cts.cli.log.Write(cts.sid, LOG_INFO, "Added route(%d,%s)", id, addr) + cts.cli.log.Write(cts.sid, LOG_INFO, "Added route(%d,%s)", r.id, addr) cts.route_wg.Add(1) go r.RunTask(&cts.route_wg) @@ -847,6 +848,7 @@ start_over: } cts.s_seed = *s_seed cts.c_seed = c_seed + cts.route_next_id = 0 // reset this whenever a new connection is made. the number of routes must be zero. cts.cli.log.Write(cts.sid, LOG_INFO, "Got seed from server[%d] %s - ver=%#x", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index], cts.s_seed.Version) @@ -875,6 +877,7 @@ start_over: goto done } } +// TODO: remember the previouslyu POSTed routes and readd them?? for { var pkt *Packet @@ -1098,6 +1101,7 @@ func NewClient(ctx context.Context, logger Logger, ctl_addrs []string, ctl_prefi c.ptc_tmout = peer_conn_tmout c.ptc_limit = peer_max c.cts_limit = rpc_max + c.cts_next_id = 0 c.cts_map = make(ClientConnMap) c.stop_req.Store(false) c.stop_chan = make(chan bool, 8) @@ -1139,7 +1143,7 @@ func NewClient(ctx context.Context, logger Logger, ctl_addrs []string, ctl_prefi func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) { var cts *ClientConn var ok bool - var id ConnId + var start_id ConnId if len(cfg.ServerAddrs) <= 0 { return nil, fmt.Errorf("no server rpc address specified") @@ -1154,18 +1158,24 @@ func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) { return nil, fmt.Errorf("too many connections - %d", c.cts_limit) } - //id = rand.Uint32() - id = ConnId(monotonic_time() / 1000) + //start_id = rand.Uint64() + //start_id = ConnId(monotonic_time() / 1000) + start_id = c.cts_next_id for { - _, ok = c.cts_map[id] + _, ok = c.cts_map[c.cts_next_id] if !ok { break } - id++ + c.cts_next_id++ + if c.cts_next_id == start_id { + c.cts_mtx.Lock() + return nil, fmt.Errorf("unable to assign id") + } } - cts.id = id - cts.cfg.Id = id // store it again in the active configuration for easy access via control channel - cts.sid = fmt.Sprintf("%d", id) // id in string used for logging + cts.id = c.cts_next_id + cts.cfg.Id = cts.id // store it again in the active configuration for easy access via control channel + cts.sid = fmt.Sprintf("%d", cts.id) // id in string used for logging - c.cts_map[id] = cts + c.cts_map[cts.id] = cts + c.cts_next_id++ c.stats.conns.Add(1) c.cts_mtx.Unlock() diff --git a/server.go b/server.go index 1334bf3..0870c1d 100644 --- a/server.go +++ b/server.go @@ -57,6 +57,7 @@ type Server struct { pts_limit int // global pts limit cts_limit int + cts_next_id ConnId cts_mtx sync.Mutex cts_map ServerConnMap cts_map_by_addr ServerConnMapByAddr @@ -76,21 +77,21 @@ type Server struct { // connection from client. // client connect to the server, the server accept it, and makes a tunnel request type ServerConn struct { - svr *Server - id ConnId - sid string // for logging + svr *Server + id ConnId + sid string // for logging - remote_addr net.Addr // client address that created this structure - local_addr net.Addr // local address that the client is connected to - pss *GuardedPacketStreamServer + remote_addr net.Addr // client address that created this structure + local_addr net.Addr // local address that the client is connected to + pss *GuardedPacketStreamServer - route_mtx sync.Mutex - route_map ServerRouteMap - route_wg sync.WaitGroup + route_mtx sync.Mutex + route_map ServerRouteMap + route_wg sync.WaitGroup - wg sync.WaitGroup - stop_req atomic.Bool - stop_chan chan bool + wg sync.WaitGroup + stop_req atomic.Bool + stop_chan chan bool } type ServerRoute struct { @@ -107,7 +108,7 @@ type ServerRoute struct { pts_mtx sync.Mutex pts_map ServerPeerConnMap pts_limit int - pts_last_id PeerId + pts_next_id PeerId pts_wg sync.WaitGroup stop_req atomic.Bool } @@ -181,7 +182,7 @@ func NewServerRoute(cts *ServerConn, id RouteId, option RouteOption, ptc_addr st r.ptc_addr = ptc_addr r.pts_limit = PTS_LIMIT r.pts_map = make(ServerPeerConnMap) - r.pts_last_id = 0 + r.pts_next_id = 0 r.stop_req.Store(false) return &r, nil @@ -199,22 +200,22 @@ func (r *ServerRoute) AddNewServerPeerConn(c *net.TCPConn) (*ServerPeerConn, err return nil, fmt.Errorf("peer-to-server connection table full") } - start_id = r.pts_last_id + start_id = r.pts_next_id for { - _, ok = r.pts_map[r.pts_last_id] + _, ok = r.pts_map[r.pts_next_id] if !ok { break } - r.pts_last_id++ - if r.pts_last_id == start_id { + r.pts_next_id++ + if r.pts_next_id == start_id { // unlikely to happen but it cycled through the whole range. return nil, fmt.Errorf("failed to assign peer-to-server connection id") } } - pts = NewServerPeerConn(r, c, r.pts_last_id) + pts = NewServerPeerConn(r, c, r.pts_next_id) r.pts_map[pts.conn_id] = pts - r.pts_last_id++ + r.pts_next_id++ r.cts.svr.stats.peers.Add(1) return pts, nil @@ -905,6 +906,7 @@ func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs s.ext_svcs = make([]Service, 0, 1) s.pts_limit = peer_max s.cts_limit = rpc_max + s.cts_next_id = 0 s.cts_map = make(ServerConnMap) s.cts_map_by_addr = make(ServerConnMapByAddr) s.stop_chan = make(chan bool, 8) @@ -1185,7 +1187,7 @@ func (s *Server) ReqStop() { func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, pss Hodu_PacketStreamServer) (*ServerConn, error) { var cts ServerConn - var id ConnId + var start_id ConnId var ok bool cts.svr = s @@ -1204,15 +1206,20 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p return nil, fmt.Errorf("too many connections - %d", s.cts_limit) } - //id = rand.Uint32() - id = ConnId(monotonic_time() / 1000) + //start_id = rand.Uint64() + //start_id = ConnId(monotonic_time() / 1000) + start_id = s.cts_next_id for { - _, ok = s.cts_map[id] + _, ok = s.cts_map[s.cts_next_id] if !ok { break } - id++ + s.cts_next_id++ + if s.cts_next_id == start_id { + s.cts_mtx.Unlock() + return nil, fmt.Errorf("unable to assign id") + } } - cts.id = id - cts.sid = fmt.Sprintf("%d", id) // id in string used for logging + cts.id = s.cts_next_id + cts.sid = fmt.Sprintf("%d", cts.id) // id in string used for logging _, ok = s.cts_map_by_addr[cts.remote_addr] if ok { @@ -1220,7 +1227,8 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p return nil, fmt.Errorf("existing client - %s", cts.remote_addr.String()) } s.cts_map_by_addr[cts.remote_addr] = &cts - s.cts_map[id] = &cts; + s.cts_map[cts.id] = &cts; + s.cts_next_id++; s.stats.conns.Store(int64(len(s.cts_map))) s.cts_mtx.Unlock()