diff --git a/client.go b/client.go index 53f31f4..0d81479 100644 --- a/client.go +++ b/client.go @@ -143,7 +143,7 @@ func NewClientRoute(cts *ServerConn, id uint32, addr *net.TCPAddr, proto ROUTE_P r.proto = proto r.peer_addr = addr r.stop_req.Store(false) - r.stop_chan = make(chan bool, 1) + r.stop_chan = make(chan bool, 8) return &r } @@ -335,7 +335,7 @@ func NewServerConn(c *Client, addr *net.TCPAddr, cfg *ClientConfig) *ServerConn cts.saddr = addr cts.cfg = cfg cts.stop_req.Store(false) - cts.stop_chan = make(chan bool, 1) + cts.stop_chan = make(chan bool, 8) // the actual connection to the server is established in the main task function // The cts.conn, cts.hdc, cts.psc fields are left unassigned here. @@ -361,7 +361,28 @@ fmt.Printf ("added client route.... %d -> %d\n", route_id, len(cts.route_map)) return r, nil } -func (cts *ServerConn) RemoveClientRoute (route_id uint32) error { +func (cts *ServerConn) RemoveClientRoute(route *ClientRoute) error { + var r *ClientRoute + var ok bool + + cts.route_mtx.Lock() + r, ok = cts.route_map[route.id] + if (!ok) { + cts.route_mtx.Unlock() + return fmt.Errorf ("non-existent route id - %d", route.id) + } + if r != route { + cts.route_mtx.Unlock() + return fmt.Errorf ("non-existent route id - %d", route.id) + } + delete(cts.route_map, route.id) + cts.route_mtx.Unlock() + + r.ReqStop() + return nil +} + +func (cts *ServerConn) RemoveClientRouteById(route_id uint32) error { var r *ClientRoute var ok bool @@ -374,27 +395,10 @@ func (cts *ServerConn) RemoveClientRoute (route_id uint32) error { delete(cts.route_map, route_id) cts.route_mtx.Unlock() - r.ReqStop() // TODO: make this unblocking or blocking? + r.ReqStop() return nil } -func (cts *ServerConn) RemoveClientRoutes () { - var r *ClientRoute - var id uint32 - - cts.route_mtx.Lock() - for _, r = range cts.route_map { - r.ReqStop() - } - - for id, r = range cts.route_map { - delete(cts.route_map, id) - } - - cts.route_map = make(ClientRouteMap) - cts.route_mtx.Unlock() -} - func (cts *ServerConn) AddClientRoutes (peer_addrs []string) error { var i int var v string @@ -421,9 +425,11 @@ func (cts *ServerConn) AddClientRoutes (peer_addrs []string) error { } } +// TODO: mutex protection for _, r = range cts.route_map { err = cts.psc.Send(MakeRouteStartPacket(r.id, r.proto, addr.String())) if err != nil { +// TODO: remove all routes??? return fmt.Errorf("unable to send route-start packet - %s", err.Error()) } } @@ -431,20 +437,41 @@ func (cts *ServerConn) AddClientRoutes (peer_addrs []string) error { return nil } +func (cts *ServerConn) RemoveClientRoutes () { + var r *ClientRoute + var id uint32 + + cts.route_mtx.Lock() + for _, r = range cts.route_map { + r.ReqStop() + } + + for id, r = range cts.route_map { + delete(cts.route_map, id) + } + + cts.route_map = make(ClientRouteMap) + cts.route_mtx.Unlock() + +// TODO: mutex protection? + for _, r = range cts.route_map { + cts.psc.Send(MakeRouteStopPacket(r.id, r.proto, r.peer_addr.String())) + } +} + func (cts *ServerConn) ReqStop() { if cts.stop_req.CompareAndSwap(false, true) { var r *ClientRoute cts.route_mtx.Lock() - for _, r = range cts.route_map { + for _, r = range cts.route_map { + cts.psc.Send(MakeRouteStopPacket(r.id, r.proto, r.peer_addr.String())) // don't care about failure r.ReqStop() } cts.route_mtx.Unlock() - // TODO: notify the server.. send term command??? cts.stop_chan <- true } -fmt.Printf ("*** Sent stop request to ServerConn..\n") } func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { @@ -461,11 +488,10 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { // TODO: HANDLE connection timeout.. // ctx, _/*cancel*/ := context.WithTimeout(context.Background(), time.Second) start_over: -fmt.Printf ("Connecting GRPC to [%s]\n", cts.saddr.String()) + cts.cli.log.Write("", LOG_INFO, "Connecting to server %s", cts.saddr.String()) conn, err = grpc.NewClient(cts.saddr.String(), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - // TODO: logging - fmt.Printf("ERROR: unable to make grpc client to %s - %s\n", cts.cfg.ServerAddr, err.Error()) + cts.cli.log.Write("", LOG_ERROR, "Failed to connect to server %s - %s", cts.saddr.String(), err.Error()) goto reconnect_to_server } @@ -477,7 +503,7 @@ fmt.Printf ("Connecting GRPC to [%s]\n", cts.saddr.String()) c_seed.Flags = 0 s_seed, err = hdc.GetSeed(cts.cli.ctx, &c_seed) if err != nil { - fmt.Printf("ERROR: unable to get seed from %s - %s\n", cts.cfg.ServerAddr, err.Error()) + cts.cli.log.Write("", LOG_ERROR, "Failed to get seed from server %s - %s", cts.saddr.String(), err.Error()) goto reconnect_to_server } cts.s_seed = *s_seed @@ -485,10 +511,12 @@ fmt.Printf ("Connecting GRPC to [%s]\n", cts.saddr.String()) psc, err = hdc.PacketStream(cts.cli.ctx) if err != nil { - fmt.Printf ("ERROR: unable to get grpc packet stream - %s\n", err.Error()) + cts.cli.log.Write("", LOG_ERROR, "Failed to get packet stream from server %s - %s", cts.saddr.String(), err.Error()) goto reconnect_to_server } + cts.cli.log.Write("", LOG_INFO, "Got packet stream from server %s", cts.saddr.String()) + cts.conn = conn cts.hdc = hdc //cts.psc = &GuardedPacketStreamClient{psc: psc} @@ -498,9 +526,10 @@ fmt.Printf ("Connecting GRPC to [%s]\n", cts.saddr.String()) // let's add routes to the client-side peers. err = cts.AddClientRoutes(cts.cfg.PeerAddrs) if err != nil { - fmt.Printf ("ERROR: unable to add routes to client-side peers - %s\n", err.Error()) + cts.cli.log.Write("", LOG_INFO, "Failed to add routes to server %s for %v - %s", cts.saddr.String(), cts.cfg.PeerAddrs, err.Error()) goto done } + fmt.Printf("[%v]\n", cts.route_map) for { @@ -508,7 +537,7 @@ fmt.Printf("[%v]\n", cts.route_map) select { case <-cts.cli.ctx.Done(): - fmt.Printf("context doine... error - %s\n", cts.cli.ctx.Err().Error()) +fmt.Printf("context doine... error - %s\n", cts.cli.ctx.Err().Error()) goto done case <-cts.stop_chan: @@ -520,13 +549,13 @@ fmt.Printf("[%v]\n", cts.route_map) } pkt, err = psc.Recv() - if errors.Is(err, io.EOF) { - fmt.Printf("server disconnected\n") - goto reconnect_to_server - } if err != nil { - fmt.Printf("server receive error - %s\n", err.Error()) - goto reconnect_to_server + if errors.Is(err, io.EOF) { + goto reconnect_to_server + } else { + cts.cli.log.Write("", LOG_INFO, "Failed to receive packet form server %s - %s", cts.saddr.String(), err.Error()) + goto reconnect_to_server + } } switch pkt.Kind { @@ -611,7 +640,7 @@ fmt.Printf("[%v]\n", cts.route_map) case PACKET_KIND_PEER_DATA: // the connection from the client to a peer has been established - fmt.Printf ("**** GOT PEER DATA\n") + //fmt.Printf ("**** GOT PEER DATA\n") var x *Packet_Data var ok bool x, ok = pkt.U.(*Packet_Data) @@ -630,24 +659,17 @@ fmt.Printf("[%v]\n", cts.route_map) } done: -fmt.Printf ("^^^^^^^^^^^^^^^^^^^^ Server Coon RunTask ending...\n") - if conn != nil { - conn.Close() - // TODO: need to reset c.sc, c.sg, c.psc to nil? - // for this we need to ensure that everyone is ending - } + cts.cli.log.Write("", LOG_INFO, "Disconnected from server %s", cts.saddr.String()) cts.RemoveClientRoutes() + if conn != nil { conn.Close() } cts.route_wg.Wait() // wait until all route tasks are finished return reconnect_to_server: - if conn != nil { - conn.Close() - // TODO: need to reset c.sc, c.sg, c.psc to nil? - // for this we need to ensure that everyone is ending - } cts.RemoveClientRoutes() - slpctx, _ = context.WithTimeout(cts.cli.ctx, 3 * time.Second) + if conn != nil { conn.Close() } + // wait for 2 seconds + slpctx, _ = context.WithTimeout(cts.cli.ctx, 2 * time.Second) select { case <-cts.cli.ctx.Done(): fmt.Printf("context doine... error - %s\n", cts.cli.ctx.Err().Error()) @@ -657,7 +679,7 @@ reconnect_to_server: case <- slpctx.Done(): // do nothing } - goto start_over + goto start_over // and reconnect } func (cts *ServerConn) ReportEvent (route_id uint32, pts_id uint32, event_type PACKET_KIND, event_data []byte) error { @@ -722,7 +744,7 @@ func NewClient(ctx context.Context, listen_on string, logger Logger, tlscfg *tls c.ext_svcs = make([]Service, 0, 1) c.cts_map = make(ServerConnMap) // TODO: make it configurable... c.stop_req.Store(false) - c.stop_chan = make(chan bool, 1) + c.stop_chan = make(chan bool, 8) c.log = logger c.ctl = &http.Server{ @@ -763,7 +785,9 @@ func (c *Client) ReqStop() { if c.stop_req.CompareAndSwap(false, true) { var cts *ServerConn - c.ctl.Shutdown(c.ctx) // to break c.ctl.ListenAndServe() + if (c.ctl != nil) { + c.ctl.Shutdown(c.ctx) // to break c.ctl.ListenAndServe() + } for _, cts = range c.cts_map { cts.ReqStop() diff --git a/cmd/main.go b/cmd/main.go index 0199866..4c89902 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -156,7 +156,7 @@ func server_main(laddrs []string) error { return fmt.Errorf("ERROR: failed to load key pair - %s\n", err) } - s, err = hodu.NewServer(laddrs, &AppLogger{id: "server", out: os.Stderr}, &tls.Config{Certificates: []tls.Certificate{cert}}) + s, err = hodu.NewServer(context.Background(), laddrs, &AppLogger{id: "server", out: os.Stderr}, &tls.Config{Certificates: []tls.Certificate{cert}}) if err != nil { return fmt.Errorf("ERROR: failed to create new server - %s", err.Error()) } diff --git a/packet.go b/packet.go index ff83544..f38868d 100644 --- a/packet.go +++ b/packet.go @@ -7,6 +7,12 @@ func MakeRouteStartPacket(route_id uint32, proto ROUTE_PROTO, addr string) *Pack U: &Packet_Route{Route: &RouteDesc{RouteId: route_id, Proto: proto, AddrStr: addr}}} } +func MakeRouteStopPacket(route_id uint32, proto ROUTE_PROTO, addr string) *Packet { + return &Packet{ + Kind: PACKET_KIND_ROUTE_STOP, + U: &Packet_Route{Route: &RouteDesc{RouteId: route_id, Proto: proto, AddrStr: addr}}} +} + func MakeRouteStartedPacket(route_id uint32, proto ROUTE_PROTO, addr string) *Packet { // the connection from a peer to the server has been established return &Packet{Kind: PACKET_KIND_ROUTE_STARTED, diff --git a/server.go b/server.go index 1e525d3..28e429d 100644 --- a/server.go +++ b/server.go @@ -25,10 +25,14 @@ type ServerPeerConnMap = map[uint32]*ServerPeerConn type ServerRouteMap = map[uint32]*ServerRoute type Server struct { + ctx context.Context + ctx_cancel context.CancelFunc tlscfg *tls.Config + wg sync.WaitGroup ext_svcs []Service stop_req atomic.Bool + stop_chan chan bool ctl *http.Server // control server @@ -179,7 +183,7 @@ func (r *ServerRoute) RunTask(wg *sync.WaitGroup) { conn, err = r.l.AcceptTCP() if err != nil { if errors.Is(err, net.ErrClosed) { - r.cts.svr.log.Write(log_id, LOG_INFO, "Rervice-side peer listener closed") + r.cts.svr.log.Write(log_id, LOG_INFO, "Server-side peer listener closed") } else { r.cts.svr.log.Write(log_id, LOG_INFO, "Server-side peer listener error - %s", err.Error()) } @@ -197,9 +201,12 @@ func (r *ServerRoute) RunTask(wg *sync.WaitGroup) { } } - r.l.Close() // don't care about double close. it could have been closed in ReqStop + r.ReqStop() + r.pts_wg.Wait() r.cts.svr.log.Write(log_id, LOG_DEBUG, "All service-side peer handlers completed") + + r.cts.RemoveServerRoute(r) // final phase... } func (r *ServerRoute) ReqStop() { @@ -295,7 +302,28 @@ func (cts *ClientConn) AddNewServerRoute(route_id uint32, proto ROUTE_PROTO) (*S return r, nil } -func (cts *ClientConn) RemoveServerRoute (route_id uint32) error { +func (cts *ClientConn) RemoveServerRoute (route* ServerRoute) error { + var r *ServerRoute + var ok bool + + cts.route_mtx.Lock() + r, ok = cts.route_map[route.id] + if (!ok) { + cts.route_mtx.Unlock() + return fmt.Errorf ("non-existent route id - %d", route.id) + } + if (r != route) { + cts.route_mtx.Unlock() + return fmt.Errorf ("non-existent route - %d", route.id) + } + delete(cts.route_map, route.id) + cts.route_mtx.Unlock() + + r.ReqStop() + return nil +} + +func (cts *ClientConn) RemoveServerRouteById (route_id uint32) (*ServerRoute, error) { var r *ServerRoute var ok bool @@ -303,13 +331,13 @@ func (cts *ClientConn) RemoveServerRoute (route_id uint32) error { r, ok = cts.route_map[route_id] if (!ok) { cts.route_mtx.Unlock() - return fmt.Errorf ("non-existent route id - %d", route_id) + return nil, fmt.Errorf ("non-existent route id - %d", route_id) } delete(cts.route_map, route_id) cts.route_mtx.Unlock() - r.ReqStop() // TODO: make this unblocking or blocking? - return nil + r.ReqStop() + return r, nil } func (cts *ClientConn) ReportEvent (route_id uint32, pts_id uint32, event_type PACKET_KIND, event_data []byte) error { @@ -336,37 +364,37 @@ func (cts *ClientConn) receive_from_stream(wg *sync.WaitGroup) { for { pkt, err = cts.pss.Recv() if errors.Is(err, io.EOF) { - // return will close stream from server side -// TODO: clean up route_map and server-side peers releated to the client connection 'cts' -fmt.Printf ("grpd stream ended\n") + cts.svr.log.Write("", LOG_INFO, "GRPC stream closed for client %s", cts.caddr) goto done } if err != nil { - //log.Printf("receive error %v", err) - fmt.Printf ("grpc stream error - %s\n", err.Error()) + cts.svr.log.Write("", LOG_ERROR, "GRPC stream error for client %s - %s", cts.caddr, err.Error()) goto done } switch pkt.Kind { case PACKET_KIND_ROUTE_START: var x *Packet_Route - //var t *ServerRoute var ok bool x, ok = pkt.U.(*Packet_Route) if ok { var r* ServerRoute - fmt.Printf ("ADDED SERVER ROUTE FOR CLEINT PEER %s\n", x.Route.AddrStr) + r, err = cts.AddNewServerRoute(x.Route.RouteId, x.Route.Proto) if err != nil { - // TODO: Send Error Response... + cts.svr.log.Write("", LOG_ERROR, "Failed to add server route for client %s peer %s", cts.caddr, x.Route.AddrStr) } else { + cts.svr.log.Write("", LOG_INFO, "Added server route(id=%d) for client %s peer %s", r.id, cts.caddr, x.Route.AddrStr) err = cts.pss.Send(MakeRouteStartedPacket(r.id, x.Route.Proto, r.laddr.String())) if err != nil { - // TODO: + r.ReqStop() + cts.svr.log.Write("", LOG_ERROR, "Failed to inform client %s of server route started for peer %s", cts.caddr, x.Route.AddrStr) + goto done } } } else { - // TODO: send invalid request... or simply keep quiet? + cts.svr.log.Write("", LOG_INFO, "Received invalid packet from %s", cts.caddr) + // TODO: need to abort this client? } case PACKET_KIND_ROUTE_STOP: @@ -374,17 +402,23 @@ fmt.Printf ("grpd stream ended\n") var ok bool x, ok = pkt.U.(*Packet_Route) if ok { - err = cts.RemoveServerRoute(x.Route.RouteId) // TODO: this must be unblocking. otherwide, other route_map will get blocked... + var r* ServerRoute + + r, err = cts.RemoveServerRouteById(x.Route.RouteId) if err != nil { - // TODO: Send Error Response... + cts.svr.log.Write("", LOG_ERROR, "Failed to delete server route(id=%d) for client %s peer %s", x.Route.RouteId, cts.caddr, x.Route.AddrStr) } else { + cts.svr.log.Write("", LOG_ERROR, "Deleted server route(id=%d) for client %s peer %s", x.Route.RouteId, cts.caddr, x.Route.AddrStr) err = cts.pss.Send(MakeRouteStoppedPacket(x.Route.RouteId, x.Route.Proto)) if err != nil { - // TODO: + r.ReqStop() + cts.svr.log.Write("", LOG_ERROR, "Failed to inform client %s of server route(id=%d) stopped for peer %s", cts.caddr, x.Route.RouteId, x.Route.AddrStr) + goto done } } } else { - // TODO: send invalid request... or simply keep quiet? + cts.svr.log.Write("", LOG_INFO, "Received invalid packet from %s", cts.caddr) + // TODO: need to abort this client? } case PACKET_KIND_PEER_STARTED: @@ -471,10 +505,13 @@ func (cts *ClientConn) RunTask(wg *sync.WaitGroup) { // or continue select { case <-ctx.Done(): // the stream context is done -fmt.Printf("grpd server done - %s\n", ctx.Err().Error()) +fmt.Printf("grpc server done - %s\n", ctx.Err().Error()) goto done case <- cts.stop_chan: + // get out of the loop to eventually to exit from + // this handler to let the main grpc server to + // close this specific client connection. goto done //default: @@ -485,6 +522,7 @@ fmt.Printf("grpd server done - %s\n", ctx.Err().Error()) done: fmt.Printf ("^^^^^^^^^^^^^^^^^ waiting for reoute_wg...\n") + cts.ReqStop() // just in case cts.route_wg.Wait() fmt.Printf ("^^^^^^^^^^^^^^^^^ waited for reoute_wg...\n") } @@ -497,8 +535,12 @@ func (cts *ClientConn) ReqStop() { r.ReqStop() } + // there is no good way to break a specific connection client to + // the grpc server. while the global grpc server is closed in + // ReqStop() for Server, the individuation connection is closed + // by returing from the grpc handler goroutine. See the comment + // RunTask() for ClientConn. cts.stop_chan <- true - //cts.c.Close() // close the accepted connection from the client } } @@ -647,7 +689,7 @@ func unaryInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, han return m, err } -func NewServer(laddrs []string, logger Logger, tlscfg *tls.Config) (*Server, error) { +func NewServer(ctx context.Context, laddrs []string, logger Logger, tlscfg *tls.Config) (*Server, error) { var s Server var l *net.TCPListener var laddr *net.TCPAddr @@ -656,9 +698,10 @@ func NewServer(laddrs []string, logger Logger, tlscfg *tls.Config) (*Server, err var gl *net.TCPListener if len(laddrs) <= 0 { - return nil, fmt.Errorf("no or too many addresses provided") + return nil, fmt.Errorf("no server addresses provided") } + s.ctx, s.ctx_cancel = context.WithCancel(ctx) s.log = logger /* create the specified number of listeners */ s.l = make([]*net.TCPListener, 0) @@ -679,6 +722,7 @@ func NewServer(laddrs []string, logger Logger, tlscfg *tls.Config) (*Server, err s.tlscfg = tlscfg s.ext_svcs = make([]Service, 0, 1) s.cts_map = make(ClientConnMap) // TODO: make it configurable... + s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) /* creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem")) @@ -741,13 +785,25 @@ func (s *Server) RunTask(wg *sync.WaitGroup) { go s.run_grpc_server(idx, &s.l_wg) } - s.l_wg.Wait() - s.log.Write("", LOG_DEBUG, "All GRPC listeners completed") - s.cts_wg.Wait() - s.log.Write("", LOG_DEBUG, "All CTS handlers completed") + // most the work is done by in separate goroutines (s.run_grp_server) + // this loop serves as a placeholder to prevent the logic flow from + // descening down to s.ReqStop() +task_loop: + for { + select { + case <-s.stop_chan: + break task_loop + } + } s.ReqStop() + s.l_wg.Wait() + s.log.Write("", LOG_DEBUG, "All GRPC listeners completed") + + s.cts_wg.Wait() + s.log.Write("", LOG_DEBUG, "All CTS handlers completed") + // stop the main grpc server after all the other tasks are finished. s.gs.Stop() } @@ -770,6 +826,11 @@ func (s *Server) ReqStop() { var l *net.TCPListener var cts *ClientConn + if (s.ctl != nil) { + // shutdown the control server if ever started. + s.ctl.Shutdown(s.ctx) + } + //s.gs.GracefulStop() //s.gs.Stop() for _, l = range s.l { @@ -781,6 +842,9 @@ func (s *Server) ReqStop() { cts.ReqStop() // request to stop connections from/to peer held in the cts structure } s.cts_mtx.Unlock() + + s.stop_chan <- true + s.ctx_cancel() } } @@ -794,7 +858,7 @@ func (s *Server) AddNewClientConn(addr net.Addr, pss Hodu_PacketStreamServer) (* cts.pss = &GuardedPacketStreamServer{Hodu_PacketStreamServer: pss} cts.stop_req.Store(false) - cts.stop_chan = make(chan bool, 1) + cts.stop_chan = make(chan bool, 8) s.cts_mtx.Lock() defer s.cts_mtx.Unlock()