From dcf3d852d2ba6f810fa926bfaacd7926fe49ebd0 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Sat, 7 Dec 2024 16:57:00 +0900 Subject: [PATCH] extended the tls configuration to rpc server --- client.go | 23 +++++++++---- cmd/config.go | 17 +++++++-- cmd/main.go | 95 +++++++++++++++++++++++++++++++++++++++++++-------- server.go | 35 +++++++++++++------ 4 files changed, 136 insertions(+), 34 deletions(-) diff --git a/client.go b/client.go index eda780d..26ae3d5 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import "time" import "google.golang.org/grpc" import "google.golang.org/grpc/codes" +import "google.golang.org/grpc/credentials" import "google.golang.org/grpc/credentials/insecure" import "google.golang.org/grpc/peer" import "google.golang.org/grpc/status" @@ -39,7 +40,8 @@ type ClientConfigActive struct { type Client struct { ctx context.Context ctx_cancel context.CancelFunc - tlscfg *tls.Config + ctltlscfg *tls.Config + rpctlscfg *tls.Config ext_mtx sync.Mutex @@ -709,7 +711,15 @@ func (cts *ClientConn) RunTask(wg *sync.WaitGroup) { start_over: cts.cli.log.Write(cts.sid, LOG_INFO, "Connecting to server %s", cts.cfg.ServerAddr) - cts.conn, err = grpc.NewClient(cts.cfg.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if cts.cli.rpctlscfg == nil { + cts.conn, err = grpc.NewClient( + cts.cfg.ServerAddr, + grpc.WithTransportCredentials(insecure.NewCredentials())) + } else { + cts.conn, err = grpc.NewClient( + cts.cfg.ServerAddr, + grpc.WithTransportCredentials(credentials.NewTLS(cts.cli.rpctlscfg))) + } if err != nil { cts.cli.log.Write(cts.sid, LOG_ERROR, "Failed to make client to server %s - %s", cts.cfg.ServerAddr, err.Error()) goto reconnect_to_server @@ -958,12 +968,13 @@ func (cts *ClientConn) ReportEvent (route_id uint32, pts_id uint32, event_type P // -------------------------------------------------------------------- -func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, tlscfg *tls.Config) *Client { +func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, ctltlscfg *tls.Config, rpctlscfg *tls.Config) *Client { var c Client var i int c.ctx, c.ctx_cancel = context.WithCancel(ctx) - c.tlscfg = tlscfg + c.ctltlscfg = ctltlscfg + c.rpctlscfg = rpctlscfg c.ext_svcs = make([]Service, 0, 1) c.cts_map_by_addr = make(ClientConnMapByAddr) c.cts_map = make(ClientConnMap) @@ -987,7 +998,7 @@ func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, tlscfg *t c.ctl[i] = &http.Server{ Addr: ctl_addrs[i], Handler: c.ctl_mux, - TLSConfig: tlscfg, + TLSConfig: c.ctltlscfg, // TODO: more settings } } @@ -1184,7 +1195,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) { l_wg.Add(1) go func(i int, cs *http.Server) { c.log.Write ("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i]) - if c.tlscfg == nil { + if c.ctltlscfg == nil { err = cs.ListenAndServe() } else { err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key diff --git a/cmd/config.go b/cmd/config.go index 057b761..bd75165 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -33,15 +33,26 @@ type ClientTLSConfig struct { ServerCACertFile string `yaml:"server-ca-cert-file"` ServerCACertText string `yaml:"server-ca-cert-text"` InsecureSkipVerify bool `yaml:"skip-verify"` + ServerName string `yaml:"server-name"` } type ServerConfig struct { - TLS ServerTLSConfig `yaml:"tls"` - + CTL struct { + TLS ServerTLSConfig `yaml:"tls"` + } `yaml:"ctl"` + + RPC struct { + TLS ServerTLSConfig `yaml:"tls"` + } `yaml:"rpc"` } type ClientConfig struct { - TLS ServerTLSConfig `yaml:"tls"` + CTL struct { + TLS ServerTLSConfig `yaml:"tls"` + } `yaml:"ctl"` + RPC struct { + TLS ClientTLSConfig `yaml:"tls"` + } `yaml:"rpc"` } diff --git a/cmd/main.go b/cmd/main.go index 087a725..bff1aec 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -189,7 +189,7 @@ func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) { var text []byte text, err = ioutil.ReadFile(cfg.ClientCACertFile) if err != nil { - return nil, fmt.Errorf("failed to load client ca certficate file %s - %s", cfg.ClientCACertFile, err.Error()) + return nil, fmt.Errorf("failed to load ca certficate file %s - %s", cfg.ClientCACertFile, err.Error()) } ok = cert_pool.AppendCertsFromPEM(text) if !ok { @@ -198,15 +198,10 @@ func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) { } } - /* - // Don't use `Certificates` it doesn't work with some certificate files. - // See, `getClientCertificate` in ${GOSRC}/src/crypto/tls/handshake_client.go for details - tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return cert, nil - } - */ tlscfg = &tls.Config{ Certificates: []tls.Certificate{cert}, + // If multiple certificates are configured, we may have to implement GetCertificate + // GetCertificate: func (chi *tls.ClientHelloInfo) (*Certificate, error) { return cert, nil } ClientAuth: tls_string_to_client_auth_type(cfg.ClientAuthType), ClientCAs: cert_pool, // trusted CA certs for client certificate verification } @@ -215,13 +210,78 @@ func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) { return tlscfg, nil } +// -------------------------------------------------------------------- + +func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) { + var tlscfg *tls.Config + + if cfg.Enabled { + var cert tls.Certificate + var cert_pool *x509.CertPool + var err error + + + if cfg.CertText != "" && cfg.KeyText != "" { + cert, err = tls.X509KeyPair([]byte(cfg.CertText), []byte(cfg.KeyText)) + } else if cfg.CertFile != "" && cfg.KeyFile != "" { + cert, err = tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + } else { + // use the embedded certificate + cert, err = tls.X509KeyPair(hodu_tls_cert_text, hodul_tls_key_text) + } + if err != nil { + return nil, fmt.Errorf("failed to load key pair - %s", err) + } + + if cfg.ServerCACertText != "" || cfg.ServerCACertFile != ""{ + var ok bool + + cert_pool = x509.NewCertPool() + + if cfg.ServerCACertText != "" { + ok = cert_pool.AppendCertsFromPEM([]byte(cfg.ServerCACertText)) + if !ok { + return nil, fmt.Errorf("failed to append certificate to pool") + } + } else if cfg.ServerCACertFile != "" { + var text []byte + text, err = ioutil.ReadFile(cfg.ServerCACertFile) + if err != nil { + return nil, fmt.Errorf("failed to load ca certficate file %s - %s", cfg.ServerCACertFile, err.Error()) + } + ok = cert_pool.AppendCertsFromPEM(text) + if !ok { + return nil, fmt.Errorf("failed to append certificate to pool") + } + } + } + + tlscfg = &tls.Config{ + //Certificates: []tls.Certificate{cert}, + GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, + RootCAs: cert_pool, + InsecureSkipVerify: cfg.InsecureSkipVerify, + ServerName: cfg.ServerName, + } + } + + return tlscfg, nil +} + +// -------------------------------------------------------------------- + func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error { var s *hodu.Server - var tlscfg *tls.Config + var ctltlscfg *tls.Config + var rpctlscfg *tls.Config var err error if cfg != nil { - tlscfg, err = make_tls_server_config(&cfg.TLS) + ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS) + if err != nil { + return err + } + rpctlscfg, err = make_tls_server_config(&cfg.RPC.TLS) if err != nil { return err } @@ -232,7 +292,8 @@ func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error ctl_addrs, svcaddrs, &AppLogger{id: "server", out: os.Stderr}, - tlscfg) + ctltlscfg, + rpctlscfg) if err != nil { return fmt.Errorf("failed to create new server - %s", err.Error()) } @@ -249,12 +310,17 @@ func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error func client_main(ctl_addrs []string, server_addr string, peer_addrs []string, cfg *ClientConfig) error { var c *hodu.Client - var tlscfg *tls.Config + var ctltlscfg *tls.Config + var rpctlscfg *tls.Config var cc hodu.ClientConfig var err error if cfg != nil { - tlscfg, err = make_tls_server_config(&cfg.TLS) + ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS) + if err != nil { + return err + } + rpctlscfg, err = make_tls_client_config(&cfg.RPC.TLS) if err != nil { return err } @@ -264,7 +330,8 @@ func client_main(ctl_addrs []string, server_addr string, peer_addrs []string, cf context.Background(), ctl_addrs, &AppLogger{id: "client", out: os.Stderr}, - tlscfg) + ctltlscfg, + rpctlscfg) cc.ServerAddr = server_addr cc.PeerAddrs = peer_addrs diff --git a/server.go b/server.go index 52e3fec..1da9a18 100644 --- a/server.go +++ b/server.go @@ -14,6 +14,7 @@ import "sync" import "sync/atomic" import "google.golang.org/grpc" +import "google.golang.org/grpc/credentials" //import "google.golang.org/grpc/metadata" import "google.golang.org/grpc/peer" import "google.golang.org/grpc/stats" @@ -28,7 +29,8 @@ type ServerRouteMap = map[uint32]*ServerRoute type Server struct { ctx context.Context ctx_cancel context.CancelFunc - tlscfg *tls.Config + ctltlscfg *tls.Config + rpctlscfg *tls.Config wg sync.WaitGroup stop_req atomic.Bool @@ -792,7 +794,7 @@ func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ return v, err } -func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logger Logger, tlscfg *tls.Config) (*Server, error) { +func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logger Logger, ctltlscfg *tls.Config, rpctlscfg *tls.Config) (*Server, error) { var s Server var l *net.TCPListener var rpcaddr *net.TCPAddr @@ -824,12 +826,14 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg s.rpc = append(s.rpc, l) } - s.tlscfg = tlscfg + s.ctltlscfg = ctltlscfg + s.rpctlscfg = rpctlscfg s.ext_svcs = make([]Service, 0, 1) s.cts_map = make(ServerConnMap) s.cts_map_by_addr = make(ServerConnMapByAddr) s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) + /* creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem")) if err != nil { @@ -837,11 +841,20 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg } gs = grpc.NewServer(grpc.Creds(creds)) */ - s.rpc_svr = grpc.NewServer( - //grpc.UnaryInterceptor(unaryInterceptor), - //grpc.StreamInterceptor(streamInterceptor), - grpc.StatsHandler(&ConnCatcher{server: &s}), - ) + if s.rpctlscfg == nil { + s.rpc_svr = grpc.NewServer( + //grpc.UnaryInterceptor(unaryInterceptor), + //grpc.StreamInterceptor(streamInterceptor), + grpc.StatsHandler(&ConnCatcher{server: &s}), + ) + } else { + s.rpc_svr = grpc.NewServer( + grpc.Creds(credentials.NewTLS(s.rpctlscfg)), + //grpc.UnaryInterceptor(unaryInterceptor), + //grpc.StreamInterceptor(streamInterceptor), + grpc.StatsHandler(&ConnCatcher{server: &s}), + ) + } RegisterHoduServer(s.rpc_svr, &s) s.ctl_prefix = "" // TODO: @@ -862,7 +875,7 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg s.ctl[i] = &http.Server{ Addr: ctl_addrs[i], Handler: s.ctl_mux, - TLSConfig: s.tlscfg, + TLSConfig: s.ctltlscfg, // TODO: more settings } } @@ -949,10 +962,10 @@ func (s *Server) RunCtlTask(wg *sync.WaitGroup) { l_wg.Add(1) go func(i int, cs *http.Server) { s.log.Write ("", LOG_INFO, "Control channel[%d] started on %s", i, s.ctl_addr[i]) - if s.tlscfg == nil { + if s.ctltlscfg == nil { err = cs.ListenAndServe() } else { - err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key + err = cs.ListenAndServeTLS("", "") // c.ctltlscfg must provide a certificate and a key } if errors.Is(err, http.ErrServerClosed) { s.log.Write("", LOG_DEBUG, "Control channel[%d] ended", i)