227 lines
5.5 KiB
Go
227 lines
5.5 KiB
Go
package hodu
|
|
|
|
import "errors"
|
|
import "fmt"
|
|
import "io"
|
|
import "strconv"
|
|
import "strings"
|
|
import "sync"
|
|
|
|
import pts "github.com/creack/pty"
|
|
import "golang.org/x/sys/unix"
|
|
|
|
// rpty
|
|
func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
|
|
var crp *ClientRpty
|
|
var ok bool
|
|
|
|
cts.rpty_mtx.Lock()
|
|
crp, ok = cts.rpty_map[id]
|
|
cts.rpty_mtx.Unlock()
|
|
|
|
if !ok { crp = nil }
|
|
return crp
|
|
}
|
|
|
|
func (cts *ClientConn) RptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
|
|
|
var poll_fds []unix.PollFd
|
|
var buf [2048]byte
|
|
var n int
|
|
var out_revents int16
|
|
var sig_revents int16
|
|
var err error
|
|
|
|
defer wg.Done()
|
|
|
|
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", crp.id, cts.C.pty_shell, cts.C.pty_user)
|
|
|
|
cts.C.stats.rpty_sessions.Add(1)
|
|
|
|
poll_fds = []unix.PollFd{
|
|
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
|
|
unix.PollFd{Fd: int32(crp.pfd[0]), Events: unix.POLLIN},
|
|
}
|
|
|
|
for {
|
|
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
|
|
if err != nil {
|
|
if errors.Is(err, unix.EINTR) { continue }
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to poll rpty(%d) stdout - %s", crp.id, err.Error())
|
|
break
|
|
}
|
|
if n == 0 { // timed out
|
|
continue
|
|
}
|
|
|
|
out_revents = poll_fds[0].Revents
|
|
sig_revents = poll_fds[1].Revents
|
|
|
|
if (out_revents & unix.POLLIN) != 0 {
|
|
n, err = crp.tty.Read(buf[:])
|
|
if n > 0 {
|
|
var err2 error
|
|
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
|
if err2 != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send %s from rpty(%d) stdout to server - %s", PACKET_KIND_RPTY_DATA.String(), crp.id, err2.Error())
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
if !errors.Is(err, io.EOF) {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if (sig_revents & unix.POLLIN) != 0 {
|
|
// don't care to read the pipe as it is closed after the loop
|
|
//unix.Read(crp.pfd[0], )
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Stop request noticed on rpty(%d) signal pipe", crp.id)
|
|
break
|
|
}
|
|
if (out_revents & (unix.POLLERR | unix.POLLNVAL)) != 0 {
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Error detected on rpty(%d) stdout", crp.id)
|
|
break
|
|
}
|
|
if (sig_revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) signal pipe", crp.id)
|
|
break
|
|
}
|
|
if (out_revents & unix.POLLHUP) != 0 && (out_revents & unix.POLLIN) == 0 {
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id)
|
|
break
|
|
}
|
|
}
|
|
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ending rpty(%d) loop", crp.id)
|
|
err = cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
|
|
if err != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_WARN, "Failed to send %s from rpty(%d) to server - %s", PACKET_KIND_RPTY_STOP.String(), crp.id, err.Error())
|
|
}
|
|
|
|
crp.ReqStop()
|
|
crp.cmd.Wait()
|
|
crp.tty.Close()
|
|
unix.Close(crp.pfd[0])
|
|
unix.Close(crp.pfd[1])
|
|
|
|
cts.rpty_mtx.Lock()
|
|
delete(cts.rpty_map, crp.id)
|
|
cts.rpty_mtx.Unlock()
|
|
|
|
cts.C.stats.rpty_sessions.Add(-1)
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Ended rpty(%d) loop", crp.id)
|
|
}
|
|
|
|
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
|
var crp *ClientRpty
|
|
var ok bool
|
|
var i int
|
|
var err error
|
|
|
|
cts.rpty_mtx.Lock()
|
|
_, ok = cts.rpty_map[id]
|
|
if ok {
|
|
cts.rpty_mtx.Unlock()
|
|
return fmt.Errorf("multiple start on rpty id %d", id)
|
|
}
|
|
|
|
crp = &ClientRpty{ cts: cts, id: id }
|
|
err = unix.Pipe(crp.pfd[:])
|
|
if err != nil {
|
|
cts.rpty_mtx.Unlock()
|
|
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
|
|
return fmt.Errorf("unable to create rpty(%d) event fd for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
|
|
}
|
|
crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user)
|
|
if err != nil {
|
|
cts.rpty_mtx.Unlock()
|
|
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
|
|
unix.Close(crp.pfd[0])
|
|
unix.Close(crp.pfd[1])
|
|
return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
|
|
}
|
|
|
|
for i = 0; i < 2; i++ {
|
|
var flags int
|
|
flags, err = unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_GETFL, 0)
|
|
if err == nil {
|
|
unix.FcntlInt(uintptr(crp.pfd[i]), unix.F_SETFL, flags | unix.O_NONBLOCK)
|
|
}
|
|
}
|
|
|
|
cts.rpty_map[id] = crp
|
|
|
|
wg.Add(1)
|
|
go cts.RptyLoop(crp, wg)
|
|
|
|
cts.rpty_mtx.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) StopRpty(id uint64) error {
|
|
var crp *ClientRpty
|
|
|
|
crp = cts.FindClientRptyById(id)
|
|
if crp == nil {
|
|
return fmt.Errorf("unknown rpty id %d", id)
|
|
}
|
|
|
|
crp.ReqStop()
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
|
|
var crp *ClientRpty
|
|
|
|
crp = cts.FindClientRptyById(id)
|
|
if crp == nil {
|
|
return fmt.Errorf("unknown rpty id %d", id)
|
|
}
|
|
|
|
crp.tty.Write(data)
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
|
|
var crp *ClientRpty
|
|
var flds []string
|
|
|
|
crp = cts.FindClientRptyById(id)
|
|
if crp == nil {
|
|
return fmt.Errorf("unknown rpty id %d", id)
|
|
}
|
|
|
|
flds = strings.Split(string(data), " ")
|
|
if len(flds) == 2 {
|
|
var rows int
|
|
var cols int
|
|
rows, _ = strconv.Atoi(flds[0])
|
|
cols, _ = strconv.Atoi(flds[1])
|
|
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
|
|
|
|
switch packet_type {
|
|
case PACKET_KIND_RPTY_START:
|
|
return cts.StartRpty(evt.Id, &cts.C.wg)
|
|
|
|
case PACKET_KIND_RPTY_STOP:
|
|
return cts.StopRpty(evt.Id)
|
|
|
|
case PACKET_KIND_RPTY_DATA:
|
|
return cts.WriteRpty(evt.Id, evt.Data)
|
|
|
|
case PACKET_KIND_RPTY_SIZE:
|
|
return cts.WriteRptySize(evt.Id, evt.Data)
|
|
}
|
|
|
|
// ignore other packet types
|
|
return nil
|
|
}
|