improved code to break unix.Poll() using a seperate pipe for rpty
This commit is contained in:
68
client.go
68
client.go
@ -185,9 +185,11 @@ const (
|
||||
)
|
||||
|
||||
type ClientRpty struct {
|
||||
cts *ClientConn
|
||||
id uint64
|
||||
cmd *exec.Cmd
|
||||
tty *os.File
|
||||
pfd [2]int
|
||||
}
|
||||
|
||||
type ClientRpx struct {
|
||||
@ -323,8 +325,9 @@ func (g *GuardedPacketStreamClient) Context() context.Context {
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func (rpty *ClientRpty) ReqStop() {
|
||||
rpty.tty.Close()
|
||||
rpty.cmd.Process.Kill()
|
||||
// don't check for a write error. the os pipe's buffer must be large enough
|
||||
unix.Write(rpty.pfd[1], []byte{0})
|
||||
}
|
||||
|
||||
func (rpx *ClientRpx) ReqStop() {
|
||||
@ -1077,7 +1080,7 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
cts.discon_mtx.Lock()
|
||||
|
||||
if (logmsg) {
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Disconnecting from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Preparing to disconnect from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
|
||||
}
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
@ -1088,7 +1091,7 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
cts.rpty_mtx.Lock()
|
||||
for _, rpty = range cts.rpty_map {
|
||||
rpty.ReqStop()
|
||||
// the loop in ReadRptyLoop() is supposed to be broken.
|
||||
// the loop in RptyLoop() is supposed to be broken.
|
||||
// let's not inform the server of this connection.
|
||||
// the server should clean up itself upon connection error
|
||||
}
|
||||
@ -1114,6 +1117,9 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
cts.remote_addr.Set("")
|
||||
// don't reset cts.local_addr_p and cts.remote_addr_p
|
||||
|
||||
if (logmsg) {
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Prepared to disconnect from server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
|
||||
}
|
||||
cts.discon_mtx.Unlock()
|
||||
}
|
||||
}
|
||||
@ -1433,7 +1439,9 @@ req_stop_and_wait_for_termination:
|
||||
cts.ReqStop()
|
||||
|
||||
wait_for_termination:
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Waiting for route tasks to stop on connection to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
|
||||
cts.route_wg.Wait() // wait until all route tasks are finished
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Removing connection to server[%d] %s", cts.cfg.Index, cts.cfg.ServerAddrs[cts.cfg.Index])
|
||||
cts.C.RemoveClientConn(cts)
|
||||
|
||||
cts.C.FireConnEvent(CLIENT_EVENT_CONN_STOPPED, cts)
|
||||
@ -1505,7 +1513,7 @@ func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
|
||||
return crp
|
||||
}
|
||||
|
||||
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
func (cts *ClientConn) RptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
|
||||
var poll_fds []unix.PollFd
|
||||
var buf [2048]byte
|
||||
@ -1514,10 +1522,13 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", crp.id, cts.C.pty_shell, cts.C.pty_user)
|
||||
|
||||
cts.C.stats.rpty_sessions.Add(1)
|
||||
|
||||
poll_fds = []unix.PollFd{
|
||||
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
|
||||
unix.PollFd{Fd: int32(crp.pfd[0]), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
for {
|
||||
@ -1532,7 +1543,11 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id)
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "EOF detected on rpty(%d) stdout", crp.id)
|
||||
break
|
||||
}
|
||||
if (poll_fds[1].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "EOF detected on rpty(%d) event pipe", crp.id)
|
||||
break
|
||||
}
|
||||
|
||||
@ -1542,35 +1557,46 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if (poll_fds[1].Revents & unix.POLLIN) != 0 {
|
||||
// don't care to read the pipe as it is closed after the loop
|
||||
//unix.Read(crp.pfd[0], )
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Stop request noticed on rpty(%d) event pipe", crp.id)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ending rpty(%d) loop", crp.id)
|
||||
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
|
||||
|
||||
crp.ReqStop()
|
||||
crp.cmd.Wait()
|
||||
crp.tty.Close()
|
||||
unix.Close(crp.pfd[0])
|
||||
unix.Close(crp.pfd[1])
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
delete(cts.rpty_map, crp.id)
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.C.stats.rpty_sessions.Add(-1)
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ended rpty(%d) loop", crp.id)
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
var i int
|
||||
var err error
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
@ -1580,22 +1606,37 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
return fmt.Errorf("multiple start on rpty id %d", id)
|
||||
}
|
||||
|
||||
crp = &ClientRpty{ id: id }
|
||||
crp = &ClientRpty{ cts: cts, id: id }
|
||||
err = unix.Pipe(crp.pfd[:])
|
||||
if err != nil {
|
||||
cts.rpty_mtx.Unlock()
|
||||
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
|
||||
return fmt.Errorf("unable to create rpty(%d) event fd for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
|
||||
}
|
||||
crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user)
|
||||
if err != nil {
|
||||
cts.rpty_mtx.Unlock()
|
||||
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
|
||||
unix.Close(crp.pfd[0])
|
||||
unix.Close(crp.pfd[1])
|
||||
return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
|
||||
}
|
||||
|
||||
for i = 0; i < 2; i++ {
|
||||
var flags int
|
||||
flags, err = unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_GETFL, 0)
|
||||
if err != nil {
|
||||
unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_SETFL, flags | unix.O_NONBLOCK)
|
||||
}
|
||||
}
|
||||
|
||||
cts.rpty_map[id] = crp
|
||||
|
||||
wg.Add(1)
|
||||
go cts.ReadRptyLoop(crp, wg)
|
||||
go cts.RptyLoop(crp, wg)
|
||||
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", id, cts.C.pty_shell, cts.C.pty_user)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1853,6 +1894,8 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Starting rpx(%d) loop", crpx.id)
|
||||
|
||||
start_time = time.Now()
|
||||
|
||||
sc = bufio.NewScanner(bytes.NewReader(data))
|
||||
@ -1937,7 +1980,7 @@ done:
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error())
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) loop", crpx.id)
|
||||
|
||||
crpx.ReqStop()
|
||||
|
||||
@ -1945,6 +1988,7 @@ done:
|
||||
delete(cts.rpx_map, crpx.id)
|
||||
cts.rpx_mtx.Unlock()
|
||||
cts.C.stats.rpx_sessions.Add(-1)
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ended rpx(%d) loop", crpx.id)
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
|
||||
|
Reference in New Issue
Block a user