diff --git a/Makefile b/Makefile index fed4981..2460349 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,7 @@ SRCS=\ hodu_grpc.pb.go \ jwt.go \ packet.go \ + pty.go \ server.go \ server-ctl.go \ server-metrics.go \ diff --git a/atom.go b/atom.go index f715daa..4c0c423 100644 --- a/atom.go +++ b/atom.go @@ -13,7 +13,7 @@ func (av* Atom[T]) Set(v T) { func (av* Atom[T]) Get() T { var v interface{} v = av.val.Load() - if v == nil { + if v == nil { var t T return t // return the zero-value } diff --git a/client-pty.go b/client-pty.go index eb35ea7..87219eb 100644 --- a/client-pty.go +++ b/client-pty.go @@ -2,16 +2,13 @@ package hodu import "encoding/json" import "errors" -import "fmt" import "io" import "net/http" import "os" import "os/exec" -import "os/user" import "strconv" import "strings" import "sync" -import "syscall" import "text/template" import pts "github.com/creack/pty" @@ -35,74 +32,11 @@ func (pty *client_pty_ws) Identity() string { return pty.Id } -func (pty *client_pty_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { - var msg []byte - var err error - - msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) - if err == nil { err = websocket.Message.Send(ws, msg) } - return err -} - - -func (pty *client_pty_ws) connect_pty(username string, password string) (*exec.Cmd, *os.File, error) { - var c *Client - var cmd *exec.Cmd - var tty *os.File - var err error - - // username and password are not used yet. - c = pty.C - - if c.pty_shell == "" { - return nil, nil, fmt.Errorf("blank pty shell") - } - - cmd = exec.Command(c.pty_shell); - if c.pty_user != "" { - var uid int - var gid int - var u *user.User - - u, err = user.Lookup(c.pty_user) - if err != nil { return nil, nil, err } - - uid, _ = strconv.Atoi(u.Uid) - gid, _ = strconv.Atoi(u.Gid) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uint32(uid), - Gid: uint32(gid), - }, - Setsid: true, - } - cmd.Dir = u.HomeDir - cmd.Env = append(cmd.Env, - "HOME=" + u.HomeDir, - "LOGNAME=" + u.Username, - "PATH=" + os.Getenv("PATH"), - "SHELL=" + c.pty_shell, - "TERM=xterm", - "USER=" + u.Username, - ) - } - - tty, err = pts.Start(cmd) - if err != nil { - return nil, nil, err - } - - //syscall.SetNonblock(int(tty.Fd()), true); - unix.SetNonblock(int(tty.Fd()), true); - - return cmd, tty, nil -} - func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var c *Client var req *http.Request - var username string - var password string + //var username string + //var password string var in *os.File var out *os.File var tty *os.File @@ -161,7 +95,7 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { break } if n > 0 { - err = pty.send_ws_data(ws, "iov", string(buf[:n])) + err = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) if err != nil { c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) break @@ -186,21 +120,21 @@ ws_recv_loop: switch ev.Type { case "open": if tty == nil && len(ev.Data) == 2 { - username = string(ev.Data[0]) - password = string(ev.Data[1]) + //username = string(ev.Data[0]) + //password = string(ev.Data[1]) wg.Add(1) go func() { var err error defer wg.Done() - cmd, tty, err = pty.connect_pty(username, password) + cmd, tty, err = connect_pty(c.pty_shell, c.pty_user) if err != nil { c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to connect pty - %s", req.RemoteAddr, err.Error()) - pty.send_ws_data(ws, "error", err.Error()) + send_ws_data_for_xterm(ws, "error", err.Error()) ws.Close() // dirty way to flag out the error } else { - err = pty.send_ws_data(ws, "status", "opened") + err = send_ws_data_for_xterm(ws, "status", "opened") if err != nil { c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to write opened event to websocket - %s", req.RemoteAddr, err.Error()) ws.Close() // dirty way to flag out the error @@ -245,7 +179,7 @@ ws_recv_loop: } if tty != nil { - err = pty.send_ws_data(ws, "status", "closed") + err = send_ws_data_for_xterm(ws, "status", "closed") if err != nil { goto done } } diff --git a/client.go b/client.go index 16f68d0..71cbfb7 100644 --- a/client.go +++ b/client.go @@ -5,9 +5,12 @@ import "context" import "crypto/tls" import "errors" import "fmt" +import "io" import "log" import "net" import "net/http" +import "os" +import "os/exec" import "slices" import "strconv" import "strings" @@ -17,12 +20,16 @@ import "time" import "unsafe" import "golang.org/x/net/websocket" +import "golang.org/x/sys/unix" + import "google.golang.org/grpc" import "google.golang.org/grpc/codes" import "google.golang.org/grpc/credentials" import "google.golang.org/grpc/credentials/insecure" import "google.golang.org/grpc/peer" import "google.golang.org/grpc/status" + +import pts "github.com/creack/pty" import "github.com/prometheus/client_golang/prometheus" import "github.com/prometheus/client_golang/prometheus/promhttp" @@ -32,6 +39,7 @@ type ClientConnMap map[ConnId]*ClientConn type ClientRouteMap map[RouteId]*ClientRoute type ClientPeerConnMap map[PeerId]*ClientPeerConn type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc +type ClientRptyMap map[uint64]*ClientRpty // -------------------------------------------------------------------- type ClientRouteConfig struct { @@ -162,6 +170,12 @@ const ( CLIENT_CONN_DISCONNECTED ) +type ClientRpty struct { + id uint64 + cmd *exec.Cmd + tty *os.File +} + // client connection to server type ClientConn struct { C *Client @@ -192,6 +206,9 @@ type ClientConn struct { ptc_mtx sync.Mutex ptc_list *list.List + rpty_mtx sync.Mutex + rpty_map ClientRptyMap + stop_req atomic.Bool stop_chan chan bool @@ -527,10 +544,6 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts var tmout time.Duration var ok bool -// TODO: handle TTY -// if route_option & RouteOption(ROUTE_OPTION_TTY) it must create a pseudo-tty insteaad of connecting to tcp address -// - defer wg.Done() tmout = time.Duration(r.cts.C.ptc_tmout) @@ -811,6 +824,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn { cts.stop_req.Store(false) cts.stop_chan = make(chan bool, 8) cts.ptc_list = list.New() + cts.rpty_map = make(ClientRptyMap) for i, _ = range cts.cfg.Routes { // override it to static regardless of the value passed in @@ -1017,6 +1031,7 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error { func (cts *ClientConn) disconnect_from_server(logmsg bool) { if cts.conn != nil { var r *ClientRoute + var crp *ClientRpty cts.discon_mtx.Lock() @@ -1028,6 +1043,17 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) { for _, r = range cts.route_map { r.ReqStop() } cts.route_mtx.Unlock() + // arrange to clean up all rpty objects + cts.rpty_mtx.Lock() + for _, crp = range cts.rpty_map { + crp.tty.Close() + crp.cmd.Process.Kill() + // the loop in ReadRptyLoop() is supposed to be broken. + // let's not inform the server of this connection. + // the server should clean up itself upon connection error + } + cts.rpty_mtx.Unlock() + // don't care about double closes when this function is called from both RunTask() and ReqStop() cts.conn.Close() @@ -1303,7 +1329,7 @@ start_over: fallthrough case PACKET_KIND_RPTY_DATA: fallthrough - case PACKET_KIND_RPTY_EOF: + case PACKET_KIND_RPTY_SIZE: var x *Packet_RptyEvt var ok bool x, ok = pkt.U.(*Packet_RptyEvt) @@ -1311,12 +1337,12 @@ start_over: err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, - "Failed to handle %s event from %s - %s", - pkt.Kind.String(), cts.remote_addr_p, err.Error()) + "Failed to handle %s event for rpty(%d) from %s - %s", + pkt.Kind.String(), x.RptyEvt.Id, cts.remote_addr_p, err.Error()) } else { cts.C.log.Write(cts.Sid, LOG_DEBUG, - "Handled %s event from %s", - pkt.Kind.String(), cts.remote_addr_p) + "Handled %s event for rpty(%d) from %s", + pkt.Kind.String(), x.RptyEvt.Id, cts.remote_addr_p) } } else { cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) @@ -1394,8 +1420,172 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type return r.ReportPacket(pts_id, packet_type, event_data) } + +func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { + + var poll_fds []unix.PollFd; + var buf []byte + var n int + var err error + + defer wg.Done() + + poll_fds = []unix.PollFd{ + unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN}, + } + + buf = make([]byte, 2048) + for { + n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely + if err != nil { + if errors.Is(err, unix.EINTR) { continue } + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to poll rpty(%d) stdout - %s", crp.id, err.Error()) + break + } + if n == 0 { // timed out + continue + } + + if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { + cts.C.log.Write(cts.Sid, LOG_DEBUG, "EOF detected on rpty(%d) stdout", crp.id) + break + } + + if (poll_fds[0].Revents & unix.POLLIN) != 0 { + n, err = crp.tty.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) { + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error()) + } + + break + } + if n > 0 { + err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n])) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error()) + break + } + } + } + } + + cts.psc.Send(MakeRptyStopPacket(crp.id, "")) + cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id) + + crp.tty.Close() // don't care about multiple closes + crp.cmd.Process.Kill() + crp.cmd.Wait() + + cts.rpty_mtx.Lock() + delete(cts.rpty_map, crp.id) + cts.rpty_mtx.Unlock() +} + +func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error { + var crp *ClientRpty + var ok bool + var err error + + cts.rpty_mtx.Lock() + _, ok = cts.rpty_map[id] + if ok { + cts.rpty_mtx.Unlock() + return fmt.Errorf("multiple start on rpty id %d", id) + } + + crp = &ClientRpty{ id: id } + crp.cmd, crp.tty, err = connect_pty(cts.C.pty_shell, cts.C.pty_user) + if err != nil { + cts.rpty_mtx.Unlock() + cts.psc.Send(MakeRptyStopPacket(id, err.Error())) + return fmt.Errorf("unable to start rpty(%d) for %s(%s) - %s", id, cts.C.pty_shell, cts.C.pty_user, err.Error()) + } + + cts.rpty_map[id] = crp + + wg.Add(1) + go cts.ReadRptyLoop(crp, wg) + + cts.rpty_mtx.Unlock() + + cts.C.log.Write(cts.Sid, LOG_INFO, "Started rpty(%d) for %s(%s)", id, cts.C.pty_shell, cts.C.pty_user) + return nil +} + +func (cts *ClientConn) StopRpty(id uint64) error { + var crp *ClientRpty + var ok bool + + cts.rpty_mtx.Lock() + crp, ok = cts.rpty_map[id] + if !ok { + cts.rpty_mtx.Unlock() + return fmt.Errorf("unknown rpty id %d", id) + } + + crp.tty.Close() // to break ReadRptyLoop() + crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop() + cts.rpty_mtx.Unlock() + return nil +} + +func (cts *ClientConn) WriteRpty(id uint64, data []byte) error { + var crp *ClientRpty + var ok bool + + cts.rpty_mtx.Lock() + crp, ok = cts.rpty_map[id] + if !ok { + cts.rpty_mtx.Unlock() + return fmt.Errorf("unknown rpty id %d", id) + } + + crp.tty.Write(data) + cts.rpty_mtx.Unlock() + return nil +} + +func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error { + var crp *ClientRpty + var ok bool + var flds []string + + cts.rpty_mtx.Lock() + crp, ok = cts.rpty_map[id] + if !ok { + cts.rpty_mtx.Unlock() + return fmt.Errorf("unknown rpty id %d", id) + } + + flds = strings.Split(string(data), " ") + if len(flds) == 2 { + var rows int + var cols int + rows, _ = strconv.Atoi(flds[0]) + cols, _ = strconv.Atoi(flds[1]) + pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)}) + } + cts.rpty_mtx.Unlock() + return nil +} + func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error { -// TODO: + switch packet_type { + case PACKET_KIND_RPTY_START: + return cts.StartRpty(evt.Id, &cts.C.wg) + + case PACKET_KIND_RPTY_STOP: + return cts.StopRpty(evt.Id) + + case PACKET_KIND_RPTY_DATA: + return cts.WriteRpty(evt.Id, evt.Data) + + case PACKET_KIND_RPTY_SIZE: + return cts.WriteRptySize(evt.Id, evt.Data) + } + + // ignore other packet types return nil } diff --git a/hodu.go b/hodu.go index 5fdd4f7..eed0605 100644 --- a/hodu.go +++ b/hodu.go @@ -184,8 +184,6 @@ func word_to_route_option(word string) RouteOption { return RouteOption(ROUTE_OPTION_TCP6) case "tcp": return RouteOption(ROUTE_OPTION_TCP) - case "tty": - return RouteOption(ROUTE_OPTION_TTY) case "http": return RouteOption(ROUTE_OPTION_HTTP) case "https": @@ -217,7 +215,6 @@ func (option RouteOption) String() string { if option & RouteOption(ROUTE_OPTION_TCP6) != 0 { str += " tcp6" } if option & RouteOption(ROUTE_OPTION_TCP4) != 0 { str += " tcp4" } if option & RouteOption(ROUTE_OPTION_TCP) != 0 { str += " tcp" } - if option & RouteOption(ROUTE_OPTION_TTY) != 0 { str += " tty" } if option & RouteOption(ROUTE_OPTION_HTTP) != 0 { str += " http" } if option & RouteOption(ROUTE_OPTION_HTTPS) != 0 { str += " https" } if option & RouteOption(ROUTE_OPTION_SSH) != 0 { str += " ssh" } diff --git a/hodu.pb.go b/hodu.pb.go index c89f329..7252802 100644 --- a/hodu.pb.go +++ b/hodu.pb.go @@ -28,7 +28,7 @@ const ( ROUTE_OPTION_TCP ROUTE_OPTION = 1 ROUTE_OPTION_TCP4 ROUTE_OPTION = 2 ROUTE_OPTION_TCP6 ROUTE_OPTION = 4 - ROUTE_OPTION_TTY ROUTE_OPTION = 8 + ROUTE_OPTION_UNUSED ROUTE_OPTION = 8 ROUTE_OPTION_HTTP ROUTE_OPTION = 16 ROUTE_OPTION_HTTPS ROUTE_OPTION = 32 ROUTE_OPTION_SSH ROUTE_OPTION = 64 @@ -41,7 +41,7 @@ var ( 1: "TCP", 2: "TCP4", 4: "TCP6", - 8: "TTY", + 8: "UNUSED", 16: "HTTP", 32: "HTTPS", 64: "SSH", @@ -51,7 +51,7 @@ var ( "TCP": 1, "TCP4": 2, "TCP6": 4, - "TTY": 8, + "UNUSED": 8, "HTTP": 16, "HTTPS": 32, "SSH": 64, @@ -103,11 +103,8 @@ const ( PACKET_KIND_CONN_NOTICE PACKET_KIND = 13 PACKET_KIND_RPTY_START PACKET_KIND = 14 PACKET_KIND_RPTY_STOP PACKET_KIND = 15 - PACKET_KIND_RPTY_STARTED PACKET_KIND = 16 - PACKET_KIND_RPTY_STOPPED PACKET_KIND = 17 - PACKET_KIND_RPTY_ABORTED PACKET_KIND = 18 - PACKET_KIND_RPTY_EOF PACKET_KIND = 19 - PACKET_KIND_RPTY_DATA PACKET_KIND = 20 + PACKET_KIND_RPTY_DATA PACKET_KIND = 16 + PACKET_KIND_RPTY_SIZE PACKET_KIND = 17 ) // Enum value maps for PACKET_KIND. @@ -128,11 +125,8 @@ var ( 13: "CONN_NOTICE", 14: "RPTY_START", 15: "RPTY_STOP", - 16: "RPTY_STARTED", - 17: "RPTY_STOPPED", - 18: "RPTY_ABORTED", - 19: "RPTY_EOF", - 20: "RPTY_DATA", + 16: "RPTY_DATA", + 17: "RPTY_SIZE", } PACKET_KIND_value = map[string]int32{ "RESERVED": 0, @@ -150,11 +144,8 @@ var ( "CONN_NOTICE": 13, "RPTY_START": 14, "RPTY_STOP": 15, - "RPTY_STARTED": 16, - "RPTY_STOPPED": 17, - "RPTY_ABORTED": 18, - "RPTY_EOF": 19, - "RPTY_DATA": 20, + "RPTY_DATA": 16, + "RPTY_SIZE": 17, } ) @@ -875,17 +866,18 @@ const file_hodu_proto_rawDesc = "" + "\bConnNoti\x18\a \x01(\v2\v.ConnNoticeH\x00R\bConnNoti\x12&\n" + "\aRptyEvt\x18\b \x01(\v2\n" + ".RptyEventH\x00R\aRptyEvtB\x03\n" + - "\x01U*^\n" + + "\x01U*a\n" + "\fROUTE_OPTION\x12\n" + "\n" + "\x06UNSPEC\x10\x00\x12\a\n" + "\x03TCP\x10\x01\x12\b\n" + "\x04TCP4\x10\x02\x12\b\n" + - "\x04TCP6\x10\x04\x12\a\n" + - "\x03TTY\x10\b\x12\b\n" + + "\x04TCP6\x10\x04\x12\n" + + "\n" + + "\x06UNUSED\x10\b\x12\b\n" + "\x04HTTP\x10\x10\x12\t\n" + "\x05HTTPS\x10 \x12\a\n" + - "\x03SSH\x10@*\xd7\x02\n" + + "\x03SSH\x10@*\xa2\x02\n" + "\vPACKET_KIND\x12\f\n" + "\bRESERVED\x10\x00\x12\x0f\n" + "\vROUTE_START\x10\x01\x12\x0e\n" + @@ -904,12 +896,9 @@ const file_hodu_proto_rawDesc = "" + "\vCONN_NOTICE\x10\r\x12\x0e\n" + "\n" + "RPTY_START\x10\x0e\x12\r\n" + - "\tRPTY_STOP\x10\x0f\x12\x10\n" + - "\fRPTY_STARTED\x10\x10\x12\x10\n" + - "\fRPTY_STOPPED\x10\x11\x12\x10\n" + - "\fRPTY_ABORTED\x10\x12\x12\f\n" + - "\bRPTY_EOF\x10\x13\x12\r\n" + - "\tRPTY_DATA\x10\x142I\n" + + "\tRPTY_STOP\x10\x0f\x12\r\n" + + "\tRPTY_DATA\x10\x10\x12\r\n" + + "\tRPTY_SIZE\x10\x112I\n" + "\x04Hodu\x12\x19\n" + "\aGetSeed\x12\x05.Seed\x1a\x05.Seed\"\x00\x12&\n" + "\fPacketStream\x12\a.Packet\x1a\a.Packet\"\x00(\x010\x01B\bZ\x06./hodub\x06proto3" diff --git a/hodu.proto b/hodu.proto index c4d33c1..870a4cf 100644 --- a/hodu.proto +++ b/hodu.proto @@ -23,7 +23,7 @@ enum ROUTE_OPTION { TCP = 1; TCP4 = 2; TCP6 = 4; - TTY = 8; + UNUSED = 8; HTTP = 16; HTTPS = 32; SSH = 64; @@ -103,11 +103,8 @@ enum PACKET_KIND { RPTY_START = 14; RPTY_STOP = 15; - RPTY_STARTED = 16; - RPTY_STOPPED = 17; - RPTY_ABORTED = 18; - RPTY_EOF = 19; - RPTY_DATA = 20; + RPTY_DATA = 16; + RPTY_SIZE = 17; }; message Packet { diff --git a/jwt.go b/jwt.go index afb6d97..0ddcfe1 100644 --- a/jwt.go +++ b/jwt.go @@ -23,7 +23,7 @@ func Sign(data []byte, privkey *rsa.PrivateKey) ([]byte, error) { func Verify(data []byte, pubkey *rsa.PublicKey, sig []byte) error { var h hash.Hash - + h = crypto.SHA512.New() h.Write(data) @@ -41,7 +41,7 @@ func SignHS512(data []byte, key string) ([]byte, error) { func VerifyHS512(data []byte, key string, sig []byte) error { var h hash.Hash - + h = crypto.SHA512.New() h.Write(data) @@ -78,7 +78,7 @@ func (j *JWT[T]) SignRS512() (string, error) { h.Algo = "RS512" h.Type = "JWT" - hb, err = json.Marshal(h) + hb, err = json.Marshal(h) if err != nil { return "", err } cb, err = json.Marshal(j.claims) diff --git a/jwt_test.go b/jwt_test.go index b86303e..218db52 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -8,16 +8,16 @@ import "testing" func TestJwt(t *testing.T) { var tok string var err error - + type JWTClaim struct { Abc string `json:"abc"` Donkey string `json:"donkey"` IssuedAt int `json:"iat"` - } - + } + var jc JWTClaim jc.Abc = "def" - jc.Donkey = "kong" + jc.Donkey = "kong" jc.IssuedAt = 111 var key *rsa.PrivateKey diff --git a/packet.go b/packet.go index c658091..85c41a3 100644 --- a/packet.go +++ b/packet.go @@ -79,10 +79,14 @@ func MakeRptyStartPacket(id uint64) *Packet { return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id}}} } -func MakeRptyStopPacket(id uint64) *Packet { - return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id}}} +func MakeRptyStopPacket(id uint64, msg string) *Packet { + return &Packet{Kind: PACKET_KIND_RPTY_STOP, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: []byte(msg)}}} } func MakeRptyDataPacket(id uint64, data []byte) *Packet { - return &Packet{Kind: PACKET_KIND_RPTY_START, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}} + return &Packet{Kind: PACKET_KIND_RPTY_DATA, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}} +} + +func MakeRptySizePacket(id uint64, data []byte) *Packet { + return &Packet{Kind: PACKET_KIND_RPTY_SIZE, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}} } diff --git a/pty.go b/pty.go new file mode 100644 index 0000000..c713069 --- /dev/null +++ b/pty.go @@ -0,0 +1,71 @@ +package hodu + +import "encoding/json" +import "fmt" +import "os" +import "os/exec" +import "os/user" +import "strconv" +import "syscall" + +import pts "github.com/creack/pty" +import "golang.org/x/net/websocket" +import "golang.org/x/sys/unix" + +func connect_pty(pty_shell string, pty_user string) (*exec.Cmd, *os.File, error) { + var cmd *exec.Cmd + var tty *os.File + var err error + + if pty_shell == "" { + return nil, nil, fmt.Errorf("blank pty shell") + } + + cmd = exec.Command(pty_shell); + if pty_user != "" { + var uid int + var gid int + var u *user.User + + u, err = user.Lookup(pty_user) + if err != nil { return nil, nil, err } + + uid, _ = strconv.Atoi(u.Uid) + gid, _ = strconv.Atoi(u.Gid) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(uid), + Gid: uint32(gid), + }, + Setsid: true, + } + cmd.Dir = u.HomeDir + cmd.Env = append(cmd.Env, + "HOME=" + u.HomeDir, + "LOGNAME=" + u.Username, + "PATH=" + os.Getenv("PATH"), + "SHELL=" + pty_shell, + "TERM=xterm", + "USER=" + u.Username, + ) + } + + tty, err = pts.Start(cmd) + if err != nil { + return nil, nil, err + } + + //syscall.SetNonblock(int(tty.Fd()), true); + unix.SetNonblock(int(tty.Fd()), true); + + return cmd, tty, nil +} + +func send_ws_data_for_xterm(ws *websocket.Conn, type_val string, data string) error { + var msg []byte + var err error + msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) + if err == nil { err = websocket.Message.Send(ws, msg) } + return err +} + diff --git a/server-pty.go b/server-pty.go index 32380f6..5efac75 100644 --- a/server-pty.go +++ b/server-pty.go @@ -7,11 +7,9 @@ import "io" import "net/http" import "os" import "os/exec" -import "os/user" import "strconv" import "strings" import "sync" -import "syscall" import "text/template" import pts "github.com/creack/pty" @@ -33,6 +31,7 @@ type server_rpty_ws struct { type server_pty_xterm_file struct { ServerCtl file string + mode string } // ------------------------------------------------------ @@ -41,74 +40,11 @@ func (pty *server_pty_ws) Identity() string { return pty.Id } -func (pty *server_pty_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { - var msg []byte - var err error - - msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) - if err == nil { err = websocket.Message.Send(ws, msg) } - return err -} - - -func (pty *server_pty_ws) connect_pty(username string, password string) (*exec.Cmd, *os.File, error) { - var s *Server - var cmd *exec.Cmd - var tty *os.File - var err error - - // username and password are not used yet. - s = pty.S - - if s.pty_shell == "" { - return nil, nil, fmt.Errorf("blank pty shell") - } - - cmd = exec.Command(s.pty_shell); - if s.pty_user != "" { - var uid int - var gid int - var u *user.User - - u, err = user.Lookup(s.pty_user) - if err != nil { return nil, nil, err } - - uid, _ = strconv.Atoi(u.Uid) - gid, _ = strconv.Atoi(u.Gid) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uint32(uid), - Gid: uint32(gid), - }, - Setsid: true, - } - cmd.Dir = u.HomeDir - cmd.Env = append(cmd.Env, - "HOME=" + u.HomeDir, - "LOGNAME=" + u.Username, - "PATH=" + os.Getenv("PATH"), - "SHELL=" + s.pty_shell, - "TERM=xterm", - "USER=" + u.Username, - ) - } - - tty, err = pts.Start(cmd) - if err != nil { - return nil, nil, err - } - - //syscall.SetNonblock(int(tty.Fd()), true); - unix.SetNonblock(int(tty.Fd()), true); - - return cmd, tty, nil -} - func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var s *Server var req *http.Request - var username string - var password string + //var username string + //var password string var in *os.File var out *os.File var tty *os.File @@ -154,7 +90,7 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { s.log.Write(pty.Id, LOG_DEBUG, "[%s] EOF detected on pty stdout", req.RemoteAddr) - break; + break } if (poll_fds[0].Revents & unix.POLLIN) != 0 { @@ -166,7 +102,7 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { break } if n > 0 { - err = pty.send_ws_data(ws, "iov", string(buf[:n])) + err = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) if err != nil { s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) break @@ -191,21 +127,22 @@ ws_recv_loop: switch ev.Type { case "open": if tty == nil && len(ev.Data) == 2 { - username = string(ev.Data[0]) - password = string(ev.Data[1]) + // not using username and password for now... + //username = string(ev.Data[0]) + //password = string(ev.Data[1]) wg.Add(1) go func() { var err error defer wg.Done() - cmd, tty, err = pty.connect_pty(username, password) + cmd, tty, err = connect_pty(s.pty_shell, s.pty_user) if err != nil { s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to connect pty - %s", req.RemoteAddr, err.Error()) - pty.send_ws_data(ws, "error", err.Error()) + send_ws_data_for_xterm(ws, "error", err.Error()) ws.Close() // dirty way to flag out the error - this will make websocket.MessageReceive to fail } else { - err = pty.send_ws_data(ws, "status", "opened") + err = send_ws_data_for_xterm(ws, "status", "opened") if err != nil { s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to write 'opened' event to websocket - %s", req.RemoteAddr, err.Error()) ws.Close() // dirty way to flag out the error @@ -250,7 +187,7 @@ ws_recv_loop: } if tty != nil { - err = pty.send_ws_data(ws, "status", "closed") + err = send_ws_data_for_xterm(ws, "status", "closed") if err != nil { goto done } } @@ -276,15 +213,6 @@ func (rpty *server_rpty_ws) Identity() string { return rpty.Id } -func (rpty *server_rpty_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { - var msg []byte - var err error - - msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) - if err == nil { err = websocket.Message.Send(ws, msg) } - return err -} - func (rpty *server_rpty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var s *Server var req *http.Request @@ -292,20 +220,14 @@ func (rpty *server_rpty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var cts *ServerConn //var username string //var password string - var in *os.File - //var out *os.File - var tty *os.File var rp *ServerRpty - var cmd *exec.Cmd var wg sync.WaitGroup - var conn_ready_chan chan bool var err error s = rpty.S req = ws.Request() - conn_ready_chan = make(chan bool, 3) - token = req.FormValue("token") - if token != "" { + token = req.FormValue("client-token") + if token == "" { ws.Close() return http.StatusBadRequest, fmt.Errorf("no client token specified") } @@ -316,68 +238,6 @@ func (rpty *server_rpty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { return http.StatusBadRequest, fmt.Errorf("invalid client token - %s", token) } -// TODO: how to get notified of broken connection.... - - - wg.Add(1) - go func() { - var conn_ready bool - - defer wg.Done() - defer ws.Close() // dirty way to break the main loop - - conn_ready = <-conn_ready_chan - if conn_ready { // connected -/* - var poll_fds []unix.PollFd; - var buf []byte - var n int - var err error - - poll_fds = []unix.PollFd{ - unix.PollFd{Fd: int32(out.Fd()), Events: unix.POLLIN}, - } - - s.stats.pty_sessions.Add(1) - buf = make([]byte, 2048) - for { - n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely - if err != nil { - if errors.Is(err, unix.EINTR) { continue } - s.log.Write("", LOG_ERROR, "[%s] Failed to poll pty stdout - %s", req.RemoteAddr, err.Error()) - break - } - if n == 0 { // timed out - continue - } - - if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { - s.log.Write(pty.Id, LOG_DEBUG, "[%s] EOF detected on pty stdout", req.RemoteAddr) - break; - } - - if (poll_fds[0].Revents & unix.POLLIN) != 0 { - n, err = out.Read(buf) - if err != nil { - if !errors.Is(err, io.EOF) { - s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error()) - } - break - } - if n > 0 { - err = pty.send_ws_data(ws, "iov", string(buf[:n])) - if err != nil { - s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) - break - } - } - } - } - s.stats.pty_sessions.Add(-1) -*/ - } - }() - ws_recv_loop: for { var msg []byte @@ -394,57 +254,37 @@ ws_recv_loop: //username = string(ev.Data[0]) //password = string(ev.Data[1]) - wg.Add(1) - go func() { - var err error - - defer wg.Done() - - rp, err = cts.StartRpty(ws) - // cmd, tty, err = pty.connect_pty(username, password) + rp, err = cts.StartRpty(ws) + if err != nil { + s.log.Write(rpty.Id, LOG_ERROR, "[%s] Failed to connect pty - %s", req.RemoteAddr, err.Error()) + send_ws_data_for_xterm(ws, "error", err.Error()) + ws.Close() // dirty way to flag out the error by making websocket.Message.Receive() fail + } else { + err = send_ws_data_for_xterm(ws, "status", "opened") if err != nil { - s.log.Write(rpty.Id, LOG_ERROR, "[%s] Failed to connect pty - %s", req.RemoteAddr, err.Error()) - rpty.send_ws_data(ws, "error", err.Error()) + s.log.Write(rpty.Id, LOG_ERROR, "[%s] Failed to write 'opened' event to websocket - %s", req.RemoteAddr, err.Error()) ws.Close() // dirty way to flag out the error } else { - err = rpty.send_ws_data(ws, "status", "opened") - if err != nil { - s.log.Write(rpty.Id, LOG_ERROR, "[%s] Failed to write 'opened' event to websocket - %s", req.RemoteAddr, err.Error()) - ws.Close() // dirty way to flag out the error - } else { - s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Opened pty session", req.RemoteAddr) - // out = tty - // in = tty - conn_ready_chan <- true - } + s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Opened pty session", req.RemoteAddr) } - }() - } - - case "close": - if tty != nil { - // cts.StopRpty() - tty.Close() - tty = nil - } - break ws_recv_loop - - case "iov": - if tty != nil { - var i int - for i, _ = range ev.Data { - in.Write([]byte(ev.Data[i])) } } + case "close": + // just break out of the loop and let the remainder to close resources + break ws_recv_loop + + case "iov": + var i int + for i, _ = range ev.Data { + cts.WriteRpty(ws, []byte(ev.Data[i])) + // ignore error for now + } + case "size": - if tty != nil && len(ev.Data) == 2 { - var rows int - var cols int - rows, _ = strconv.Atoi(ev.Data[0]) - cols, _ = strconv.Atoi(ev.Data[1]) - pts.Setsize(tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)}) - s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Resized terminal to %d,%d", req.RemoteAddr, rows, cols) + if len(ev.Data) == 2 { + cts.WriteRptySize(ws, []byte(fmt.Sprintf("%s %s", ev.Data[0], ev.Data[1]))) + s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Requested to resize rpty terminal to %s,%s", req.RemoteAddr, ev.Data[0], ev.Data[1]) // ignore error } } @@ -452,24 +292,10 @@ ws_recv_loop: } } - if tty != nil { - //err = pty.send_ws_data(ws, "status", "closed") - /* - err = s.SendRpty() - if err != nil { goto done } - */ - } - done: - conn_ready_chan <- false - ws.Close() - if cmd != nil { - // kill the child process underneath to close ptym(the master pty). - //cmd.Process.Signal(syscall.SIGTERM) - cmd.Process.Kill() - } - if tty != nil { tty.Close() } - if cmd != nil { cmd.Wait() } + cts.StopRpty(ws) + ws.Close() // don't care about multiple closes + wg.Wait() s.log.Write(rpty.Id, LOG_DEBUG, "[%s] Ended rpty session for %s", req.RemoteAddr, token) @@ -511,7 +337,7 @@ func (pty *server_pty_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Req status_code = WriteHtmlRespHeader(w, http.StatusOK) tmpl.Execute(w, &xterm_session_info{ - Mode: "pty", + Mode: pty.mode, ConnId: "-1", RouteId: "-1", }) diff --git a/server-pxy.go b/server-pxy.go index 34fca19..ce86806 100644 --- a/server-pxy.go +++ b/server-pxy.go @@ -562,15 +562,6 @@ func (pxy *server_pxy_ssh_ws) Identity() string { // TODO: put this task to sync group. // TODO: put the above proxy task to sync group too. -func (pxy *server_pxy_ssh_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { - var msg []byte - var err error - - msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) - if err == nil { err = websocket.Message.Send(ws, msg) } - return err -} - func (pxy *server_pxy_ssh_ws) connect_ssh (ctx context.Context, username string, password string, r *ServerRoute) (*ssh.Client, *ssh.Session, io.Writer, io.Reader, error) { var cc *ssh.ClientConfig var addr *net.TCPAddr @@ -673,7 +664,7 @@ func (pxy *server_pxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { var pi *ServerRouteProxyInfo pi, err = s.wpx_foreign_port_proxy_maker("ssh", conn_id) if err != nil { - pxy.send_ws_data(ws, "error", err.Error()) + send_ws_data_for_xterm(ws, "error", err.Error()) goto done } @@ -685,7 +676,7 @@ func (pxy *server_pxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { r = proxy_info_to_server_route(pi) } if err != nil { - pxy.send_ws_data(ws, "error", err.Error()) + send_ws_data_for_xterm(ws, "error", err.Error()) goto done } @@ -713,7 +704,7 @@ func (pxy *server_pxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { break } if n > 0 { - err = pxy.send_ws_data(ws, "iov", string(buf[:n])) + err = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) if err != nil { s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) break @@ -753,10 +744,10 @@ ws_recv_loop: c, sess, in, out, err = pxy.connect_ssh(connect_ssh_ctx, username, password, r) if err != nil { s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to connect ssh - %s", req.RemoteAddr, err.Error()) - pxy.send_ws_data(ws, "error", err.Error()) + send_ws_data_for_xterm(ws, "error", err.Error()) ws.Close() // dirty way to flag out the error } else { - err = pxy.send_ws_data(ws, "status", "opened") + err = send_ws_data_for_xterm(ws, "status", "opened") if err != nil { s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to write opened event to websocket - %s", req.RemoteAddr, err.Error()) ws.Close() // dirty way to flag out the error @@ -800,7 +791,7 @@ ws_recv_loop: } if sess != nil { - err = pxy.send_ws_data(ws, "status", "closed") + err = send_ws_data_for_xterm(ws, "status", "closed") if err != nil { goto done } } diff --git a/server.go b/server.go index 4dbecad..e200ea3 100644 --- a/server.go +++ b/server.go @@ -42,6 +42,9 @@ type ServerRouteMap map[RouteId]*ServerRoute type ServerPeerConnMap map[PeerId]*ServerPeerConn type ServerSvcPortMap map[PortId]ConnRouteId +type ServerRptyMap map[uint64]*ServerRpty +type ServerRptyMapByWs map[*websocket.Conn]*ServerRpty + type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error) @@ -185,14 +188,14 @@ type ServerConn struct { pts_mtx sync.Mutex pts_list *list.List - rpty_next_id uint64 + rpty_next_id uint64 rpty_mtx sync.Mutex - rpty_map map[uint64]*ServerRpty - rpty_map_by_ws map[*websocket.Conn]*ServerRpty + rpty_map ServerRptyMap + rpty_map_by_ws ServerRptyMapByWs - wg sync.WaitGroup - stop_req atomic.Bool - stop_chan chan bool + wg sync.WaitGroup + stop_req atomic.Bool + stop_chan chan bool } type ServerRoute struct { @@ -462,11 +465,6 @@ func (r *ServerRoute) ReportPacket(pts_id PeerId, packet_type PACKET_KIND, event // ------------------------------------ -func (rpty *ServerRpty) Stop() { -} - -// ------------------------------------ - func (cts *ServerConn) make_route_listener(id RouteId, option RouteOption, svc_requested_addr string) (*net.TCPListener, *net.TCPAddr, error) { var l *net.TCPListener var svcaddr *net.TCPAddr @@ -651,7 +649,6 @@ func (cts *ServerConn) ReqStopAllServerRoutes() { cts.route_mtx.Unlock() } - func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) { var ok bool var start_id uint64 @@ -701,10 +698,123 @@ func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) { return nil , err } -// TODO: send request... return rpty, nil } +func (cts *ServerConn) StopRpty(ws *websocket.Conn) error { + + // called by the websocket handler. + var rpty *ServerRpty + var id uint64 + var ok bool + var err error + + cts.rpty_mtx.Lock() + rpty, ok = cts.rpty_map_by_ws[ws] + if !ok { + return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr()) + } + + id = rpty.id + cts.rpty_mtx.Unlock() + + // send the stop request to the client side + err = cts.pss.Send(MakeRptyStopPacket(id, "")) + if err != nil { + return fmt.Errorf("unable to send stop rpty request to client - %s", err.Error()) + } + + return nil +} + +func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error { + // called this when the stop requested comes from the client + // abort the websocket side. + + var rpty *ServerRpty + var ok bool + + cts.rpty_mtx.Lock() + rpty, ok = cts.rpty_map[id] + if !ok { + return fmt.Errorf("unknown rpty id %d", id) + } + rpty.ws.Close() + cts.rpty_mtx.Unlock() + + cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for %s - %s", id, cts.RemoteAddr, msg) + return nil +} + +func (cts *ServerConn) WriteRpty(ws *websocket.Conn, data []byte) error { + var rpty *ServerRpty + var id uint64 + var ok bool + var err error + + cts.rpty_mtx.Lock() + rpty, ok = cts.rpty_map_by_ws[ws] + if !ok { + return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr()) + } + + id = rpty.id + cts.rpty_mtx.Unlock() + + err = cts.pss.Send(MakeRptyDataPacket(id, data)) + if err != nil { + return fmt.Errorf("unable to send rpty data to client - %s", err.Error()) + } + + return nil +} + +func (cts *ServerConn) WriteRptySize(ws *websocket.Conn, data []byte) error { + var rpty *ServerRpty + var id uint64 + var ok bool + var err error + + cts.rpty_mtx.Lock() + rpty, ok = cts.rpty_map_by_ws[ws] + if !ok { + return fmt.Errorf("unknown ws connection for rpty size - %v", ws.RemoteAddr()) + } + + id = rpty.id + cts.rpty_mtx.Unlock() + + err = cts.pss.Send(MakeRptySizePacket(id, data)) + if err != nil { + return fmt.Errorf("unable to send rpty size to client - %s", err.Error()) + } + + return nil +} + +func (cts *ServerConn) ReadRptyAndWriteWs(id uint64, data []byte) error { + var ok bool + var rpty *ServerRpty + var err error + + cts.rpty_mtx.Lock() + rpty, ok = cts.rpty_map[id] + if !ok { + cts.rpty_mtx.Unlock() + return fmt.Errorf("unknown rpty id - %d", id) + } + + err = send_ws_data_for_xterm(rpty.ws, "iov", string(data)) + if err != nil { + cts.rpty_mtx.Unlock() + return fmt.Errorf("failed to write rpty data(%d) to ws - %s", id, err.Error()) + } + + cts.rpty_mtx.Unlock() + return nil +} + + func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error { var r *ServerRoute var ok bool @@ -720,6 +830,20 @@ func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type return r.ReportPacket(pts_id, packet_type, event_data) } +func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error { + switch packet_type { + case PACKET_KIND_RPTY_STOP: + // stop requested from the server + return cts.StopRptyWsById(evt.Id, string(evt.Data)) + + case PACKET_KIND_RPTY_DATA: + return cts.ReadRptyAndWriteWs(evt.Id, evt.Data) + } + + // ignore other packet types + return nil +} + func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { var pkt *Packet var err error @@ -754,12 +878,12 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { err = cts.pss.Send(MakeRouteStoppedPacket(RouteId(x.Route.RouteId), RouteOption(x.Route.ServiceOption), x.Route.TargetAddrStr, x.Route.TargetName, x.Route.ServiceAddrStr, x.Route.ServiceNetStr)) if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to send route_stopped event(%d,%s,%v,%s) to client %s - %s", + "Failed to send ROUTE_STOPPED event(%d,%s,%v,%s) to client %s - %s", x.Route.RouteId, x.Route.TargetAddrStr, x.Route.ServiceOption, x.Route.ServiceNetStr, cts.RemoteAddr, err.Error()) goto done } else { cts.S.log.Write(cts.Sid, LOG_DEBUG, - "Sent route_stopped event(%d,%s,%v,%s) to client %s", + "Sent ROUTE_STOPPED event(%d,%s,%v,%s) to client %s", x.Route.RouteId, x.Route.TargetAddrStr, x.Route.ServiceOption, x.Route.ServiceNetStr, cts.RemoteAddr) } @@ -771,7 +895,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { if err != nil { r.ReqStop() cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to send route_started event(%d,%s,%s,%s%v,%v) to client %s - %s", + "Failed to send ROUTE_STARTED event(%d,%s,%s,%s%v,%v) to client %s - %s", r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet, cts.RemoteAddr, err.Error()) goto done } @@ -801,7 +925,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { if err != nil { r.ReqStop() cts.S.log.Write(cts.Sid, LOG_ERROR, - "Failed to send route_stopped event(%d,%s,%s,%v.%v) to client %s - %s", + "Failed to send ROUTE_STOPPED event(%d,%s,%s,%v.%v) to client %s - %s", r.Id, r.PtcAddr, r.SvcAddr.String(), r.SvcOption, r.SvcPermNet.String(), cts.RemoteAddr, err.Error()) goto done } @@ -862,13 +986,13 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { x, ok = pkt.U.(*Packet_Conn) if ok { if x.Conn.Token == "" { - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - blank token", cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s - blank token", pkt.Kind.String(), cts.RemoteAddr) cts.pss.Send(MakeConnErrorPacket(1, "blank token refused")) cts.ReqStop() // TODO: is this desirable to disconnect? } else if x.Conn.Token != cts.ClientToken.Get() { _, err = strconv.ParseUint(x.Conn.Token, 10, int(unsafe.Sizeof(ConnId(0)) * 8)) if err == nil { // this is not != nil. this is to check if the token is numeric - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - numeric token '%s'", cts.RemoteAddr, x.Conn.Token) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s - numeric token '%s'", pkt.Kind.String(), cts.RemoteAddr, x.Conn.Token) cts.pss.Send(MakeConnErrorPacket(1, "numeric token refused")) cts.ReqStop() // TODO: is this desirable to disconnect? } else { @@ -877,7 +1001,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { if ok { // error cts.S.cts_mtx.Unlock() - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s - duplicate token '%s'", cts.RemoteAddr, x.Conn.Token) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s - duplicate token '%s'", pkt.Kind.String(), cts.RemoteAddr, x.Conn.Token) cts.pss.Send(MakeConnErrorPacket(1, fmt.Sprintf("duplicate token refused - %s", x.Conn.Token))) cts.ReqStop() // TODO: is this desirable to disconnect? } else { @@ -892,7 +1016,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { } } } else { - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_desc packet from %s", cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) } case PACKET_KIND_CONN_NOTICE: @@ -908,27 +1032,42 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { } } } else { - cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid conn_notice packet from %s", cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) } - /*case PACKET_KIND_RPTY_START: - case PACKET_KIND_RPTY_STOP:*/ - case PACKET_KIND_RPTY_STARTED: - fallthrough - case PACKET_KIND_RPTY_STOPPED: - fallthrough - case PACKET_KIND_RPTY_ABORTED: - fallthrough - case PACKET_KIND_RPTY_EOF: + //case PACKET_KIND_RPTY_START: stop is never sent by the client to the server + case PACKET_KIND_RPTY_STOP: fallthrough case PACKET_KIND_RPTY_DATA: - // inspect the token - // find the right websocket handler... - // report it to the right websocket handler + var x *Packet_RptyEvt + var ok bool + x, ok = pkt.U.(*Packet_RptyEvt) + if ok { + err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt) + if err != nil { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpty(%d) from %s - %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr, err.Error()) + } else { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpty(%d) from %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr) + } + + } else { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) + } } } done: + + // arrange to break all rpty resources + cts.rpty_mtx.Lock() + if len(cts.rpty_map) > 0 { + var rpty *ServerRpty + for _, rpty = range cts.rpty_map { + rpty.ws.Close() + } + } + cts.rpty_mtx.Unlock() + cts.S.log.Write(cts.Sid, LOG_INFO, "RPC stream receiver ended") } @@ -1443,16 +1582,32 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.ctl_mux.Handle("/_pty/xterm.css/", s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) s.ctl_mux.Handle("/_pty/xterm.html", - s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.html"})) + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.html", mode: "pty"})) s.ctl_mux.Handle("/_pty/xterm.html/", s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) s.ctl_mux.Handle("/_pty/", s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_redir:xterm.html"})) -/* - s.ctl_mux.Handle("/_rpts/ws", - s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_rpts_ws{S: &s, Id: HS_ID_CTL}))) -*/ + s.ctl_mux.Handle("/_rpty/ws", + s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_rpty_ws{S: &s, Id: HS_ID_CTL}))) + s.ctl_mux.Handle("/_rpty/xterm.js", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.js"})) + s.ctl_mux.Handle("/_rpty/xterm.js/", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) + s.ctl_mux.Handle("/_rpty/xterm-addon-fit.js", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm-addon-fit.js"})) + s.ctl_mux.Handle("/_rpty/xterm-addon-fit.js/", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) + s.ctl_mux.Handle("/_rpty/xterm.css", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.css"})) + s.ctl_mux.Handle("/_rpty/xterm.css/", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) + s.ctl_mux.Handle("/_rpty/xterm.html", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.html", mode: "rpty"})) + s.ctl_mux.Handle("/_rpty/xterm.html/", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) + s.ctl_mux.Handle("/_rpty/", + s.WrapHttpHandler(&server_pty_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_redir:xterm.html"})) s.ctl = make([]*http.Server, len(cfg.CtlAddrs)) for i = 0; i < len(cfg.CtlAddrs); i++ { @@ -1898,6 +2053,9 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p cts.stop_chan = make(chan bool, 8) cts.pts_list = list.New() + cts.rpty_map = make(ServerRptyMap) + cts.rpty_map_by_ws = make(ServerRptyMapByWs) + s.cts_mtx.Lock() if s.cts_limit > 0 && len(s.cts_map) >= s.cts_limit { diff --git a/transform.go b/transform.go index fda4a5a..36a3309 100644 --- a/transform.go +++ b/transform.go @@ -68,7 +68,7 @@ func (t *Transformer) Transform(dst []byte, src []byte, at_eof bool) (int, int, err = transform.ErrShortSrc done: - return ndst, nsrc, err + return ndst, nsrc, err } func (t *Transformer) copy_all(dst []byte, src []byte) (int, error) { diff --git a/xterm.html b/xterm.html index e7421cf..5c65766 100644 --- a/xterm.html +++ b/xterm.html @@ -109,6 +109,7 @@ window.onload = function(event) { const login_pty_part = document.getElementById('login-pty-part'); const username_field = document.getElementById('username'); const password_field= document.getElementById('password'); + const qparams = new URLSearchParams(window.location.search); if (xt_mode == 'ssh') { login_ssh_part.style.display = 'block'; @@ -208,6 +209,12 @@ window.onload = function(event) { pathname = pathname.substring(0, pathname.lastIndexOf('/')); let url = prefix + window.location.host + pathname + '/ws'; + if (xt_mode == 'rpty') { + // when accessing rpty, the server requires a client token + let client_token = qparams.get('client-token'); + if (client_token != null && client_token != '') url += '?client-token=' + client_token; + } + const socket = new WebSocket(url); socket.binaryType = 'arraybuffer';