rpty at least working

This commit is contained in:
2025-08-12 02:50:10 +09:00
parent 05cb0823b4
commit d818acc53d
16 changed files with 568 additions and 403 deletions

210
client.go
View File

@ -5,9 +5,12 @@ import "context"
import "crypto/tls"
import "errors"
import "fmt"
import "io"
import "log"
import "net"
import "net/http"
import "os"
import "os/exec"
import "slices"
import "strconv"
import "strings"
@ -17,12 +20,16 @@ import "time"
import "unsafe"
import "golang.org/x/net/websocket"
import "golang.org/x/sys/unix"
import "google.golang.org/grpc"
import "google.golang.org/grpc/codes"
import "google.golang.org/grpc/credentials"
import "google.golang.org/grpc/credentials/insecure"
import "google.golang.org/grpc/peer"
import "google.golang.org/grpc/status"
import pts "github.com/creack/pty"
import "github.com/prometheus/client_golang/prometheus"
import "github.com/prometheus/client_golang/prometheus/promhttp"
@ -32,6 +39,7 @@ type ClientConnMap map[ConnId]*ClientConn
type ClientRouteMap map[RouteId]*ClientRoute
type ClientPeerConnMap map[PeerId]*ClientPeerConn
type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc
type ClientRptyMap map[uint64]*ClientRpty
// --------------------------------------------------------------------
type ClientRouteConfig struct {
@ -162,6 +170,12 @@ const (
CLIENT_CONN_DISCONNECTED
)
type ClientRpty struct {
id uint64
cmd *exec.Cmd
tty *os.File
}
// client connection to server
type ClientConn struct {
C *Client
@ -192,6 +206,9 @@ type ClientConn struct {
ptc_mtx sync.Mutex
ptc_list *list.List
rpty_mtx sync.Mutex
rpty_map ClientRptyMap
stop_req atomic.Bool
stop_chan chan bool
@ -527,10 +544,6 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
var tmout time.Duration
var ok bool
// TODO: handle TTY
// if route_option & RouteOption(ROUTE_OPTION_TTY) it must create a pseudo-tty insteaad of connecting to tcp address
//
defer wg.Done()
tmout = time.Duration(r.cts.C.ptc_tmout)
@ -811,6 +824,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
cts.stop_req.Store(false)
cts.stop_chan = make(chan bool, 8)
cts.ptc_list = list.New()
cts.rpty_map = make(ClientRptyMap)
for i, _ = range cts.cfg.Routes {
// override it to static regardless of the value passed in
@ -1017,6 +1031,7 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error {
func (cts *ClientConn) disconnect_from_server(logmsg bool) {
if cts.conn != nil {
var r *ClientRoute
var crp *ClientRpty
cts.discon_mtx.Lock()
@ -1028,6 +1043,17 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
for _, r = range cts.route_map { r.ReqStop() }
cts.route_mtx.Unlock()
// arrange to clean up all rpty objects
cts.rpty_mtx.Lock()
for _, crp = range cts.rpty_map {
crp.tty.Close()
crp.cmd.Process.Kill()
// the loop in ReadRptyLoop() is supposed to be broken.
// let's not inform the server of this connection.
// the server should clean up itself upon connection error
}
cts.rpty_mtx.Unlock()
// don't care about double closes when this function is called from both RunTask() and ReqStop()
cts.conn.Close()
@ -1303,7 +1329,7 @@ start_over:
fallthrough
case PACKET_KIND_RPTY_DATA:
fallthrough
case PACKET_KIND_RPTY_EOF:
case PACKET_KIND_RPTY_SIZE:
var x *Packet_RptyEvt
var ok bool
x, ok = pkt.U.(*Packet_RptyEvt)
@ -1311,12 +1337,12 @@ start_over:
err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR,
"Failed to handle %s event from %s - %s",
pkt.Kind.String(), cts.remote_addr_p, err.Error())
"Failed to handle %s event for rpty(%d) from %s - %s",
pkt.Kind.String(), x.RptyEvt.Id, cts.remote_addr_p, err.Error())
} else {
cts.C.log.Write(cts.Sid, LOG_DEBUG,
"Handled %s event from %s",
pkt.Kind.String(), cts.remote_addr_p)
"Handled %s event for rpty(%d) from %s",
pkt.Kind.String(), x.RptyEvt.Id, cts.remote_addr_p)
}
} else {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
@ -1394,8 +1420,172 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type
return r.ReportPacket(pts_id, packet_type, event_data)
}
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
var poll_fds []unix.PollFd;
var buf []byte
var n int
var err error
defer wg.Done()
poll_fds = []unix.PollFd{
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
}
buf = make([]byte, 2048)
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
if errors.Is(err, unix.EINTR) { continue }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to poll rpty(%d) stdout - %s", crp.id, err.Error())
break
}
if n == 0 { // timed out
continue
}
if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id)
break
}
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
n, err = crp.tty.Read(buf)
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
}
break
}
if n > 0 {
err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error())
break
}
}
}
}
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
crp.tty.Close() // don't care about multiple closes
crp.cmd.Process.Kill()
crp.cmd.Wait()
cts.rpty_mtx.Lock()
delete(cts.rpty_map, crp.id)
cts.rpty_mtx.Unlock()
}
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
var crp *ClientRpty
var ok bool
var err error
cts.rpty_mtx.Lock()
_, ok = cts.rpty_map[id]
if ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("multiple start on rpty id %d", id)
}
crp = &ClientRpty{ id: id }
crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user)
if err != nil {
cts.rpty_mtx.Unlock()
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
}
cts.rpty_map[id] = crp
wg.Add(1)
go cts.ReadRptyLoop(crp, wg)
cts.rpty_mtx.Unlock()
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", id, cts.C.pty_shell, cts.C.pty_user)
return nil
}
func (cts *ClientConn) StopRpty(id uint64) error {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Close() // to break ReadRptyLoop()
crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop()
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Write(data)
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
var crp *ClientRpty
var ok bool
var flds []string
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown rpty id %d", id)
}
flds = strings.Split(string(data), " ")
if len(flds) == 2 {
var rows int
var cols int
rows, _ = strconv.Atoi(flds[0])
cols, _ = strconv.Atoi(flds[1])
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
}
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
// TODO:
switch packet_type {
case PACKET_KIND_RPTY_START:
return cts.StartRpty(evt.Id, &cts.C.wg)
case PACKET_KIND_RPTY_STOP:
return cts.StopRpty(evt.Id)
case PACKET_KIND_RPTY_DATA:
return cts.WriteRpty(evt.Id, evt.Data)
case PACKET_KIND_RPTY_SIZE:
return cts.WriteRptySize(evt.Id, evt.Data)
}
// ignore other packet types
return nil
}