From 05cb0823b427106f599df7e29f1db17f66905a77 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Sun, 10 Aug 2025 17:23:01 +0900 Subject: [PATCH] some code clean-up in handling grpc packets --- client.go | 137 +++++++++++++-------------------- hodu.pb.go | 40 +++++----- hodu.proto | 4 +- packet.go | 12 +-- server-pty.go | 208 +++++++++++++++++++++++++++++++++++++++++++++++++- server.go | 138 ++++++++++++++++++++------------- 6 files changed, 374 insertions(+), 165 deletions(-) diff --git a/client.go b/client.go index 04ccf12..16f68d0 100644 --- a/client.go +++ b/client.go @@ -790,6 +790,7 @@ func (r *ClientRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event } } + default: // ignore all others } @@ -1200,6 +1201,8 @@ start_over: switch pkt.Kind { case PACKET_KIND_ROUTE_STARTED: + fallthrough + case PACKET_KIND_ROUTE_STOPPED: // the server side managed to set up the route the client requested var x *Packet_Route var ok bool @@ -1208,96 +1211,41 @@ start_over: err = cts.ReportPacket(RouteId(x.Route.RouteId), 0, pkt.Kind, x.Route) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle route_started event(%d,%s) from %s - %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p, err.Error()) + "Failed to handle %s event(%d,%s) from %s - %s", + pkt.Kind.String(), x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled route_started event(%d,%s) from %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p) + "Handled %s event(%d,%s) from %s", + pkt.Kind.String(), x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid route_started event from %s", cts.remote_addr_p) - } - - case PACKET_KIND_ROUTE_STOPPED: - var x *Packet_Route - var ok bool - x, ok = pkt.U.(*Packet_Route) - if ok { - err = cts.ReportPacket(RouteId(x.Route.RouteId), 0, pkt.Kind, x.Route) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle route_stopped event(%d,%s) from %s - %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p, err.Error()) - } else { - cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled route_stopped event(%d,%s) from %s", - x.Route.RouteId, x.Route.TargetAddrStr, cts.remote_addr_p) - } - } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid route_stopped event from %s", cts.remote_addr_p) - } - - case PACKET_KIND_PEER_STARTED: - // the connection from the client to a peer has been established - var x *Packet_Peer - var ok bool - x, ok = pkt.U.(*Packet_Peer) - if ok { - err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_STARTED, x.Peer) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_started event from %s for peer(%d,%d,%s,%s) - %s", - cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) - } else { - cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_started event from %s for peer(%d,%d,%s,%s)", - cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) - } - } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_started event from %s", cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) } // PACKET_KIND_PEER_ABORTED is never sent by server to client. // the code here doesn't handle the event. - + case PACKET_KIND_PEER_STARTED: + fallthrough case PACKET_KIND_PEER_STOPPED: + fallthrough + case PACKET_KIND_PEER_EOF: // the connection from the client to a peer has been established var x *Packet_Peer var ok bool x, ok = pkt.U.(*Packet_Peer) if ok { - err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_STOPPED, x.Peer) + err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), pkt.Kind, x.Peer) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_stopped event from %s for peer(%d,%d,%s,%s) - %s", - cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) + "Failed to handle %s event from %s for peer(%d,%d,%s,%s) - %s", + pkt.Kind.String(), cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_stopped event from %s for peer(%d,%d,%s,%s)", - cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) + "Handled %s event from %s for peer(%d,%d,%s,%s)", + pkt.Kind.String(), cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_stopped event from %s", cts.remote_addr_p) - } - - case PACKET_KIND_PEER_EOF: - var x *Packet_Peer - var ok bool - x, ok = pkt.U.(*Packet_Peer) - if ok { - err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_EOF, x.Peer) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_eof event from %s for peer(%d,%d,%s,%s) - %s", - cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) - } else { - cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_eof event from %s for peer(%d,%d,%s,%s)", - cts.remote_addr_p, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) - } - } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_eof event from %s", cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) } case PACKET_KIND_PEER_DATA: @@ -1306,18 +1254,18 @@ start_over: var ok bool x, ok = pkt.U.(*Packet_Data) if ok { - err = cts.ReportPacket(RouteId(x.Data.RouteId), PeerId(x.Data.PeerId), PACKET_KIND_PEER_DATA, x.Data.Data) + err = cts.ReportPacket(RouteId(x.Data.RouteId), PeerId(x.Data.PeerId), pkt.Kind, x.Data.Data) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_data event from %s for peer(%d,%d) - %s", - cts.remote_addr_p, x.Data.RouteId, x.Data.PeerId, err.Error()) + "Failed to handle %s event from %s for peer(%d,%d) - %s", + pkt.Kind.String(), cts.remote_addr_p, x.Data.RouteId, x.Data.PeerId, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_data event from %s for peer(%d,%d)", - cts.remote_addr_p, x.Data.RouteId, x.Data.PeerId) + "Handled %s event from %s for peer(%d,%d)", + pkt.Kind.String(), cts.remote_addr_p, x.Data.RouteId, x.Data.PeerId) } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_data event from %s", cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) } case PACKET_KIND_CONN_ERROR: @@ -1325,10 +1273,10 @@ start_over: var ok bool x, ok = pkt.U.(*Packet_ConnErr) if ok { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Received conn_error(%d, %s) event from %s", x.ConnErr.ErrorId, x.ConnErr.Text, cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Received %s(%d, %s) event from %s", pkt.Kind.String(), x.ConnErr.ErrorId, x.ConnErr.Text, cts.remote_addr_p) if cts.cfg.CloseOnConnErrorEvent { goto done } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_error event from %s", cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %sevent from %s", pkt.Kind.String(), cts.remote_addr_p) } case PACKET_KIND_CONN_NOTICE: @@ -1337,7 +1285,7 @@ start_over: var ok bool x, ok = pkt.U.(*Packet_ConnNoti) if ok { - cts.C.log.Write(cts.Sid, LOG_DEBUG, "conn_notice message '%s' received from %s", x.ConnNoti.Text, cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_DEBUG, "%s message '%s' received from %s", pkt.Kind.String(), x.ConnNoti.Text, cts.remote_addr_p) if cts.C.conn_notice_handlers != nil { var handler ClientConnNoticeHandler for _, handler = range cts.C.conn_notice_handlers { @@ -1345,18 +1293,34 @@ start_over: } } } else { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_notice packet from %s", cts.remote_addr_p) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.remote_addr_p) } case PACKET_KIND_RPTY_START: - // TODO: + fallthrough case PACKET_KIND_RPTY_STOP: - // TODO: + fallthrough case PACKET_KIND_RPTY_DATA: - // TODO: + fallthrough case PACKET_KIND_RPTY_EOF: - // TODO: + var x *Packet_RptyEvt + var ok bool + x, ok = pkt.U.(*Packet_RptyEvt) + if ok { + err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, + "Failed to handle %s event from %s - %s", + pkt.Kind.String(), cts.remote_addr_p, err.Error()) + } else { + cts.C.log.Write(cts.Sid, LOG_DEBUG, + "Handled %s event from %s", + pkt.Kind.String(), cts.remote_addr_p) + } + } else { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) + } default: // do nothing. ignore the rest @@ -1430,6 +1394,11 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type return r.ReportPacket(pts_id, packet_type, event_data) } +func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error { +// TODO: + return nil +} + // -------------------------------------------------------------------- func (m ClientPeerConnMap) get_sorted_keys() []PeerId { diff --git a/hodu.pb.go b/hodu.pb.go index 0b61889..c89f329 100644 --- a/hodu.pb.go +++ b/hodu.pb.go @@ -605,7 +605,7 @@ func (x *ConnNotice) GetText() string { type RptyEvent struct { state protoimpl.MessageState `protogen:"open.v1"` - Token string `protobuf:"bytes,1,opt,name=Token,proto3" json:"Token,omitempty"` + Id uint64 `protobuf:"varint,1,opt,name=Id,proto3" json:"Id,omitempty"` Data []byte `protobuf:"bytes,2,opt,name=Data,proto3" json:"Data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -641,11 +641,11 @@ func (*RptyEvent) Descriptor() ([]byte, []int) { return file_hodu_proto_rawDescGZIP(), []int{7} } -func (x *RptyEvent) GetToken() string { +func (x *RptyEvent) GetId() uint64 { if x != nil { - return x.Token + return x.Id } - return "" + return 0 } func (x *RptyEvent) GetData() []byte { @@ -666,7 +666,7 @@ type Packet struct { // *Packet_Conn // *Packet_ConnErr // *Packet_ConnNoti - // *Packet_Rpty + // *Packet_RptyEvt U isPacket_U `protobuf_oneof:"U"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -770,10 +770,10 @@ func (x *Packet) GetConnNoti() *ConnNotice { return nil } -func (x *Packet) GetRpty() *RptyEvent { +func (x *Packet) GetRptyEvt() *RptyEvent { if x != nil { - if x, ok := x.U.(*Packet_Rpty); ok { - return x.Rpty + if x, ok := x.U.(*Packet_RptyEvt); ok { + return x.RptyEvt } } return nil @@ -807,8 +807,8 @@ type Packet_ConnNoti struct { ConnNoti *ConnNotice `protobuf:"bytes,7,opt,name=ConnNoti,proto3,oneof"` } -type Packet_Rpty struct { - Rpty *RptyEvent `protobuf:"bytes,8,opt,name=Rpty,proto3,oneof"` +type Packet_RptyEvt struct { + RptyEvt *RptyEvent `protobuf:"bytes,8,opt,name=RptyEvt,proto3,oneof"` } func (*Packet_Route) isPacket_U() {} @@ -823,7 +823,7 @@ func (*Packet_ConnErr) isPacket_U() {} func (*Packet_ConnNoti) isPacket_U() {} -func (*Packet_Rpty) isPacket_U() {} +func (*Packet_RptyEvt) isPacket_U() {} var File_hodu_proto protoreflect.FileDescriptor @@ -859,10 +859,10 @@ const file_hodu_proto_rawDesc = "" + "\x04Text\x18\x02 \x01(\tR\x04Text\" \n" + "\n" + "ConnNotice\x12\x12\n" + - "\x04Text\x18\x01 \x01(\tR\x04Text\"5\n" + - "\tRptyEvent\x12\x14\n" + - "\x05Token\x18\x01 \x01(\tR\x05Token\x12\x12\n" + - "\x04Data\x18\x02 \x01(\fR\x04Data\"\xab\x02\n" + + "\x04Text\x18\x01 \x01(\tR\x04Text\"/\n" + + "\tRptyEvent\x12\x0e\n" + + "\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" + + "\x04Data\x18\x02 \x01(\fR\x04Data\"\xb1\x02\n" + "\x06Packet\x12 \n" + "\x04Kind\x18\x01 \x01(\x0e2\f.PACKET_KINDR\x04Kind\x12\"\n" + "\x05Route\x18\x02 \x01(\v2\n" + @@ -872,9 +872,9 @@ const file_hodu_proto_rawDesc = "" + "\x04Conn\x18\x05 \x01(\v2\t.ConnDescH\x00R\x04Conn\x12&\n" + "\aConnErr\x18\x06 \x01(\v2\n" + ".ConnErrorH\x00R\aConnErr\x12)\n" + - "\bConnNoti\x18\a \x01(\v2\v.ConnNoticeH\x00R\bConnNoti\x12 \n" + - "\x04Rpty\x18\b \x01(\v2\n" + - ".RptyEventH\x00R\x04RptyB\x03\n" + + "\bConnNoti\x18\a \x01(\v2\v.ConnNoticeH\x00R\bConnNoti\x12&\n" + + "\aRptyEvt\x18\b \x01(\v2\n" + + ".RptyEventH\x00R\aRptyEvtB\x03\n" + "\x01U*^\n" + "\fROUTE_OPTION\x12\n" + "\n" + @@ -949,7 +949,7 @@ var file_hodu_proto_depIdxs = []int32{ 6, // 4: Packet.Conn:type_name -> ConnDesc 7, // 5: Packet.ConnErr:type_name -> ConnError 8, // 6: Packet.ConnNoti:type_name -> ConnNotice - 9, // 7: Packet.Rpty:type_name -> RptyEvent + 9, // 7: Packet.RptyEvt:type_name -> RptyEvent 2, // 8: Hodu.GetSeed:input_type -> Seed 10, // 9: Hodu.PacketStream:input_type -> Packet 2, // 10: Hodu.GetSeed:output_type -> Seed @@ -973,7 +973,7 @@ func file_hodu_proto_init() { (*Packet_Conn)(nil), (*Packet_ConnErr)(nil), (*Packet_ConnNoti)(nil), - (*Packet_Rpty)(nil), + (*Packet_RptyEvt)(nil), } type x struct{} out := protoimpl.TypeBuilder{ diff --git a/hodu.proto b/hodu.proto index d116e35..c4d33c1 100644 --- a/hodu.proto +++ b/hodu.proto @@ -82,7 +82,7 @@ message ConnNotice { }; message RptyEvent { - string Token = 1; + uint64 Id = 1; bytes Data = 2; }; @@ -120,6 +120,6 @@ message Packet { ConnDesc Conn = 5; ConnError ConnErr = 6; ConnNotice ConnNoti = 7; - RptyEvent Rpty = 8; + RptyEvent RptyEvt = 8; }; } diff --git a/packet.go b/packet.go index 2db66c0..c658091 100644 --- a/packet.go +++ b/packet.go @@ -75,14 +75,14 @@ func MakeConnNoticePacket(msg string) *Packet { return &Packet{Kind: PACKET_KIND_CONN_NOTICE, U: &Packet_ConnNoti{ConnNoti: &ConnNotice{Text: msg}}} } -func MakeRptyStartPacket(token string) *Packet { - return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_Rpty{Rpty: &RptyEvent{Token: token}}} +func MakeRptyStartPacket(id uint64) *Packet { + return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id}}} } -func MakeRptyStopPacket(token string) *Packet { - return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_Rpty{Rpty: &RptyEvent{Token: token}}} +func MakeRptyStopPacket(id uint64) *Packet { + return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id}}} } -func MakeRptyDataPacket(token string, data []byte) *Packet { - return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_Rpty{Rpty: &RptyEvent{Token: token, Data: data}}} +func MakeRptyDataPacket(id uint64, data []byte) *Packet { + return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}} } diff --git a/server-pty.go b/server-pty.go index 422c21f..32380f6 100644 --- a/server-pty.go +++ b/server-pty.go @@ -207,7 +207,7 @@ ws_recv_loop: } else { err = pty.send_ws_data(ws, "status", "opened") if err != nil { - s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to write opened event to websocket - %s", req.RemoteAddr, err.Error()) + s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to write 'opened' event to websocket - %s", req.RemoteAddr, err.Error()) ws.Close() // dirty way to flag out the error } else { s.log.Write(pty.Id, LOG_DEBUG, "[%s] Opened pty session", req.RemoteAddr) @@ -270,6 +270,212 @@ done: return http.StatusOK, err } + +// ------------------------------------------------------ +func (rpty *server_rpty_ws) Identity() string { + return rpty.Id +} + +func (rpty *server_rpty_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { + var msg []byte + var err error + + msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) + if err == nil { err = websocket.Message.Send(ws, msg) } + return err +} + +func (rpty *server_rpty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { + var s *Server + var req *http.Request + var token string + var cts *ServerConn + //var username string + //var password string + var in *os.File + //var out *os.File + var tty *os.File + var rp *ServerRpty + var cmd *exec.Cmd + var wg sync.WaitGroup + var conn_ready_chan chan bool + var err error + + s = rpty.S + req = ws.Request() + conn_ready_chan = make(chan bool, 3) + token = req.FormValue("token") + if token != "" { + ws.Close() + return http.StatusBadRequest, fmt.Errorf("no client token specified") + } + + cts = s.FindServerConnByClientToken(token) + if cts == nil { + ws.Close() + return http.StatusBadRequest, fmt.Errorf("invalid client token - %s", token) + } + +// TODO: how to get notified of broken connection.... + + + wg.Add(1) + go func() { + var conn_ready bool + + defer wg.Done() + defer ws.Close() // dirty way to break the main loop + + conn_ready = <-conn_ready_chan + if conn_ready { // connected +/* + var poll_fds []unix.PollFd; + var buf []byte + var n int + var err error + + poll_fds = []unix.PollFd{ + unix.PollFd{Fd: int32(out.Fd()), Events: unix.POLLIN}, + } + + s.stats.pty_sessions.Add(1) + buf = make([]byte, 2048) + for { + n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely + if err != nil { + if errors.Is(err, unix.EINTR) { continue } + s.log.Write("", LOG_ERROR, "[%s] Failed to poll pty stdout - %s", req.RemoteAddr, err.Error()) + break + } + if n == 0 { // timed out + continue + } + + if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { + s.log.Write(pty.Id, LOG_DEBUG, "[%s] EOF detected on pty stdout", req.RemoteAddr) + break; + } + + if (poll_fds[0].Revents & unix.POLLIN) != 0 { + n, err = out.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) { + s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error()) + } + break + } + if n > 0 { + err = pty.send_ws_data(ws, "iov", string(buf[:n])) + if err != nil { + s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) + break + } + } + } + } + s.stats.pty_sessions.Add(-1) +*/ + } + }() + +ws_recv_loop: + for { + var msg []byte + err = websocket.Message.Receive(ws, &msg) + if err != nil { goto done } + + if len(msg) > 0 { + var ev json_xterm_ws_event + err = json.Unmarshal(msg, &ev) + if err == nil { + switch ev.Type { + case "open": + if rp == nil && len(ev.Data) == 2 { + //username = string(ev.Data[0]) + //password = string(ev.Data[1]) + + wg.Add(1) + go func() { + var err error + + defer wg.Done() + + rp, err = cts.StartRpty(ws) + // cmd, tty, err = pty.connect_pty(username, password) + if err != nil { + s.log.Write(rpty.Id, LOG_ERROR, "[%s] Failed to connect pty - %s", req.RemoteAddr, err.Error()) + rpty.send_ws_data(ws, "error", err.Error()) + ws.Close() // dirty way to flag out the error + } else { + err = rpty.send_ws_data(ws, "status", "opened") + if err != nil { + s.log.Write(rpty.Id, LOG_ERROR, "[%s] Failed to write 'opened' event to websocket - %s", req.RemoteAddr, err.Error()) + ws.Close() // dirty way to flag out the error + } else { + s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Opened pty session", req.RemoteAddr) + // out = tty + // in = tty + conn_ready_chan <- true + } + } + }() + } + + case "close": + if tty != nil { + // cts.StopRpty() + tty.Close() + tty = nil + } + break ws_recv_loop + + case "iov": + if tty != nil { + var i int + for i, _ = range ev.Data { + in.Write([]byte(ev.Data[i])) + } + } + + case "size": + if tty != nil && len(ev.Data) == 2 { + var rows int + var cols int + rows, _ = strconv.Atoi(ev.Data[0]) + cols, _ = strconv.Atoi(ev.Data[1]) + pts.Setsize(tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)}) + s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Resized terminal to %d,%d", req.RemoteAddr, rows, cols) + // ignore error + } + } + } + } + } + + if tty != nil { + //err = pty.send_ws_data(ws, "status", "closed") + /* + err = s.SendRpty() + if err != nil { goto done } + */ + } + +done: + conn_ready_chan <- false + ws.Close() + if cmd != nil { + // kill the child process underneath to close ptym(the master pty). + //cmd.Process.Signal(syscall.SIGTERM) + cmd.Process.Kill() + } + if tty != nil { tty.Close() } + if cmd != nil { cmd.Wait() } + wg.Wait() + s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Ended rpty session for %s", req.RemoteAddr, token) + + return http.StatusOK, err +} + // ------------------------------------------------------ func (pty *server_pty_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { diff --git a/server.go b/server.go index 43a6970..4dbecad 100644 --- a/server.go +++ b/server.go @@ -185,6 +185,11 @@ type ServerConn struct { pts_mtx sync.Mutex pts_list *list.List + rpty_next_id uint64 + rpty_mtx sync.Mutex + rpty_map map[uint64]*ServerRpty + rpty_map_by_ws map[*websocket.Conn]*ServerRpty + wg sync.WaitGroup stop_req atomic.Bool stop_chan chan bool @@ -214,6 +219,11 @@ type ServerRoute struct { stop_req atomic.Bool } +type ServerRpty struct { + id uint64 + ws *websocket.Conn +} + type GuardedPacketStreamServer struct { mtx sync.Mutex //pss Hodu_PacketStreamServer @@ -449,6 +459,12 @@ func (r *ServerRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event return spc.ReportPacket(packet_type, event_data) } + +// ------------------------------------ + +func (rpty *ServerRpty) Stop() { +} + // ------------------------------------ func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_requested_addr string) (*net.TCPListener, *net.TCPAddr, error) { @@ -636,7 +652,57 @@ func (cts *ServerConn) ReqStopAllServerRoutes() { } -func (cts *ServerConn) StartRpts() { +func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) { + var ok bool + var start_id uint64 + var assigned_id uint64 + var rpty *ServerRpty + var err error + + cts.rpty_mtx.Lock() + start_id = cts.rpty_next_id + for { + _, ok = cts.rpty_map[cts.rpty_next_id] + if !ok { + assigned_id = cts.rpty_next_id + cts.rpty_next_id++ + if cts.rpty_next_id == 0 { cts.rpty_next_id++ } + break + } + cts.rpty_next_id++ + if cts.rpty_next_id == 0 { cts.rpty_next_id++ } + if cts.rpty_next_id == start_id { + cts.rpty_mtx.Unlock() + return nil, fmt.Errorf("unable to assign id") + } + } + + _, ok = cts.rpty_map_by_ws[ws] + if ok { + cts.rpty_mtx.Unlock() + return nil, fmt.Errorf("connection already associated with rpty. possibly internal error") + } + + rpty = &ServerRpty{ + id: assigned_id, + ws: ws, + } + + cts.rpty_map[assigned_id] = rpty + cts.rpty_map_by_ws[ws] = rpty + cts.rpty_mtx.Unlock() + + err = cts.pss.Send(MakeRptyStartPacket(assigned_id)) + if err != nil { + cts.rpty_mtx.Lock() + delete(cts.rpty_map, assigned_id) + delete(cts.rpty_map_by_ws, ws) + cts.rpty_mtx.Unlock() + return nil , err + } + +// TODO: send request... + return rpty, nil } func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error { @@ -745,65 +811,28 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { } case PACKET_KIND_PEER_STARTED: - // the connection from the client to a peer has been established - var x *Packet_Peer - var ok bool - x, ok = pkt.U.(*Packet_Peer) - if ok { - err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_STARTED, x.Peer) - if err != nil { - cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_started event from %s for peer(%d,%d,%s,%s) - %s", - cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) - } else { - cts.S.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_started event from %s for peer(%d,%d,%s,%s)", - cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) - } - } else { - // invalid event data - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_started event from %s", cts.RemoteAddr) - } - + fallthrough case PACKET_KIND_PEER_ABORTED: - var x *Packet_Peer - var ok bool - x, ok = pkt.U.(*Packet_Peer) - if ok { - err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_ABORTED, x.Peer) - if err != nil { - cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_aborted event from %s for peer(%d,%d,%s,%s) - %s", - cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) - } else { - cts.S.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_aborted event from %s for peer(%d,%d,%s,%s)", - cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) - } - } else { - // invalid event data - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_aborted event from %s", cts.RemoteAddr) - } - + fallthrough case PACKET_KIND_PEER_STOPPED: // the connection from the client to a peer has been established var x *Packet_Peer var ok bool x, ok = pkt.U.(*Packet_Peer) if ok { - err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), PACKET_KIND_PEER_STOPPED, x.Peer) + err = cts.ReportPacket(RouteId(x.Peer.RouteId), PeerId(x.Peer.PeerId), pkt.Kind, x.Peer) if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_stopped event from %s for peer(%d,%d,%s,%s) - %s", - cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) + "Failed to handle %s event from %s for peer(%d,%d,%s,%s) - %s", + pkt.Kind.String(), cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr, err.Error()) } else { cts.S.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_stopped event from %s for peer(%d,%d,%s,%s)", - cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) + "Handled %s event from %s for peer(%d,%d,%s,%s)", + pkt.Kind.String(), cts.RemoteAddr, x.Peer.RouteId, x.Peer.PeerId, x.Peer.LocalAddrStr, x.Peer.RemoteAddrStr) } } else { // invalid event data - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_stopped event from %s", cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.RemoteAddr) } case PACKET_KIND_PEER_DATA: @@ -812,19 +841,19 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { var ok bool x, ok = pkt.U.(*Packet_Data) if ok { - err = cts.ReportPacket(RouteId(x.Data.RouteId), PeerId(x.Data.PeerId), PACKET_KIND_PEER_DATA, x.Data.Data) + err = cts.ReportPacket(RouteId(x.Data.RouteId), PeerId(x.Data.PeerId), pkt.Kind, x.Data.Data) if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle peer_data event from %s for peer(%d,%d) - %s", - cts.RemoteAddr, x.Data.RouteId, x.Data.PeerId, err.Error()) + "Failed to handle %s event from %s for peer(%d,%d) - %s", + pkt.Kind.String(), cts.RemoteAddr, x.Data.RouteId, x.Data.PeerId, err.Error()) } else { cts.S.log.Write(cts.Sid, LOG_DEBUG, - "Handled peer_data event from %s for peer(%d,%d)", - cts.RemoteAddr, x.Data.RouteId, x.Data.PeerId) + "Handled %s event from %s for peer(%d,%d)", + pkt.Kind.String(), cts.RemoteAddr, x.Data.RouteId, x.Data.PeerId) } } else { // invalid event data - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid peer_data event from %s", cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.RemoteAddr) } case PACKET_KIND_CONN_DESC: @@ -885,9 +914,13 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { /*case PACKET_KIND_RPTY_START: case PACKET_KIND_RPTY_STOP:*/ case PACKET_KIND_RPTY_STARTED: + fallthrough case PACKET_KIND_RPTY_STOPPED: + fallthrough case PACKET_KIND_RPTY_ABORTED: + fallthrough case PACKET_KIND_RPTY_EOF: + fallthrough case PACKET_KIND_RPTY_DATA: // inspect the token // find the right websocket handler... @@ -1892,6 +1925,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p } cts.Id = assigned_id cts.Sid = fmt.Sprintf("%d", cts.Id) // id in string used for logging + cts.rpty_next_id = 1 _, ok = s.cts_map_by_addr[cts.RemoteAddr] if ok {