improved code to break unix.Poll() using a seperate pipe for rpty

This commit is contained in:
2025-08-21 20:40:51 +09:00
parent 7d3ce7147a
commit 7ec1387132

View File

@ -185,9 +185,11 @@ const (
)
type ClientRpty struct {
cts *ClientConn
id uint64
cmd *exec.Cmd
tty *os.File
pfd [2]int
}
type ClientRpx struct {
@ -323,8 +325,9 @@ func (g *GuardedPacketStreamClient) Context() context.Context {
// --------------------------------------------------------------------
func (rpty *ClientRpty) ReqStop() {
rpty.tty.Close()
rpty.cmd.Process.Kill()
// don't check for a write error. the os pipe's buffer must be large enough
unix.Write(rpty.pfd[1], []byte{0})
}
func (rpx *ClientRpx) ReqStop() {
@ -1077,7 +1080,7 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
cts.discon_mtx.Lock()
if (logmsg) {
cts.C.log.Write(cts.Sid, LOG_INFO, "Disconnecting from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
cts.C.log.Write(cts.Sid, LOG_INFO, "Preparing to disconnect from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
}
cts.route_mtx.Lock()
@ -1088,7 +1091,7 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
cts.rpty_mtx.Lock()
for _, rpty = range cts.rpty_map {
rpty.ReqStop()
// the loop in ReadRptyLoop() is supposed to be broken.
// the loop in RptyLoop() is supposed to be broken.
// let's not inform the server of this connection.
// the server should clean up itself upon connection error
}
@ -1114,6 +1117,9 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
cts.remote_addr.Set("")
// don't reset cts.local_addr_p and cts.remote_addr_p
if (logmsg) {
cts.C.log.Write(cts.Sid, LOG_INFO, "Prepared to disconnect from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
}
cts.discon_mtx.Unlock()
}
}
@ -1433,7 +1439,9 @@ req_stop_and_wait_for_termination:
cts.ReqStop()
wait_for_termination:
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Waiting for route tasks to stop on connection to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
cts.route_wg.Wait() // wait until all route tasks are finished
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Removing connection to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
cts.C.RemoveClientConn(cts)
cts.C.FireConnEvent(CLIENT_EVENT_CONN_STOPPED, cts)
@ -1505,7 +1513,7 @@ func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
return crp
}
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
func (cts *ClientConn) RptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
var poll_fds []unix.PollFd
var buf [2048]byte
@ -1514,10 +1522,13 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
defer wg.Done()
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", crp.id, cts.C.pty_shell, cts.C.pty_user)
cts.C.stats.rpty_sessions.Add(1)
poll_fds = []unix.PollFd{
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
unix.PollFd{Fd: int32(crp.pfd[0]), Events: unix.POLLIN},
}
for {
@ -1532,7 +1543,11 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
}
if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id)
cts.C.log.Write(cts.Sid, LOG_ERROR, "EOF detected on rpty(%d) stdout", crp.id)
break
}
if (poll_fds[1].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_ERROR, "EOF detected on rpty(%d) event pipe", crp.id)
break
}
@ -1542,35 +1557,46 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
var err2 error
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
}
break
}
}
if (poll_fds[1].Revents & unix.POLLIN) != 0 {
// don't care to read the pipe as it is closed after the loop
//unix.Read(crp.pfd[0], )
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Stop request noticed on rpty(%d) event pipe", crp.id)
break
}
}
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ending rpty(%d) loop", crp.id)
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
crp.ReqStop()
crp.cmd.Wait()
crp.tty.Close()
unix.Close(crp.pfd[0])
unix.Close(crp.pfd[1])
cts.rpty_mtx.Lock()
delete(cts.rpty_map, crp.id)
cts.rpty_mtx.Unlock()
cts.C.stats.rpty_sessions.Add(-1)
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ended rpty(%d) loop", crp.id)
}
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
var crp *ClientRpty
var ok bool
var i int
var err error
cts.rpty_mtx.Lock()
@ -1580,22 +1606,37 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
return fmt.Errorf("multiple start on rpty id %d", id)
}
crp = &ClientRpty{ id: id }
crp = &ClientRpty{ cts: cts, id: id }
err = unix.Pipe(crp.pfd[:])
if err != nil {
cts.rpty_mtx.Unlock()
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
return fmt.Errorf("unable to create rpty(%d) event fd for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
}
crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user)
if err != nil {
cts.rpty_mtx.Unlock()
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
unix.Close(crp.pfd[0])
unix.Close(crp.pfd[1])
return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
}
for i = 0; i < 2; i++ {
var flags int
flags, err = unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_GETFL, 0)
if err != nil {
unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_SETFL, flags | unix.O_NONBLOCK)
}
}
cts.rpty_map[id] = crp
wg.Add(1)
go cts.ReadRptyLoop(crp, wg)
go cts.RptyLoop(crp, wg)
cts.rpty_mtx.Unlock()
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", id, cts.C.pty_shell, cts.C.pty_user)
return nil
}
@ -1853,6 +1894,8 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
defer wg.Done()
cts.C.log.Write(cts.Sid, LOG_INFO, "Starting rpx(%d) loop", crpx.id)
start_time = time.Now()
sc = bufio.NewScanner(bytes.NewReader(data))
@ -1937,7 +1980,7 @@ done:
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error())
}
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) loop", crpx.id)
crpx.ReqStop()
@ -1945,6 +1988,7 @@ done:
delete(cts.rpx_map, crpx.id)
cts.rpx_mtx.Unlock()
cts.C.stats.rpx_sessions.Add(-1)
cts.C.log.Write(cts.Sid, LOG_INFO, "Ended rpx(%d) loop", crpx.id)
}
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {