added signal handler to the client

This commit is contained in:
hyung-hwan 2024-11-13 02:20:25 +09:00
parent f02536bf24
commit 9c927464b0
3 changed files with 69 additions and 31 deletions

View File

@ -10,10 +10,11 @@ import "fmt"
import "io" import "io"
import "log" import "log"
import "net" import "net"
//import "os" import "os"
import "os/signal"
import "sync" import "sync"
import "sync/atomic" import "sync/atomic"
//import "syscall" import "syscall"
//import "time" //import "time"
//import "github.com/google/uuid" //import "github.com/google/uuid"
@ -337,6 +338,8 @@ func (c *Client) RunTask(ctx context.Context) {
var cts *ServerConn var cts *ServerConn
var err error var err error
defer c.wg.Done();
// TODO: HANDLE connection timeout.. // TODO: HANDLE connection timeout..
// ctx, _/*cancel*/ := context.WithTimeout(context.Background(), time.Second) // ctx, _/*cancel*/ := context.WithTimeout(context.Background(), time.Second)
conn, err = grpc.NewClient(c.saddr.String(), grpc.WithTransportCredentials(insecure.NewCredentials())) conn, err = grpc.NewClient(c.saddr.String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
@ -476,8 +479,9 @@ func (c *Client) RunTask(ctx context.Context) {
//done: //done:
c.ReqStop() // just in case... c.ReqStop() // just in case...
c.wg.Wait()
c.sc.Close() c.sc.Close()
syscall.Kill(syscall.Getpid(), syscall.SIGTERM) // TODO: find a better to terminate the signal handler...
} }
func (c *Client) ReqStop() { func (c *Client) ReqStop() {
@ -487,6 +491,41 @@ func (c *Client) ReqStop() {
} }
} }
// --------------------------------------------------------------------
func (c *Client) handle_os_signals() {
var sighup_chan chan os.Signal
var sigterm_chan chan os.Signal
var sig os.Signal
defer c.wg.Done()
sighup_chan = make(chan os.Signal, 1)
sigterm_chan = make(chan os.Signal, 1)
signal.Notify(sighup_chan, syscall.SIGHUP)
signal.Notify(sigterm_chan, syscall.SIGTERM, os.Interrupt)
chan_loop:
for {
select {
case <-sighup_chan:
// TODO:
//s.RefreshConfig()
case sig = <-sigterm_chan:
// TODO: get timeout value from config
//c.Shutdown(fmt.Sprintf("termination by signal %s", sig), 3*time.Second)
c.ReqStop()
//log.Debugf("termination by signal %s", sig)
fmt.Printf("termination by signal %s\n", sig)
break chan_loop
}
}
fmt.Printf ("end of signal handler...\n");
}
// -------------------------------------------------------------------- // --------------------------------------------------------------------
const rootCert = `-----BEGIN CERTIFICATE----- const rootCert = `-----BEGIN CERTIFICATE-----
@ -510,7 +549,6 @@ func client_main(server_addr string, peer_addrs []string) error {
var cert_pool *x509.CertPool var cert_pool *x509.CertPool
var tlscfg *tls.Config var tlscfg *tls.Config
var cc ClientConfig var cc ClientConfig
var wg sync.WaitGroup
cert_pool = x509.NewCertPool() cert_pool = x509.NewCertPool()
ok := cert_pool.AppendCertsFromPEM([]byte(rootCert)) ok := cert_pool.AppendCertsFromPEM([]byte(rootCert))
@ -527,9 +565,13 @@ func client_main(server_addr string, peer_addrs []string) error {
return err return err
} }
wg.Add(1) fmt.Printf ("XXXXXXXXXXXXXXXXXXXXXXXXXXXX\n");
c.wg.Add(1)
go c.handle_os_signals()
c.wg.Add(1)
go c.RunTask(context.Background()); go c.RunTask(context.Background());
c.wg.Wait();
fmt.Printf ("YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY\n");
wg.Wait();
return nil return nil
} }

2
go.mod
View File

