From 7ec138713204ac968963c7a961d64355be4e65c5 Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Thu, 21 Aug 2025 20:40:51 +0900 Subject: [PATCH] improved code to break unix.Poll() using a seperate pipe for rpty --- client.go | 68 +++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 0a42b22..e9fc4f6 100644 --- a/client.go +++ b/client.go @@ -185,9 +185,11 @@ const ( ) type ClientRpty struct { + cts *ClientConn id uint64 cmd *exec.Cmd tty *os.File + pfd [2]int } type ClientRpx struct { @@ -323,8 +325,9 @@ func (g *GuardedPacketStreamClient) Context() context.Context { // -------------------------------------------------------------------- func (rpty *ClientRpty) ReqStop() { - rpty.tty.Close() rpty.cmd.Process.Kill() + // don't check for a write error. the os pipe's buffer must be large enough + unix.Write(rpty.pfd[1], []byte{0}) } func (rpx *ClientRpx) ReqStop() { @@ -1077,7 +1080,7 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) { cts.discon_mtx.Lock() if (logmsg) { - cts.C.log.Write(cts.Sid, LOG_INFO, "Disconnecting from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) + cts.C.log.Write(cts.Sid, LOG_INFO, "Preparing to disconnect from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) } cts.route_mtx.Lock() @@ -1088,7 +1091,7 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) { cts.rpty_mtx.Lock() for _, rpty = range cts.rpty_map { rpty.ReqStop() - // the loop in ReadRptyLoop() is supposed to be broken. + // the loop in RptyLoop() is supposed to be broken. // let's not inform the server of this connection. // the server should clean up itself upon connection error } @@ -1114,6 +1117,9 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) { cts.remote_addr.Set("") // don't reset cts.local_addr_p and cts.remote_addr_p + if (logmsg) { + cts.C.log.Write(cts.Sid, LOG_INFO, "Prepared to disconnect from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) + } cts.discon_mtx.Unlock() } } @@ -1433,7 +1439,9 @@ req_stop_and_wait_for_termination: cts.ReqStop() wait_for_termination: + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Waiting for route tasks to stop on connection to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) cts.route_wg.Wait() // wait until all route tasks are finished + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Removing connection to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index]) cts.C.RemoveClientConn(cts) cts.C.FireConnEvent(CLIENT_EVENT_CONN_STOPPED, cts) @@ -1505,7 +1513,7 @@ func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty { return crp } -func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { +func (cts *ClientConn) RptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { var poll_fds []unix.PollFd var buf [2048]byte @@ -1514,10 +1522,13 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { defer wg.Done() + cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", crp.id, cts.C.pty_shell, cts.C.pty_user) + cts.C.stats.rpty_sessions.Add(1) poll_fds = []unix.PollFd{ unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN}, + unix.PollFd{Fd: int32(crp.pfd[0]), Events: unix.POLLIN}, } for { @@ -1532,7 +1543,11 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { } if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { - cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id) + cts.C.log.Write(cts.Sid, LOG_ERROR, "EOF detected on rpty(%d) stdout", crp.id) + break + } + if (poll_fds[1].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { + cts.C.log.Write(cts.Sid, LOG_ERROR, "EOF detected on rpty(%d) event pipe", crp.id) break } @@ -1542,35 +1557,46 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { var err2 error err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n])) if err2 != nil { - cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error()) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error()) break } } if err != nil { if !errors.Is(err, io.EOF) { - cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error()) + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error()) } break } } + if (poll_fds[1].Revents & unix.POLLIN) != 0 { + // don't care to read the pipe as it is closed after the loop + //unix.Read(crp.pfd[0], ) + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Stop request noticed on rpty(%d) event pipe", crp.id) + break + } } + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ending rpty(%d) loop", crp.id) cts.psc.Send(MakeRptyStopPacket(crp.id, "")) - cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id) crp.ReqStop() crp.cmd.Wait() + crp.tty.Close() + unix.Close(crp.pfd[0]) + unix.Close(crp.pfd[1]) cts.rpty_mtx.Lock() delete(cts.rpty_map, crp.id) cts.rpty_mtx.Unlock() cts.C.stats.rpty_sessions.Add(-1) + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ended rpty(%d) loop", crp.id) } func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error { var crp *ClientRpty var ok bool + var i int var err error cts.rpty_mtx.Lock() @@ -1580,22 +1606,37 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error { return fmt.Errorf("multiple start on rpty id %d", id) } - crp = &ClientRpty{ id: id } + crp = &ClientRpty{ cts: cts, id: id } + err = unix.Pipe(crp.pfd[:]) + if err != nil { + cts.rpty_mtx.Unlock() + cts.psc.Send(MakeRptyStopPacket(id, err.Error())) + return fmt.Errorf("unable to create rpty(%d) event fd for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error()) + } crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user) if err != nil { cts.rpty_mtx.Unlock() cts.psc.Send(MakeRptyStopPacket(id, err.Error())) + unix.Close(crp.pfd[0]) + unix.Close(crp.pfd[1]) return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error()) } + for i = 0; i < 2; i++ { + var flags int + flags, err = unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_GETFL, 0) + if err != nil { + unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_SETFL, flags | unix.O_NONBLOCK) + } + } + cts.rpty_map[id] = crp wg.Add(1) - go cts.ReadRptyLoop(crp, wg) + go cts.RptyLoop(crp, wg) cts.rpty_mtx.Unlock() - cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", id, cts.C.pty_shell, cts.C.pty_user) return nil } @@ -1853,6 +1894,8 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) defer wg.Done() + cts.C.log.Write(cts.Sid, LOG_INFO, "Starting rpx(%d) loop", crpx.id) + start_time = time.Now() sc = bufio.NewScanner(bytes.NewReader(data)) @@ -1937,7 +1980,7 @@ done: if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error()) } - cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id) + cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) loop", crpx.id) crpx.ReqStop() @@ -1945,6 +1988,7 @@ done: delete(cts.rpx_map, crpx.id) cts.rpx_mtx.Unlock() cts.C.stats.rpx_sessions.Add(-1) + cts.C.log.Write(cts.Sid, LOG_INFO, "Ended rpx(%d) loop", crpx.id) } func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {