diff --git a/client.go b/client.go index 5023c32..9613f85 100644 --- a/client.go +++ b/client.go @@ -431,42 +431,48 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p) } - r.lifetime_mtx.Lock() - if r.Lifetime > 0 { - r.LifetimeStart = time.Now() - r.lifetime_timer = time.NewTimer(r.Lifetime) - } - r.lifetime_mtx.Unlock() + r.ptc_wg.Add(1) // increment counter here -main_loop: - for { - if r.lifetime_timer != nil { - select { - case <-r.stop_chan: - break main_loop + go func() { // and run the waiting loop in a goroutine to give a positive counter to r.ptc_wg.Wait() + r.lifetime_mtx.Lock() + if r.Lifetime > 0 { + r.LifetimeStart = time.Now() + r.lifetime_timer = time.NewTimer(r.Lifetime) + } + r.lifetime_mtx.Unlock() - case <-r.lifetime_timer.C: - r.cts.C.log.Write(r.cts.Sid, LOG_INFO, "route(%d,%s) reached end of lifetime(%v)", - r.Id, r.PeerAddr, r.Lifetime) - break main_loop - } - } else { - select { - case <-r.stop_chan: - break main_loop + waiting_loop: + for { + if r.lifetime_timer != nil { + select { + case <-r.stop_chan: + break waiting_loop + + case <-r.lifetime_timer.C: + r.cts.C.log.Write(r.cts.Sid, LOG_INFO, "route(%d,%s) reached end of lifetime(%v)", + r.Id, r.PeerAddr, r.Lifetime) + break waiting_loop + } + } else { + select { + case <-r.stop_chan: + break waiting_loop + } } } - } - r.lifetime_mtx.Lock() - if r.lifetime_timer != nil { - r.lifetime_timer.Stop() - r.lifetime_timer = nil - } - r.lifetime_mtx.Unlock() + r.lifetime_mtx.Lock() + if r.lifetime_timer != nil { + r.lifetime_timer.Stop() + r.lifetime_timer = nil + } + r.lifetime_mtx.Unlock() + + r.ReqStop() // just in case + r.ptc_wg.Done() + }() done: - r.ReqStop() r.ptc_wg.Wait() // wait for all peer tasks are finished err = r.cts.psc.Send(MakeRouteStopPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)) diff --git a/server.go b/server.go index 70520eb..74a708e 100644 --- a/server.go +++ b/server.go @@ -907,29 +907,33 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { // Increment the wait count here before the loop begins cts.route_wg.Add(1) - for { - // exit if context is done - // or continue - select { - case <-ctx.Done(): // the stream context is done - cts.S.log.Write(cts.Sid, LOG_INFO, "RPC stream done - %s", ctx.Err().Error()) - goto done + // start the loop inside a goroutine so that route_wg counter + // is likely to be greater than 1 what Wait() is called. + go func() { + waiting_loop: + for { + // exit if context is done or continue + select { + case <-ctx.Done(): // the stream context is done + cts.S.log.Write(cts.Sid, LOG_INFO, "RPC stream done - %s", ctx.Err().Error()) + break waiting_loop - case <- cts.stop_chan: - // get out of the loop to eventually to exit from - // this handler to let the main grpc server to - // close this specific client connection. - goto done + case <- cts.stop_chan: + // get out of the loop to eventually to exit from + // this handler to let the main grpc server to + // close this specific client connection. + break waiting_loop - //default: - // no other case is ready. - // without the default case, the select construct would block + //default: + // no other case is ready. + // without the default case, the select construct would block + } } - } -done: - cts.ReqStop() // just in case - cts.route_wg.Done() + cts.ReqStop() // just in case + cts.route_wg.Done() + }() + cts.route_wg.Wait() cts.S.FireConnEvent(SERVER_EVENT_CONN_STOPPED, cts)