From 09593fd678aca016aa0542c80a9f9e4d784220ef Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Tue, 11 Mar 2025 21:12:05 +0900 Subject: [PATCH] fixed some potential concurrency issues in the client side. enhanded route_started log messages renamed service to svc for some items added two more fields to denote requested service address/network to ServerRoute --- Makefile | 5 +- README.md | 4 +- bulletin.go | 38 +++++++++- bulletin_test.go | 2 +- client-ctl.go | 108 ++++++++++++++++----------- client.go | 188 ++++++++++++++++++++++++++++++----------------- hodu.proto | 16 ++-- server-ctl.go | 4 +- server.go | 11 ++- 9 files changed, 250 insertions(+), 126 deletions(-) diff --git a/Makefile b/Makefile index c153978..d819e12 100644 --- a/Makefile +++ b/Makefile @@ -50,9 +50,12 @@ $(NAME): $(DATA) $(SRCS) $(CMD_DATA) $(CMD_SRCS) CGO_ENABLED=1 go build -x -ldflags "-X 'main.HODU_NAME=$(NAME)' -X 'main.HODU_VERSION=$(VERSION)'" -o $@ $(CMD_SRCS) ##CGO_ENABLED=1 go build -x -ldflags "-X 'main.HODU_NAME=$(NAME)' -X 'main.HODU_VERSION=$(VERSION)' -linkmode external -extldflags=-static" -o $@ $(CMD_SRCS) +$(NAME).debug: $(DATA) $(SRCS) $(CMD_DATA) $(CMD_SRCS) + CGO_ENABLED=1 go build -race -x -ldflags "-X 'main.HODU_NAME=$(NAME)' -X 'main.HODU_VERSION=$(VERSION)'" -o $@ $(CMD_SRCS) + clean: go clean -x -i - rm -f $(NAME) + rm -f $(NAME) $(NAME).debug test: go test -x diff --git a/README.md b/README.md index 7501c40..6bcdd1b 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,8 @@ "client-peer-addr": "192.168.1.104:22", "client-peer-name": "Star gate", "server-peer-option": "tcp4 ssh", - "server-peer-service-addr": "0.0.0.0:0", - "server-peer-service-net": "", + "server-peer-svc-addr": "0.0.0.0:0", + "server-peer-svc-net": "", "lifetime": "0" } ``` diff --git a/bulletin.go b/bulletin.go index cd48b89..f691967 100644 --- a/bulletin.go +++ b/bulletin.go @@ -1,6 +1,7 @@ package hodu import "container/list" +import "container/ring" import "errors" import "sync" @@ -19,11 +20,19 @@ type Bulletin[T interface{}] struct { sbsc_map BulletinSubscriptionMap sbsc_mtx sync.RWMutex closed bool + + r_mtx sync.RWMutex + r *ring.Ring + r_capa int + r_full bool } -func NewBulletin[T interface{}]() *Bulletin[T] { +func NewBulletin[T interface{}](capa int) *Bulletin[T] { return &Bulletin[T]{ sbsc_map: make(BulletinSubscriptionMap, 0), + r: ring.New(capa), + r_capa: capa, + r_full: false, } } @@ -134,3 +143,30 @@ func (b *Bulletin[T]) Publish(topic string, data T) { b.sbsc_mtx.Unlock() } +func (b *Bulletin[T]) Enqueue(topic string, data T) { + b.r_mtx.Lock() + b.r.Value = data // update the value at the current position + b.r = b.r.Next() // move the current position + b.r_mtx.Unlock() +} + +func (b *Bulletin[T]) Dequeue() { + b.r_mtx.Lock() + b.r_mtx.Unlock() +} + +/* +func (b *Bulletin[T]) RunTask(wg *sync.WaitGroup) { + var done bool + var msg T + var ok bool + + defer wg.Done() + + for !done { + select { + case msg, ok = <- b.C: + if !ok { done = true } + } + } +}*/ diff --git a/bulletin_test.go b/bulletin_test.go index 7c89ed7..d2d316e 100644 --- a/bulletin_test.go +++ b/bulletin_test.go @@ -14,7 +14,7 @@ func TestBulletin(t *testing.T) { var nmsgs1 int var nmsgs2 int - b = hodu.NewBulletin[string]() + b = hodu.NewBulletin[string](100) s1, _ = b.Subscribe("t1") s2, _ = b.Subscribe("t2") diff --git a/client-ctl.go b/client-ctl.go index 08ca96b..5b18991 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -32,8 +32,12 @@ type json_in_client_route struct { ClientPeerAddr string `json:"client-peer-addr"` ClientPeerName string `json:"client-peer-name"` ServerPeerOption string `json:"server-peer-option"` - ServerPeerServiceAddr string `json:"server-peer-service-addr"` // desired listening address on the server side - ServerPeerServiceNet string `json:"server-peer-service-net"` // permitted network in prefix notation + + // the following two fields in the input structure is the requested values. + // the actual values are returned in json_out_clinet_route and may be different from the requested ones + ServerPeerSvcAddr string `json:"server-peer-svc-addr"` // requested listening address on the server side - not actual + ServerPeerSvcNet string `json:"server-peer-svc-net"` // requested permitted network in prefix notation - not actual + Lifetime string `json:"lifetime"` } @@ -65,8 +69,12 @@ type json_out_client_route struct { ClientPeerAddr string `json:"client-peer-addr"` ClientPeerName string `json:"client-peer-name"` ServerPeerOption string `json:"server-peer-option"` - ServerPeerListenAddr string `json:"server-peer-service-addr"` - ServerPeerNet string `json:"server-peer-service-net"` + + // These two values are actual addresses and networks listening on peer service port + // and may be different from the requested values conveyed in json_in_client_route. + ServerPeerSvcAddr string `json:"server-peer-svc-addr"` + ServerPeerSvcNet string `json:"server-peer-svc-net"` + Lifetime string `json:"lifetime"` LifetimeStart int64 `json:"lifetime-start"` } @@ -231,34 +239,40 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R var cts *ClientConn var jsp []json_out_client_route var ri RouteId + var local_addr string + var remote_addr string cts = c.cts_map[ci] jsp = make([]json_out_client_route, 0) cts.route_mtx.Lock() for _, ri = range cts.route_map.get_sorted_keys() { var r *ClientRoute - var spla string = "" + var lftsta time.Time + var lftdur time.Duration r = cts.route_map[ri] - if r.server_peer_listen_addr != nil { spla = r.server_peer_listen_addr.String() } + + lftsta, lftdur = r.GetLifetimeInfo() jsp = append(jsp, json_out_client_route{ Id: r.Id, ClientPeerAddr: r.PeerAddr, ClientPeerName: r.PeerName, - ServerPeerListenAddr: spla, - ServerPeerNet: r.ServerPeerNet, + ServerPeerSvcAddr: r.ServerPeerSvcAddr, + ServerPeerSvcNet: r.ServerPeerSvcNet, ServerPeerOption: r.ServerPeerOption.String(), - Lifetime: DurationToSecString(r.Lifetime), - LifetimeStart: r.LifetimeStart.Unix(), + Lifetime: DurationToSecString(lftdur), + LifetimeStart: lftsta.Unix(), }) } cts.route_mtx.Unlock() + + local_addr, remote_addr = cts.GetAddrInfo() js = append(js, json_out_client_conn{ Id: cts.Id, ReqServerAddrs: cts.cfg.ServerAddrs, CurrentServerIndex: cts.cfg.Index, - ServerAddr: cts.remote_addr, - ClientAddr: cts.local_addr, + ServerAddr: remote_addr, + ClientAddr: local_addr, ClientToken: cts.Token, Routes: jsp, }) @@ -344,33 +358,38 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt var jsp []json_out_client_route var js *json_out_client_conn var ri RouteId + var local_addr string + var remote_addr string jsp = make([]json_out_client_route, 0) cts.route_mtx.Lock() for _, ri = range cts.route_map.get_sorted_keys() { var r *ClientRoute - var spla string = "" + var lftsta time.Time + var lftdur time.Duration r = cts.route_map[ri] - if r.server_peer_listen_addr != nil { spla = r.server_peer_listen_addr.String() } + + lftsta, lftdur = r.GetLifetimeInfo() jsp = append(jsp, json_out_client_route{ Id: r.Id, ClientPeerAddr: r.PeerAddr, ClientPeerName: r.PeerName, - ServerPeerListenAddr: spla, - ServerPeerNet: r.ServerPeerNet, + ServerPeerSvcAddr: r.ServerPeerSvcAddr, + ServerPeerSvcNet: r.ServerPeerSvcNet, ServerPeerOption: r.ServerPeerOption.String(), - Lifetime: DurationToSecString(r.Lifetime), - LifetimeStart: r.LifetimeStart.Unix(), + Lifetime: DurationToSecString(lftdur), + LifetimeStart: lftsta.Unix(), }) } cts.route_mtx.Unlock() + local_addr, remote_addr = cts.GetAddrInfo() js = &json_out_client_conn{ Id: cts.Id, ReqServerAddrs: cts.cfg.ServerAddrs, CurrentServerIndex: cts.cfg.Index, - ServerAddr: cts.local_addr, - ClientAddr: cts.remote_addr, + ServerAddr: local_addr, + ClientAddr: remote_addr, ClientToken: cts.Token, Routes: jsp, } @@ -424,19 +443,21 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r cts.route_mtx.Lock() for _, ri = range cts.route_map.get_sorted_keys() { var r *ClientRoute - var spla string = "" + var lftsta time.Time + var lftdur time.Duration r = cts.route_map[ri] - if r.server_peer_listen_addr != nil { spla = r.server_peer_listen_addr.String() } + + lftsta, lftdur = r.GetLifetimeInfo() jsp = append(jsp, json_out_client_route{ Id: r.Id, ClientPeerAddr: r.PeerAddr, ClientPeerName: r.PeerName, - ServerPeerListenAddr: spla, - ServerPeerNet: r.ServerPeerNet, + ServerPeerSvcAddr: r.ServerPeerSvcAddr, + ServerPeerSvcNet: r.ServerPeerSvcNet, ServerPeerOption: r.ServerPeerOption.String(), - Lifetime: DurationToSecString(r.Lifetime), - LifetimeStart: r.LifetimeStart.Unix(), + Lifetime: DurationToSecString(lftdur), + LifetimeStart: lftsta.Unix(), }) } cts.route_mtx.Unlock() @@ -482,8 +503,8 @@ func (ctl *client_ctl_client_conns_id_routes) ServeHTTP(w http.ResponseWriter, r PeerAddr: jcr.ClientPeerAddr, PeerName: jcr.ClientPeerName, Option: server_peer_option, - ServiceAddr: jcr.ServerPeerServiceAddr, - ServiceNet: jcr.ServerPeerServiceNet, + ServiceAddr: jcr.ServerPeerSvcAddr, + ServiceNet: jcr.ServerPeerSvcNet, Lifetime: lifetime, Static: false, } @@ -539,19 +560,21 @@ func (ctl *client_ctl_client_conns_id_routes_id) ServeHTTP(w http.ResponseWriter switch req.Method { case http.MethodGet: - var spla string = "" - if r.server_peer_listen_addr != nil { spla = r.server_peer_listen_addr.String() } + var lftsta time.Time + var lftdur time.Duration + status_code = WriteJsonRespHeader(w, http.StatusOK) + + lftsta, lftdur = r.GetLifetimeInfo() err = je.Encode(json_out_client_route{ Id: r.Id, ClientPeerAddr: r.PeerAddr, ClientPeerName: r.PeerName, - ServerPeerListenAddr: spla, - ServerPeerNet: r.ServerPeerNet, + ServerPeerSvcAddr: r.ServerPeerSvcAddr, + ServerPeerSvcNet: r.ServerPeerSvcNet, ServerPeerOption: r.ServerPeerOption.String(), - Lifetime: DurationToSecString(r.Lifetime), - LifetimeStart: r.LifetimeStart.Unix(), - + Lifetime: DurationToSecString(lftdur), + LifetimeStart: lftsta.Unix(), }) if err != nil { goto oops } @@ -624,18 +647,21 @@ func (ctl *client_ctl_client_conns_id_routes_spsp) ServeHTTP(w http.ResponseWrit switch req.Method { case http.MethodGet: - var spla string = "" - if r.server_peer_listen_addr != nil { spla = r.server_peer_listen_addr.String() } + var lftsta time.Time + var lftdur time.Duration + status_code = WriteJsonRespHeader(w, http.StatusOK) + + lftsta, lftdur = r.GetLifetimeInfo() err = je.Encode(json_out_client_route{ Id: r.Id, ClientPeerAddr: r.PeerAddr, ClientPeerName: r.PeerName, - ServerPeerListenAddr: spla, - ServerPeerNet: r.ServerPeerNet, + ServerPeerSvcAddr: r.ServerPeerSvcAddr, + ServerPeerSvcNet: r.ServerPeerSvcNet, ServerPeerOption: r.ServerPeerOption.String(), - Lifetime: DurationToSecString(r.Lifetime), - LifetimeStart: r.LifetimeStart.Unix(), + Lifetime: DurationToSecString(lftdur), + LifetimeStart: lftsta.Unix(), }) if err != nil { goto oops } diff --git a/client.go b/client.go index 1f76edd..89094bf 100644 --- a/client.go +++ b/client.go @@ -36,8 +36,8 @@ type ClientRouteConfig struct { PeerAddr string PeerName string Option RouteOption - ServiceAddr string // server-peer-service-addr - ServiceNet string // server-peer-service-net + ServiceAddr string // server-peer-svc-addr + ServiceNet string // server-peer-svc-net Lifetime time.Duration Static bool } @@ -117,8 +117,7 @@ type Client struct { } } -type ClientConnState int - +type ClientConnState = int32 const ( CLIENT_CONN_CONNECTING ClientConnState = iota CLIENT_CONN_CONNECTED @@ -132,11 +131,15 @@ type ClientConn struct { cfg ClientConnConfigActive Id ConnId Sid string // id rendered in string - State ClientConnState + State atomic.Int32 // ClientConnState Token string + addr_mtx sync.Mutex // because the following fields are updated concurrently local_addr string remote_addr string + local_addr_p string // not reset when disconnected. used mostly in writing log messages without addr_mtx + remote_addr_p string // not reset when disconnected. used mostly in writing log messages without addr_mtx + conn *grpc.ClientConn // grpc connection to the server hdc HoduClient psc *GuardedPacketStreamClient // guarded grpc stream @@ -151,6 +154,8 @@ type ClientConn struct { stop_req atomic.Bool stop_chan chan bool + + discon_mtx sync.Mutex } type ClientRoute struct { @@ -162,9 +167,11 @@ type ClientRoute struct { PeerName string PeerOption RouteOption - server_peer_listen_addr *net.TCPAddr // actual service-side service address - ServerPeerAddr string // desired server-side service address - ServerPeerNet string + ReqServerPeerSvcAddr string // requested server-side service address + ReqServerPeerSvcNet string // requested server-side service address + ServerPeerListenAddr *net.TCPAddr // actual service-side service address + ServerPeerSvcAddr string // actual server-side service address + ServerPeerSvcNet string // actual server-side network ServerPeerOption RouteOption ptc_mtx sync.Mutex @@ -237,8 +244,8 @@ func NewClientRoute(cts *ClientConn, id RouteId, static bool, client_peer_addr s // if the client_peer_addr is a domain name, it can't tell between tcp4 and tcp6 r.PeerOption = StringToRouteOption(TcpAddrStrClass(client_peer_addr)) - r.ServerPeerAddr = server_peer_svc_addr - r.ServerPeerNet = server_peer_svc_net // permitted network for server-side peer + r.ReqServerPeerSvcAddr = server_peer_svc_addr + r.ReqServerPeerSvcNet = server_peer_svc_net // permitted network for server-side peer r.ServerPeerOption = server_peer_option r.LifetimeStart = time.Now() r.Lifetime = lifetime @@ -340,6 +347,12 @@ func (r *ClientRoute) ResetLifetime(lifetime time.Duration) error { } } +func (r *ClientRoute) GetLifetimeInfo() (time.Time, time.Duration) { + r.lifetime_mtx.Lock() + defer r.lifetime_mtx.Unlock() + return r.LifetimeStart, r.Lifetime +} + func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { var err error @@ -348,16 +361,16 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { // it merely implements some timeout if set. defer wg.Done() - err = r.cts.psc.Send(MakeRouteStartPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ServerPeerAddr, r.ServerPeerNet)) + err = r.cts.psc.Send(MakeRouteStartPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)) if err != nil { r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG, - "Failed to send route_start for route(%d,%s,%v,%v) to %s", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ServerPeerNet, r.cts.remote_addr) + "Failed to send route_start for route(%d,%s,%v,%s,%s) to %s - %s", + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p, err.Error()) goto done } else { r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG, - "Sent route_start for route(%d,%s,%v,%v) to %s", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ServerPeerNet, r.cts.remote_addr) + "Sent route_start for route(%d,%s,%v,%s,%s) to %s", + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet, r.cts.remote_addr_p) } r.lifetime_mtx.Lock() @@ -375,8 +388,8 @@ main_loop: break main_loop case <-r.lifetime_timer.C: - r.cts.C.log.Write(r.cts.Sid, LOG_INFO, "route(%d,%s,%v,%v) reached end of lifetime(%v)", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ServerPeerNet, r.Lifetime) + r.cts.C.log.Write(r.cts.Sid, LOG_INFO, "route(%d,%s) reached end of lifetime(%v)", + r.Id, r.PeerAddr, r.Lifetime) break main_loop } } else { @@ -398,15 +411,15 @@ done: r.ReqStop() r.ptc_wg.Wait() // wait for all peer tasks are finished - err = r.cts.psc.Send(MakeRouteStopPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ServerPeerAddr, r.ServerPeerNet)) + err = r.cts.psc.Send(MakeRouteStopPacket(r.Id, r.ServerPeerOption, r.PeerAddr, r.PeerName, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet)) if err != nil { r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG, - "Failed to route_stop for route(%d,%s,%v,%v) to %s - %s", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ServerPeerNet, r.cts.remote_addr, err.Error()) + "Failed to route_stop for route(%d,%s) to %s - %s", + r.Id, r.PeerAddr, r.cts.remote_addr_p, err.Error()) } else { r.cts.C.log.Write(r.cts.Sid, LOG_DEBUG, - "Sent route_stop for route(%d,%s,%v,%v) to %s", - r.Id, r.PeerAddr, r.ServerPeerOption, r.ServerPeerNet, r.cts.remote_addr) + "Sent route_stop for route(%d,%s) to %s", + r.Id, r.PeerAddr, r.cts.remote_addr_p) } r.cts.RemoveClientRoute(r) @@ -537,17 +550,26 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event var rd *RouteDesc rd, ok = event_data.(*RouteDesc) if !ok { - r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid data in route_started event(%d)", r.Id) + r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, + "Protocol error - invalid data in route_started event(%d)", r.Id) r.ReqStop() } else { var addr *net.TCPAddr addr, err = net.ResolveTCPAddr(TcpAddrStrClass(rd.TargetAddrStr), rd.TargetAddrStr) if err != nil { - r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid service address(%s) for server peer in route_started event(%d)", rd.TargetAddrStr, r.Id) + r.cts.C.log.Write(r.cts.Sid,LOG_ERROR, + "Protocol error - invalid service address(%s) for server peer in route_started event(%d)", rd.TargetAddrStr, r.Id) r.ReqStop() } else { - r.server_peer_listen_addr = addr - r.ServerPeerNet = rd.ServiceNetStr + // received the server-side addresses + r.ServerPeerListenAddr = addr + r.ServerPeerSvcAddr = rd.TargetAddrStr + r.ServerPeerSvcNet = rd.ServiceNetStr + + r.cts.C.log.Write(r.cts.Sid, LOG_INFO, + "Ingested route_started(%d,%s,%s) for route(%d,%s,%v,%s,%s)", + rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr, + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet); } } @@ -557,8 +579,17 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event // in case of the failed ROUTE_START, r.ReqStop() may trigger another ROUTE_STOP sent to the server. // but the server must be able to handle this case as invalid route. var ok bool - _, ok = event_data.(*RouteDesc) - if !ok { r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid data in route_started event(%d)", r.Id) } + var rd *RouteDesc + + rd, ok = event_data.(*RouteDesc) + if !ok { + r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, "Protocol error - invalid data in route_started event(%d)", r.Id) + } else { + r.cts.C.log.Write(r.cts.Sid, LOG_INFO, + "Ingested route_stopped(%d,%s,%s) for route(%d,%s,%v,%s,%s)", + rd.RouteId, rd.TargetAddrStr, rd.ServiceNetStr, + r.Id, r.PeerAddr, r.ServerPeerOption, r.ReqServerPeerSvcAddr, r.ReqServerPeerSvcNet); + } r.ReqStop() case PACKET_KIND_PEER_STARTED: @@ -753,7 +784,7 @@ func (cts *ClientConn) AddNewClientRoute(rc *ClientRouteConfig) (*ClientRoute, e if cts.C.route_persister != nil { cts.C.route_persister.Save(cts, r) } cts.route_mtx.Unlock() - cts.C.log.Write(cts.Sid, LOG_INFO, "Added route(%d,%d) %s", cts.Id, r.Id, r.PeerAddr) + cts.C.log.Write(cts.Sid, LOG_INFO, "Added route(%d,%s)", r.Id, r.PeerAddr) cts.route_wg.Add(1) go r.RunTask(&cts.route_wg) @@ -797,9 +828,7 @@ func (cts *ClientConn) RemoveClientRoute(route *ClientRoute) error { } delete(cts.route_map, route.Id) cts.C.stats.routes.Add(-1) - if cts.C.route_persister != nil { - cts.C.route_persister.Delete(cts, r) - } + if cts.C.route_persister != nil { cts.C.route_persister.Delete(cts, r) } cts.route_mtx.Unlock() cts.C.log.Write(cts.Sid, LOG_INFO, "Removed route(%d,%s)", route.Id, route.PeerAddr) @@ -837,7 +866,7 @@ func (cts *ClientConn) RemoveClientRouteByServerPeerSvcPortId(port_id PortId) er // use the actual route id for finding. cts.route_mtx.Lock() for _, r = range cts.route_map { - if r.server_peer_listen_addr.Port == int(port_id) { + if r.ServerPeerListenAddr.Port == int(port_id) { delete(cts.route_map, r.Id) cts.C.stats.routes.Add(-1) if cts.C.route_persister != nil { cts.C.route_persister.Delete(cts, r) } @@ -875,7 +904,7 @@ func (cts *ClientConn) FindClientRouteByServerPeerSvcPortId(port_id PortId) *Cli // use the actual route id for finding. cts.route_mtx.Lock() for _, r = range cts.route_map { - if r.server_peer_listen_addr.Port == int(port_id) { + if r.ServerPeerListenAddr.Port == int(port_id) { cts.route_mtx.Unlock() return r; // return the first match } @@ -899,15 +928,23 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error { return nil } -func (cts *ClientConn) disconnect_from_server() { +func (cts *ClientConn) disconnect_from_server(logmsg bool) { if cts.conn != nil { var r *ClientRoute + cts.discon_mtx.Lock() + + if (logmsg) { + cts.C.log.Write(cts.Sid, LOG_INFO, "Disconnecting from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) + } + cts.route_mtx.Lock() for _, r = range cts.route_map { r.ReqStop() } cts.route_mtx.Unlock() + // don't care about double closes when this function is called from both RunTask() and ReqStop() cts.conn.Close() + // don't reset cts.conn to nil here // if this function is called from RunTask() // for reconnection, it will be set to a new value @@ -915,18 +952,29 @@ func (cts *ClientConn) disconnect_from_server() { // if it's called from ReqStop(), we don't really // need to care about it. + cts.addr_mtx.Lock() cts.local_addr = "" cts.remote_addr = "" + // don't reset cts.local_addr_p and cts.remote_addr_p + cts.addr_mtx.Unlock() + + cts.discon_mtx.Unlock() } } func (cts *ClientConn) ReqStop() { if cts.stop_req.CompareAndSwap(false, true) { - cts.disconnect_from_server() + cts.disconnect_from_server(false) cts.stop_chan <- true } } +func (cts *ClientConn) GetAddrInfo() (string, string) { + cts.addr_mtx.Lock() + defer cts.addr_mtx.Unlock() + return cts.local_addr, cts.remote_addr +} + func timed_interceptor(tmout time.Duration) grpc.UnaryClientInterceptor { // The client calls GetSeed() as the first call to the server. // To simulate a kind of connect timeout to the server and find out an unresponsive server, @@ -955,7 +1003,7 @@ func (cts *ClientConn) RunTask(wg *sync.WaitGroup) { defer wg.Done() // arrange to call at the end of this function start_over: - cts.State = CLIENT_CONN_CONNECTING + cts.State.Store(CLIENT_CONN_CONNECTING) cts.cfg.Index = (cts.cfg.Index + 1) % len(cts.cfg.ServerAddrs) cts.C.log.Write(cts.Sid, LOG_INFO, "Connecting to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) if cts.C.rpc_tls == nil { @@ -1004,13 +1052,17 @@ start_over: p, ok = peer.FromContext(psc.Context()) if ok { + cts.addr_mtx.Lock() cts.remote_addr = p.Addr.String() cts.local_addr = p.LocalAddr.String() + cts.remote_addr_p = cts.remote_addr + cts.local_addr_p = cts.local_addr + cts.addr_mtx.Unlock() } cts.C.log.Write(cts.Sid, LOG_INFO, "Got packet stream from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) - cts.State = CLIENT_CONN_CONNECTED + cts.State.Store(CLIENT_CONN_CONNECTED) cts.Token = cts.cfg.ClientToken if cts.Token == "" { cts.Token = cts.C.token } @@ -1062,7 +1114,7 @@ start_over: if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) { goto reconnect_to_server } else { - cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr, err.Error()) + cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr_p, err.Error()) goto reconnect_to_server } } @@ -1078,14 +1130,14 @@ start_over: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to handle route_started event(%d,%s) from %s - %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr, err.Error()) + x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Handled route_started event(%d,%s) from %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr) + x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid route_started event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid route_started event from %s", cts.remote_addr_p) } case PACKET_KIND_ROUTE_STOPPED: @@ -1097,14 +1149,14 @@ start_over: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to handle route_stopped event(%d,%s) from %s - %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr, err.Error()) + x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Handled route_stopped event(%d,%s) from %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr) + x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid route_stopped event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid route_stopped event from %s", cts.remote_addr_p) } case PACKET_KIND_PEER_STARTED: @@ -1117,14 +1169,14 @@ start_over: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to handle peer_started event from %s for peer(%d,%d,%s,%s) - %s", - cts.remote_addr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) + cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Handled peer_started event from %s for peer(%d,%d,%s,%s)", - cts.remote_addr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) + cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_started event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_started event from %s", cts.remote_addr_p) } // PACKET_KIND_PEER_ABORTED is never sent by server to client. @@ -1140,14 +1192,14 @@ start_over: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to handle peer_stopped event from %s for peer(%d,%d,%s,%s) - %s", - cts.remote_addr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) + cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Handled peer_stopped event from %s for peer(%d,%d,%s,%s)", - cts.remote_addr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) + cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_stopped event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_stopped event from %s", cts.remote_addr_p) } case PACKET_KIND_PEER_EOF: @@ -1159,14 +1211,14 @@ start_over: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to handle peer_eof event from %s for peer(%d,%d,%s,%s) - %s", - cts.remote_addr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) + cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Handled peer_eof event from %s for peer(%d,%d,%s,%s)", - cts.remote_addr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) + cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_eof event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_eof event from %s", cts.remote_addr_p) } case PACKET_KIND_PEER_DATA: @@ -1179,14 +1231,14 @@ start_over: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to handle peer_data event from %s for peer(%d,%d) - %s", - cts.remote_addr, x.Data.RouteId, x.Data.PeerId, err.Error()) + cts.remote_addr_p, x.Data.RouteId, x.Data.PeerId, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Handled peer_data event from %s for peer(%d,%d)", - cts.remote_addr, x.Data.RouteId, x.Data.PeerId) + cts.remote_addr_p, x.Data.RouteId, x.Data.PeerId) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_data event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_data event from %s", cts.remote_addr_p) } case PACKET_KIND_CONN_ERROR: @@ -1194,10 +1246,10 @@ start_over: var ok bool x, ok = pkt.U.(*Packet_ConnErr) if ok { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Received conn_error(%d, %s) event from %s", x.ConnErr.ErrorId, x.ConnErr.Text, cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Received conn_error(%d, %s) event from %s", x.ConnErr.ErrorId, x.ConnErr.Text, cts.remote_addr_p) if cts.cfg.CloseOnConnErrorEvent { goto done } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_error event from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_error event from %s", cts.remote_addr_p) } case PACKET_KIND_CONN_NOTICE: @@ -1206,12 +1258,12 @@ start_over: var ok bool x, ok = pkt.U.(*Packet_ConnNoti) if ok { - cts.C.log.Write(cts.Sid, LOG_DEBUG, "conn_notice message '%s' received from %s", x.ConnNoti.Text, cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_DEBUG, "conn_notice message '%s' received from %s", x.ConnNoti.Text, cts.remote_addr_p) if cts.C.conn_notice != nil { cts.C.conn_notice.Handle(cts, x.ConnNoti.Text) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_notice packet from %s", cts.remote_addr) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_notice packet from %s", cts.remote_addr_p) } default: @@ -1221,7 +1273,7 @@ start_over: done: cts.C.log.Write(cts.Sid, LOG_INFO, "Disconnected from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) - cts.State = CLIENT_CONN_DISCONNECTED + cts.State.Store(CLIENT_CONN_DISCONNECTED) req_stop_and_wait_for_termination: //cts.RemoveClientRoutes() // this isn't needed as each task removes itself from cts upon its termination @@ -1233,13 +1285,11 @@ wait_for_termination: return reconnect_to_server: - cts.State = CLIENT_CONN_DISCONNECTING - - if cts.conn != nil { - cts.C.log.Write(cts.Sid, LOG_INFO, "Disconnecting from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) - } - cts.disconnect_from_server() - cts.State = CLIENT_CONN_DISCONNECTED + // this is active disconnect. if the connection is closed from the server side, + // CONNECTING will never be set before DISCONNECTED. See the code after the done: label above + cts.State.Store(CLIENT_CONN_DISCONNECTING) + cts.disconnect_from_server(true) + cts.State.Store(CLIENT_CONN_DISCONNECTED) // wait for 2 seconds slpctx, cancel_sleep = context.WithTimeout(cts.C.Ctx, 2 * time.Second) diff --git a/hodu.proto b/hodu.proto index 55c2804..85a7ef3 100644 --- a/hodu.proto +++ b/hodu.proto @@ -32,26 +32,26 @@ enum ROUTE_OPTION { message RouteDesc { uint32 RouteId = 1; - // C->S(ROUTE_START): client-side peer address - // S->C(ROUTE_STARTED): server-side listening address + // C->S(ROUTE_START/STOP): client-side peer address + // S->C(ROUTE_STARTED/STOPPED): server-side listening address string TargetAddrStr = 2; - // C->S(ROUTE_START): human-readable name of client-side peer - // S->C(ROUTE_STARTED): clone as sent by C + // C->S(ROUTE_START/STOPPED): human-readable name of client-side peer + // S->C(ROUTE_STARTED/STOPPED): clone as sent by C string TargetName= 3; - // C->S(ROUTE_START): desired listening option on the server-side(e.g. tcp, tcp4, tcp6) + + // C->S(ROUTE_START): requested listening option on the server-side(e.g. tcp, tcp4, tcp6) + // hint to the service-side peer(e.g. local) + // hint to the client-side peer(e.g. tty, http, https) // S->C(ROUTE_STARTED): cloned as sent by C. uint32 ServiceOption = 4; - // C->S(ROUTE_START): desired lisening address on the service-side + // C->S(ROUTE_START): requested lisening address on the service-side // S->C(ROUTE_STARTED): cloned as sent by C string ServiceAddrStr = 5; - // C->S(ROUTE_START): permitted network of server-side peers. - // S->C(ROUTE_STARTED): cloned as sent by C. + // C->S(ROUTE_START): requested permitted network of server-side peers. + // S->C(ROUTE_STARTED): actual permitted network of server-side peers string ServiceNetStr = 6; }; diff --git a/server-ctl.go b/server-ctl.go index de80161..e5c47ad 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -28,8 +28,8 @@ type json_out_server_route struct { ClientPeerAddr string `json:"client-peer-addr"` ClientPeerName string `json:"client-peer-name"` ServerPeerOption string `json:"server-peer-option"` - ServerPeerServiceAddr string `json:"server-peer-service-addr"` // actual listening address - ServerPeerServiceNet string `json:"server-peer-service-net"` + ServerPeerServiceAddr string `json:"server-peer-svc-addr"` // actual listening address + ServerPeerServiceNet string `json:"server-peer-svc-net"` } type json_out_server_peer struct { diff --git a/server.go b/server.go index 96bd2d8..c4fd798 100644 --- a/server.go +++ b/server.go @@ -108,6 +108,7 @@ type Server struct { ext_mtx sync.Mutex ext_svcs []Service + ext_closed bool pxy_ws *server_proxy_ssh_ws pxy_mux *http.ServeMux @@ -1220,7 +1221,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.svc_port_map = make(ServerSvcPortMap) s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) - s.bulletin = NewBulletin[ServerEvent]() + s.bulletin = NewBulletin[ServerEvent](1000) /* creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem")) @@ -1954,6 +1955,11 @@ func (s *Server) StartService(cfg interface{}) { func (s *Server) StartExtService(svc Service, data interface{}) { s.ext_mtx.Lock() + if s.ext_closed { + // don't start it if it's already closed + s.ext_mtx.Unlock() + return + } s.ext_svcs = append(s.ext_svcs, svc) s.ext_mtx.Unlock() s.wg.Add(1) @@ -1978,9 +1984,12 @@ func (s *Server) StartWpxService() { func (s *Server) StopServices() { var ext_svc Service s.ReqStop() + s.ext_mtx.Lock() for _, ext_svc = range s.ext_svcs { ext_svc.StopServices() } + s.ext_closed = true + s.ext_mtx.Unlock() } func (s *Server) FixServices() {