From cd32380425792256a82d2e8c4cf92e631c961358 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Thu, 13 Mar 2025 21:24:59 +0900 Subject: [PATCH] added Atom[T] to have atomic manipulation of composite values --- Makefile | 1 + atom.go | 21 +++++++++++ client-ctl.go | 15 +++----- client.go | 31 +++++----------- server-ctl.go | 62 ++++++++++++------------------- server.go | 100 +++++++++++++++++++++++++++++++++++++------------- 6 files changed, 135 insertions(+), 95 deletions(-) create mode 100644 atom.go diff --git a/Makefile b/Makefile index 11d94b1..ee1e88c 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ NAME=hodu VERSION=1.0.0 SRCS=\ + atom.go \ bulletin.go \ client.go \ client-ctl.go \ diff --git a/atom.go b/atom.go new file mode 100644 index 0000000..f715daa --- /dev/null +++ b/atom.go @@ -0,0 +1,21 @@ +package hodu + +import "sync/atomic" + +type Atom[T any] struct { + val atomic.Value +} + +func (av* Atom[T]) Set(v T) { + av.val.Store(v) +} + +func (av* Atom[T]) Get() T { + var v interface{} + v = av.val.Load() + if v == nil { + var t T + return t // return the zero-value + } + return v.(T) +} diff --git a/client-ctl.go b/client-ctl.go index 5b18991..aa114d6 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -239,8 +239,6 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R var cts *ClientConn var jsp []json_out_client_route var ri RouteId - var local_addr string - var remote_addr string cts = c.cts_map[ci] jsp = make([]json_out_client_route, 0) @@ -266,13 +264,12 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R } cts.route_mtx.Unlock() - local_addr, remote_addr = cts.GetAddrInfo() js = append(js, json_out_client_conn{ Id: cts.Id, ReqServerAddrs: cts.cfg.ServerAddrs, CurrentServerIndex: cts.cfg.Index, - ServerAddr: remote_addr, - ClientAddr: local_addr, + ServerAddr: cts.remote_addr.Get(), + ClientAddr: cts.local_addr.Get(), ClientToken: cts.Token, Routes: jsp, }) @@ -358,8 +355,6 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt var jsp []json_out_client_route var js *json_out_client_conn var ri RouteId - var local_addr string - var remote_addr string jsp = make([]json_out_client_route, 0) cts.route_mtx.Lock() @@ -383,13 +378,13 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt }) } cts.route_mtx.Unlock() - local_addr, remote_addr = cts.GetAddrInfo() + js = &json_out_client_conn{ Id: cts.Id, ReqServerAddrs: cts.cfg.ServerAddrs, CurrentServerIndex: cts.cfg.Index, - ServerAddr: local_addr, - ClientAddr: remote_addr, + ServerAddr: cts.remote_addr.Get(), + ClientAddr: cts.local_addr.Get(), ClientToken: cts.Token, Routes: jsp, } diff --git a/client.go b/client.go index 89094bf..a97e7e6 100644 --- a/client.go +++ b/client.go @@ -134,11 +134,10 @@ type ClientConn struct { State atomic.Int32 // ClientConnState Token string - addr_mtx sync.Mutex // because the following fields are updated concurrently - local_addr string - remote_addr string - local_addr_p string // not reset when disconnected. used mostly in writing log messages without addr_mtx - remote_addr_p string // not reset when disconnected. used mostly in writing log messages without addr_mtx + local_addr Atom[string] + remote_addr Atom[string] + local_addr_p string + remote_addr_p string conn *grpc.ClientConn // grpc connection to the server hdc HoduClient @@ -952,11 +951,9 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) { // if it's called from ReqStop(), we don't really // need to care about it. - cts.addr_mtx.Lock() - cts.local_addr = "" - cts.remote_addr = "" + cts.local_addr.Set("") + cts.remote_addr.Set("") // don't reset cts.local_addr_p and cts.remote_addr_p - cts.addr_mtx.Unlock() cts.discon_mtx.Unlock() } @@ -969,12 +966,6 @@ func (cts *ClientConn) ReqStop() { } } -func (cts *ClientConn) GetAddrInfo() (string, string) { - cts.addr_mtx.Lock() - defer cts.addr_mtx.Unlock() - return cts.local_addr, cts.remote_addr -} - func timed_interceptor(tmout time.Duration) grpc.UnaryClientInterceptor { // The client calls GetSeed() as the first call to the server. // To simulate a kind of connect timeout to the server and find out an unresponsive server, @@ -1052,12 +1043,10 @@ start_over: p, ok = peer.FromContext(psc.Context()) if ok { - cts.addr_mtx.Lock() - cts.remote_addr = p.Addr.String() - cts.local_addr = p.LocalAddr.String() - cts.remote_addr_p = cts.remote_addr - cts.local_addr_p = cts.local_addr - cts.addr_mtx.Unlock() + cts.remote_addr.Set(p.Addr.String()) + cts.local_addr.Set(p.LocalAddr.String()) + cts.remote_addr_p = cts.remote_addr.Get() + cts.local_addr_p = cts.local_addr.Get() } cts.C.log.Write(cts.Sid, LOG_INFO, "Got packet stream from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) diff --git a/server-ctl.go b/server-ctl.go index b6fb446..5182098 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -31,8 +31,8 @@ type json_out_server_route struct { ClientPeerAddr string `json:"client-peer-addr"` ClientPeerName string `json:"client-peer-name"` ServerPeerOption string `json:"server-peer-option"` - ServerPeerServiceAddr string `json:"server-peer-svc-addr"` // actual listening address - ServerPeerServiceNet string `json:"server-peer-svc-net"` + ServerPeerSvcAddr string `json:"server-peer-svc-addr"` // actual listening address + ServerPeerSvcNet string `json:"server-peer-svc-net"` } type json_out_server_peer struct { @@ -210,8 +210,8 @@ func (ctl *server_ctl_server_conns) ServeHTTP(w http.ResponseWriter, req *http.R Id: r.Id, ClientPeerAddr: r.PtcAddr, ClientPeerName: r.PtcName, - ServerPeerServiceAddr: r.SvcAddr.String(), - ServerPeerServiceNet: r.SvcPermNet.String(), + ServerPeerSvcAddr: r.SvcAddr.String(), + ServerPeerSvcNet: r.SvcPermNet.String(), ServerPeerOption: r.SvcOption.String(), }) } @@ -220,7 +220,7 @@ func (ctl *server_ctl_server_conns) ServeHTTP(w http.ResponseWriter, req *http.R Id: cts.Id, ClientAddr: cts.RemoteAddr.String(), ServerAddr: cts.LocalAddr.String(), - ClientToken: cts.ClientToken, + ClientToken: cts.ClientToken.Get(), Routes: jsp, }) } @@ -281,8 +281,8 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt Id: r.Id, ClientPeerAddr: r.PtcAddr, ClientPeerName: r.PtcName, - ServerPeerServiceAddr: r.SvcAddr.String(), - ServerPeerServiceNet: r.SvcPermNet.String(), + ServerPeerSvcAddr: r.SvcAddr.String(), + ServerPeerSvcNet: r.SvcPermNet.String(), ServerPeerOption: r.SvcOption.String(), }) } @@ -291,7 +291,7 @@ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *htt Id: cts.Id, ClientAddr: cts.RemoteAddr.String(), ServerAddr: cts.LocalAddr.String(), - ClientToken: cts.ClientToken, + ClientToken: cts.ClientToken.Get(), Routes: jsp, } @@ -350,8 +350,8 @@ func (ctl *server_ctl_server_conns_id_routes) ServeHTTP(w http.ResponseWriter, r Id: r.Id, ClientPeerAddr: r.PtcAddr, ClientPeerName: r.PtcName, - ServerPeerServiceAddr: r.SvcAddr.String(), - ServerPeerServiceNet: r.SvcPermNet.String(), + ServerPeerSvcAddr: r.SvcAddr.String(), + ServerPeerSvcNet: r.SvcPermNet.String(), ServerPeerOption: r.SvcOption.String(), }) } @@ -437,8 +437,8 @@ func (ctl *server_ctl_server_conns_id_routes_id) ServeHTTP(w http.ResponseWriter Id: r.Id, ClientPeerAddr: r.PtcAddr, ClientPeerName: r.PtcName, - ServerPeerServiceAddr: r.SvcAddr.String(), - ServerPeerServiceNet: r.SvcPermNet.String(), + ServerPeerSvcAddr: r.SvcAddr.String(), + ServerPeerSvcNet: r.SvcPermNet.String(), ServerPeerOption: r.SvcOption.String(), }) if err != nil { goto oops } @@ -699,19 +699,6 @@ oops: } // ------------------------------------ -type json_ctl_ws_event struct { - Type string `json:"type"` - Data []string `json:"data"` -} - -func (pxy *server_ctl_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { - var msg []byte - var err error - - msg, err = json.Marshal(json_ssh_ws_event{Type: type_val, Data: []string{ data } }) - if err == nil { err = websocket.Message.Send(ws, msg) } - return err -} func (ctl *server_ctl_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var s *Server @@ -719,6 +706,7 @@ func (ctl *server_ctl_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var sbsc *ServerEventSubscription var status_code int var err error + var xerr error s = ctl.s @@ -740,7 +728,6 @@ func (ctl *server_ctl_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { } status_code, _ = ctl.s.Cfg.CtlAuth.Authenticate(req) -fmt.Printf ("status code %d\n", status_code) if status_code != http.StatusOK { goto done } @@ -760,14 +747,20 @@ fmt.Printf ("status code %d\n", status_code) for c != nil { var e *ServerEvent var ok bool + var msg[] byte e, ok = <- c if ok { - // TODO: handle this part better - err = ctl.send_ws_data(ws, "server", fmt.Sprintf("%d,%d,%d", e.Desc.Conn, e.Desc.Route, e.Desc.Peer)) + msg, err = json.Marshal(e) if err != nil { - // TODO: logging... + xerr = fmt.Errorf("failed to marshal event - %+v - %s", e, err.Error()) c = nil + } else { + err = websocket.Message.Send(ws, msg) + if err != nil { + xerr = fmt.Errorf("failed to send message - %s", err.Error()) + c = nil + } } } else { // most likely sbcs.C is closed. if not readable, break the loop @@ -785,15 +778,7 @@ ws_recv_loop: if err != nil { break ws_recv_loop } if len(msg) > 0 { - var ev json_ssh_ws_event - err = json.Unmarshal(msg, &ev) - if err == nil { - switch ev.Type { - case "open": - case "close": - break ws_recv_loop - } - } + // do nothing. discard received messages } } @@ -804,5 +789,6 @@ ws_recv_loop: done: ws.Close() wg.Wait() + if err == nil && xerr != nil { err = xerr } return http.StatusOK, err } diff --git a/server.go b/server.go index 23b29eb..8da2100 100644 --- a/server.go +++ b/server.go @@ -67,10 +67,6 @@ type ServerConfig struct { WpxTls *tls.Config } -const SERVER_EVENT_TOPIC_CONN string = "conn" -const SERVER_EVENT_TOPIC_ROUTE string = "route" -const SERVER_EVENT_TOPIC_PEER string = "peer" - type ServerEventKind int const ( SERVER_EVENT_CONN_ADDED = iota @@ -81,15 +77,35 @@ const ( SERVER_EVENT_PEER_DELETED ) -type ServerEventDesc struct { - Conn ConnId - Route RouteId - Peer PeerId +type ServerEvent struct { + Kind ServerEventKind `json:"type"` + Data interface{} `json:"data"` } -type ServerEvent struct { - Kind ServerEventKind - Desc ServerEventDesc +type ServerEventConnAdded struct { + Conn ConnId `json:"conn-id"` + ServerAddr string `json:"server-addr"` + ClientAddr string `json:"client-addr"` + ClientToken string `json:"client-token"` +} + +type ServerEventConnDeleted struct { + Conn ConnId `json:"conn-id:"` +} + +type ServerEventRouteAdded struct { + Conn ConnId `json:"conn-id"` + Route RouteId `json:"route-id"` + ClientPeerAddr string `json:"client-peer-addr"` + ClientPeerName string `json:"client-peer-name"` + ServerPeerOption string `json:"server-peer-option"` + ServerPeerSvcAddr string `json:"server-peer-svc-addr"` + ServerPeerSvcNet string `json:"server-peer-svc-net"` +} + +type ServerEventRouteDeleted struct { + Conn ConnId `json:"conn-id"` + Route RouteId `json:"route-id"` } type ServerEventBulletin = Bulletin[*ServerEvent] @@ -160,7 +176,7 @@ type ServerConn struct { S *Server Id ConnId Sid string // for logging - ClientToken string // provided by client + ClientToken Atom[string] // provided by client RemoteAddr net.Addr // client address that created this structure LocalAddr net.Addr // local address that the client is connected to @@ -365,6 +381,15 @@ func (r *ServerRoute) RunTask(wg *sync.WaitGroup) { r.Cts.S.log.Write(r.Cts.Sid, LOG_DEBUG, "All service-side peer handlers ended on route(%d)", r.Id) r.Cts.RemoveServerRoute(r) // final phase... + + r.Cts.S.bulletin.Enqueue( + &ServerEvent{ + Kind: SERVER_EVENT_ROUTE_DELETED, + Data: &ServerEventRouteDeleted { + Conn: r.Cts.Id, Route: r.Id, + }, + }, + ) } func (r *ServerRoute) ReqStop() { @@ -501,6 +526,22 @@ func (cts *ServerConn) AddNewServerRoute(route_id RouteId, proto RouteOption, pt cts.S.stats.routes.Add(1) cts.route_mtx.Unlock() + cts.S.bulletin.Enqueue( + &ServerEvent{ + Kind: SERVER_EVENT_ROUTE_ADDED, + Data: &ServerEventRouteAdded{ + Conn: cts.Id, + Route: r.Id, + ClientPeerAddr: r.PtcAddr, + ClientPeerName: r.PtcName, + ServerPeerSvcAddr: r.SvcAddr.String(), + ServerPeerSvcNet: r.SvcPermNet.String(), + ServerPeerOption: r.SvcOption.String(), + }, + }, + ) + + // Don't detached the cts task as a go-routine as this function cts.route_wg.Add(1) go r.RunTask(&cts.route_wg) return r, nil @@ -786,7 +827,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - blank token", cts.RemoteAddr) cts.pss.Send(MakeConnErrorPacket(1, "blank token refused")) cts.ReqStop() // TODO: is this desirable to disconnect? - } else if x.Conn.Token != cts.ClientToken { + } else if x.Conn.Token != cts.ClientToken.Get() { _, err = strconv.ParseUint(x.Conn.Token, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err == nil { // this is not != nil. this is to check if the token is numeric cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - numeric token '%s'", cts.RemoteAddr, x.Conn.Token) @@ -802,8 +843,8 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { cts.pss.Send(MakeConnErrorPacket(1, fmt.Sprintf("duplicate token refused - %s", x.Conn.Token))) cts.ReqStop() // TODO: is this desirable to disconnect? } else { - if cts.ClientToken != "" { delete(cts.S.cts_map_by_token, cts.ClientToken) } - cts.ClientToken = x.Conn.Token + if cts.ClientToken.Get() != "" { delete(cts.S.cts_map_by_token, cts.ClientToken.Get()) } + cts.ClientToken.Set(x.Conn.Token) cts.S.cts_map_by_token[x.Conn.Token] = cts cts.S.cts_mtx.Unlock() cts.S.log.Write(cts.Sid, LOG_INFO, "client(%d) %s - token set to '%s'", cts.Id, cts.RemoteAddr, x.Conn.Token) @@ -850,8 +891,10 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { // function be the channel waiter only. // increment on the wait group is for the caller to wait for // these detached goroutines to finish. - wg.Add(1) - go cts.receive_from_stream(wg) + //wg.Add(1) + //go cts.receive_from_stream(wg) + cts.route_wg.Add(1) + go cts.receive_from_stream(&cts.route_wg) for { // exit if context is done @@ -879,7 +922,7 @@ done: cts.S.bulletin.Enqueue( &ServerEvent{ Kind: SERVER_EVENT_CONN_DELETED, - Desc: ServerEventDesc{ Conn: cts.Id }, + Data: &ServerEventConnDeleted{ Conn: cts.Id }, }, ) // Don't detached the cts task as a go-routine as this function @@ -942,7 +985,12 @@ func (s *Server) PacketStream(strm Hodu_PacketStreamServer) error { s.bulletin.Enqueue( &ServerEvent{ Kind: SERVER_EVENT_CONN_ADDED, - Desc: ServerEventDesc{ Conn: cts.Id }, + Data: &ServerEventConnAdded{ + Conn: cts.Id, + ServerAddr: cts.LocalAddr.String(), + ClientAddr: cts.RemoteAddr.String(), + ClientToken: cts.ClientToken.Get(), + }, }, ) @@ -1700,15 +1748,15 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p s.cts_mtx.Unlock() return nil, fmt.Errorf("existing client address - %s", cts.RemoteAddr.String()) } - if cts.ClientToken != "" { + if cts.ClientToken.Get() != "" { // this check is not needed as Token is never set at this phase // however leave statements here for completeness - _, ok = s.cts_map_by_token[cts.ClientToken] + _, ok = s.cts_map_by_token[cts.ClientToken.Get()] if ok { s.cts_mtx.Unlock() - return nil, fmt.Errorf("existing client token - %s", cts.ClientToken) + return nil, fmt.Errorf("existing client token - %s", cts.ClientToken.Get()) } - s.cts_map_by_token[cts.ClientToken] = &cts + s.cts_map_by_token[cts.ClientToken.Get()] = &cts } s.cts_map_by_addr[cts.RemoteAddr] = &cts s.cts_map[cts.Id] = &cts @@ -1744,7 +1792,7 @@ func (s *Server) RemoveServerConn(cts *ServerConn) error { delete(s.cts_map, cts.Id) delete(s.cts_map_by_addr, cts.RemoteAddr) - if cts.ClientToken != "" { delete(s.cts_map_by_token, cts.ClientToken) } + if cts.ClientToken.Get() != "" { delete(s.cts_map_by_token, cts.ClientToken.Get()) } s.stats.conns.Store(int64(len(s.cts_map))) s.cts_mtx.Unlock() @@ -1765,7 +1813,7 @@ func (s *Server) RemoveServerConnByAddr(addr net.Addr) (*ServerConn, error) { } delete(s.cts_map, cts.Id) delete(s.cts_map_by_addr, cts.RemoteAddr) - if cts.ClientToken != "" { delete(s.cts_map_by_token, cts.ClientToken) } + if cts.ClientToken.Get() != "" { delete(s.cts_map_by_token, cts.ClientToken.Get()) } s.stats.conns.Store(int64(len(s.cts_map))) s.cts_mtx.Unlock() @@ -1786,7 +1834,7 @@ func (s *Server) RemoveServerConnByClientToken(token string) (*ServerConn, error } delete(s.cts_map, cts.Id) delete(s.cts_map_by_addr, cts.RemoteAddr) - delete(s.cts_map_by_token, cts.ClientToken) // no Empty check becuase an empty token is never found in the map + delete(s.cts_map_by_token, cts.ClientToken.Get()) // no Empty check becuase an empty token is never found in the map s.stats.conns.Store(int64(len(s.cts_map))) s.cts_mtx.Unlock()