From 1e6fbed19d7fed780cf0a0400baaba75e24fa9d6 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Wed, 12 Mar 2025 12:08:56 +0900 Subject: [PATCH] fixed wrong queue implementation in bulletin.go --- Makefile | 2 +- bulletin.go | 242 ++++++++++++++++++++++++++++++++++------------- bulletin_test.go | 78 ++++++++++++++- server-ctl.go | 89 +++++++++++++++++ server.go | 74 ++++++++++----- 5 files changed, 391 insertions(+), 94 deletions(-) diff --git a/Makefile b/Makefile index d819e12..11d94b1 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ SRCS=\ server-ctl.go \ server-metrics.go \ server-peer.go \ - server-proxy.go \ + server-pxy.go \ system.go \ transform.go \ diff --git a/bulletin.go b/bulletin.go index f691967..e17765c 100644 --- a/bulletin.go +++ b/bulletin.go @@ -4,6 +4,7 @@ import "container/list" import "container/ring" import "errors" import "sync" +import "time" type BulletinSubscription[T interface{}] struct { C chan T @@ -17,22 +18,50 @@ type BulletinSubscriptionList = *list.List type BulletinSubscriptionMap map[string]BulletinSubscriptionList type Bulletin[T interface{}] struct { + svc Service + sbsc_map BulletinSubscriptionMap + sbsc_list *list.List sbsc_mtx sync.RWMutex - closed bool + blocked bool r_mtx sync.RWMutex r *ring.Ring - r_capa int - r_full bool + r_head *ring.Ring + r_tail *ring.Ring + r_len int + r_cap int + r_chan chan struct{} + stop_chan chan struct{} } -func NewBulletin[T interface{}](capa int) *Bulletin[T] { +func NewBulletin[T interface{}](svc Service, capa int) *Bulletin[T] { + var r *ring.Ring + + r = ring.New(capa) return &Bulletin[T]{ sbsc_map: make(BulletinSubscriptionMap, 0), - r: ring.New(capa), - r_capa: capa, - r_full: false, + sbsc_list: list.New(), + r: r, + r_head: r, + r_tail: r, + r_cap: capa, + r_len: 0, + r_chan: make(chan struct{}, 1), + stop_chan: make(chan struct{}, 1), + } +} + +func (b *Bulletin[T]) unsubscribe_list_nolock(sl BulletinSubscriptionList) { + var sbsc *BulletinSubscription[T] + var e *list.Element + + for e = sl.Front(); e != nil; e = e.Next() { + sbsc = e.Value.(*BulletinSubscription[T]) + sl.Remove(sbsc.node) + close(sbsc.C) + sbsc.b = nil + sbsc.node = nil } } @@ -41,20 +70,12 @@ func (b *Bulletin[T]) unsubscribe_all_nolock() { var sl BulletinSubscriptionList for topic, sl = range b.sbsc_map { - var sbsc *BulletinSubscription[T] - var e *list.Element - - for e = sl.Front(); e != nil; e = e.Next() { - sbsc = e.Value.(*BulletinSubscription[T]) - close(sbsc.C) - sbsc.b = nil - sbsc.node = nil - } - + b.unsubscribe_list_nolock(sl) delete(b.sbsc_map, topic) } - b.closed = true + b.unsubscribe_list_nolock(b.sbsc_list) + b.blocked = true } func (b *Bulletin[T]) UnsubscribeAll() { @@ -63,33 +84,44 @@ func (b *Bulletin[T]) UnsubscribeAll() { b.sbsc_mtx.Unlock() } -func (b *Bulletin[T]) Close() { +func (b *Bulletin[T]) Block() { b.sbsc_mtx.Lock() - if !b.closed { - b.unsubscribe_all_nolock() - b.closed = true - } + b.blocked = true + b.sbsc_mtx.Unlock() +} + +func (b *Bulletin[T]) Unblock() { + b.sbsc_mtx.Lock() + b.blocked = false b.sbsc_mtx.Unlock() } func (b *Bulletin[T]) Subscribe(topic string) (*BulletinSubscription[T], error) { var sbsc BulletinSubscription[T] - var sbsc_list BulletinSubscriptionList - var ok bool - if b.closed { return nil, errors.New("closed bulletin") } + b.sbsc_mtx.Lock() + if b.blocked { + b.sbsc_mtx.Unlock() + return nil, errors.New("blocked") + } sbsc.C = make(chan T, 128) // TODO: size? sbsc.b = b sbsc.topic = topic - b.sbsc_mtx.Lock() - sbsc_list, ok = b.sbsc_map[topic] - if !ok { - sbsc_list = list.New() - b.sbsc_map[topic] = sbsc_list + if topic == "" { + sbsc.node = b.sbsc_list.PushBack(&sbsc) + } else { + var sbsc_list BulletinSubscriptionList + var ok bool + + sbsc_list, ok = b.sbsc_map[topic] + if !ok { + sbsc_list = list.New() + b.sbsc_map[topic] = sbsc_list + } + sbsc.node = sbsc_list.PushBack(&sbsc) } - sbsc.node = sbsc_list.PushBack(&sbsc) b.sbsc_mtx.Unlock() return &sbsc, nil } @@ -100,12 +132,19 @@ func (b *Bulletin[T]) Unsubscribe(sbsc *BulletinSubscription[T]) { var ok bool b.sbsc_mtx.Lock() - sl, ok = b.sbsc_map[sbsc.topic] - if ok { - sl.Remove(sbsc.node) + if sbsc.topic == "" { + b.sbsc_list.Remove(sbsc.node) close(sbsc.C) sbsc.node = nil sbsc.b = nil + } else { + sl, ok = b.sbsc_map[sbsc.topic] + if ok { + sl.Remove(sbsc.node) + close(sbsc.C) + sbsc.node = nil + sbsc.b = nil + } } b.sbsc_mtx.Unlock() } @@ -114,59 +153,130 @@ func (b *Bulletin[T]) Unsubscribe(sbsc *BulletinSubscription[T]) { func (b *Bulletin[T]) Publish(topic string, data T) { var sl BulletinSubscriptionList var ok bool - var topics [2]string - var t string - - if b.closed { return } if topic == "" { return } - topics[0] = topic - topics[1] = "" - b.sbsc_mtx.Lock() - for _, t = range topics { - sl, ok = b.sbsc_map[t] - if ok { - var sbsc *BulletinSubscription[T] - var e *list.Element - for e = sl.Front(); e != nil; e = e.Next() { - sbsc = e.Value.(*BulletinSubscription[T]) - select { - case sbsc.C <- data: - // ok. could be written. - default: - // channel full. discard it - } + if b.blocked { + b.sbsc_mtx.Unlock() + return + } + + sl, ok = b.sbsc_map[topic] + if ok { + var sbsc *BulletinSubscription[T] + var e *list.Element + for e = sl.Front(); e != nil; e = e.Next() { + sbsc = e.Value.(*BulletinSubscription[T]) + select { + case sbsc.C <- data: + // ok. could be written. + default: + // channel full. discard it } } } b.sbsc_mtx.Unlock() } -func (b *Bulletin[T]) Enqueue(topic string, data T) { +func (b *Bulletin[T]) Enqueue(data T) { + // hopefuly, it's fater to use a single mutex, a ring buffer, and a notification channel than + // to use a channel to pass messages. TODO: performance verification b.r_mtx.Lock() - b.r.Value = data // update the value at the current position - b.r = b.r.Next() // move the current position + if b.blocked { + b.r_mtx.Unlock() + return + } + + if b.r_len < b.r_cap { + b.r_len++ + } else { + b.r_head = b.r_head.Next() + } + b.r_tail.Value = data // update the value at the current position + b.r_tail = b.r_tail.Next() // move the current position + select { + case b.r_chan <- struct{}{}: + // write success + default: + // don't care if not writable + } b.r_mtx.Unlock() } -func (b *Bulletin[T]) Dequeue() { +func (b *Bulletin[T]) Dequeue() (T, bool) { + var v T + var ok bool + b.r_mtx.Lock() + + if b.r_len > 0 { + v = b.r_head.Value.(T) // store the value for returning + b.r_head.Value = nil // nullify the value + b.r_head = b.r_head.Next() // advance the head position + b.r_len-- + ok = true + } + b.r_mtx.Unlock() + return v, ok } -/* func (b *Bulletin[T]) RunTask(wg *sync.WaitGroup) { var done bool - var msg T - var ok bool + var tmr *time.Timer defer wg.Done() + tmr = time.NewTimer(3 * time.Second) for !done { - select { - case msg, ok = <- b.C: - if !ok { done = true } + var msg T + var ok bool + + msg, ok = b.Dequeue() + if !ok { + select { + case <-b.stop_chan: + // this may break the loop prematurely while there + // are messages to read as it uses two different channels: + // one for stop, another for notification + done = true + case <-b.r_chan: + // noti received. + tmr.Stop() + tmr.Reset(3 * time.Second) + case <-tmr.C: + // try to dequeue again + tmr.Reset(3 * time.Second) + } + } else { + // forward msg to all subscribers... + var e *list.Element + var sbsc *BulletinSubscription[T] + + tmr.Stop() + + b.sbsc_mtx.Lock() + for e = b.sbsc_list.Front(); e != nil; e = e.Next() { + sbsc = e.Value.(*BulletinSubscription[T]) + select { + case sbsc.C <- msg: + // ok. could be written. + default: + // channel full. discard it + } + } + b.sbsc_mtx.Unlock() } } -}*/ + + tmr.Stop() +} + +func (b *Bulletin[T]) ReqStop() { + select { + case b.stop_chan <- struct{}{}: + // write success + default: + // ignore failure + } +} diff --git a/bulletin_test.go b/bulletin_test.go index d2d316e..5522c45 100644 --- a/bulletin_test.go +++ b/bulletin_test.go @@ -6,7 +6,7 @@ import "sync" import "testing" import "time" -func TestBulletin(t *testing.T) { +func TestBulletin1(t *testing.T) { var b *hodu.Bulletin[string] var s1 *hodu.BulletinSubscription[string] var s2 *hodu.BulletinSubscription[string] @@ -14,7 +14,7 @@ func TestBulletin(t *testing.T) { var nmsgs1 int var nmsgs2 int - b = hodu.NewBulletin[string](100) + b = hodu.NewBulletin[string](nil, 100) s1, _ = b.Subscribe("t1") s2, _ = b.Subscribe("t2") @@ -38,9 +38,7 @@ func TestBulletin(t *testing.T) { case m, ok = <-c2: if ok { fmt.Printf ("s2: %+v\n", m); nmsgs2++ } else { c2 = nil; fmt.Printf ("s2 closed\n") } } - } - }() b.Publish("t1", "donkey") @@ -61,7 +59,8 @@ func TestBulletin(t *testing.T) { b.Publish("t2", "fly to the skyp") time.Sleep(100 * time.Millisecond) - b.Close() + b.Block() + b.UnsubscribeAll() wg.Wait() fmt.Printf ("---------------------\n") @@ -69,3 +68,72 @@ func TestBulletin(t *testing.T) { if nmsgs2 != 5 { t.Errorf("number of messages for s2 received must be 5, but got %d\n", nmsgs2) } } +func TestBulletin2(t *testing.T) { + var b *hodu.Bulletin[string] + var s1 *hodu.BulletinSubscription[string] + var s2 *hodu.BulletinSubscription[string] + var wg sync.WaitGroup + var nmsgs1 int + var nmsgs2 int + + b = hodu.NewBulletin[string](nil, 13) // if the size is too small, some messages are lost + + wg.Add(1) + go b.RunTask(&wg) + + s1, _ = b.Subscribe("") + s2, _ = b.Subscribe("") + + wg.Add(1) + go func() { + var m string + var ok bool + var c1 chan string + var c2 chan string + + c1 = s1.C + c2 = s2.C + + defer wg.Done() + for c1 != nil || c2 != nil { + select { + case m, ok = <-c1: + if ok { fmt.Printf ("s1: %+v\n", m); nmsgs1++ } else { c1 = nil; fmt.Printf ("s1 closed\n")} + + case m, ok = <-c2: + if ok { fmt.Printf ("s2: %+v\n", m); nmsgs2++ } else { c2 = nil; fmt.Printf ("s2 closed\n") } + } + } + }() + + b.Enqueue("donkey") + b.Enqueue("monkey") + b.Enqueue("donkey kong") + b.Enqueue("monkey hong") + b.Enqueue("home") + b.Enqueue("fire") + b.Enqueue("sunflower") + b.Enqueue("itsy bitsy spider") + b.Enqueue("marigold") + b.Enqueue("parrot") + b.Enqueue("tiger") + b.Enqueue("walrus") + b.Enqueue("donkey runs") + // without this unsubscription may happen before s2.C can receive messages + // 100 millisconds must be longer than enough for all messages to be received + time.Sleep(100 * time.Millisecond) + b.Unsubscribe(s2) + b.Enqueue("lion king") + b.Enqueue("fly to the ground") + b.Enqueue("dig it") + b.Enqueue("dig it dawg") + time.Sleep(100 * time.Millisecond) + + b.UnsubscribeAll() + b.ReqStop() + wg.Wait() + fmt.Printf ("---------------------\n") + + if nmsgs1 != 17 { t.Errorf("number of messages for s1 received must be 17, but got %d\n", nmsgs1) } + if nmsgs2 != 13 { t.Errorf("number of messages for s2 received must be 13, but got %d\n", nmsgs2) } +} diff --git a/server-ctl.go b/server-ctl.go index e5c47ad..743699d 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -3,8 +3,11 @@ package hodu import "encoding/json" import "fmt" import "net/http" +import "sync" import "time" +import "golang.org/x/net/websocket" + type ServerTokenClaim struct { ExpiresAt int64 `json:"exp"` IssuedAt int64 `json:"iat"` @@ -104,6 +107,10 @@ type server_ctl_stats struct { server_ctl } +type server_ctl_ws struct { + server_ctl +} + // ------------------------------------ func (ctl *server_ctl) Id() string { @@ -690,3 +697,85 @@ func (ctl *server_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request) oops: return status_code, err } + +// ------------------------------------ +type json_ctl_ws_event struct { + Type string `json:"type"` + Data []string `json:"data"` +} + +func (pxy *server_ctl_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { + var msg []byte + var err error + + msg, err = json.Marshal(json_ssh_ws_event{Type: type_val, Data: []string{ data } }) + if err == nil { err = websocket.Message.Send(ws, msg) } + return err +} + +func (ctl *server_ctl_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { + var s *Server + var wg sync.WaitGroup + var sbsc *ServerEventSubscription + var err error + + s = ctl.s + sbsc, err = s.bulletin.Subscribe("") + if err != nil { goto done } + + wg.Add(1) + go func() { + var c chan *ServerEvent + var err error + + defer wg.Done() + c = sbsc.C + + for c != nil { + var e *ServerEvent + var ok bool + + e, ok = <- c + if ok { + fmt.Printf ("s1: %+v\n", e) + err = ctl.send_ws_data(ws, "server", fmt.Sprintf("%d,%d,%d", e.Desc.Conn, e.Desc.Route, e.Desc.Peer)) + if err != nil { + // TODO: logging... + c = nil + } + } else { + // most likely sbcs.C is closed. if not readable, break the loop + c = nil + } + } + + ws.Close() // hack to break the recv loop. don't care about double closes + }() + +ws_recv_loop: + for { + var msg []byte + err = websocket.Message.Receive(ws, &msg) + if err != nil { goto done } + + if len(msg) > 0 { + var ev json_ssh_ws_event + err = json.Unmarshal(msg, &ev) + if err == nil { + switch ev.Type { + case "open": + case "close": + break ws_recv_loop + } + } + } + } + +done: + // Ubsubscribe() to break the internal event reception + // goroutine as well as for cleanup + s.bulletin.Unsubscribe(sbsc) + ws.Close() + wg.Wait() + return http.StatusOK, err // TODO: change code... +} diff --git a/server.go b/server.go index c4fd798..c2663dd 100644 --- a/server.go +++ b/server.go @@ -92,7 +92,8 @@ type ServerEvent struct { Desc ServerEventDesc } -type ServerEventBulletin = Bulletin[ServerEvent] +type ServerEventBulletin = Bulletin[*ServerEvent] +type ServerEventSubscription = BulletinSubscription[*ServerEvent] type Server struct { UnimplementedHoduServer @@ -118,6 +119,7 @@ type Server struct { wpx_mux *http.ServeMux wpx []*http.Server // proxy server than handles http/https only + ctl_ws *server_ctl_ws ctl_mux *http.ServeMux ctl []*http.Server // control server @@ -877,14 +879,12 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { done: cts.ReqStop() // just in case cts.route_wg.Wait() - /* - cts.S.bulletin.Publish( - SERVER_EVENT_TOPIC_CONN, - ServerEvent{ + cts.S.bulletin.Enqueue( + &ServerEvent{ Kind: SERVER_EVENT_CONN_DELETED, Desc: ServerEventDesc{ Conn: cts.Id }, }, - )*/ + ) // Don't detached the cts task as a go-routine as this function cts.S.log.Write(cts.Sid, LOG_INFO, "End of connection task") } @@ -942,14 +942,12 @@ func (s *Server) PacketStream(strm Hodu_PacketStreamServer) error { return fmt.Errorf("unable to add client %s - %s", p.Addr.String(), err.Error()) } - /* - s.bulletin.Publish( - SERVER_EVENT_TOPIC_CONN, - ServerEvent{ + s.bulletin.Enqueue( + &ServerEvent{ Kind: SERVER_EVENT_CONN_ADDED, Desc: ServerEventDesc{ Conn: cts.Id }, }, - )*/ + ) // Don't detached the cts task as a go-routine as this function // is invoked as a go-routine by the grpc server. @@ -1122,7 +1120,12 @@ type ServerHttpHandler interface { Id() string Cors(req *http.Request) bool Authenticate(req *http.Request) (int, string) - ServeHTTP (w http.ResponseWriter, req *http.Request) (int, error) + ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) +} + +type ServerWebsocketHandler interface { + Id() string + ServeWebsocket(ws *websocket.Conn) (int, error) } func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler { @@ -1177,6 +1180,33 @@ func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler { }) } +func (s *Server) WrapWebsocketHandler(handler ServerWebsocketHandler) websocket.Handler { + return websocket.Handler(func(ws *websocket.Conn) { + var status_code int + var err error + var start_time time.Time + var time_taken time.Duration + var req *http.Request + + req = ws.Request() + start_time = time.Now() + s.log.Write(handler.Id(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.URL.String()) + + status_code, err = handler.ServeWebsocket(ws) + // it looks like the websocket handler never comes down here... + + time_taken = time.Now().Sub(start_time) + + if status_code > 0 { + if err != nil { + s.log.Write(handler.Id(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds(), err.Error()) + } else { + s.log.Write(handler.Id(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.URL.String(), status_code, time_taken.Seconds()) + } + } + }) +} + func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfig) (*Server, error) { var s Server var l *net.TCPListener @@ -1221,15 +1251,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.svc_port_map = make(ServerSvcPortMap) s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) - s.bulletin = NewBulletin[ServerEvent](1000) - -/* - creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem")) - if err != nil { - log.Fatalf("failed to create credentials: %v", err) - } - gs = grpc.NewServer(grpc.Creds(creds)) -*/ + s.bulletin = NewBulletin[*ServerEvent](&s, 1024) opts = append(opts, grpc.StatsHandler(&ConnCatcher{server: &s})) if s.Cfg.RpcTls != nil { opts = append(opts, grpc.Creds(credentials.NewTLS(s.Cfg.RpcTls))) } @@ -1274,6 +1296,10 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.ctl_mux.Handle(s.Cfg.CtlPrefix + "/_ctl/metrics", promhttp.HandlerFor(s.promreg, promhttp.HandlerOpts{ EnableOpenMetrics: true })) + //s.ctl_ws = &server_ctl_ws{server_ctl{s: &s, id: HS_ID_CTL}} + s.ctl_mux.Handle("/_ctl/events", + s.WrapWebsocketHandler(&server_ctl_ws{server_ctl{s: &s, id: HS_ID_CTL}})) + s.ctl = make([]*http.Server, len(cfg.CtlAddrs)) for i = 0; i < len(cfg.CtlAddrs); i++ { s.ctl[i] = &http.Server{ @@ -1450,7 +1476,6 @@ task_loop: } } - s.bulletin.Close() s.ReqStop() s.rpc_wg.Wait() @@ -1601,6 +1626,8 @@ func (s *Server) ReqStop() { var cts *ServerConn var hs *http.Server + s.bulletin.Block() + // call cancellation function before anything else // to break sub-tasks relying on this server context. // for example, http.Client in server_proxy_http_main @@ -1949,6 +1976,8 @@ func (s *Server) FindServerConnByIdStr(conn_id string) (*ServerConn, error) { } func (s *Server) StartService(cfg interface{}) { + s.wg.Add(1) + go s.bulletin.RunTask(&s.wg) s.wg.Add(1) go s.RunTask(&s.wg) } @@ -1984,6 +2013,7 @@ func (s *Server) StartWpxService() { func (s *Server) StopServices() { var ext_svc Service s.ReqStop() + s.bulletin.ReqStop() s.ext_mtx.Lock() for _, ext_svc = range s.ext_svcs { ext_svc.StopServices()