diff --git a/bulletin.go b/bulletin.go index 25c39b2..cd48b89 100644 --- a/bulletin.go +++ b/bulletin.go @@ -1,84 +1,136 @@ package hodu import "container/list" +import "errors" import "sync" -type BulletinChan = chan interface{} - -type BulletinSubscription struct { - c chan interface{} - b *Bulletin +type BulletinSubscription[T interface{}] struct { + C chan T + b *Bulletin[T] topic string node *list.Element } type BulletinSubscriptionList = *list.List + type BulletinSubscriptionMap map[string]BulletinSubscriptionList -type Bulletin struct { +type Bulletin[T interface{}] struct { sbsc_map BulletinSubscriptionMap sbsc_mtx sync.RWMutex + closed bool } -func NewBulletin() *Bulletin { - return &Bulletin{ - sbsc_map: make(BulletinSubscriptionMap, 0), +func NewBulletin[T interface{}]() *Bulletin[T] { + return &Bulletin[T]{ + sbsc_map: make(BulletinSubscriptionMap, 0), } } -func (b *Bulletin) Subscribe(topic string) *BulletinSubscription { - var sbsc BulletinSubscription +func (b *Bulletin[T]) unsubscribe_all_nolock() { + var topic string + var sl BulletinSubscriptionList + + for topic, sl = range b.sbsc_map { + var sbsc *BulletinSubscription[T] + var e *list.Element + + for e = sl.Front(); e != nil; e = e.Next() { + sbsc = e.Value.(*BulletinSubscription[T]) + close(sbsc.C) + sbsc.b = nil + sbsc.node = nil + } + + delete(b.sbsc_map, topic) + } + + b.closed = true +} + +func (b *Bulletin[T]) UnsubscribeAll() { + b.sbsc_mtx.Lock() + b.unsubscribe_all_nolock() + b.sbsc_mtx.Unlock() +} + +func (b *Bulletin[T]) Close() { + b.sbsc_mtx.Lock() + if !b.closed { + b.unsubscribe_all_nolock() + b.closed = true + } + b.sbsc_mtx.Unlock() +} + +func (b *Bulletin[T]) Subscribe(topic string) (*BulletinSubscription[T], error) { + var sbsc BulletinSubscription[T] var sbsc_list BulletinSubscriptionList var ok bool - sbsc.b = b - sbsc.c = make(chan interface{}) - sbsc.topic = topic - b.sbsc_mtx.Lock() + if b.closed { return nil, errors.New("closed bulletin") } + sbsc.C = make(chan T, 128) // TODO: size? + sbsc.b = b + sbsc.topic = topic + + b.sbsc_mtx.Lock() sbsc_list, ok = b.sbsc_map[topic] if !ok { sbsc_list = list.New() b.sbsc_map[topic] = sbsc_list } - - sbsc.node = sbsc_list.PushBack(&sbsc) b.sbsc_mtx.Unlock() - return &sbsc + return &sbsc, nil } -func (b *Bulletin) Unsbsccribe(sbsc *BulletinSubscription) { - if sbsc.b == b { +func (b *Bulletin[T]) Unsubscribe(sbsc *BulletinSubscription[T]) { + if sbsc.b == b && sbsc.node != nil { var sl BulletinSubscriptionList var ok bool b.sbsc_mtx.Lock() sl, ok = b.sbsc_map[sbsc.topic] - if ok { sl.Remove(sbsc.node) } + if ok { + sl.Remove(sbsc.node) + close(sbsc.C) + sbsc.node = nil + sbsc.b = nil + } b.sbsc_mtx.Unlock() } } -func (b *Bulletin) Publish(topic string, data interface{}) { +func (b *Bulletin[T]) Publish(topic string, data T) { var sl BulletinSubscriptionList var ok bool + var topics [2]string + var t string + + if b.closed { return } + if topic == "" { return } + + topics[0] = topic + topics[1] = "" b.sbsc_mtx.Lock() - sl, ok = b.sbsc_map[topic] - if ok { - var sbsc *BulletinSubscription - var e *list.Element - for e = sl.Front(); e != nil; e = e.Next() { - sbsc = e.Value.(*BulletinSubscription) - sbsc.c <- data + for _, t = range topics { + sl, ok = b.sbsc_map[t] + if ok { + var sbsc *BulletinSubscription[T] + var e *list.Element + for e = sl.Front(); e != nil; e = e.Next() { + sbsc = e.Value.(*BulletinSubscription[T]) + select { + case sbsc.C <- data: + // ok. could be written. + default: + // channel full. discard it + } + } } } b.sbsc_mtx.Unlock() } -func (s *BulletinSubscription) Receive() interface{} { - var x interface{} - x = <- s.c - return x -} diff --git a/bulletin_test.go b/bulletin_test.go index 1dfa7fa..7c89ed7 100644 --- a/bulletin_test.go +++ b/bulletin_test.go @@ -2,26 +2,70 @@ package hodu_test import "fmt" import "hodu" +import "sync" import "testing" +import "time" func TestBulletin(t *testing.T) { - var b *hodu.Bulletin - var s1 *hodu.BulletinSubscription - var s2 *hodu.BulletinSubscription + var b *hodu.Bulletin[string] + var s1 *hodu.BulletinSubscription[string] + var s2 *hodu.BulletinSubscription[string] + var wg sync.WaitGroup + var nmsgs1 int + var nmsgs2 int - b = hodu.NewBulletin() + b = hodu.NewBulletin[string]() - s1 = b.Subscribe("t1") - s2 = b.Subscribe("t2") + s1, _ = b.Subscribe("t1") + s2, _ = b.Subscribe("t2") + wg.Add(1) go func() { - fmt.Printf ("s1: %+v\n", s1.Receive()) - }() + var m string + var ok bool + var c1 chan string + var c2 chan string + + c1 = s1.C + c2 = s2.C + + defer wg.Done() + for c1 != nil || c2 != nil { + select { + case m, ok = <-c1: + if ok { fmt.Printf ("s1: %+v\n", m); nmsgs1++ } else { c1 = nil; fmt.Printf ("s1 closed\n")} + + case m, ok = <-c2: + if ok { fmt.Printf ("s2: %+v\n", m); nmsgs2++ } else { c2 = nil; fmt.Printf ("s2 closed\n") } + } + + } - go func() { - fmt.Printf ("s2: %+v\n", s2.Receive()) }() b.Publish("t1", "donkey") + b.Publish("t2", "monkey") + b.Publish("t1", "donkey kong") + b.Publish("t2", "monkey hong") + b.Publish("t3", "home") + b.Publish("t2", "fire") + b.Publish("t1", "sunflower") + b.Publish("t2", "itsy bitsy spider") + b.Publish("t3", "marigold") + b.Publish("t3", "parrot") + time.Sleep(100 * time.Millisecond) + b.Publish("t2", "tiger") + time.Sleep(100 * time.Millisecond) + b.Unsubscribe(s2) + b.Publish("t2", "lion king") + b.Publish("t2", "fly to the skyp") + time.Sleep(100 * time.Millisecond) + + b.Close() + wg.Wait() + fmt.Printf ("---------------------\n") + + if nmsgs1 != 3 { t.Errorf("number of messages for s1 received must be 3, but got %d\n", nmsgs1) } + if nmsgs2 != 5 { t.Errorf("number of messages for s2 received must be 5, but got %d\n", nmsgs2) } } diff --git a/server-proxy.go b/server-proxy.go index 3c88b59..2e126ff 100644 --- a/server-proxy.go +++ b/server-proxy.go @@ -535,9 +535,13 @@ func (pxy *server_proxy_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.R w.Header().Set("Location", req.URL.Path + "_/") w.WriteHeader(status_code) + case "_forbidden": status_code = WriteEmptyRespHeader(w, http.StatusForbidden) + case "_notfound": + status_code = WriteEmptyRespHeader(w, http.StatusNotFound) + default: status_code = WriteEmptyRespHeader(w, http.StatusNotFound) } diff --git a/server.go b/server.go index f395b55..96bd2d8 100644 --- a/server.go +++ b/server.go @@ -67,6 +67,33 @@ type ServerConfig struct { WpxTls *tls.Config } +const SERVER_EVENT_TOPIC_CONN string = "conn" +const SERVER_EVENT_TOPIC_ROUTE string = "route" +const SERVER_EVENT_TOPIC_PEER string = "peer" + +type ServerEventKind int +const ( + SERVER_EVENT_CONN_ADDED = iota + SERVER_EVENT_CONN_DELETED + SERVER_EVENT_ROUTE_ADDED + SERVER_EVENT_ROUTE_DELETED + SERVER_EVENT_PEER_ADDED + SERVER_EVENT_PEER_DELETED +) + +type ServerEventDesc struct { + Conn ConnId + Route RouteId + Peer PeerId +} + +type ServerEvent struct { + Kind ServerEventKind + Desc ServerEventDesc +} + +type ServerEventBulletin = Bulletin[ServerEvent] + type Server struct { UnimplementedHoduServer Named @@ -112,6 +139,8 @@ type Server struct { svc_port_mtx sync.Mutex svc_port_map ServerSvcPortMap + bulletin *ServerEventBulletin + promreg *prometheus.Registry stats struct { conns atomic.Int64 @@ -847,6 +876,15 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { done: cts.ReqStop() // just in case cts.route_wg.Wait() + /* + cts.S.bulletin.Publish( + SERVER_EVENT_TOPIC_CONN, + ServerEvent{ + Kind: SERVER_EVENT_CONN_DELETED, + Desc: ServerEventDesc{ Conn: cts.Id }, + }, + )*/ + // Don't detached the cts task as a go-routine as this function cts.S.log.Write(cts.Sid, LOG_INFO, "End of connection task") } @@ -903,6 +941,15 @@ func (s *Server) PacketStream(strm Hodu_PacketStreamServer) error { return fmt.Errorf("unable to add client %s - %s", p.Addr.String(), err.Error()) } + /* + s.bulletin.Publish( + SERVER_EVENT_TOPIC_CONN, + ServerEvent{ + Kind: SERVER_EVENT_CONN_ADDED, + Desc: ServerEventDesc{ Conn: cts.Id }, + }, + )*/ + // Don't detached the cts task as a go-routine as this function // is invoked as a go-routine by the grpc server. s.cts_wg.Add(1) @@ -1173,6 +1220,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.svc_port_map = make(ServerSvcPortMap) s.stop_chan = make(chan bool, 8) s.stop_req.Store(false) + s.bulletin = NewBulletin[ServerEvent]() /* creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem")) @@ -1244,16 +1292,20 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi websocket.Handler(func(ws *websocket.Conn) { s.pxy_ws.ServeWebsocket(ws) })) s.pxy_mux.Handle("/_ssh/server-conns/{conn_id}/routes/{route_id}", s.WrapHttpHandler(&server_ctl_server_conns_id_routes_id{server_ctl{s: &s, id: HS_ID_PXY, noauth: true}})) + s.pxy_mux.Handle("/_ssh/xterm.js", + s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm.js"})) + s.pxy_mux.Handle("/_ssh/xterm.js.map", + s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "_notfound"})) + s.pxy_mux.Handle("/_ssh/xterm-addon-fit.js", + s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm-addon-fit.js"})) + s.pxy_mux.Handle("/_ssh/xterm-addon-fit.js.map", + s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "_notfound"})) + s.pxy_mux.Handle("/_ssh/xterm.css", + s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm.css"})) s.pxy_mux.Handle("/_ssh/{conn_id}/", s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "_redirect"})) s.pxy_mux.Handle("/_ssh/{conn_id}/{route_id}/", s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm.html"})) - s.pxy_mux.Handle("/_ssh/xterm.js", - s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm.js"})) - s.pxy_mux.Handle("/_ssh/xterm-addon-fit.js", - s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm-addon-fit.js"})) - s.pxy_mux.Handle("/_ssh/xterm.css", - s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "xterm.css"})) s.pxy_mux.Handle("/_ssh/", s.WrapHttpHandler(&server_proxy_xterm_file{server_proxy: server_proxy{s: &s, id: HS_ID_PXY}, file: "_forbidden"})) s.pxy_mux.Handle("/favicon.ico", @@ -1397,6 +1449,7 @@ task_loop: } } + s.bulletin.Close() s.ReqStop() s.rpc_wg.Wait()