updated to limit max rpc connecitons and peer connections.

changed to use time.Duration type for timeout values
This commit is contained in:
hyung-hwan 2024-12-10 14:37:14 +09:00
parent 6fdae92300
commit 07e11a8e25
4 changed files with 92 additions and 41 deletions

View File

@ -30,7 +30,7 @@ type ClientPeerCancelFuncMap = map[PeerId]context.CancelFunc
type ClientConfig struct { type ClientConfig struct {
ServerAddrs []string ServerAddrs []string
PeerAddrs []string PeerAddrs []string
ServerSeedTimeout int ServerSeedTmout time.Duration
ServerAuthority string // http2 :authority header ServerAuthority string // http2 :authority header
} }
@ -54,6 +54,9 @@ type Client struct {
ctl_mux *http.ServeMux ctl_mux *http.ServeMux
ctl []*http.Server // control server ctl []*http.Server // control server
ptc_tmout time.Duration // timeout seconds to connect to peer
ptc_limit int // global maximum number of peers
cts_limit int
cts_mtx sync.Mutex cts_mtx sync.Mutex
cts_map ClientConnMap cts_map ClientConnMap
@ -304,13 +307,14 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, pts_raddr string, pts_laddr s
var d net.Dialer var d net.Dialer
var ctx context.Context var ctx context.Context
var cancel context.CancelFunc var cancel context.CancelFunc
var tmout time.Duration
var ok bool var ok bool
defer wg.Done() defer wg.Done()
// TODO: make timeout value configurable tmout = time.Duration(r.cts.cli.ptc_tmout)
// TODO: fire the cancellation function upon stop request??? if tmout < 0 { tmout = 10 * time.Second}
ctx, cancel = context.WithTimeout(r.cts.cli.ctx, 10 * time.Second) ctx, cancel = context.WithTimeout(r.cts.cli.ctx, tmout)
r.ptc_mtx.Lock() r.ptc_mtx.Lock()
r.ptc_cancel_map[pts_id] = cancel r.ptc_cancel_map[pts_id] = cancel
r.ptc_mtx.Unlock() r.ptc_mtx.Unlock()
@ -441,8 +445,21 @@ func (r *ClientRoute) ReportEvent(pts_id PeerId, event_type PACKET_KIND, event_d
"Protocol error - invalid data in peer_started event(%d,%d)", r.id, pts_id) "Protocol error - invalid data in peer_started event(%d,%d)", r.id, pts_id)
r.ReqStop() r.ReqStop()
} else { } else {
r.ptc_wg.Add(1) if r.cts.cli.ptc_limit > 0 && int(r.cts.cli.stats.peers.Load()) >= r.cts.cli.ptc_limit {
go r.ConnectToPeer(pts_id, pd.RemoteAddrStr, pd.LocalAddrStr, &r.ptc_wg) r.cts.cli.log.Write(r.cts.sid, LOG_ERROR,
"Rejecting to connect to peer(%s)for route(%d,%d) - allowed max %d",
r.peer_addr, r.id, pts_id, r.cts.cli.ptc_limit)
err = r.cts.psc.Send(MakePeerAbortedPacket(r.id, pts_id, "", ""))
if err != nil {
r.cts.cli.log.Write(r.cts.sid, LOG_ERROR,
"Failed to send peer_aborted(%d,%d) for route(%d,%d,%s,%s) - %s",
r.id, pts_id, r.id, pts_id, "", "", err.Error())
}
} else {
r.ptc_wg.Add(1)
go r.ConnectToPeer(pts_id, pd.RemoteAddrStr, pd.LocalAddrStr, &r.ptc_wg)
}
} }
case PACKET_KIND_PEER_ABORTED: case PACKET_KIND_PEER_ABORTED:
@ -563,15 +580,23 @@ func NewClientConn(c *Client, cfg *ClientConfig) *ClientConn {
func (cts *ClientConn) AddNewClientRoute(addr string, server_peer_net string, proto ROUTE_PROTO) (*ClientRoute, error) { func (cts *ClientConn) AddNewClientRoute(addr string, server_peer_net string, proto ROUTE_PROTO) (*ClientRoute, error) {
var r *ClientRoute var r *ClientRoute
var id RouteId var id RouteId
var ok bool var nattempts RouteId
nattempts = 0
id = RouteId(rand.Uint32())
cts.route_mtx.Lock() cts.route_mtx.Lock()
id = RouteId(rand.Uint32())
for { for {
var ok bool
_, ok = cts.route_map[id] _, ok = cts.route_map[id]
if !ok { break } if !ok { break }
id++ id++
nattempts++
if nattempts == ^RouteId(0) {
cts.route_mtx.Unlock()
return nil, fmt.Errorf("route map full")
}
} }
//if cts.route_map[route_id] != nil { //if cts.route_map[route_id] != nil {
@ -718,14 +743,14 @@ func (cts *ClientConn) ReqStop() {
} }
} }
func timed_interceptor(tmout_sec int) grpc.UnaryClientInterceptor { func timed_interceptor(tmout time.Duration) grpc.UnaryClientInterceptor {
// The client calls GetSeed() as the first call to the server. // The client calls GetSeed() as the first call to the server.
// To simulate a kind of connect timeout to the server and find out an unresponsive server, // To simulate a kind of connect timeout to the server and find out an unresponsive server,
// Place a unary intercepter that places a new context with a timeout on the GetSeed() call. // Place a unary intercepter that places a new context with a timeout on the GetSeed() call.
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var cancel context.CancelFunc var cancel context.CancelFunc
if tmout_sec > 0 && method == Hodu_GetSeed_FullMethodName { if tmout > 0 && method == Hodu_GetSeed_FullMethodName {
ctx, cancel = context.WithTimeout(ctx, time.Duration(tmout_sec) * time.Second) ctx, cancel = context.WithTimeout(ctx, tmout)
defer cancel() defer cancel()
} }
return invoker(ctx, method, req, reply, cc, opts...) return invoker(ctx, method, req, reply, cc, opts...)
@ -759,8 +784,8 @@ start_over:
opts = append(opts, grpc.WithAuthority(cts.cli.rpctlscfg.ServerName)) opts = append(opts, grpc.WithAuthority(cts.cli.rpctlscfg.ServerName))
} }
} }
if cts.cfg.ServerSeedTimeout > 0 { if cts.cfg.ServerSeedTmout > 0 {
opts = append(opts, grpc.WithUnaryInterceptor(timed_interceptor(cts.cfg.ServerSeedTimeout))) opts = append(opts, grpc.WithUnaryInterceptor(timed_interceptor(cts.cfg.ServerSeedTmout)))
} }
cts.conn, err = grpc.NewClient(cts.cfg.ServerAddrs[cts.cfg.Index], opts...) cts.conn, err = grpc.NewClient(cts.cfg.ServerAddrs[cts.cfg.Index], opts...)
@ -770,9 +795,6 @@ start_over:
} }
cts.hdc = NewHoduClient(cts.conn) cts.hdc = NewHoduClient(cts.conn)
// TODO: HANDLE connection timeout.. may have to run GetSeed or PacketStream in anther goroutnine
// ctx, _/*cancel*/ := context.WithTimeout(context.Background(), time.Second)
// seed exchange is for furture expansion of the protocol // seed exchange is for furture expansion of the protocol
// there is nothing to do much about it for now. // there is nothing to do much about it for now.
c_seed.Version = HODU_RPC_VERSION c_seed.Version = HODU_RPC_VERSION
@ -1023,7 +1045,7 @@ func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config) *Client { func NewClient(ctx context.Context, logger Logger, ctl_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, rpc_max int, peer_max int, peer_conn_tmout time.Duration) *Client {
var c Client var c Client
var i int var i int
var hs_log *log.Logger var hs_log *log.Logger
@ -1032,6 +1054,9 @@ func NewClient(ctx context.Context, ctl_addrs []string, logger Logger, ctl_prefi
c.ctltlscfg = ctltlscfg c.ctltlscfg = ctltlscfg
c.rpctlscfg = rpctlscfg c.rpctlscfg = rpctlscfg
c.ext_svcs = make([]Service, 0, 1) c.ext_svcs = make([]Service, 0, 1)
c.ptc_tmout = peer_conn_tmout
c.ptc_limit = peer_max
c.cts_limit = rpc_max
c.cts_map = make(ClientConnMap) c.cts_map = make(ClientConnMap)
c.stop_req.Store(false) c.stop_req.Store(false)
c.stop_chan = make(chan bool, 8) c.stop_chan = make(chan bool, 8)
@ -1083,6 +1108,11 @@ func (c *Client) AddNewClientConn(cfg *ClientConfig) (*ClientConn, error) {
c.cts_mtx.Lock() c.cts_mtx.Lock()
if c.cts_limit > 0 && len(c.cts_map) >= c.cts_limit {
c.cts_mtx.Unlock()
return nil, fmt.Errorf("too many connections - %d", c.cts_limit)
}
//id = rand.Uint32() //id = rand.Uint32()
id = ConnId(monotonic_time() / 1000) id = ConnId(monotonic_time() / 1000)
for { for {

View File

@ -8,6 +8,7 @@ import "hodu"
import "io" import "io"
import "io/ioutil" import "io/ioutil"
import "os" import "os"
import "time"
import "gopkg.in/yaml.v3" import "gopkg.in/yaml.v3"
@ -48,25 +49,32 @@ type CTLServiceConfig struct {
type RPCServiceConfig struct { // rpc server-side configuration type RPCServiceConfig struct { // rpc server-side configuration
Addrs []string `yaml:"addresses"` Addrs []string `yaml:"addresses"`
MaxConns int `yaml:"max-connections"` // TODO: implement this item
} }
type RPCEndpointConfig struct { // rpc client-side configuration type RPCEndpointConfig struct { // rpc client-side configuration
Authority string `yaml:"authority"` Authority string `yaml:"authority"`
Addrs []string `yaml:"addresses"` Addrs []string `yaml:"addresses"`
SeedTimeout int `yaml:"seed-timeout"` SeedTmout time.Duration `yaml:"seed-timeout"`
} }
type AppConfig struct { type ServerAppConfig struct {
LogMask []string `yaml:"log-mask"` LogMask []string `yaml:"log-mask"`
LogFile string `yaml:"log-file"` LogFile string `yaml:"log-file"`
MaxPeers int `yaml:"max-peer-conns"` // maximum number of connections from peers
MaxRpcConns int `yaml:"max-rpc-conns"` // maximum number of rpc connections
}
type ClientAppConfig struct {
LogMask []string `yaml:"log-mask"`
LogFile string `yaml:"log-file"`
MaxPeers int `yaml:"max-peer-conns"` // maximum number of connections from peers
MaxRpcConns int `yaml:"max-rpc-conns"` // maximum number of rpc connections
PeerConnTmout time.Duration `yaml:"peer-conn-timeout"`
} }
type ServerConfig struct { type ServerConfig struct {
APP AppConfig `yaml:"app"` APP ServerAppConfig `yaml:"app"`
// TODO: add some limits
// max number of clients, max nubmer of peers
CTL struct { CTL struct {
Service CTLServiceConfig `yaml:"service"` Service CTLServiceConfig `yaml:"service"`
TLS ServerTLSConfig `yaml:"tls"` TLS ServerTLSConfig `yaml:"tls"`
@ -79,10 +87,8 @@ type ServerConfig struct {
} }
type ClientConfig struct { type ClientConfig struct {
APP AppConfig `yaml:"app"` APP ClientAppConfig `yaml:"app"`
// TODO: add some limits
// max nubmer of peers
CTL struct { CTL struct {
Service CTLServiceConfig `yaml:"service"` Service CTLServiceConfig `yaml:"service"`
TLS ServerTLSConfig `yaml:"tls"` TLS ServerTLSConfig `yaml:"tls"`

View File

@ -184,12 +184,14 @@ func server_main(ctl_addrs []string, rpc_addrs []string, cfg *ServerConfig) erro
logger = &AppLogger{id: "server", out: os.Stderr, mask: log_mask} logger = &AppLogger{id: "server", out: os.Stderr, mask: log_mask}
s, err = hodu.NewServer( s, err = hodu.NewServer(
context.Background(), context.Background(),
logger,
ctl_addrs, ctl_addrs,
rpc_addrs, rpc_addrs,
logger,
ctl_prefix, ctl_prefix,
ctltlscfg, ctltlscfg,
rpctlscfg) rpctlscfg,
cfg.APP.MaxRpcConns,
cfg.APP.MaxPeers)
if err != nil { if err != nil {
return fmt.Errorf("failed to create new server - %s", err.Error()) return fmt.Errorf("failed to create new server - %s", err.Error())
} }
@ -229,7 +231,7 @@ func client_main(ctl_addrs []string, rpc_addrs []string, peer_addrs []string, cf
if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs } if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs }
ctl_prefix = cfg.CTL.Service.Prefix ctl_prefix = cfg.CTL.Service.Prefix
cc.ServerSeedTimeout = cfg.RPC.Endpoint.SeedTimeout cc.ServerSeedTmout = cfg.RPC.Endpoint.SeedTmout
cc.ServerAuthority = cfg.RPC.Endpoint.Authority cc.ServerAuthority = cfg.RPC.Endpoint.Authority
log_mask = log_strings_to_mask(cfg.APP.LogMask) log_mask = log_strings_to_mask(cfg.APP.LogMask)
} }
@ -242,11 +244,14 @@ func client_main(ctl_addrs []string, rpc_addrs []string, peer_addrs []string, cf
logger = &AppLogger{id: "client", out: os.Stderr, mask: log_mask} logger = &AppLogger{id: "client", out: os.Stderr, mask: log_mask}
c = hodu.NewClient( c = hodu.NewClient(
context.Background(), context.Background(),
ctl_addrs,
logger, logger,
ctl_addrs,
ctl_prefix, ctl_prefix,
ctltlscfg, ctltlscfg,
rpctlscfg) rpctlscfg,
cfg.APP.MaxRpcConns,
cfg.APP.MaxPeers,
cfg.APP.PeerConnTmout)
c.StartService(&cc) c.StartService(&cc)
c.StartCtlService() // control channel c.StartCtlService() // control channel

View File

@ -51,6 +51,7 @@ type Server struct {
rpc_wg sync.WaitGroup rpc_wg sync.WaitGroup
rpc_svr *grpc.Server rpc_svr *grpc.Server
pts_limit int // global pts limit
cts_limit int cts_limit int
cts_mtx sync.Mutex cts_mtx sync.Mutex
cts_map ServerConnMap cts_map ServerConnMap
@ -247,6 +248,11 @@ func (r *ServerRoute) RunTask(wg *sync.WaitGroup) {
conn.Close() conn.Close()
} }
if r.cts.svr.pts_limit > 0 && int(r.cts.svr.stats.peers.Load()) >= r.cts.svr.pts_limit {
r.cts.svr.log.Write(r.cts.sid, LOG_DEBUG, "Rejected server-side peer %s to route(%d) - allowed max %d", raddr.String(), r.id, r.cts.svr.pts_limit)
conn.Close()
}
pts, err = r.AddNewServerPeerConn(conn) pts, err = r.AddNewServerPeerConn(conn)
if err != nil { if err != nil {
r.cts.svr.log.Write(r.cts.sid, LOG_ERROR, "Failed to add server-side peer %s to route(%d) - %s", r.id, raddr.String(), r.id, err.Error()) r.cts.svr.log.Write(r.cts.sid, LOG_ERROR, "Failed to add server-side peer %s to route(%d) - %s", r.id, raddr.String(), r.id, err.Error())
@ -851,7 +857,7 @@ func (hlw *server_ctl_log_writer) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logger Logger, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config) (*Server, error) { func NewServer(ctx context.Context, logger Logger, ctl_addrs []string, rpc_addrs []string, ctl_prefix string, ctltlscfg *tls.Config, rpctlscfg *tls.Config, rpc_max int, peer_max int) (*Server, error) {
var s Server var s Server
var l *net.TCPListener var l *net.TCPListener
var rpcaddr *net.TCPAddr var rpcaddr *net.TCPAddr
@ -888,7 +894,8 @@ func NewServer(ctx context.Context, ctl_addrs []string, rpc_addrs []string, logg
s.ctltlscfg = ctltlscfg s.ctltlscfg = ctltlscfg
s.rpctlscfg = rpctlscfg s.rpctlscfg = rpctlscfg
s.ext_svcs = make([]Service, 0, 1) s.ext_svcs = make([]Service, 0, 1)
s.cts_limit = CTS_LIMIT // TODO: accept this from configuration s.pts_limit = peer_max
s.cts_limit = rpc_max
s.cts_map = make(ServerConnMap) s.cts_map = make(ServerConnMap)
s.cts_map_by_addr = make(ServerConnMapByAddr) s.cts_map_by_addr = make(ServerConnMapByAddr)
s.stop_chan = make(chan bool, 8) s.stop_chan = make(chan bool, 8)
@ -1099,14 +1106,14 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
cts.stop_chan = make(chan bool, 8) cts.stop_chan = make(chan bool, 8)
s.cts_mtx.Lock() s.cts_mtx.Lock()
defer s.cts_mtx.Unlock()
if len(s.cts_map) > s.cts_limit { if s.cts_limit > 0 && len(s.cts_map) >= s.cts_limit {
s.cts_mtx.Unlock()
return nil, fmt.Errorf("too many connections - %d", s.cts_limit) return nil, fmt.Errorf("too many connections - %d", s.cts_limit)
} }
//id = rand.Uint32() //id = rand.Uint32()
id = ConnId(monotonic_time()/ 1000) id = ConnId(monotonic_time() / 1000)
for { for {
_, ok = s.cts_map[id] _, ok = s.cts_map[id]
if !ok { break } if !ok { break }
@ -1117,11 +1124,14 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
_, ok = s.cts_map_by_addr[cts.remote_addr] _, ok = s.cts_map_by_addr[cts.remote_addr]
if ok { if ok {
s.cts_mtx.Unlock()
return nil, fmt.Errorf("existing client - %s", cts.remote_addr.String()) return nil, fmt.Errorf("existing client - %s", cts.remote_addr.String())
} }
s.cts_map_by_addr[cts.remote_addr] = &cts s.cts_map_by_addr[cts.remote_addr] = &cts
s.cts_map[id] = &cts; s.cts_map[id] = &cts;
s.stats.conns.Store(int64(len(s.cts_map))) s.stats.conns.Store(int64(len(s.cts_map)))
s.cts_mtx.Unlock()
s.log.Write("", LOG_DEBUG, "Added client connection from %s", cts.remote_addr.String()) s.log.Write("", LOG_DEBUG, "Added client connection from %s", cts.remote_addr.String())
return &cts, nil return &cts, nil
} }