From b17f3af7ada06e621b1231d633093c628c1e0c7d Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Tue, 3 Dec 2024 00:55:19 +0900 Subject: [PATCH] added the id field to ServerConn --- client-ctl.go | 29 ++++++------- client.go | 2 +- server-ctl.go | 66 +++++++++++++++++++++++++++- server.go | 118 ++++++++++++++++++++++++++++++++++++-------------- 4 files changed, 164 insertions(+), 51 deletions(-) diff --git a/client-ctl.go b/client-ctl.go index 7b6da65..1aa0278 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -149,10 +149,11 @@ func (ctl *client_ctl_client_conns) ServeHTTP(w http.ResponseWriter, req *http.R var cts *ClientConn err = json.NewDecoder(req.Body).Decode(&s) - if err != nil { + if err != nil || s.ServerAddr == "" { status_code = http.StatusBadRequest; w.WriteHeader(status_code) goto done } + cc.ServerAddr = s.ServerAddr //cc.PeerAddrs = s.PeerAddrs cts, err = c.start_service(&cc) // TODO: this can be blocking. do we have to resolve addresses before calling this? also not good because resolution succeed or fail at each attempt. however ok as ServeHTTP itself is in a goroutine? @@ -196,6 +197,7 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt var conn_id string var conn_nid uint64 var je *json.Encoder + var cts *ClientConn c = ctl.c je = json.NewEncoder(w) @@ -209,19 +211,18 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt goto done } + cts = c.FindClientConnById(uint32(conn_nid)) + if cts == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } + goto done + } + switch req.Method { case http.MethodGet: var r *ClientRoute var jsp []json_out_client_route var js *json_out_client_conn - var cts *ClientConn - - cts = c.FindClientConnById(uint32(conn_nid)) - if cts == nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } - goto done - } jsp = make([]json_out_client_route, 0) cts.route_mtx.Lock() @@ -239,13 +240,9 @@ func (ctl *client_ctl_client_conns_id) ServeHTTP(w http.ResponseWriter, req *htt if err = je.Encode(js); err != nil { goto oops } case http.MethodDelete: - err = c.RemoveClientConnById(uint32(conn_nid)) - if err != nil { - status_code = http.StatusNotFound; w.WriteHeader(status_code) - if err = je.Encode(json_errmsg{Text: err.Error()}); err != nil { goto oops } - } else { - status_code = http.StatusNoContent; w.WriteHeader(status_code) - } + //c.RemoveClientConn(cts) + cts.ReqStop() + status_code = http.StatusNoContent; w.WriteHeader(status_code) default: status_code = http.StatusBadRequest; w.WriteHeader(status_code) diff --git a/client.go b/client.go index 9de3535..f1b2bde 100644 --- a/client.go +++ b/client.go @@ -951,7 +951,7 @@ func (c *Client) RemoveClientConn(cts *ClientConn) error { fmt.Printf("REMOVEDDDDDD CONNECTION FROM %s total servers %d\n", cts.cfg.ServerAddr, len(c.cts_map_by_addr)) c.cts_mtx.Unlock() - c.ReqStop() + cts.ReqStop() return nil } diff --git a/server-ctl.go b/server-ctl.go index 30a1649..ad0c6fb 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -2,6 +2,7 @@ package hodu import "encoding/json" import "net/http" +import "strconv" type json_out_server_conn struct { @@ -86,5 +87,68 @@ oops: // ------------------------------------ func (ctl *server_ctl_server_conns_id) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // TODO: + var s *Server + var status_code int + var err error + var je *json.Encoder + var conn_id string + var conn_nid uint64 + var cts *ServerConn + + s = ctl.s + je = json.NewEncoder(w) + + conn_id = req.PathValue("conn_id") + + conn_nid, err = strconv.ParseUint(conn_id, 10, 32) + if err != nil { + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "wrong connection id - " + conn_id}); err != nil { goto oops } + goto done + } + + cts = s.FindServerConnById(uint32(conn_nid)) + if cts == nil { + status_code = http.StatusNotFound; w.WriteHeader(status_code) + if err = je.Encode(json_errmsg{Text: "non-existent connection id - " + conn_id}); err != nil { goto oops } + goto done + } + + switch req.Method { + case http.MethodGet: + var r *ServerRoute + var jsp []json_out_server_route + var js *json_out_server_conn + + jsp = make([]json_out_server_route, 0) + cts.route_mtx.Lock() + for _, r = range cts.route_map { + jsp = append(jsp, json_out_server_route{ + Id: r.id, + ClientPeerAddr: r.ptc_addr, + ServerPeerListenAddr: r.laddr.String(), + }) + } + js = &json_out_server_conn{Id: cts.id, ClientAddr: cts.caddr.String(), ServerAddr: cts.local_addr.String(), Routes: jsp} + cts.route_mtx.Unlock() + + status_code = http.StatusOK; w.WriteHeader(status_code) + if err = je.Encode(js); err != nil { goto oops } + + case http.MethodDelete: + //s.RemoveServerConn(cts) + cts.ReqStop() + status_code = http.StatusNoContent; w.WriteHeader(status_code) + + default: + status_code = http.StatusBadRequest; w.WriteHeader(status_code) + } + +done: + s.log.Write("", LOG_DEBUG, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.URL.String(), status_code) // TODO: time taken + return + +oops: + s.log.Write("", LOG_ERROR, "[%s] %s %s - %s", req.RemoteAddr, req.Method, req.URL.String(), err.Error()) + return } diff --git a/server.go b/server.go index 77b36a0..35c3dfe 100644 --- a/server.go +++ b/server.go @@ -19,34 +19,36 @@ import "google.golang.org/grpc/stats" const PTS_LIMIT = 8192 -type ServerConnMap = map[net.Addr]*ServerConn +type ServerConnMapByAddr = map[net.Addr]*ServerConn +type ServerConnMap = map[uint32]*ServerConn type ServerPeerConnMap = map[uint32]*ServerPeerConn type ServerRouteMap = map[uint32]*ServerRoute type Server struct { - ctx context.Context - ctx_cancel context.CancelFunc - tlscfg *tls.Config + ctx context.Context + ctx_cancel context.CancelFunc + tlscfg *tls.Config - wg sync.WaitGroup - ext_mtx sync.Mutex - ext_svcs []Service - stop_req atomic.Bool - stop_chan chan bool + wg sync.WaitGroup + ext_mtx sync.Mutex + ext_svcs []Service + stop_req atomic.Bool + stop_chan chan bool - ctl_prefix string - ctl_mux *http.ServeMux - ctl *http.Server // control server + ctl_prefix string + ctl_mux *http.ServeMux + ctl *http.Server // control server - l []*net.TCPListener // main listener for grpc - l_wg sync.WaitGroup + l []*net.TCPListener // main listener for grpc + l_wg sync.WaitGroup - cts_mtx sync.Mutex - cts_map ServerConnMap - cts_wg sync.WaitGroup + cts_mtx sync.Mutex + cts_map ServerConnMap + cts_map_by_addr ServerConnMapByAddr + cts_wg sync.WaitGroup - gs *grpc.Server - log Logger + gs *grpc.Server + log Logger UnimplementedHoduServer } @@ -56,6 +58,7 @@ type Server struct { type ServerConn struct { svr *Server id uint32 + lid string // for logging caddr net.Addr // client address that created this structure local_addr net.Addr pss *GuardedPacketStreamServer @@ -390,7 +393,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { if err != nil { cts.svr.log.Write("", LOG_ERROR, "Failed to add server route for client %s peer %s", cts.caddr, x.Route.AddrStr) } else { - cts.svr.log.Write("", LOG_INFO, "Added server route(id=%d) for client %s peer %s", r.id, cts.caddr, x.Route.AddrStr) + cts.svr.log.Write("", LOG_INFO, "Added server route(id=%d) for client %s peer %s to cts(id=%d)", r.id, cts.caddr, x.Route.AddrStr, cts.id) err = cts.pss.Send(MakeRouteStartedPacket(r.id, x.Route.Proto, r.laddr.String())) if err != nil { r.ReqStop() @@ -728,7 +731,8 @@ func NewServer(ctx context.Context, ctl_addr string, laddrs []string, logger Log s.tlscfg = tlscfg s.ext_svcs = make([]Service, 0, 1) - s.cts_map = make(ServerConnMap) // TODO: make it configurable... + s.cts_map = make(ServerConnMap) + s.cts_map_by_addr = make(ServerConnMapByAddr) s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) /* @@ -750,9 +754,9 @@ func NewServer(ctx context.Context, ctl_addr string, laddrs []string, logger Log s.ctl_mux = http.NewServeMux() cwd, _ = os.Getwd() s.ctl_mux.Handle(s.ctl_prefix + "/ui/", http.StripPrefix(s.ctl_prefix, http.FileServer(http.Dir(cwd)))) // TODO: proper directory. it must not use the current working directory... - //s.ctl_mux.HandleFunc(s.ctl_prefix + "/ws/tty", websocket.Handler(server_ws_tty).ServeHTTP) s.ctl_mux.Handle(s.ctl_prefix + "/ws/tty", new_server_ctl_ws_tty(&s)) s.ctl_mux.Handle(s.ctl_prefix + "/server-conns", &server_ctl_server_conns{s: &s}) + s.ctl_mux.Handle(s.ctl_prefix + "/server-conns/{conn_id}", &server_ctl_server_conns_id{s: &s}) s.ctl = &http.Server{ Addr: ctl_addr, @@ -872,6 +876,7 @@ func (s *Server) ReqStop() { func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, pss Hodu_PacketStreamServer) (*ServerConn, error) { var cts ServerConn + var id uint32 var ok bool cts.svr = s @@ -886,11 +891,21 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p s.cts_mtx.Lock() defer s.cts_mtx.Unlock() - _, ok = s.cts_map[cts.caddr] + id = rand.Uint32() + for { + _, ok = s.cts_map[id] + if !ok { break } + id++ + } + cts.id = id + cts.lid = fmt.Sprintf("%d", id) // id in string used for logging + + _, ok = s.cts_map_by_addr[cts.caddr] if ok { return nil, fmt.Errorf("existing client - %s", cts.caddr.String()) } - s.cts_map[cts.caddr] = &cts + s.cts_map_by_addr[cts.caddr] = &cts + s.cts_map[id] = &cts; s.log.Write("", LOG_DEBUG, "Added client connection from %s", cts.caddr.String()) return &cts, nil } @@ -906,25 +921,62 @@ func (s *Server) ReqStopAllServerConns() { } } -func (s *Server) RemoveServerConn(cts *ServerConn) { +func (s *Server) RemoveServerConn(cts *ServerConn) error { + var conn *ServerConn + var ok bool + s.cts_mtx.Lock() - delete(s.cts_map, cts.caddr) - s.log.Write("", LOG_DEBUG, "Removed client connection from %s", cts.caddr.String()) + + conn, ok = s.cts_map[cts.id] + if !ok { + s.cts_mtx.Unlock() + return fmt.Errorf("non-existent connection id - %d", cts.id) + } + if conn != cts { + s.cts_mtx.Unlock() + return fmt.Errorf("non-existent connection id - %d", cts.id) + } + + delete(s.cts_map, cts.id) + delete(s.cts_map_by_addr, cts.caddr) s.cts_mtx.Unlock() + + cts.ReqStop() + return nil } -func (s *Server) RemoveServerConnByAddr(addr net.Addr) { +func (s *Server) RemoveServerConnByAddr(addr net.Addr) error { + var cts *ServerConn + var ok bool + + s.cts_mtx.Lock() + + cts, ok = s.cts_map_by_addr[addr] + if !ok { + s.cts_mtx.Unlock() + return fmt.Errorf("non-existent connection address - %s", addr.String()) + } + delete(s.cts_map, cts.id) + delete(s.cts_map_by_addr, cts.caddr) + s.cts_mtx.Unlock() + + cts.ReqStop() + return nil +} + +func (s* Server) FindServerConnById(id uint32) *ServerConn { var cts *ServerConn var ok bool s.cts_mtx.Lock() defer s.cts_mtx.Unlock() - cts, ok = s.cts_map[addr] - if ok { - delete(s.cts_map, cts.caddr) - cts.ReqStop() + cts, ok = s.cts_map[id] + if !ok { + return nil } + + return cts } func (s *Server) FindServerConnByAddr(addr net.Addr) *ServerConn { @@ -934,7 +986,7 @@ func (s *Server) FindServerConnByAddr(addr net.Addr) *ServerConn { s.cts_mtx.Lock() defer s.cts_mtx.Unlock() - cts, ok = s.cts_map[addr] + cts, ok = s.cts_map_by_addr[addr] if !ok { return nil }