some code to handle configuration file and tls

This commit is contained in:
hyung-hwan 2024-12-07 00:54:19 +09:00
parent e144a39c5c
commit e51077a749
4 changed files with 131 additions and 80 deletions

View File

@ -10,13 +10,16 @@ SRCS=\
server.go \ server.go \
server-ctl.go \ server-ctl.go \
server-peer.go \ server-peer.go \
server-ws.go \ server-ws.go
CMD_SRCS=\
cmd/config.go \
cmd/main.go cmd/main.go
all: hodu all: hodu
hodu: $(SRCS) hodu: $(SRCS) $(CMD_SRCS)
CGO_ENABLED=0 go build -x -o $@ cmd/main.go CGO_ENABLED=0 go build -x -o $@ $(CMD_SRCS)
clean: clean:
go clean -x -i go clean -x -i

View File

@ -7,6 +7,7 @@ import "flag"
import "fmt" import "fmt"
import "hodu" import "hodu"
import "io" import "io"
import "io/ioutil"
import "os" import "os"
import "os/signal" import "os/signal"
import "path/filepath" import "path/filepath"
@ -156,37 +157,94 @@ func (sh *signal_handler) WriteLog(id string, level hodu.LogLevel, fmt string, a
sh.svc.WriteLog(id, level, fmt, args...) sh.svc.WriteLog(id, level, fmt, args...)
} }
func server_main(ctl_addrs []string, svcaddrs []string) error { func tls_string_to_client_auth_type(str string) tls.ClientAuthType {
var s *hodu.Server switch str {
var err error case tls.NoClientCert.String():
var cert tls.Certificate return tls.NoClientCert
var cert_pool *x509.CertPool case tls.RequestClientCert.String():
return tls.RequestClientCert
case tls.RequireAnyClientCert.String():
return tls.RequireAnyClientCert
case tls.VerifyClientCertIfGiven.String():
return tls.VerifyClientCertIfGiven
case tls.RequireAndVerifyClientCert.String():
return tls.RequireAndVerifyClientCert
default:
return tls.NoClientCert
}
}
// --------------------------------------------------------------------
func make_server_tls_config(cfg *ServerTLSConfig) (*tls.Config, error) {
var tlscfg *tls.Config var tlscfg *tls.Config
if cfg.Enabled {
var cert tls.Certificate
var cert_pool *x509.CertPool
var err error
if cfg.CertText != "" && cfg.KeyText != "" {
cert, err = tls.X509KeyPair([]byte(cfg.CertText), []byte(cfg.KeyText))
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
cert, err = tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
} else {
// use the embedded certificate
cert, err = tls.X509KeyPair([]byte(rootCert), []byte(rootKey)) cert, err = tls.X509KeyPair([]byte(rootCert), []byte(rootKey))
if err != nil {
return fmt.Errorf("ERROR: failed to load key pair - %s\n", err)
} }
if err != nil {
return nil, fmt.Errorf("failed to load key pair - %s", err)
}
if cfg.ClientCACertText != "" || cfg.ClientCACertFile != ""{
var ok bool
cert_pool = x509.NewCertPool() cert_pool = x509.NewCertPool()
ok := cert_pool.AppendCertsFromPEM([]byte(rootCert))
if cfg.ClientCACertText != "" {
ok = cert_pool.AppendCertsFromPEM([]byte(cfg.ClientCACertText))
if !ok { if !ok {
return fmt.Errorf("ERROR: failed to append root certificate\n") return nil, fmt.Errorf("failed to append certificate to pool")
}
} else if cfg.ClientCACertFile != "" {
var text []byte
text, err = ioutil.ReadFile(cfg.ClientCACertFile)
if err != nil {
return nil, fmt.Errorf("failed to load client ca certficate file %s - %s", cfg.ClientCACertFile, err.Error())
}
ok = cert_pool.AppendCertsFromPEM(text)
if !ok {
return nil, fmt.Errorf("failed to append certificate to pool")
}
}
} }
/* /*
// Don't use `Certificates` it doesn't work with some certificate files. // Don't use `Certificates` it doesn't work with some certificate files.
// See, `getClientCertificate` in ${GOSRC}/src/crypto/tls/handshake_client.go for details // See, `getClientCertificate` in ${GOSRC}/src/crypto/tls/handshake_client.go for details
tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return clientCert, nil return cert, nil
} }
*/ */
tlscfg = &tls.Config{ tlscfg = &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
ClientAuth: tls.NoClientCert, // tls.RequestClientCert, tls.RequestAnyClientCert, VerifyClientCertIfGiven, RequireAndVerifyClientCert ClientAuth: tls_string_to_client_auth_type(cfg.ClientAuthType),
ClientCAs: cert_pool, // trusted CA certs for client certificate verification ClientCAs: cert_pool, // trusted CA certs for client certificate verification
ServerName: "hodu", //ServerName: "hodu",
}
}
return tlscfg, nil
}
func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error {
var s *hodu.Server
var tlscfg *tls.Config
var err error
tlscfg, err = make_server_tls_config(&cfg.TLS)
if err != nil {
return err
} }
s, err = hodu.NewServer( s, err = hodu.NewServer(
@ -196,7 +254,7 @@ func server_main(ctl_addrs []string, svcaddrs []string) error {
&AppLogger{id: "server", out: os.Stderr}, &AppLogger{id: "server", out: os.Stderr},
tlscfg) tlscfg)
if err != nil { if err != nil {
return fmt.Errorf("ERROR: failed to create new server - %s", err.Error()) return fmt.Errorf("failed to create new server - %s", err.Error())
} }
s.StartService(nil) s.StartService(nil)
@ -209,58 +267,15 @@ func server_main(ctl_addrs []string, svcaddrs []string) error {
// -------------------------------------------------------------------- // --------------------------------------------------------------------
func client_main(ctl_addrs []string, server_addr string, peer_addrs []string) error { func client_main(ctl_addrs []string, server_addr string, peer_addrs []string, cfg *ClientConfig) error {
var c *hodu.Client var c *hodu.Client
var cert tls.Certificate
var cert_pool *x509.CertPool
var tlscfg *tls.Config var tlscfg *tls.Config
var cc hodu.ClientConfig var cc hodu.ClientConfig
var err error var err error
/* tlscfg, err = make_server_tls_config(&cfg.TLS)
cert_pool = x509.NewCertPool()
ok := cert_pool.AppendCertsFromPEM([]byte(rootCert))
if !ok {
fmt.Printf("failed to parse root certificate")
}
tlscfg = &tls.Config{
RootCAs: cert_pool,
ClientAuth:
ServerName: "hodu",
//InsecureSkipVerify: true,
}
tlscfg := &tls.Config{
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
},
}
*/
cert, err = tls.X509KeyPair([]byte(rootCert), []byte(rootKey))
if err != nil { if err != nil {
return fmt.Errorf("ERROR: failed to load key pair - %s\n", err) return err
}
cert_pool = x509.NewCertPool()
ok := cert_pool.AppendCertsFromPEM([]byte(rootCert))
if !ok {
return fmt.Errorf("ERROR: failed to append root certificate\n")
}
tlscfg = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.NoClientCert, // tls.RequestClientCert, tls.RequestAnyClientCert, VerifyClientCertIfGiven, RequireAndVerifyClientCert
ClientCAs: cert_pool, // trusted CA certs for client certificate verification
ServerName: "hodu",
} }
c = hodu.NewClient( c = hodu.NewClient(
@ -290,6 +305,8 @@ func main() {
if strings.EqualFold(os.Args[1], "server") { if strings.EqualFold(os.Args[1], "server") {
var rpc_addrs[] string var rpc_addrs[] string
var ctl_addrs[] string var ctl_addrs[] string
var cfgfile string
var cfg *ServerConfig
ctl_addrs = make([]string, 0) ctl_addrs = make([]string, 0)
rpc_addrs = make([]string, 0) rpc_addrs = make([]string, 0)
@ -303,6 +320,10 @@ func main() {
rpc_addrs = append(rpc_addrs, v) rpc_addrs = append(rpc_addrs, v)
return nil return nil
}) })
flgs.Func("config-file", "specify a configuration file path", func(v string) error {
cfgfile = v
return nil
})
flgs.SetOutput(io.Discard) // prevent usage output flgs.SetOutput(io.Discard) // prevent usage output
err = flgs.Parse(os.Args[2:]) err = flgs.Parse(os.Args[2:])
if err != nil { if err != nil {
@ -314,7 +335,15 @@ func main() {
goto wrong_usage goto wrong_usage
} }
err = server_main(ctl_addrs, rpc_addrs) if (cfgfile != "") {
cfg, err = LoadServerConfig(cfgfile)
if err != nil {
fmt.Printf ("ERROR: failed to load configuration file %s - %s\n", cfgfile, err.Error())
goto oops
}
}
err = server_main(ctl_addrs, rpc_addrs, cfg)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: server error - %s\n", err.Error()) fmt.Fprintf(os.Stderr, "ERROR: server error - %s\n", err.Error())
goto oops goto oops
@ -322,6 +351,8 @@ func main() {
} else if strings.EqualFold(os.Args[1], "client") { } else if strings.EqualFold(os.Args[1], "client") {
var rpc_addrs []string var rpc_addrs []string
var ctl_addrs []string var ctl_addrs []string
var cfgfile string
var cfg *ClientConfig
ctl_addrs = make([]string, 0) ctl_addrs = make([]string, 0)
rpc_addrs = make([]string, 0) rpc_addrs = make([]string, 0)
@ -335,6 +366,10 @@ func main() {
rpc_addrs = append(rpc_addrs, v) rpc_addrs = append(rpc_addrs, v)
return nil return nil
}) })
flgs.Func("config-file", "specify a configuration file path", func(v string) error {
cfgfile = v
return nil
})
flgs.SetOutput(io.Discard) flgs.SetOutput(io.Discard)
err = flgs.Parse(os.Args[2:]) err = flgs.Parse(os.Args[2:])
if err != nil { if err != nil {
@ -345,7 +380,16 @@ func main() {
if len(rpc_addrs) < 1 { if len(rpc_addrs) < 1 {
goto wrong_usage goto wrong_usage
} }
err = client_main(ctl_addrs, rpc_addrs[0], flgs.Args())
if (cfgfile != "") {
cfg, err = LoadClientConfig(cfgfile)
if err != nil {
fmt.Printf ("ERROR: failed to load configuration file %s - %s\n", cfgfile, err.Error())
goto oops
}
}
err = client_main(ctl_addrs, rpc_addrs[0], flgs.Args(), cfg)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: client error - %s\n", err.Error()) fmt.Fprintf(os.Stderr, "ERROR: client error - %s\n", err.Error())
goto oops goto oops

1
go.mod
View File

@ -13,4 +13,5 @@ require (
golang.org/x/sys v0.24.0 // indirect golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect golang.org/x/text v0.17.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

3
go.sum
View File

@ -14,3 +14,6 @@ google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=