Files
hodu/client-cts-rpty.go

227 lines
5.5 KiB
Go

package hodu
import "errors"
import "fmt"
import "io"
import "strconv"
import "strings"
import "sync"
import pts "github.com/creack/pty"
import "golang.org/x/sys/unix"
// rpty
func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
cts.rpty_mtx.Unlock()
if !ok { crp = nil }
return crp
}
func (cts *ClientConn) RptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
var poll_fds []unix.PollFd
var buf [2048]byte
var n int
var out_revents int16
var sig_revents int16
var err error
defer wg.Done()
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", crp.id, cts.C.pty_shell, cts.C.pty_user)
cts.C.stats.rpty_sessions.Add(1)
poll_fds = []unix.PollFd{
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
unix.PollFd{Fd: int32(crp.pfd[0]), Events: unix.POLLIN},
}
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
if errors.Is(err, unix.EINTR) { continue }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to poll rpty(%d) stdout - %s", crp.id, err.Error())
break
}
if n == 0 { // timed out
continue
}
out_revents = poll_fds[0].Revents
sig_revents = poll_fds[1].Revents
if (out_revents & unix.POLLIN) != 0 {
n, err = crp.tty.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send %s from rpty(%d) stdout to server - %s", PACKET_KIND_RPTY_DATA.String(), crp.id, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
}
break
}
}
if (sig_revents & unix.POLLIN) != 0 {
// don't care to read the pipe as it is closed after the loop
//unix.Read(crp.pfd[0], )
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Stop request noticed on rpty(%d) signal pipe", crp.id)
break
}
if (out_revents & (unix.POLLERR | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Error detected on rpty(%d) stdout", crp.id)
break
}
if (sig_revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) signal pipe", crp.id)
break
}
if (out_revents & unix.POLLHUP) != 0 && (out_revents & unix.POLLIN) == 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id)
break
}
}
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ending rpty(%d) loop", crp.id)
err = cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_WARN, "Failed to send %s from rpty(%d) to server - %s", PACKET_KIND_RPTY_STOP.String(), crp.id, err.Error())
}
crp.ReqStop()
crp.cmd.Wait()
crp.tty.Close()
unix.Close(crp.pfd[0])
unix.Close(crp.pfd[1])
cts.rpty_mtx.Lock()
delete(cts.rpty_map, crp.id)
cts.rpty_mtx.Unlock()
cts.C.stats.rpty_sessions.Add(-1)
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ended rpty(%d) loop", crp.id)
}
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
var crp *ClientRpty
var ok bool
var i int
var err error
cts.rpty_mtx.Lock()
_, ok = cts.rpty_map[id]
if ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("multiple start on rpty id %d", id)
}
crp = &ClientRpty{ cts: cts, id: id }
err = unix.Pipe(crp.pfd[:])
if err != nil {
cts.rpty_mtx.Unlock()
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
return fmt.Errorf("unable to create rpty(%d) event fd for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
}
crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user)
if err != nil {
cts.rpty_mtx.Unlock()
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
unix.Close(crp.pfd[0])
unix.Close(crp.pfd[1])
return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
}
for i = 0; i < 2; i++ {
var flags int
flags, err = unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_GETFL, 0)
if err == nil {
unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_SETFL, flags | unix.O_NONBLOCK)
}
}
cts.rpty_map[id] = crp
wg.Add(1)
go cts.RptyLoop(crp, wg)
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) StopRpty(id uint64) error {
var crp *ClientRpty
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
crp.ReqStop()
return nil
}
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
var crp *ClientRpty
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Write(data)
return nil
}
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
var crp *ClientRpty
var flds []string
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
flds = strings.Split(string(data), " ")
if len(flds) == 2 {
var rows int
var cols int
rows, _ = strconv.Atoi(flds[0])
cols, _ = strconv.Atoi(flds[1])
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
}
return nil
}
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
switch packet_type {
case PACKET_KIND_RPTY_START:
return cts.StartRpty(evt.Id, &cts.C.wg)
case PACKET_KIND_RPTY_STOP:
return cts.StopRpty(evt.Id)
case PACKET_KIND_RPTY_DATA:
return cts.WriteRpty(evt.Id, evt.Data)
case PACKET_KIND_RPTY_SIZE:
return cts.WriteRptySize(evt.Id, evt.Data)
}
// ignore other packet types
return nil
}