@ -1,6 +1,6 @@
module hodu module hodu
go 1.22.0 go 1.21.0
require ( require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0

View File

@ -16,7 +16,7 @@ import "syscall"
import "time" import "time"
import "google.golang.org/grpc" import "google.golang.org/grpc"
import "google.golang.org/grpc/metadata" //import "google.golang.org/grpc/metadata"
import "google.golang.org/grpc/peer" import "google.golang.org/grpc/peer"
import "google.golang.org/grpc/stats" import "google.golang.org/grpc/stats"
@ -37,7 +37,6 @@ type Server struct {
wg sync.WaitGroup wg sync.WaitGroup
stop_req atomic.Bool stop_req atomic.Bool
// grpc stuffs
gs *grpc.Server gs *grpc.Server
UnimplementedHoduServer UnimplementedHoduServer
} }
@ -296,12 +295,12 @@ func (cts *ClientConn) ReqStop() {
// ------------------------------------ // ------------------------------------
func handle_os_signals(s *Server, exit_chan chan<- bool) { func (s *Server) handle_os_signals() {
var ( var sighup_chan chan os.Signal
sighup_chan chan os.Signal var sigterm_chan chan os.Signal
sigterm_chan chan os.Signal var sig os.Signal
sig os.Signal
) defer s.wg.Done()
sighup_chan = make(chan os.Signal, 1) sighup_chan = make(chan os.Signal, 1)
sigterm_chan = make(chan os.Signal, 1) sigterm_chan = make(chan os.Signal, 1)
@ -321,7 +320,6 @@ chan_loop:
s.ReqStop() s.ReqStop()
//log.Debugf("termination by signal %s", sig) //log.Debugf("termination by signal %s", sig)
fmt.Printf("termination by signal %s\n", sig) fmt.Printf("termination by signal %s\n", sig)
exit_chan <- true
break chan_loop break chan_loop
} }
} }
@ -491,11 +489,11 @@ func (cc *ConnCatcher) HandleConn(ctx context.Context, cs stats.ConnStats) {
} else { } else {
addr = p.Addr.String() addr = p.Addr.String()
} }
/*
md,ok:=metadata.FromIncomingContext(ctx) md,ok:=metadata.FromIncomingContext(ctx)
fmt.Printf("%+v%+v\n",md,ok) fmt.Printf("%+v%+v\n",md,ok)
if ok { if ok {
} }*/
switch cs.(type) { switch cs.(type) {
case *stats.ConnBegin: case *stats.ConnBegin:
fmt.Printf("**** client connected - [%s]\n", addr) fmt.Printf("**** client connected - [%s]\n", addr)
@ -640,9 +638,11 @@ func (s *Server) run_grpc_server(idx int) error {
return nil return nil
} }
func (s *Server) MainLoop() error { func (s *Server) MainLoop() {
var idx int var idx int
defer s.wg.Done()
for idx, _ = range s.l { for idx, _ = range s.l {
s.l_wg.Add(1) s.l_wg.Add(1)
go s.run_grpc_server(idx) go s.run_grpc_server(idx)
@ -650,9 +650,7 @@ func (s *Server) MainLoop() error {
s.l_wg.Wait(); s.l_wg.Wait();
s.ReqStop() s.ReqStop()
s.wg.Wait() syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
return nil
} }
func (s *Server) ReqStop() { func (s *Server) ReqStop() {
@ -765,7 +763,7 @@ BAMCA0gAMEUCIEKzVMF3JqjQjuM2rX7Rx8hancI5KJhwfeKu1xbyR7XaAiEA2UT7
func server_main(laddrs []string) error { func server_main(laddrs []string) error {
var s *Server var s *Server
var err error var err error
var exit_chan chan bool
var cert tls.Certificate var cert tls.Certificate
cert, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) cert, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey))
@ -778,13 +776,11 @@ func server_main(laddrs []string) error {
return fmt.Errorf("ERROR: failed to create new server - %s", err.Error()) return fmt.Errorf("ERROR: failed to create new server - %s", err.Error())
} }
exit_chan = make(chan bool, 1) s.wg.Add(1)
go handle_os_signals(s, exit_chan) go s.handle_os_signals()
err = s.MainLoop() // this is blocking. ReqStop() will be called from a signal handler s.wg.Add(1)
if err != nil { go s.MainLoop() // this is blocking. ReqStop() will be called from a signal handler
return err s.wg.Wait()
}
<-exit_chan // wait until the term signal handler almost reaches the end
return nil return nil
} }