extended the tls configuration to rpc server

This commit is contained in:
hyung-hwan 2024-12-07 16:57:00 +09:00
parent 6ad7ffd1a6
commit dcf3d852d2
4 changed files with 136 additions and 34 deletions

View File

@ -13,6 +13,7 @@ import "time"
import "google.golang.org/grpc" import "google.golang.org/grpc"
import "google.golang.org/grpc/codes" import "google.golang.org/grpc/codes"
import "google.golang.org/grpc/credentials"
import "google.golang.org/grpc/credentials/insecure" import "google.golang.org/grpc/credentials/insecure"
import "google.golang.org/grpc/peer" import "google.golang.org/grpc/peer"
import "google.golang.org/grpc/status" import "google.golang.org/grpc/status"
@ -39,7 +40,8 @@ type ClientConfigActive struct {
type Client struct { type Client struct {
ctx context.Context ctx context.Context
ctx_cancel context.CancelFunc ctx_cancel context.CancelFunc
tlscfg *tls.Config ctltlscfg *tls.Config
rpctlscfg *tls.Config
ext_mtx sync.Mutex ext_mtx sync.Mutex
@ -709,7 +711,15 @@ func (cts *ClientConn) RunTask(wg *sync.WaitGroup) {
start_over: start_over:
cts.cli.log.Write(cts.sid, LOG_INFO, "Connecting to server %s", cts.cfg.ServerAddr) cts.cli.log.Write(cts.sid, LOG_INFO, "Connecting to server %s", cts.cfg.ServerAddr)
cts.conn, err = grpc.NewClient(cts.cfg.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if cts.cli.rpctlscfg == nil {
cts.conn, err = grpc.NewClient(
cts.cfg.ServerAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
cts.conn, err = grpc.NewClient(
cts.cfg.ServerAddr,
grpc.WithTransportCredentials(credentials.NewTLS(cts.cli.rpctlscfg)))
}
if err != nil { if err != nil {
cts.cli.log.Write(cts.sid, LOG_ERROR, "Failed to make client to server %s - %s", cts.cfg.ServerAddr, err.Error()) cts.cli.log.Write(cts.sid, LOG_ERROR, "Failed to make client to server %s - %s", cts.cfg.ServerAddr, err.Error())
goto reconnect_to_server goto reconnect_to_server
@ -958,12 +968,13 @@ func (cts *ClientConn) ReportEvent (route_id uint32, pts_id uint32, event_type P
// -------------------------------------------------------------------- // --------------------------------------------------------------------
func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, tlscfg *tls.Config) *Client { func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, ctltlscfg *tls.Config, rpctlscfg *tls.Config) *Client {
var c Client var c Client
var i int var i int
c.ctx, c.ctx_cancel = context.WithCancel(ctx) c.ctx, c.ctx_cancel = context.WithCancel(ctx)
c.tlscfg = tlscfg c.ctltlscfg = ctltlscfg
c.rpctlscfg = rpctlscfg
c.ext_svcs = make([]Service, 0, 1) c.ext_svcs = make([]Service, 0, 1)
c.cts_map_by_addr = make(ClientConnMapByAddr) c.cts_map_by_addr = make(ClientConnMapByAddr)
c.cts_map = make(ClientConnMap) c.cts_map = make(ClientConnMap)
@ -987,7 +998,7 @@ func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, tlscfg *t
c.ctl[i] = &http.Server{ c.ctl[i] = &http.Server{
Addr: ctl_addrs[i], Addr: ctl_addrs[i],
Handler: c.ctl_mux, Handler: c.ctl_mux,
TLSConfig: tlscfg, TLSConfig: c.ctltlscfg,
// TODO: more settings // TODO: more settings
} }
} }
@ -1184,7 +1195,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
l_wg.Add(1) l_wg.Add(1)
go func(i int, cs *http.Server) { go func(i int, cs *http.Server) {
c.log.Write ("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i]) c.log.Write ("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
if c.tlscfg == nil { if c.ctltlscfg == nil {
err = cs.ListenAndServe() err = cs.ListenAndServe()
} else { } else {
err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key

View File

@ -33,15 +33,26 @@ type ClientTLSConfig struct {
ServerCACertFile string `yaml:"server-ca-cert-file"` ServerCACertFile string `yaml:"server-ca-cert-file"`
ServerCACertText string `yaml:"server-ca-cert-text"` ServerCACertText string `yaml:"server-ca-cert-text"`
InsecureSkipVerify bool `yaml:"skip-verify"` InsecureSkipVerify bool `yaml:"skip-verify"`
ServerName string `yaml:"server-name"`
} }
type ServerConfig struct { type ServerConfig struct {
TLS ServerTLSConfig `yaml:"tls"` CTL struct {
TLS ServerTLSConfig `yaml:"tls"`
} `yaml:"ctl"`
RPC struct {
TLS ServerTLSConfig `yaml:"tls"`
} `yaml:"rpc"`
} }
type ClientConfig struct { type ClientConfig struct {
TLS ServerTLSConfig `yaml:"tls"` CTL struct {
TLS ServerTLSConfig `yaml:"tls"`
} `yaml:"ctl"`
RPC struct {
TLS ClientTLSConfig `yaml:"tls"`
} `yaml:"rpc"`
} }

View File

@ -189,7 +189,7 @@ func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) {
var text []byte var text []byte
text, err = ioutil.ReadFile(cfg.ClientCACertFile) text, err = ioutil.ReadFile(cfg.ClientCACertFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load client ca certficate file %s - %s", cfg.ClientCACertFile, err.Error()) return nil, fmt.Errorf("failed to load ca certficate file %s - %s", cfg.ClientCACertFile, err.Error())
} }
ok = cert_pool.AppendCertsFromPEM(text) ok = cert_pool.AppendCertsFromPEM(text)
if !ok { if !ok {
@ -198,15 +198,10 @@ func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) {
} }
} }
/*
// Don't use `Certificates` it doesn't work with some certificate files.
// See, `getClientCertificate` in ${GOSRC}/src/crypto/tls/handshake_client.go for details
tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return cert, nil
}
*/
tlscfg = &tls.Config{ tlscfg = &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
// If multiple certificates are configured, we may have to implement GetCertificate
// GetCertificate: func (chi *tls.ClientHelloInfo) (*Certificate, error) { return cert, nil }
ClientAuth: tls_string_to_client_auth_type(cfg.ClientAuthType), ClientAuth: tls_string_to_client_auth_type(cfg.ClientAuthType),
ClientCAs: cert_pool, // trusted CA certs for client certificate verification ClientCAs: cert_pool, // trusted CA certs for client certificate verification
} }
@ -215,13 +210,78 @@ func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) {
return tlscfg, nil return tlscfg, nil
} }
// --------------------------------------------------------------------
func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) {
var tlscfg *tls.Config
if cfg.Enabled {
var cert tls.Certificate
var cert_pool *x509.CertPool
var err error
if cfg.CertText != "" && cfg.KeyText != "" {
cert, err = tls.X509KeyPair([]byte(cfg.CertText), []byte(cfg.KeyText))
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
cert, err = tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
} else {
// use the embedded certificate
cert, err = tls.X509KeyPair(hodu_tls_cert_text, hodul_tls_key_text)
}
if err != nil {
return nil, fmt.Errorf("failed to load key pair - %s", err)
}
if cfg.ServerCACertText != "" || cfg.ServerCACertFile != ""{
var ok bool
cert_pool = x509.NewCertPool()
if cfg.ServerCACertText != "" {
ok = cert_pool.AppendCertsFromPEM([]byte(cfg.ServerCACertText))
if !ok {
return nil, fmt.Errorf("failed to append certificate to pool")
}
} else if cfg.ServerCACertFile != "" {
var text []byte
text, err = ioutil.ReadFile(cfg.ServerCACertFile)
if err != nil {
return nil, fmt.Errorf("failed to load ca certficate file %s - %s", cfg.ServerCACertFile, err.Error())
}
ok = cert_pool.AppendCertsFromPEM(text)
if !ok {
return nil, fmt.Errorf("failed to append certificate to pool")
}
}
}
tlscfg = &tls.Config{
//Certificates: []tls.Certificate{cert},
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil },
RootCAs: cert_pool,
InsecureSkipVerify: cfg.InsecureSkipVerify,
ServerName: cfg.ServerName,
}
}
return tlscfg, nil
}
// --------------------------------------------------------------------
func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error { func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error {
var s *hodu.Server var s *hodu.Server
var tlscfg *tls.Config var ctltlscfg *tls.Config
var rpctlscfg *tls.Config
var err error var err error
if cfg != nil { if cfg != nil {
tlscfg, err = make_tls_server_config(&cfg.TLS) ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS)
if err != nil {
return err
}
rpctlscfg, err = make_tls_server_config(&cfg.RPC.TLS)
if err != nil { if err != nil {
return err return err
} }
@ -232,7 +292,8 @@ func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error
ctl_addrs, ctl_addrs,
svcaddrs, svcaddrs,
&AppLogger{id: "server", out: os.Stderr}, &AppLogger{id: "server", out: os.Stderr},
tlscfg) ctltlscfg,
rpctlscfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to create new server - %s", err.Error()) return fmt.Errorf("failed to create new server - %s", err.Error())
} }
@ -249,12 +310,17 @@ func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error
func client_main(ctl_addrs []string, server_addr string, peer_addrs []string, cfg *ClientConfig) error { func client_main(ctl_addrs []string, server_addr string, peer_addrs []string, cfg *ClientConfig) error {
var c *hodu.Client var c *hodu.Client
var tlscfg *tls.Config var ctltlscfg *tls.Config
var rpctlscfg *tls.Config
var cc hodu.ClientConfig var cc hodu.ClientConfig
var err error var err error
if cfg != nil { if cfg != nil {
tlscfg, err = make_tls_server_config(&cfg.TLS) ctltlscfg, err = make_tls_server_config(&cfg.CTL.TLS)
if err != nil {
return err
}
rpctlscfg, err = make_tls_client_config(&cfg.RPC.TLS)
if err != nil { if err != nil {
return err return err
} }
@ -264,7 +330,8 @@ func client_main(ctl_addrs []string, server_addr string, peer_addrs []string, cf
context.Background(), context.Background(),
ctl_addrs, ctl_addrs,
&AppLogger{id: "client", out: os.Stderr}, &AppLogger{id: "client", out: os.Stderr},
tlscfg) ctltlscfg,
rpctlscfg)
cc.ServerAddr = server_addr cc.ServerAddr = server_addr
cc.PeerAddrs = peer_addrs cc.PeerAddrs = peer_addrs

View File

@ -14,6 +14,7 @@ import "sync"
import "sync/atomic" import "sync/atomic"
import "google.golang.org/grpc" import "google.golang.org/grpc"
import "google.golang.org/grpc/credentials"
//import "google.golang.org/grpc/metadata" //import "google.golang.org/grpc/metadata"
import "google.golang.org/grpc/peer" import "google.golang.org/grpc/peer"
import "google.golang.org/grpc/stats" import "google.golang.org/grpc/stats"
@ -28,7 +29,8 @@ type ServerRouteMap = map[uint32]*ServerRoute
type Server struct { type Server struct {
ctx context.Context ctx context.Context
ctx_cancel context.CancelFunc ctx_cancel context.CancelFunc
tlscfg *tls.Config ctltlscfg *tls.Config
rpctlscfg *tls.Config
wg sync.WaitGroup wg sync.WaitGroup
stop_req atomic.Bool stop_req atomic.Bool
@ -792,7 +794,7 @@ func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ
return v, err return v, err
} }
func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logger Logger, tlscfg *tls.Config) (*Server, error) { func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logger Logger, ctltlscfg *tls.Config, rpctlscfg *tls.Config) (*Server, error) {
var s Server var s Server
var l *net.TCPListener var l *net.TCPListener
var rpcaddr *net.TCPAddr var rpcaddr *net.TCPAddr
@ -824,12 +826,14 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg
s.rpc = append(s.rpc, l) s.rpc = append(s.rpc, l)
} }
s.tlscfg = tlscfg s.ctltlscfg = ctltlscfg
s.rpctlscfg = rpctlscfg
s.ext_svcs = make([]Service, 0, 1) s.ext_svcs = make([]Service, 0, 1)
s.cts_map = make(ServerConnMap) s.cts_map = make(ServerConnMap)
s.cts_map_by_addr = make(ServerConnMapByAddr) s.cts_map_by_addr = make(ServerConnMapByAddr)
s.stop_chan = make(chan bool, 8) s.stop_chan = make(chan bool, 8)
s.stop_req.Store(false) s.stop_req.Store(false)
/* /*
creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem")) creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem"))
if err != nil { if err != nil {
@ -837,11 +841,20 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg
} }
gs = grpc.NewServer(grpc.Creds(creds)) gs = grpc.NewServer(grpc.Creds(creds))
*/ */
s.rpc_svr = grpc.NewServer( if s.rpctlscfg == nil {
//grpc.UnaryInterceptor(unaryInterceptor), s.rpc_svr = grpc.NewServer(
//grpc.StreamInterceptor(streamInterceptor), //grpc.UnaryInterceptor(unaryInterceptor),
grpc.StatsHandler(&ConnCatcher{server: &s}), //grpc.StreamInterceptor(streamInterceptor),
) grpc.StatsHandler(&ConnCatcher{server: &s}),
)
} else {
s.rpc_svr = grpc.NewServer(
grpc.Creds(credentials.NewTLS(s.rpctlscfg)),
//grpc.UnaryInterceptor(unaryInterceptor),
//grpc.StreamInterceptor(streamInterceptor),
grpc.StatsHandler(&ConnCatcher{server: &s}),
)
}
RegisterHoduServer(s.rpc_svr, &s) RegisterHoduServer(s.rpc_svr, &s)
s.ctl_prefix = "" // TODO: s.ctl_prefix = "" // TODO:
@ -862,7 +875,7 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg
s.ctl[i] = &http.Server{ s.ctl[i] = &http.Server{
Addr: ctl_addrs[i], Addr: ctl_addrs[i],
Handler: s.ctl_mux, Handler: s.ctl_mux,
TLSConfig: s.tlscfg, TLSConfig: s.ctltlscfg,
// TODO: more settings // TODO: more settings
} }
} }
@ -949,10 +962,10 @@ func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
l_wg.Add(1) l_wg.Add(1)
go func(i int, cs *http.Server) { go func(i int, cs *http.Server) {
s.log.Write ("", LOG_INFO, "Control channel[%d] started on %s", i, s.ctl_addr[i]) s.log.Write ("", LOG_INFO, "Control channel[%d] started on %s", i, s.ctl_addr[i])
if s.tlscfg == nil { if s.ctltlscfg == nil {
err = cs.ListenAndServe() err = cs.ListenAndServe()
} else { } else {
err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key err = cs.ListenAndServeTLS("", "") // c.ctltlscfg must provide a certificate and a key
} }
if errors.Is(err, http.ErrServerClosed) { if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_DEBUG, "Control channel[%d] ended", i) s.log.Write("", LOG_DEBUG, "Control channel[%d] ended", i)