diff --git a/c-peer.go b/c-peer.go index dc0740a..920f6c8 100644 --- a/c-peer.go +++ b/c-peer.go @@ -2,6 +2,7 @@ package main import "fmt" import "net" +import "sync" func NewClientPeerConn(r *ClientRoute, c net.Conn, id uint32) (*ClientPeerConn) { var cpc ClientPeerConn @@ -17,13 +18,14 @@ func NewClientPeerConn(r *ClientRoute, c net.Conn, id uint32) (*ClientPeerConn) return &cpc } -func (cpc *ClientPeerConn) RunTask() error { +func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error { //var conn *net.TCPConn //var addr *net.TCPAddr var err error var buf [4096]byte var n int + defer wg.Done() fmt.Printf("CONNECTION ESTABLISHED TO PEER... ABOUT TO READ DATA...\n") for { @@ -43,8 +45,6 @@ func (cpc *ClientPeerConn) RunTask() error { //done: cpc.ReqStop() - //cpc.c.RemoveClientPeerConn(cpc) - //cpc.c.wg.Done() return nil } diff --git a/client.go b/client.go index 337a5f9..4c57fa3 100644 --- a/client.go +++ b/client.go @@ -75,9 +75,8 @@ type ServerConn struct { route_mtx sync.Mutex route_map ClientRouteMap - //route_wg sync.WaitGroup + route_wg sync.WaitGroup - //wg sync.WaitGroup stop_req atomic.Bool stop_chan chan bool } @@ -120,9 +119,10 @@ func NewClientRoute(cts *ServerConn, id uint32, addr *net.TCPAddr, proto ROUTE_P return &r; } -func (r *ClientRoute) RunTask() { +func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { // this task on the route object isn't actually necessary. // most useful works are triggered by ReportEvent() and done by ConnectToPeer() + defer wg.Done() main_loop: for { @@ -131,6 +131,8 @@ main_loop: break main_loop } } + + r.ptc_wg.Wait() // wait for all peer tasks are finished fmt.Printf ("*** End fo Client Roue Task\n") } @@ -177,17 +179,14 @@ func (r* ClientRoute) ConnectToPeer(pts_id uint32) { } fmt.Printf("STARTED NEW SERVER PEER STAK\n") - //r.ptc_wg.Add(1) - //go ptc.RunTask() - //r.ptc_wg.Wait() - ptc.RunTask() - conn.Close() // don't care about double close. it could have been closed in StopTask + r.ptc_wg.Add(1) + go ptc.RunTask(&r.ptc_wg) } func (r* ClientRoute) ReportEvent (pts_id uint32, event_type PACKET_KIND, event_data []byte) error { switch event_type { case PACKET_KIND_PEER_STARTED: - go r.ConnectToPeer(pts_id) + r.ConnectToPeer(pts_id) // TODO: other types } @@ -196,6 +195,21 @@ func (r* ClientRoute) ReportEvent (pts_id uint32, event_type PACKET_KIND, event_ } // -------------------------------------------------------------------- +func NewServerConn(c *Client, addr *net.TCPAddr, cfg *ClientConfig) *ServerConn { + var cts ServerConn + + cts.cli = c + cts.route_map = make(ClientRouteMap) + cts.saddr = addr + cts.cfg = cfg + cts.stop_req.Store(false) + cts.stop_chan = make(chan bool, 1) + + // the actual connection to the server is established in the main task function + // The cts.conn, cts.hdc, cts.psc fields are left unassigned here. + + return &cts +} func (cts *ServerConn) AddNewClientRoute(route_id uint32, addr *net.TCPAddr, proto ROUTE_PROTO) (*ClientRoute, error) { var r *ClientRoute @@ -210,7 +224,8 @@ func (cts *ServerConn) AddNewClientRoute(route_id uint32, addr *net.TCPAddr, pro cts.route_mtx.Unlock() fmt.Printf ("added client route.... %d -> %d\n", route_id, len(cts.route_map)) - go r.RunTask() + cts.route_wg.Add(1) + go r.RunTask(&cts.route_wg) return r, nil } @@ -231,6 +246,23 @@ func (cts *ServerConn) RemoveClientRoute (route_id uint32) error { return nil; } +func (cts *ServerConn) RemoveClientRoutes () { + var r *ClientRoute + var id uint32 + + cts.route_mtx.Lock() + for _, r = range cts.route_map { + r.ReqStop() + } + + for id, r = range cts.route_map { + delete(cts.route_map, id) + } + + cts.route_map = make(ClientRouteMap) + cts.route_mtx.Unlock() +} + func (cts *ServerConn) AddClientRoutes (peer_addrs []string) error { var i int var v string @@ -253,7 +285,7 @@ func (cts *ServerConn) AddClientRoutes (peer_addrs []string) error { _, err = cts.AddNewClientRoute(uint32(i), addr, proto) if err != nil { - return fmt.Errorf("unable to add client route for %s", addr) + return fmt.Errorf("unable to add client route for %s - %s", addr, err.Error()) } } @@ -270,9 +302,12 @@ func (cts *ServerConn) AddClientRoutes (peer_addrs []string) error { func (cts *ServerConn) ReqStop() { if cts.stop_req.CompareAndSwap(false, true) { var r *ClientRoute + + cts.route_mtx.Lock() for _, r = range cts.route_map { r.ReqStop() } + cts.route_mtx.Unlock() // TODO: notify the server.. send term command??? cts.stop_chan <- true @@ -284,25 +319,27 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { var conn *grpc.ClientConn = nil var hdc HoduClient var psc PacketStreamClient + var slpctx context.Context var err error defer wg.Done() // arrange to call at the end of this function // TODO: HANDLE connection timeout.. // ctx, _/*cancel*/ := context.WithTimeout(context.Background(), time.Second) -fmt.Printf (">>>[%s]\n", cts.saddr.String()) +start_over: +fmt.Printf ("Connecting GRPC to [%s]\n", cts.saddr.String()) conn, err = grpc.NewClient(cts.saddr.String(), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { // TODO: logging - fmt.Printf("ERROR: unable to connect to %s - %s", cts.cfg.server_addr, err.Error()) - goto done + fmt.Printf("ERROR: unable to make grpc client to %s - %s\n", cts.cfg.server_addr, err.Error()) + goto reconnect_to_server } hdc = NewHoduClient(conn) - psc, err = hdc.PacketStream(cts.cli.ctx) // TODO: accept external context and use it.L + psc, err = hdc.PacketStream(cts.cli.ctx) if err != nil { - fmt.Printf ("ERROR: unable to get the packet stream - %s", err.Error()) - goto done + fmt.Printf ("ERROR: unable to get grpc packet stream - %s\n", err.Error()) + goto reconnect_to_server } cts.conn = conn @@ -313,21 +350,21 @@ fmt.Printf (">>>[%s]\n", cts.saddr.String()) // let's add routes to the client-side peers. err = cts.AddClientRoutes(cts.cfg.peer_addrs) if err != nil { - fmt.Printf ("ERROR: unable to add routes to client-side peers - %s", err.Error()) + fmt.Printf ("ERROR: unable to add routes to client-side peers - %s\n", err.Error()) goto done } +fmt.Printf("[%v]\n", cts.route_map) -main_loop: for { var pkt *Packet select { case <-cts.cli.ctx.Done(): fmt.Printf("context doine... error - %s\n", cts.cli.ctx.Err().Error()) - break main_loop + goto done case <-cts.stop_chan: - break main_loop + goto done default: // no other case is ready. @@ -337,11 +374,11 @@ main_loop: pkt, err = psc.Recv() if err == io.EOF { fmt.Printf("server disconnected\n") - break + goto reconnect_to_server } if err != nil { fmt.Printf("server receive error - %s\n", err.Error()) - break + goto reconnect_to_server } switch pkt.Kind { @@ -434,6 +471,28 @@ fmt.Printf ("^^^^^^^^^^^^^^^^^^^^ Server Coon RunTask ending...\n") // TODO: need to reset c.sc, c.sg, c.psc to nil? // for this we need to ensure that everyone is ending } + cts.RemoveClientRoutes() + cts.route_wg.Wait() // wait until all route tasks are finished + return + +reconnect_to_server: + if conn != nil { + conn.Close() + // TODO: need to reset c.sc, c.sg, c.psc to nil? + // for this we need to ensure that everyone is ending + } + cts.RemoveClientRoutes() + slpctx, _ = context.WithTimeout(cts.cli.ctx, 3 * time.Second) + select { + case <-cts.cli.ctx.Done(): + fmt.Printf("context doine... error - %s\n", cts.cli.ctx.Err().Error()) + goto done + case <-cts.stop_chan: + goto done + case <- slpctx.Done(): + // do nothing + } + goto start_over } func (cts *ServerConn) ReportEvent (route_id uint32, pts_id uint32, event_type PACKET_KIND, event_data []byte) error { @@ -483,6 +542,7 @@ func (r *ClientRoute) AddNewClientPeerConn (c net.Conn) (*ClientPeerConn, error) return ptc, nil } + // -------------------------------------------------------------------- @@ -504,18 +564,10 @@ func NewClient(ctx context.Context, listen_on string, tlscfg *tls.Config) *Clien } func (c *Client) AddNewServerConn(addr *net.TCPAddr, cfg *ClientConfig) (*ServerConn, error) { - var cts ServerConn + var cts *ServerConn var ok bool - cts.cli = c - cts.route_map = make(ClientRouteMap) - cts.saddr = addr - cts.cfg = cfg - //cts.conn = conn - //cts.hdc = hdc - //cts.psc = psc - cts.stop_req.Store(false) - cts.stop_chan = make(chan bool, 1) + cts = NewServerConn(c, addr, cfg) c.cts_mtx.Lock() defer c.cts_mtx.Unlock() @@ -525,9 +577,9 @@ func (c *Client) AddNewServerConn(addr *net.TCPAddr, cfg *ClientConfig) (*Server return nil, fmt.Errorf("existing server - %s", addr.String()) } - c.cts_map[addr] = &cts; + c.cts_map[addr] = cts; fmt.Printf ("ADD total servers %d\n", len(c.cts_map)); - return &cts, nil + return cts, nil } func (c *Client) RemoveServerConn(cts *ServerConn) { diff --git a/main.go b/main.go index 43c993a..0dffc2a 100644 --- a/main.go +++ b/main.go @@ -2,26 +2,20 @@ package main import "flag" import "fmt" +import "io" import "os" import "strings" -type VoidWriter struct { -} - -func (w *VoidWriter) Write(p []byte) (int, error) { - return len(p), nil -} - func main() { var err error var flgs *flag.FlagSet - if len(os.Args) < 2 { goto wrong_usage } if strings.EqualFold(os.Args[1], "server") { var la []string + la = make([]string, 0) flgs = flag.NewFlagSet("", flag.ContinueOnError) @@ -29,10 +23,10 @@ func main() { la = append(la, v) return nil }) - flgs.SetOutput(&VoidWriter{}) // prevent usage output + flgs.SetOutput(io.Discard) // prevent usage output err = flgs.Parse(os.Args[2:]) if err != nil { - fmt.Printf ("ERROR: unable to parse command arguments - %s\n", err.Error()) + fmt.Printf ("ERROR: %s\n", err.Error()) goto wrong_usage } @@ -61,10 +55,10 @@ func main() { sa = append(sa, v) return nil }) - flgs.SetOutput(&VoidWriter{}) // prevent usage output + flgs.SetOutput(io.Discard) err = flgs.Parse(os.Args[2:]) if err != nil { - fmt.Printf ("ERROR: unable to parse command arguments - %s\n", err.Error()) + fmt.Printf ("ERROR: %s\n", err.Error()) goto wrong_usage }