added the version subcommand
added subjectAltNames to the embedded cert file
This commit is contained in:
146
cmd/main.go
146
cmd/main.go
@ -2,13 +2,11 @@ package main
|
||||
|
||||
import "context"
|
||||
import "crypto/tls"
|
||||
import "crypto/x509"
|
||||
import _ "embed"
|
||||
import "flag"
|
||||
import "fmt"
|
||||
import "hodu"
|
||||
import "io"
|
||||
import "io/ioutil"
|
||||
import "os"
|
||||
import "os/signal"
|
||||
import "path/filepath"
|
||||
@ -18,6 +16,11 @@ import "sync"
|
||||
import "syscall"
|
||||
import "time"
|
||||
|
||||
|
||||
// Don't change these items to 'const' as they can be overridden with a linker option
|
||||
var HODU_NAME string = "hodu"
|
||||
var HODU_VERSION string = "0.0.0"
|
||||
|
||||
//go:embed tls.crt
|
||||
var hodu_tls_cert_text []byte
|
||||
//go:embed tls.key
|
||||
@ -136,138 +139,6 @@ func (sh *signal_handler) WriteLog(id string, level hodu.LogLevel, fmt string, a
|
||||
sh.svc.WriteLog(id, level, fmt, args...)
|
||||
}
|
||||
|
||||
func tls_string_to_client_auth_type(str string) tls.ClientAuthType {
|
||||
switch str {
|
||||
case tls.NoClientCert.String():
|
||||
return tls.NoClientCert
|
||||
case tls.RequestClientCert.String():
|
||||
return tls.RequestClientCert
|
||||
case tls.RequireAnyClientCert.String():
|
||||
return tls.RequireAnyClientCert
|
||||
case tls.VerifyClientCertIfGiven.String():
|
||||
return tls.VerifyClientCertIfGiven
|
||||
case tls.RequireAndVerifyClientCert.String():
|
||||
return tls.RequireAndVerifyClientCert
|
||||
default:
|
||||
return tls.NoClientCert
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func make_tls_server_config(cfg *ServerTLSConfig) (*tls.Config, error) {
|
||||
var tlscfg *tls.Config
|
||||
|
||||
if cfg.Enabled {
|
||||
var cert tls.Certificate
|
||||
var cert_pool *x509.CertPool
|
||||
var err error
|
||||
|
||||
if cfg.CertText != "" && cfg.KeyText != "" {
|
||||
cert, err = tls.X509KeyPair([]byte(cfg.CertText), []byte(cfg.KeyText))
|
||||
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
|
||||
cert, err = tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
} else {
|
||||
// use the embedded certificate
|
||||
cert, err = tls.X509KeyPair(hodu_tls_cert_text, hodul_tls_key_text)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load key pair - %s", err)
|
||||
}
|
||||
|
||||
if cfg.ClientCACertText != "" || cfg.ClientCACertFile != ""{
|
||||
var ok bool
|
||||
|
||||
cert_pool = x509.NewCertPool()
|
||||
|
||||
if cfg.ClientCACertText != "" {
|
||||
ok = cert_pool.AppendCertsFromPEM([]byte(cfg.ClientCACertText))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to append certificate to pool")
|
||||
}
|
||||
} else if cfg.ClientCACertFile != "" {
|
||||
var text []byte
|
||||
text, err = ioutil.ReadFile(cfg.ClientCACertFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load ca certficate file %s - %s", cfg.ClientCACertFile, err.Error())
|
||||
}
|
||||
ok = cert_pool.AppendCertsFromPEM(text)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to append certificate to pool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tlscfg = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
// If multiple certificates are configured, we may have to implement GetCertificate
|
||||
// GetCertificate: func (chi *tls.ClientHelloInfo) (*Certificate, error) { return cert, nil }
|
||||
ClientAuth: tls_string_to_client_auth_type(cfg.ClientAuthType),
|
||||
ClientCAs: cert_pool, // trusted CA certs for client certificate verification
|
||||
}
|
||||
}
|
||||
|
||||
return tlscfg, nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func make_tls_client_config(cfg *ClientTLSConfig) (*tls.Config, error) {
|
||||
var tlscfg *tls.Config
|
||||
|
||||
if cfg.Enabled {
|
||||
var cert tls.Certificate
|
||||
var cert_pool *x509.CertPool
|
||||
var err error
|
||||
|
||||
|
||||
if cfg.CertText != "" && cfg.KeyText != "" {
|
||||
cert, err = tls.X509KeyPair([]byte(cfg.CertText), []byte(cfg.KeyText))
|
||||
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
|
||||
cert, err = tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
} else {
|
||||
// use the embedded certificate
|
||||
cert, err = tls.X509KeyPair(hodu_tls_cert_text, hodul_tls_key_text)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load key pair - %s", err)
|
||||
}
|
||||
|
||||
if cfg.ServerCACertText != "" || cfg.ServerCACertFile != ""{
|
||||
var ok bool
|
||||
|
||||
cert_pool = x509.NewCertPool()
|
||||
|
||||
if cfg.ServerCACertText != "" {
|
||||
ok = cert_pool.AppendCertsFromPEM([]byte(cfg.ServerCACertText))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to append certificate to pool")
|
||||
}
|
||||
} else if cfg.ServerCACertFile != "" {
|
||||
var text []byte
|
||||
text, err = ioutil.ReadFile(cfg.ServerCACertFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load ca certficate file %s - %s", cfg.ServerCACertFile, err.Error())
|
||||
}
|
||||
ok = cert_pool.AppendCertsFromPEM(text)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to append certificate to pool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tlscfg = &tls.Config{
|
||||
//Certificates: []tls.Certificate{cert},
|
||||
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil },
|
||||
RootCAs: cert_pool,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
ServerName: cfg.ServerName,
|
||||
}
|
||||
}
|
||||
|
||||
return tlscfg, nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func server_main(ctl_addrs []string, svcaddrs []string, cfg *ServerConfig) error {
|
||||
@ -385,7 +256,7 @@ func main() {
|
||||
}
|
||||
|
||||
if (cfgfile != "") {
|
||||
cfg, err = LoadServerConfig(cfgfile)
|
||||
cfg, err = load_server_config(cfgfile)
|
||||
if err != nil {
|
||||
fmt.Printf ("ERROR: failed to load configuration file %s - %s\n", cfgfile, err.Error())
|
||||
goto oops
|
||||
@ -431,7 +302,7 @@ func main() {
|
||||
}
|
||||
|
||||
if (cfgfile != "") {
|
||||
cfg, err = LoadClientConfig(cfgfile)
|
||||
cfg, err = load_client_config(cfgfile)
|
||||
if err != nil {
|
||||
fmt.Printf ("ERROR: failed to load configuration file %s - %s\n", cfgfile, err.Error())
|
||||
goto oops
|
||||
@ -443,6 +314,8 @@ func main() {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: client error - %s\n", err.Error())
|
||||
goto oops
|
||||
}
|
||||
} else if strings.EqualFold(os.Args[1], "version") {
|
||||
fmt.Printf("%s %s\n", HODU_NAME, HODU_VERSION)
|
||||
} else {
|
||||
goto wrong_usage
|
||||
}
|
||||
@ -452,6 +325,7 @@ func main() {
|
||||
wrong_usage:
|
||||
fmt.Fprintf(os.Stderr, "USAGE: %s server --rpc-on=addr:port --ctl-on=addr:port\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, " %s client --rpc-server=addr:port --ctl-on=addr:port [peer-addr:peer-port ...]\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, " %s version\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
|
||||
oops:
|
||||
|
Reference in New Issue
Block a user