added some experimental code using grpc
This commit is contained in:
790
server.go
Normal file
790
server.go
Normal file
@ -0,0 +1,790 @@
|
||||
package main
|
||||
|
||||
//import "bufio"
|
||||
//import "bytes"
|
||||
import "context"
|
||||
import "crypto/tls"
|
||||
import "fmt"
|
||||
import "io"
|
||||
import "math/rand"
|
||||
import "net"
|
||||
import "os"
|
||||
import "os/signal"
|
||||
import "sync"
|
||||
import "sync/atomic"
|
||||
import "syscall"
|
||||
import "time"
|
||||
|
||||
import "google.golang.org/grpc"
|
||||
import "google.golang.org/grpc/metadata"
|
||||
import "google.golang.org/grpc/peer"
|
||||
import "google.golang.org/grpc/stats"
|
||||
|
||||
const PTS_LIMIT = 8192
|
||||
//const CTS_LIMIT = 2048
|
||||
|
||||
type ClientConnMap = map[net.Addr]*ClientConn
|
||||
type ServerPeerConnMap = map[uint32]*ServerPeerConn
|
||||
type ServerRouteMap = map[uint32]*ServerRoute
|
||||
|
||||
type Server struct {
|
||||
tlscfg *tls.Config
|
||||
l []*net.TCPListener // central listener
|
||||
l_wg sync.WaitGroup
|
||||
|
||||
cts_mtx sync.Mutex
|
||||
cts_map ClientConnMap
|
||||
wg sync.WaitGroup
|
||||
stop_req atomic.Bool
|
||||
|
||||
// grpc stuffs
|
||||
gs *grpc.Server
|
||||
UnimplementedHoduServer
|
||||
}
|
||||
|
||||
// client connection to server.
|
||||
// client connect to the server, the server accept it, and makes a tunnel request
|
||||
type ClientConn struct {
|
||||
svr *Server
|
||||
caddr net.Addr // client address that created this structure
|
||||
pss Hodu_PacketStreamServer
|
||||
|
||||
cw_mtx sync.Mutex
|
||||
route_mtx sync.Mutex
|
||||
routes ServerRouteMap
|
||||
//route_wg sync.WaitGroup
|
||||
|
||||
wg sync.WaitGroup
|
||||
stop_req atomic.Bool
|
||||
greeted bool
|
||||
}
|
||||
|
||||
type ServerRoute struct {
|
||||
cts *ClientConn
|
||||
l *net.TCPListener
|
||||
laddr *net.TCPAddr
|
||||
id uint32
|
||||
|
||||
pts_mtx sync.Mutex
|
||||
pts_map ServerPeerConnMap
|
||||
pts_limit int
|
||||
pts_last_id uint32
|
||||
pts_wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
func (r *ServerRoute) AddNewServerPeerConn(c *net.TCPConn) (*ServerPeerConn, error) {
|
||||
var pts *ServerPeerConn
|
||||
var ok bool
|
||||
var start_id uint32
|
||||
|
||||
r.pts_mtx.Lock()
|
||||
defer r.pts_mtx.Unlock()
|
||||
|
||||
if len(r.pts_map) >= r.pts_limit {
|
||||
return nil, fmt.Errorf("peer-to-server connection table full")
|
||||
}
|
||||
|
||||
start_id = r.pts_last_id
|
||||
for {
|
||||
_, ok = r.pts_map[r.pts_last_id]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
r.pts_last_id++
|
||||
if r.pts_last_id == start_id {
|
||||
// unlikely to happen but it cycled through the whole range.
|
||||
return nil, fmt.Errorf("failed to assign peer-to-server connection id")
|
||||
}
|
||||
}
|
||||
|
||||
pts = NewServerPeerConn(r, c, r.pts_last_id)
|
||||
r.pts_map[pts.conn_id] = pts
|
||||
r.pts_last_id++
|
||||
|
||||
return pts, nil
|
||||
}
|
||||
|
||||
func (r *ServerRoute) RemoveServerPeerConn(pts *ServerPeerConn) {
|
||||
r.pts_mtx.Lock()
|
||||
delete(r.pts_map, pts.conn_id)
|
||||
r.pts_mtx.Unlock()
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
func (r *ServerRoute) RunTask() {
|
||||
var err error
|
||||
var conn *net.TCPConn
|
||||
var pts *ServerPeerConn
|
||||
|
||||
for {
|
||||
conn, err = r.l.AcceptTCP()
|
||||
if err != nil {
|
||||
// TODO: logging
|
||||
fmt.Printf("[%s,%d] accept failure - %s\n", r.cts.caddr.String(), r.id, err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
pts, err = r.AddNewServerPeerConn(conn)
|
||||
if err != nil {
|
||||
// TODO: logging
|
||||
fmt.Printf("YYYYYYYY - %s\n", err.Error())
|
||||
conn.Close()
|
||||
} else {
|
||||
fmt.Printf("STARTED NEW SERVER PEER STAK\n")
|
||||
r.pts_wg.Add(1)
|
||||
go pts.RunTask()
|
||||
}
|
||||
}
|
||||
|
||||
r.l.Close() // don't care about double close. it could have been closed in StopTask
|
||||
r.pts_wg.Wait()
|
||||
|
||||
// cts.l_wg.Done()
|
||||
// TODO:inform that the job is done?
|
||||
}
|
||||
|
||||
func (r *ServerRoute) StopTask() {
|
||||
fmt.Printf ("stoppping stak..\n")
|
||||
// TODO: all pts stop...
|
||||
r.l.Close();
|
||||
// TODO: wait??
|
||||
}
|
||||
|
||||
func (r *ServerRoute) ReportEvent (pts_id uint32, event_type PACKET_KIND, event_data []byte) error {
|
||||
var spc *ServerPeerConn
|
||||
var ok bool
|
||||
|
||||
r.pts_mtx.Lock()
|
||||
spc, ok = r.pts_map[pts_id]
|
||||
if !ok {
|
||||
return fmt.Errorf("non-existent peer id - %u", pts_id)
|
||||
}
|
||||
r.pts_mtx.Unlock();
|
||||
|
||||
return spc.ReportEvent(event_type, event_data)
|
||||
}
|
||||
// ------------------------------------
|
||||
|
||||
func (cts *ClientConn) make_route_listener(proto ROUTE_PROTO) (*net.TCPListener, *net.TCPAddr, error) {
|
||||
var l *net.TCPListener
|
||||
var err error
|
||||
var laddr *net.TCPAddr
|
||||
var port int
|
||||
var tries int = 0
|
||||
var nw string
|
||||
|
||||
switch proto {
|
||||
case ROUTE_PROTO_TCP:
|
||||
nw = "tcp"
|
||||
case ROUTE_PROTO_TCP4:
|
||||
nw = "tcp4"
|
||||
case ROUTE_PROTO_TCP6:
|
||||
nw = "tcp6"
|
||||
}
|
||||
|
||||
for {
|
||||
port = rand.Intn(65535-32000+1) + 32000
|
||||
|
||||
laddr, err = net.ResolveTCPAddr(nw, fmt.Sprintf(":%d", port))
|
||||
if err == nil {
|
||||
l, err = net.ListenTCP(nw, laddr) // make the binding address configurable. support multiple binding addresses???
|
||||
if err == nil {
|
||||
fmt.Printf("listening .... on ... %d\n", port)
|
||||
return l, laddr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement max retries..
|
||||
tries++
|
||||
if tries >= 1000 {
|
||||
err = fmt.Errorf("unable to allocate port")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func NewServerRoute(cts *ClientConn, id uint32, proto ROUTE_PROTO) (*ServerRoute, error) {
|
||||
var r ServerRoute
|
||||
var l *net.TCPListener
|
||||
var laddr *net.TCPAddr
|
||||
var err error
|
||||
|
||||
l, laddr, err = cts.make_route_listener(proto);
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.cts = cts
|
||||
r.id = id
|
||||
r.l = l
|
||||
r.laddr = laddr
|
||||
r.pts_limit = PTS_LIMIT
|
||||
r.pts_map = make(ServerPeerConnMap)
|
||||
r.pts_last_id = 0
|
||||
|
||||
return &r, nil;
|
||||
}
|
||||
|
||||
func (cts *ClientConn) AddNewServerRoute(route_id uint32, proto ROUTE_PROTO) (*ServerRoute, error) {
|
||||
var r *ServerRoute
|
||||
var err error
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
if cts.routes[route_id] != nil {
|
||||
cts.route_mtx.Unlock()
|
||||
return nil, fmt.Errorf ("existent route id - %d", route_id)
|
||||
}
|
||||
r, err = NewServerRoute(cts, route_id, proto)
|
||||
if err != nil {
|
||||
cts.route_mtx.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
cts.routes[route_id] = r;
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
go r.RunTask()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) RemoveServerRoute (route_id uint32) error {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
r, ok = cts.routes[route_id]
|
||||
if (!ok) {
|
||||
cts.route_mtx.Unlock()
|
||||
return fmt.Errorf ("non-existent route id - %d", route_id)
|
||||
}
|
||||
delete(cts.routes, route_id)
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
r.StopTask() // TODO: make this unblocking or blocking?
|
||||
return nil;
|
||||
}
|
||||
|
||||
func (cts *ClientConn) ReportEvent (route_id uint32, pts_id uint32, event_type PACKET_KIND, event_data []byte) error {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
r, ok = cts.routes[route_id]
|
||||
if (!ok) {
|
||||
cts.route_mtx.Unlock()
|
||||
return fmt.Errorf ("non-existent route id - %d", route_id)
|
||||
}
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
return r.ReportEvent(pts_id, event_type, event_data)
|
||||
}
|
||||
|
||||
func (cts *ClientConn) ReqStop() {
|
||||
if cts.stop_req.CompareAndSwap(false, true) {
|
||||
var r *ServerRoute
|
||||
|
||||
for _, r = range cts.routes {
|
||||
r.StopTask()
|
||||
}
|
||||
|
||||
//cts.c.Close() // close the accepted connection from the client
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
func handle_os_signals(s *Server, exit_chan chan<- bool) {
|
||||
var (
|
||||
sighup_chan chan os.Signal
|
||||
sigterm_chan chan os.Signal
|
||||
sig os.Signal
|
||||
)
|
||||
|
||||
sighup_chan = make(chan os.Signal, 1)
|
||||
sigterm_chan = make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(sighup_chan, syscall.SIGHUP)
|
||||
signal.Notify(sigterm_chan, syscall.SIGTERM, os.Interrupt)
|
||||
|
||||
chan_loop:
|
||||
for {
|
||||
select {
|
||||
case <-sighup_chan:
|
||||
// TODO:
|
||||
//s.RefreshConfig()
|
||||
case sig = <-sigterm_chan:
|
||||
// TODO: get timeout value from config
|
||||
//s.Shutdown(fmt.Sprintf("termination by signal %s", sig), 3*time.Second)
|
||||
s.ReqStop()
|
||||
//log.Debugf("termination by signal %s", sig)
|
||||
fmt.Printf("termination by signal %s\n", sig)
|
||||
exit_chan <- true
|
||||
break chan_loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func (s *Server) PacketStream(strm Hodu_PacketStreamServer) error {
|
||||
var ctx context.Context
|
||||
var p *peer.Peer
|
||||
var ok bool
|
||||
var pkt *Packet
|
||||
var err error
|
||||
var cts *ClientConn
|
||||
|
||||
ctx = strm.Context()
|
||||
p, ok = peer.FromContext(ctx)
|
||||
if (!ok) {
|
||||
return fmt.Errorf("failed to get peer from packet stream context")
|
||||
}
|
||||
|
||||
cts, err = s.AddNewClientConn(p.Addr, strm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to add client %s - %s", p.Addr.String(), err.Error())
|
||||
}
|
||||
|
||||
|
||||
for {
|
||||
// exit if context is done
|
||||
// or continue
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// no other case is ready.
|
||||
// without the default case, the select construct would block
|
||||
}
|
||||
|
||||
pkt, err = strm.Recv()
|
||||
if err == io.EOF {
|
||||
// return will close stream from server side
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
//log.Printf("receive error %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch pkt.Kind {
|
||||
case PACKET_KIND_ROUTE_START:
|
||||
var x *Packet_Route
|
||||
//var t *ServerRoute
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_Route)
|
||||
if ok {
|
||||
var r* ServerRoute
|
||||
fmt.Printf ("ADDED SERVER ROUTE FOR CLEINT PEER %s\n", x.Route.AddrStr)
|
||||
r, err = cts.AddNewServerRoute(x.Route.RouteId, x.Route.Proto)
|
||||
if err != nil {
|
||||
// TODO: Send Error Response...
|
||||
} else {
|
||||
err = strm.Send(MakeRouteStartedPacket(r.id, x.Route.Proto, r.laddr.String()))
|
||||
if err != nil {
|
||||
// TODO:
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO: send invalid request... or simply keep quiet?
|
||||
}
|
||||
|
||||
case PACKET_KIND_ROUTE_STOP:
|
||||
var x *Packet_Route
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_Route)
|
||||
if ok {
|
||||
err = cts.RemoveServerRoute(x.Route.RouteId); // TODO: this must be unblocking. otherwide, other routes will get blocked...
|
||||
if err != nil {
|
||||
// TODO: Send Error Response...
|
||||
} else {
|
||||
err = strm.Send(MakeRouteStoppedPacket(x.Route.RouteId, x.Route.Proto))
|
||||
if err != nil {
|
||||
// TODO:
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO: send invalid request... or simply keep quiet?
|
||||
}
|
||||
|
||||
case PACKET_KIND_PEER_STARTED:
|
||||
// the connection from the client to a peer has been established
|
||||
var x *Packet_Peer
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_Peer)
|
||||
if ok {
|
||||
err = cts.ReportEvent(x.Peer.RouteId, x.Peer.PeerId, PACKET_KIND_PEER_STARTED, nil)
|
||||
if err != nil {
|
||||
// TODO:
|
||||
} else {
|
||||
// TODO:
|
||||
}
|
||||
} else {
|
||||
// TODO
|
||||
}
|
||||
|
||||
case PACKET_KIND_PEER_STOPPED:
|
||||
// the connection from the client to a peer has been established
|
||||
var x *Packet_Peer
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_Peer)
|
||||
if ok {
|
||||
err = cts.ReportEvent(x.Peer.RouteId, x.Peer.PeerId, PACKET_KIND_PEER_STOPPED, nil)
|
||||
if err != nil {
|
||||
// TODO:
|
||||
} else {
|
||||
// TODO:
|
||||
}
|
||||
} else {
|
||||
// TODO
|
||||
}
|
||||
|
||||
case PACKET_KIND_PEER_DATA:
|
||||
// the connection from the client to a peer has been established
|
||||
var x *Packet_Data
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_Data)
|
||||
if ok {
|
||||
err = cts.ReportEvent(x.Data.RouteId, x.Data.PeerId, PACKET_KIND_PEER_DATA, x.Data.Data)
|
||||
if err != nil {
|
||||
// TODO:
|
||||
} else {
|
||||
// TODO:
|
||||
}
|
||||
} else {
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
type ConnCatcher struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
func (cc *ConnCatcher) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (cc *ConnCatcher) HandleRPC(ctx context.Context, s stats.RPCStats) {
|
||||
}
|
||||
|
||||
func (cc *ConnCatcher) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
||||
return ctx;
|
||||
//return context.TODO()
|
||||
}
|
||||
|
||||
func (cc *ConnCatcher) HandleConn(ctx context.Context, cs stats.ConnStats) {
|
||||
// fmt.Println(ctx.Value("user_id")) // Returns nil, can't access the value
|
||||
var p *peer.Peer
|
||||
var ok bool
|
||||
var addr string
|
||||
|
||||
p, ok = peer.FromContext(ctx)
|
||||
if (!ok) {
|
||||
addr = ""
|
||||
} else {
|
||||
addr = p.Addr.String()
|
||||
}
|
||||
|
||||
md,ok:=metadata.FromIncomingContext(ctx)
|
||||
fmt.Printf("%+v%+v\n",md,ok)
|
||||
if ok {
|
||||
}
|
||||
switch cs.(type) {
|
||||
case *stats.ConnBegin:
|
||||
fmt.Printf("**** client connected - [%s]\n", addr)
|
||||
case *stats.ConnEnd:
|
||||
fmt.Printf("**** client disconnected - [%s]\n", addr)
|
||||
cc.server.RemoveClientConnByAddr(p.Addr);
|
||||
}
|
||||
}
|
||||
|
||||
// wrappedStream wraps around the embedded grpc.ServerStream, and intercepts the RecvMsg and
|
||||
// SendMsg method call.
|
||||
type wrappedStream struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (w *wrappedStream) RecvMsg(m any) error {
|
||||
fmt.Printf("Receive a message (Type: %T) at %s\n", m, time.Now().Format(time.RFC3339))
|
||||
return w.ServerStream.RecvMsg(m)
|
||||
}
|
||||
|
||||
func (w *wrappedStream) SendMsg(m any) error {
|
||||
fmt.Printf("Send a message (Type: %T) at %v\n", m, time.Now().Format(time.RFC3339))
|
||||
return w.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func newWrappedStream(s grpc.ServerStream) grpc.ServerStream {
|
||||
return &wrappedStream{s}
|
||||
}
|
||||
|
||||
func streamInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
// authentication (token verification)
|
||||
/*
|
||||
md, ok := metadata.FromIncomingContext(ss.Context())
|
||||
if !ok {
|
||||
return errMissingMetadata
|
||||
}
|
||||
if !valid(md["authorization"]) {
|
||||
return errInvalidToken
|
||||
}
|
||||
*/
|
||||
|
||||
err := handler(srv, newWrappedStream(ss))
|
||||
if err != nil {
|
||||
fmt.Printf("RPC failed with error: %v\n", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func unaryInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
// authentication (token verification)
|
||||
/*
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, errMissingMetadata
|
||||
}
|
||||
if !valid(md["authorization"]) {
|
||||
// return nil, errInvalidToken
|
||||
}
|
||||
*/
|
||||
m, err := handler(ctx, req)
|
||||
if err != nil {
|
||||
fmt.Printf("RPC failed with error: %v\n", err)
|
||||
}
|
||||
fmt.Printf ("RPC OK\n");
|
||||
return m, err
|
||||
}
|
||||
|
||||
func NewServer(laddrs []string, tlscfg *tls.Config) (*Server, error) {
|
||||
var s Server
|
||||
var l *net.TCPListener
|
||||
var laddr *net.TCPAddr
|
||||
var err error
|
||||
var addr string
|
||||
var gl *net.TCPListener
|
||||
|
||||
if len(laddrs) <= 0 {
|
||||
return nil, fmt.Errorf("no or too many addresses provided")
|
||||
}
|
||||
|
||||
/* create the specified number of listeners */
|
||||
s.l = make([]*net.TCPListener, 0)
|
||||
for _, addr = range laddrs {
|
||||
laddr, err = net.ResolveTCPAddr(NET_TYPE_TCP, addr)
|
||||
if err != nil {
|
||||
goto oops
|
||||
}
|
||||
|
||||
l, err = net.ListenTCP(NET_TYPE_TCP, laddr)
|
||||
if err != nil {
|
||||
goto oops
|
||||
}
|
||||
|
||||
s.l = append(s.l, l)
|
||||
}
|
||||
|
||||
s.tlscfg = tlscfg
|
||||
s.cts_map = make(ClientConnMap) // TODO: make it configurable...
|
||||
s.stop_req.Store(false)
|
||||
/*
|
||||
creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem"))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create credentials: %v", err)
|
||||
}
|
||||
gs = grpc.NewServer(grpc.Creds(creds))
|
||||
*/
|
||||
s.gs = grpc.NewServer(
|
||||
grpc.UnaryInterceptor(unaryInterceptor),
|
||||
grpc.StreamInterceptor(streamInterceptor),
|
||||
grpc.StatsHandler(&ConnCatcher{ server: &s }),
|
||||
) // TODO: have this outside the server struct?
|
||||
RegisterHoduServer (s.gs, &s)
|
||||
|
||||
return &s, nil
|
||||
|
||||
oops:
|
||||
/* TODO: check if gs needs to be closed... */
|
||||
if gl != nil {
|
||||
gl.Close()
|
||||
}
|
||||
|
||||
for _, l = range s.l {
|
||||
l.Close()
|
||||
}
|
||||
s.l = make([]*net.TCPListener, 0)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *Server) run_grpc_server(idx int) error {
|
||||
var l *net.TCPListener
|
||||
var err error
|
||||
|
||||
l = s.l[idx]
|
||||
fmt.Printf ("serving grpc on %d listener\n", idx)
|
||||
// it seems to be safe to call a single grpc server on differnt listening sockets multiple times
|
||||
// TODO: check if this assumption is ok
|
||||
err = s.gs.Serve(l);
|
||||
if err != nil {
|
||||
fmt.Printf ("XXXXXXXXXXXXXXXXXXXXXXXXXxx %s\n", err.Error());
|
||||
}
|
||||
|
||||
s.l_wg.Done();
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) MainLoop() error {
|
||||
var idx int
|
||||
|
||||
for idx, _ = range s.l {
|
||||
s.l_wg.Add(1)
|
||||
go s.run_grpc_server(idx)
|
||||
}
|
||||
|
||||
s.l_wg.Wait();
|
||||
s.ReqStop()
|
||||
s.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ReqStop() {
|
||||
if s.stop_req.CompareAndSwap(false, true) {
|
||||
var l *net.TCPListener
|
||||
var cts *ClientConn
|
||||
|
||||
//s.gs.GracefulStop()
|
||||
s.gs.Stop()
|
||||
for _, l = range s.l {
|
||||
l.Close()
|
||||
}
|
||||
|
||||
s.cts_mtx.Lock() // TODO: this mya create dead-lock. check possibility of dead lock???
|
||||
for _, cts = range s.cts_map {
|
||||
cts.ReqStop() // request to stop connections from/to peer held in the cts structure
|
||||
}
|
||||
s.cts_mtx.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) AddNewClientConn(addr net.Addr, pss Hodu_PacketStreamServer) (*ClientConn, error) {
|
||||
var cts ClientConn
|
||||
var ok bool
|
||||
|
||||
cts.svr = s
|
||||
cts.routes = make(ServerRouteMap)
|
||||
cts.caddr = addr
|
||||
cts.pss = pss
|
||||
|
||||
cts.stop_req.Store(false)
|
||||
cts.greeted = false
|
||||
|
||||
s.cts_mtx.Lock()
|
||||
defer s.cts_mtx.Unlock()
|
||||
|
||||
_, ok = s.cts_map[addr]
|
||||
if ok {
|
||||
return nil, fmt.Errorf("existing client - %s", addr.String())
|
||||
}
|
||||
|
||||
s.cts_map[addr] = &cts;
|
||||
fmt.Printf ("ADD total clients %d\n", len(s.cts_map));
|
||||
return &cts, nil
|
||||
}
|
||||
|
||||
func (s *Server) RemoveClientConn(cts *ClientConn) {
|
||||
s.cts_mtx.Lock()
|
||||
delete(s.cts_map, cts.caddr)
|
||||
fmt.Printf ("REMOVE total clients %d\n", len(s.cts_map));
|
||||
s.cts_mtx.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) RemoveClientConnByAddr(addr net.Addr) {
|
||||
var cts *ClientConn
|
||||
var ok bool
|
||||
|
||||
s.cts_mtx.Lock()
|
||||
defer s.cts_mtx.Unlock()
|
||||
|
||||
cts, ok = s.cts_map[addr]
|
||||
if ok {
|
||||
cts.ReqStop()
|
||||
delete(s.cts_map, cts.caddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) FindClientConnByAddr (addr net.Addr) *ClientConn {
|
||||
var cts *ClientConn
|
||||
var ok bool
|
||||
|
||||
s.cts_mtx.Lock()
|
||||
defer s.cts_mtx.Unlock()
|
||||
|
||||
cts, ok = s.cts_map[addr]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cts
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
const serverKey = `-----BEGIN EC PARAMETERS-----
|
||||
BggqhkjOPQMBBw==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIHg+g2unjA5BkDtXSN9ShN7kbPlbCcqcYdDu+QeV8XWuoAoGCCqGSM49
|
||||
AwEHoUQDQgAEcZpodWh3SEs5Hh3rrEiu1LZOYSaNIWO34MgRxvqwz1FMpLxNlx0G
|
||||
cSqrxhPubawptX5MSr02ft32kfOlYbaF5Q==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
|
||||
const serverCert = `-----BEGIN CERTIFICATE-----
|
||||
MIIB+TCCAZ+gAwIBAgIJAL05LKXo6PrrMAoGCCqGSM49BAMCMFkxCzAJBgNVBAYT
|
||||
AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn
|
||||
aXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNTEyMDgxNDAxMTNa
|
||||
Fw0yNTEyMDUxNDAxMTNaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0
|
||||
YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMM
|
||||
CWxvY2FsaG9zdDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABHGaaHVod0hLOR4d
|
||||
66xIrtS2TmEmjSFjt+DIEcb6sM9RTKS8TZcdBnEqq8YT7m2sKbV+TEq9Nn7d9pHz
|
||||
pWG2heWjUDBOMB0GA1UdDgQWBBR0fqrecDJ44D/fiYJiOeBzfoqEijAfBgNVHSME
|
||||
GDAWgBR0fqrecDJ44D/fiYJiOeBzfoqEijAMBgNVHRMEBTADAQH/MAoGCCqGSM49
|
||||
BAMCA0gAMEUCIEKzVMF3JqjQjuM2rX7Rx8hancI5KJhwfeKu1xbyR7XaAiEA2UT7
|
||||
1xOP035EcraRmWPe7tO0LpXgMxlh2VItpc2uc2w=
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
func server_main(laddrs []string) error {
|
||||
var s *Server
|
||||
var err error
|
||||
var exit_chan chan bool
|
||||
var cert tls.Certificate
|
||||
|
||||
cert, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ERROR: failed to load key pair - %s\n", err)
|
||||
}
|
||||
|
||||
s, err = NewServer(laddrs, &tls.Config{Certificates: []tls.Certificate{cert}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ERROR: failed to create new server - %s", err.Error())
|
||||
}
|
||||
|
||||
exit_chan = make(chan bool, 1)
|
||||
go handle_os_signals(s, exit_chan)
|
||||
err = s.MainLoop() // this is blocking. ReqStop() will be called from a signal handler
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-exit_chan // wait until the term signal handler almost reaches the end
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user