diff --git a/client.go b/client.go index 44fada9..cb4dbe6 100644 --- a/client.go +++ b/client.go @@ -10,10 +10,11 @@ import "fmt" import "io" import "log" import "net" -//import "os" +import "os" +import "os/signal" import "sync" import "sync/atomic" -//import "syscall" +import "syscall" //import "time" //import "github.com/google/uuid" @@ -331,12 +332,14 @@ func NewClient(cfg *ClientConfig, tlscfg *tls.Config) (*Client, error) { return &c, nil } - + func (c *Client) RunTask(ctx context.Context) { var conn *grpc.ClientConn var cts *ServerConn var err error + defer c.wg.Done(); + // TODO: HANDLE connection timeout.. // ctx, _/*cancel*/ := context.WithTimeout(context.Background(), time.Second) conn, err = grpc.NewClient(c.saddr.String(), grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -476,8 +479,9 @@ func (c *Client) RunTask(ctx context.Context) { //done: c.ReqStop() // just in case... - c.wg.Wait() c.sc.Close() + + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) // TODO: find a better to terminate the signal handler... } func (c *Client) ReqStop() { @@ -487,6 +491,41 @@ func (c *Client) ReqStop() { } } + +// -------------------------------------------------------------------- + +func (c *Client) handle_os_signals() { + var sighup_chan chan os.Signal + var sigterm_chan chan os.Signal + var sig os.Signal + + defer c.wg.Done() + + sighup_chan = make(chan os.Signal, 1) + sigterm_chan = make(chan os.Signal, 1) + + signal.Notify(sighup_chan, syscall.SIGHUP) + signal.Notify(sigterm_chan, syscall.SIGTERM, os.Interrupt) + +chan_loop: + for { + select { + case <-sighup_chan: + // TODO: + //s.RefreshConfig() + case sig = <-sigterm_chan: + // TODO: get timeout value from config + //c.Shutdown(fmt.Sprintf("termination by signal %s", sig), 3*time.Second) + c.ReqStop() + //log.Debugf("termination by signal %s", sig) + fmt.Printf("termination by signal %s\n", sig) + break chan_loop + } + } + +fmt.Printf ("end of signal handler...\n"); +} + // -------------------------------------------------------------------- const rootCert = `-----BEGIN CERTIFICATE----- @@ -510,7 +549,6 @@ func client_main(server_addr string, peer_addrs []string) error { var cert_pool *x509.CertPool var tlscfg *tls.Config var cc ClientConfig - var wg sync.WaitGroup cert_pool = x509.NewCertPool() ok := cert_pool.AppendCertsFromPEM([]byte(rootCert)) @@ -527,9 +565,13 @@ func client_main(server_addr string, peer_addrs []string) error { return err } - wg.Add(1) +fmt.Printf ("XXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"); + c.wg.Add(1) + go c.handle_os_signals() + c.wg.Add(1) go c.RunTask(context.Background()); + c.wg.Wait(); +fmt.Printf ("YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY\n"); - wg.Wait(); return nil } diff --git a/go.mod b/go.mod index 09caa43..d31ecbf 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module hodu -go 1.22.0 +go 1.21.0 require ( github.com/google/uuid v1.6.0 diff --git a/server.go b/server.go index e315a2b..86e4b96 100644 --- a/server.go +++ b/server.go @@ -16,7 +16,7 @@ import "syscall" import "time" import "google.golang.org/grpc" -import "google.golang.org/grpc/metadata" +//import "google.golang.org/grpc/metadata" import "google.golang.org/grpc/peer" import "google.golang.org/grpc/stats" @@ -37,7 +37,6 @@ type Server struct { wg sync.WaitGroup stop_req atomic.Bool - // grpc stuffs gs *grpc.Server UnimplementedHoduServer } @@ -296,12 +295,12 @@ func (cts *ClientConn) ReqStop() { // ------------------------------------ -func handle_os_signals(s *Server, exit_chan chan<- bool) { - var ( - sighup_chan chan os.Signal - sigterm_chan chan os.Signal - sig os.Signal - ) +func (s *Server) handle_os_signals() { + var sighup_chan chan os.Signal + var sigterm_chan chan os.Signal + var sig os.Signal + + defer s.wg.Done() sighup_chan = make(chan os.Signal, 1) sigterm_chan = make(chan os.Signal, 1) @@ -321,7 +320,6 @@ chan_loop: s.ReqStop() //log.Debugf("termination by signal %s", sig) fmt.Printf("termination by signal %s\n", sig) - exit_chan <- true break chan_loop } } @@ -491,11 +489,11 @@ func (cc *ConnCatcher) HandleConn(ctx context.Context, cs stats.ConnStats) { } else { addr = p.Addr.String() } - +/* md,ok:=metadata.FromIncomingContext(ctx) fmt.Printf("%+v%+v\n",md,ok) if ok { -} +}*/ switch cs.(type) { case *stats.ConnBegin: fmt.Printf("**** client connected - [%s]\n", addr) @@ -640,9 +638,11 @@ func (s *Server) run_grpc_server(idx int) error { return nil } -func (s *Server) MainLoop() error { +func (s *Server) MainLoop() { var idx int + defer s.wg.Done() + for idx, _ = range s.l { s.l_wg.Add(1) go s.run_grpc_server(idx) @@ -650,9 +650,7 @@ func (s *Server) MainLoop() error { s.l_wg.Wait(); s.ReqStop() - s.wg.Wait() - - return nil + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) } func (s *Server) ReqStop() { @@ -765,7 +763,7 @@ BAMCA0gAMEUCIEKzVMF3JqjQjuM2rX7Rx8hancI5KJhwfeKu1xbyR7XaAiEA2UT7 func server_main(laddrs []string) error { var s *Server var err error - var exit_chan chan bool + var cert tls.Certificate cert, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) @@ -778,13 +776,11 @@ func server_main(laddrs []string) error { return fmt.Errorf("ERROR: failed to create new server - %s", err.Error()) } - exit_chan = make(chan bool, 1) - go handle_os_signals(s, exit_chan) - err = s.MainLoop() // this is blocking. ReqStop() will be called from a signal handler - if err != nil { - return err - } + s.wg.Add(1) + go s.handle_os_signals() + s.wg.Add(1) + go s.MainLoop() // this is blocking. ReqStop() will be called from a signal handler + s.wg.Wait() - <-exit_chan // wait until the term signal handler almost reaches the end return nil }