rpty at least working
This commit is contained in:
210
client.go
210
client.go
@ -5,9 +5,12 @@ import "context"
|
||||
import "crypto/tls"
|
||||
import "errors"
|
||||
import "fmt"
|
||||
import "io"
|
||||
import "log"
|
||||
import "net"
|
||||
import "net/http"
|
||||
import "os"
|
||||
import "os/exec"
|
||||
import "slices"
|
||||
import "strconv"
|
||||
import "strings"
|
||||
@ -17,12 +20,16 @@ import "time"
|
||||
import "unsafe"
|
||||
|
||||
import "golang.org/x/net/websocket"
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
import "google.golang.org/grpc"
|
||||
import "google.golang.org/grpc/codes"
|
||||
import "google.golang.org/grpc/credentials"
|
||||
import "google.golang.org/grpc/credentials/insecure"
|
||||
import "google.golang.org/grpc/peer"
|
||||
import "google.golang.org/grpc/status"
|
||||
|
||||
import pts "github.com/creack/pty"
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
import "github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
@ -32,6 +39,7 @@ type ClientConnMap map[ConnId]*ClientConn
|
||||
type ClientRouteMap map[RouteId]*ClientRoute
|
||||
type ClientPeerConnMap map[PeerId]*ClientPeerConn
|
||||
type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc
|
||||
type ClientRptyMap map[uint64]*ClientRpty
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
type ClientRouteConfig struct {
|
||||
@ -162,6 +170,12 @@ const (
|
||||
CLIENT_CONN_DISCONNECTED
|
||||
)
|
||||
|
||||
type ClientRpty struct {
|
||||
id uint64
|
||||
cmd *exec.Cmd
|
||||
tty *os.File
|
||||
}
|
||||
|
||||
// client connection to server
|
||||
type ClientConn struct {
|
||||
C *Client
|
||||
@ -192,6 +206,9 @@ type ClientConn struct {
|
||||
ptc_mtx sync.Mutex
|
||||
ptc_list *list.List
|
||||
|
||||
rpty_mtx sync.Mutex
|
||||
rpty_map ClientRptyMap
|
||||
|
||||
stop_req atomic.Bool
|
||||
stop_chan chan bool
|
||||
|
||||
@ -527,10 +544,6 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
|
||||
var tmout time.Duration
|
||||
var ok bool
|
||||
|
||||
// TODO: handle TTY
|
||||
// if route_option & RouteOption(ROUTE_OPTION_TTY) it must create a pseudo-tty insteaad of connecting to tcp address
|
||||
//
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
tmout = time.Duration(r.cts.C.ptc_tmout)
|
||||
@ -811,6 +824,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
|
||||
cts.stop_req.Store(false)
|
||||
cts.stop_chan = make(chan bool, 8)
|
||||
cts.ptc_list = list.New()
|
||||
cts.rpty_map = make(ClientRptyMap)
|
||||
|
||||
for i, _ = range cts.cfg.Routes {
|
||||
// override it to static regardless of the value passed in
|
||||
@ -1017,6 +1031,7 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error {
|
||||
func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
if cts.conn != nil {
|
||||
var r *ClientRoute
|
||||
var crp *ClientRpty
|
||||
|
||||
cts.discon_mtx.Lock()
|
||||
|
||||
@ -1028,6 +1043,17 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
for _, r = range cts.route_map { r.ReqStop() }
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
// arrange to clean up all rpty objects
|
||||
cts.rpty_mtx.Lock()
|
||||
for _, crp = range cts.rpty_map {
|
||||
crp.tty.Close()
|
||||
crp.cmd.Process.Kill()
|
||||
// the loop in ReadRptyLoop() is supposed to be broken.
|
||||
// let's not inform the server of this connection.
|
||||
// the server should clean up itself upon connection error
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
// don't care about double closes when this function is called from both RunTask() and ReqStop()
|
||||
cts.conn.Close()
|
||||
|
||||
@ -1303,7 +1329,7 @@ start_over:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPTY_DATA:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPTY_EOF:
|
||||
case PACKET_KIND_RPTY_SIZE:
|
||||
var x *Packet_RptyEvt
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_RptyEvt)
|
||||
@ -1311,12 +1337,12 @@ start_over:
|
||||
err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR,
|
||||
"Failed to handle %s event from %s - %s",
|
||||
pkt.Kind.String(), cts.remote_addr_p, err.Error())
|
||||
"Failed to handle %s event for rpty(%d) from %s - %s",
|
||||
pkt.Kind.String(), x.RptyEvt.Id, cts.remote_addr_p, err.Error())
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG,
|
||||
"Handled %s event from %s",
|
||||
pkt.Kind.String(), cts.remote_addr_p)
|
||||
"Handled %s event for rpty(%d) from %s",
|
||||
pkt.Kind.String(), x.RptyEvt.Id, cts.remote_addr_p)
|
||||
}
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
|
||||
@ -1394,8 +1420,172 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type
|
||||
return r.ReportPacket(pts_id, packet_type, event_data)
|
||||
}
|
||||
|
||||
|
||||
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
|
||||
var poll_fds []unix.PollFd;
|
||||
var buf []byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
poll_fds = []unix.PollFd{
|
||||
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
buf = make([]byte, 2048)
|
||||
for {
|
||||
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) { continue }
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to poll rpty(%d) stdout - %s", crp.id, err.Error())
|
||||
break
|
||||
}
|
||||
if n == 0 { // timed out
|
||||
continue
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id)
|
||||
break
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
|
||||
n, err = crp.tty.Read(buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
|
||||
|
||||
crp.tty.Close() // don't care about multiple closes
|
||||
crp.cmd.Process.Kill()
|
||||
crp.cmd.Wait()
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
delete(cts.rpty_map, crp.id)
|
||||
cts.rpty_mtx.Unlock()
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
var err error
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
_, ok = cts.rpty_map[id]
|
||||
if ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("multiple start on rpty id %d", id)
|
||||
}
|
||||
|
||||
crp = &ClientRpty{ id: id }
|
||||
crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user)
|
||||
if err != nil {
|
||||
cts.rpty_mtx.Unlock()
|
||||
cts.psc.Send(MakeRptyStopPacket(id, err.Error()))
|
||||
return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error())
|
||||
}
|
||||
|
||||
cts.rpty_map[id] = crp
|
||||
|
||||
wg.Add(1)
|
||||
go cts.ReadRptyLoop(crp, wg)
|
||||
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", id, cts.C.pty_shell, cts.C.pty_user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StopRpty(id uint64) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
crp.tty.Close() // to break ReadRptyLoop()
|
||||
crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop()
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
crp.tty.Write(data)
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
var flds []string
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
flds = strings.Split(string(data), " ")
|
||||
if len(flds) == 2 {
|
||||
var rows int
|
||||
var cols int
|
||||
rows, _ = strconv.Atoi(flds[0])
|
||||
cols, _ = strconv.Atoi(flds[1])
|
||||
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
|
||||
// TODO:
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPTY_START:
|
||||
return cts.StartRpty(evt.Id, &cts.C.wg)
|
||||
|
||||
case PACKET_KIND_RPTY_STOP:
|
||||
return cts.StopRpty(evt.Id)
|
||||
|
||||
case PACKET_KIND_RPTY_DATA:
|
||||
return cts.WriteRpty(evt.Id, evt.Data)
|
||||
|
||||
case PACKET_KIND_RPTY_SIZE:
|
||||
return cts.WriteRptySize(evt.Id, evt.Data)
|
||||
}
|
||||
|
||||
// ignore other packet types
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user