diff --git a/client-ctl.go b/client-ctl.go index 5190bcd..15b99fc 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -96,7 +96,7 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R js = make([]json_out_client_conn, 0) c.cts_mtx.Lock() - for _, cts = range c.cts_map_by_id { + for _, cts = range c.cts_map { var r *ClientRoute var jsp []json_out_client_route @@ -145,11 +145,11 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R } case http.MethodDelete: - // delete all server conneections - var cts *ClientConn - c.cts_mtx.Lock() - for _, cts = range c.cts_map { cts.ReqStop() } - c.cts_mtx.Unlock() + // delete all client connections to servers. if we request to stop all + // client connections, they will remove themselves from the client. + // we do passive deletion rather than doing active deletion by calling + // c.RemoveAllClientConns() + c.ReqStopAllClientConns() status_code = http.StatusNoContent; w.WriteHeader(status_code) default: @@ -177,7 +177,6 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt var conn_nid uint64 var je *json.Encoder - c = ctl.c je = json.NewEncoder(w) @@ -190,7 +189,6 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt goto done } - switch req.Method { case http.MethodGet: var r *ClientRoute @@ -234,7 +232,6 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt } return - done: // TODO: need to handle x-forwarded-for and other stuff? this is not a real web service, though c.log.Write("", LOG_DEBUG, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) // TODO: time taken @@ -243,9 +240,10 @@ done: oops: c.log.Write("", LOG_ERROR, "[%s] %s %s - %s", req.RemoteAddr, req.Method, req.URL.String(), err.Error()) return - } +// ------------------------------------ + func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, req *http.Request) { var c *Client var status_code int @@ -313,7 +311,8 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r } case http.MethodDelete: - cts.RemoveClientRoutes() + //cts.RemoveAllClientRoutes() + cts.ReqStopAllClientRoutes() status_code = http.StatusNoContent; w.WriteHeader(status_code) default: diff --git a/client.go b/client.go index 8d15d9f..d3f12ff 100644 --- a/client.go +++ b/client.go @@ -18,8 +18,8 @@ import "google.golang.org/grpc/status" type PacketStreamClient grpc.BidiStreamingClient[Packet, Packet] -type ClientConnMap = map[string]*ClientConn -type ClientConnMapById = map[uint32]*ClientConn +type ClientConnMapByAddr = map[string]*ClientConn +type ClientConnMap = map[uint32]*ClientConn type ClientPeerConnMap = map[uint32]*ClientPeerConn type ClientRouteMap = map[uint32]*ClientRoute type ClientPeerCancelFuncMap = map[uint32]context.CancelFunc @@ -45,8 +45,8 @@ type Client struct { ctl *http.Server // control server cts_mtx sync.Mutex + cts_map_by_addr ClientConnMapByAddr cts_map ClientConnMap - cts_map_by_id ClientConnMapById wg sync.WaitGroup stop_req atomic.Bool @@ -408,15 +408,27 @@ fmt.Printf("added client route.... %d -> %d\n", id, len(cts.route_map)) return r, nil } -func (cts *ClientConn) RemoveClientRoutes() { +func (cts *ClientConn) ReqStopAllClientRoutes() { var r *ClientRoute cts.route_mtx.Lock() + defer cts.route_mtx.Unlock() + + for _, r = range cts.route_map { + r.ReqStop() + } +} + +func (cts *ClientConn) RemoveAllClientRoutes() { + var r *ClientRoute + + cts.route_mtx.Lock() + defer cts.route_mtx.Unlock() + for _, r = range cts.route_map { delete(cts.route_map, r.id) r.ReqStop() } - cts.route_mtx.Unlock() } func (cts *ClientConn) RemoveClientRoute(route *ClientRoute) error { @@ -766,8 +778,8 @@ func NewClient(ctx context.Context, listen_on string, logger Logger, tlscfg *tls c.ctx, c.ctx_cancel = context.WithCancel(ctx) c.tlscfg = tlscfg c.ext_svcs = make([]Service, 0, 1) + c.cts_map_by_addr = make(ClientConnMapByAddr) c.cts_map = make(ClientConnMap) - c.cts_map_by_id = make(ClientConnMapById) c.stop_req.Store(false) c.stop_chan = make(chan bool, 8) c.log = logger @@ -800,14 +812,14 @@ func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) { c.cts_mtx.Lock() defer c.cts_mtx.Unlock() - _, ok = c.cts_map[cfg.ServerAddr] + _, ok = c.cts_map_by_addr[cfg.ServerAddr] if ok { return nil, fmt.Errorf("existing server - %s", cfg.ServerAddr) } id = rand.Uint32() for { - _, ok = c.cts_map_by_id[id] + _, ok = c.cts_map[id] if !ok { break } id++ } @@ -815,19 +827,43 @@ func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) { cts.cfg.Id = id // store it again in the active configuration for easy access via control channel cts.lid = fmt.Sprintf("%d", id) // id in string used for logging - c.cts_map[cfg.ServerAddr] = cts - c.cts_map_by_id[id] = cts -fmt.Printf("ADD total servers %d\n", len(c.cts_map)) + c.cts_map_by_addr[cfg.ServerAddr] = cts + c.cts_map[id] = cts +fmt.Printf("ADD total servers %d\n", len(c.cts_map_by_addr)) return cts, nil } +func (c* Client) ReqStopAllClientConns() { + var cts *ClientConn + + c.cts_mtx.Lock() + defer c.cts_mtx.Unlock() + + for _, cts = range c.cts_map { + cts.ReqStop() + } +} + +func (c *Client) RemoveAllClientConns() { + var cts *ClientConn + + c.cts_mtx.Lock() + defer c.cts_mtx.Unlock() + + for _, cts = range c.cts_map { + delete(c.cts_map_by_addr, cts.cfg.ServerAddr) + delete(c.cts_map, cts.id) + cts.ReqStop() + } +} + func (c *Client) RemoveClientConn(cts *ClientConn) error { var conn *ClientConn var ok bool c.cts_mtx.Lock() - conn, ok = c.cts_map[cts.cfg.ServerAddr] + conn, ok = c.cts_map[cts.id] if !ok { c.cts_mtx.Unlock() return fmt.Errorf("non-existent connection id - %d", cts.id) @@ -837,9 +873,9 @@ func (c *Client) RemoveClientConn(cts *ClientConn) error { return fmt.Errorf("non-existent connection id - %d", cts.id) } - delete(c.cts_map, cts.cfg.ServerAddr) - delete(c.cts_map_by_id, cts.id) -fmt.Printf("REMOVEDDDDDD CONNECTION FROM %s total servers %d\n", cts.cfg.ServerAddr, len(c.cts_map)) + delete(c.cts_map_by_addr, cts.cfg.ServerAddr) + delete(c.cts_map, cts.id) +fmt.Printf("REMOVEDDDDDD CONNECTION FROM %s total servers %d\n", cts.cfg.ServerAddr, len(c.cts_map_by_addr)) c.cts_mtx.Unlock() c.ReqStop() @@ -852,7 +888,7 @@ func (c *Client) RemoveClientConnById(conn_id uint32) error { c.cts_mtx.Lock() - cts, ok = c.cts_map_by_id[conn_id] + cts, ok = c.cts_map[conn_id] if !ok { c.cts_mtx.Unlock() return fmt.Errorf("non-existent connection id - %d", conn_id) @@ -860,9 +896,9 @@ func (c *Client) RemoveClientConnById(conn_id uint32) error { // NOTE: removal by id doesn't perform identity check - delete(c.cts_map, cts.cfg.ServerAddr) - delete(c.cts_map_by_id, cts.id) -fmt.Printf("REMOVEDDDDDD CONNECTION FROM %s total servers %d\n", cts.cfg.ServerAddr, len(c.cts_map)) + delete(c.cts_map_by_addr, cts.cfg.ServerAddr) + delete(c.cts_map, cts.id) +fmt.Printf("REMOVEDDDDDD CONNECTION FROM %s total servers %d\n", cts.cfg.ServerAddr, len(c.cts_map_by_addr)) c.cts_mtx.Unlock() cts.ReqStop() @@ -876,7 +912,7 @@ func (c *Client) FindClientConnById(id uint32) *ClientConn { c.cts_mtx.Lock() defer c.cts_mtx.Unlock() - cts, ok = c.cts_map_by_id[id] + cts, ok = c.cts_map[id] if !ok { return nil }