From 8331fdc1a22f129ebd1cbb9d8a791de0fc6f60ca Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Mon, 23 Jun 2025 21:09:24 +0900 Subject: [PATCH] implemented the pts feature in the server side as well --- Makefile | 1 + client-metrics.go | 1 + client-pts.go | 11 +- cmd/config.go | 2 + cmd/main.go | 7 + hodu.go | 5 + server-ctl.go | 2 + server-metrics.go | 13 ++ server-pts.go | 324 ++++++++++++++++++++++++++++++++++++++++++++++ server-pxy.go | 9 +- server.go | 45 ++++++- 11 files changed, 403 insertions(+), 17 deletions(-) create mode 100644 server-pts.go diff --git a/Makefile b/Makefile index e299718..093c686 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ SRCS=\ server-ctl.go \ server-metrics.go \ server-peer.go \ + server-pts.go \ server-pxy.go \ system.go \ transform.go \ diff --git a/client-metrics.go b/client-metrics.go index b2e4ec3..f867b78 100644 --- a/client-metrics.go +++ b/client-metrics.go @@ -60,6 +60,7 @@ func (c ClientCollector) Describe(ch chan<- *prometheus.Desc) { ch <- c.ClientConns ch <- c.ClientRoutes ch <- c.ClientPeers + ch <- c.PtsSessions } func (c ClientCollector) Collect(ch chan<- prometheus.Metric) { diff --git a/client-pts.go b/client-pts.go index ca3186d..e39851e 100644 --- a/client-pts.go +++ b/client-pts.go @@ -28,13 +28,6 @@ type client_pts_xterm_file struct { file string } -/* -type json_ssh_ws_event struct { - Type string `json:"type"` - Data []string `json:"data"` -} -*/ - // ------------------------------------------------------ func (pts *client_pts_ws) Identity() string { @@ -45,7 +38,7 @@ func (pts *client_pts_ws) send_ws_data(ws *websocket.Conn, type_val string, data var msg []byte var err error - msg, err = json.Marshal(json_ssh_ws_event{Type: type_val, Data: []string{ data } }) + msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) if err == nil { err = websocket.Message.Send(ws, msg) } return err } @@ -186,7 +179,7 @@ ws_recv_loop: if err != nil { goto done } if len(msg) > 0 { - var ev json_ssh_ws_event + var ev json_xterm_ws_event err = json.Unmarshal(msg, &ev) if err == nil { switch ev.Type { diff --git a/cmd/config.go b/cmd/config.go index 94019b0..557779b 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -94,6 +94,8 @@ type ServerAppConfig struct { LogRotate int `yaml:"log-rotate"` MaxPeers int `yaml:"max-peer-conns"` // maximum number of connections from peers MaxRpcConns int `yaml:"max-rpc-conns"` // maximum number of rpc connections + PtsUser string `yaml:"pts-user"` + PtsShell string `yaml:"pts-shell"` XtermHtmlFile string `yaml:"xterm-html-file"` } diff --git a/cmd/main.go b/cmd/main.go index cbe7730..773e06e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -97,6 +97,8 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx var logmask hodu.LogMask var logfile_maxsize int64 var logfile_rotate int + var pts_user string + var pts_shell string var xterm_html_file string var xterm_html string var err error @@ -132,6 +134,9 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx config.CtlPrefix = cfg.CTL.Service.Prefix config.RpcMaxConns = cfg.APP.MaxRpcConns config.MaxPeers = cfg.APP.MaxPeers + + pts_user = cfg.APP.PtsUser + pts_shell = cfg.APP.PtsShell xterm_html_file = cfg.APP.XtermHtmlFile logmask = log_strings_to_mask(cfg.APP.LogMask) @@ -167,6 +172,8 @@ func server_main(ctl_addrs []string, rpc_addrs []string, pxy_addrs []string, wpx return fmt.Errorf("failed to create server - %s", err.Error()) } + if pts_user != "" { s.SetPtsUser(pts_user) } + if pts_shell != "" { s.SetPtsShell(pts_shell) } if xterm_html != "" { s.SetXtermHtml(xterm_html) } s.StartService(nil) diff --git a/hodu.go b/hodu.go index eec3c60..5fdd4f7 100644 --- a/hodu.go +++ b/hodu.go @@ -118,6 +118,11 @@ type json_out_go_stats struct { } +type json_xterm_ws_event struct { + Type string `json:"type"` + Data []string `json:"data"` +} + // --------------------------------------------------------- //go:embed xterm.js diff --git a/server-ctl.go b/server-ctl.go index 3f0927b..c81fb81 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -76,6 +76,7 @@ type json_out_server_stats struct { ServerPeers int64 `json:"server-peers"` SshProxySessions int64 `json:"pxy-ssh-sessions"` + ServerPtsSessions int64 `json:"server-pts-sessions"` } // this is a more specialized variant of json_in_notice @@ -911,6 +912,7 @@ func (ctl *server_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request) stats.ServerRoutes = s.stats.routes.Load() stats.ServerPeers = s.stats.peers.Load() stats.SshProxySessions = s.stats.ssh_proxy_sessions.Load() + stats.ServerPtsSessions = s.stats.pts_sessions.Load() status_code = WriteJsonRespHeader(w, http.StatusOK) if err = je.Encode(stats); err != nil { goto oops } diff --git a/server-metrics.go b/server-metrics.go index 851a1b3..bba4b95 100644 --- a/server-metrics.go +++ b/server-metrics.go @@ -11,6 +11,7 @@ type ServerCollector struct { ServerRoutes *prometheus.Desc ServerPeers *prometheus.Desc SshProxySessions *prometheus.Desc + PtsSessions *prometheus.Desc } // NewServerCollector returns a new ServerCollector with all prometheus.Desc initialized @@ -52,6 +53,11 @@ func NewServerCollector(server *Server) ServerCollector { "Number of SSH proxy sessions", nil, nil, ), + PtsSessions: prometheus.NewDesc( + prefix + "pts_sessions", + "Number of pts session", + nil, nil, + ), } } @@ -61,6 +67,7 @@ func (c ServerCollector) Describe(ch chan<- *prometheus.Desc) { ch <- c.ServerRoutes ch <- c.ServerPeers ch <- c.SshProxySessions + ch <- c.PtsSessions } func (c ServerCollector) Collect(ch chan<- prometheus.Metric) { @@ -97,4 +104,10 @@ func (c ServerCollector) Collect(ch chan<- prometheus.Metric) { prometheus.GaugeValue, float64(c.server.stats.ssh_proxy_sessions.Load()), ) + + ch <- prometheus.MustNewConstMetric( + c.PtsSessions, + prometheus.GaugeValue, + float64(c.server.stats.pts_sessions.Load()), + ) } diff --git a/server-pts.go b/server-pts.go new file mode 100644 index 0000000..6187138 --- /dev/null +++ b/server-pts.go @@ -0,0 +1,324 @@ +package hodu + +import "encoding/json" +import "errors" +import "fmt" +import "io" +import "net/http" +import "os" +import "os/exec" +import "os/user" +import "strconv" +import "sync" +import "syscall" +import "text/template" + +import "github.com/creack/pty" +import "golang.org/x/net/websocket" +import "golang.org/x/sys/unix" + +type server_pts_ws struct { + S *Server + Id string + ws *websocket.Conn +} + +type server_pts_xterm_file struct { + ServerCtl + file string +} + +// ------------------------------------------------------ + +func (pts *server_pts_ws) Identity() string { + return pts.Id +} + +func (pts *server_pts_ws) send_ws_data(ws *websocket.Conn, type_val string, data string) error { + var msg []byte + var err error + + msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) + if err == nil { err = websocket.Message.Send(ws, msg) } + return err +} + + +func (pts *server_pts_ws) connect_pts(username string, password string) (*exec.Cmd, *os.File, error) { + var s *Server + var cmd *exec.Cmd + var tty *os.File + var err error + + // username and password are not used yet. + s = pts.S + + if s.pts_shell == "" { + return nil, nil, fmt.Errorf("blank pts shell") + } + + cmd = exec.Command(s.pts_shell); + if s.pts_user != "" { + var uid int + var gid int + var u *user.User + + u, err = user.Lookup(s.pts_user) + if err != nil { return nil, nil, err } + + uid, _ = strconv.Atoi(u.Uid) + gid, _ = strconv.Atoi(u.Gid) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(uid), + Gid: uint32(gid), + }, + Setsid: true, + } + cmd.Dir = u.HomeDir + cmd.Env = append(cmd.Env, + "HOME=" + u.HomeDir, + "LOGNAME=" + u.Username, + "PATH=" + os.Getenv("PATH"), + "SHELL=" + s.pts_shell, + "TERM=xterm", + "USER=" + u.Username, + ) + } + + tty, err = pty.Start(cmd) + if err != nil { + return nil, nil, err + } + + //syscall.SetNonblock(int(tty.Fd()), true); + unix.SetNonblock(int(tty.Fd()), true); + + return cmd, tty, nil +} + +func (pts *server_pts_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { + var s *Server + var req *http.Request + var username string + var password string + var in *os.File + var out *os.File + var tty *os.File + var cmd *exec.Cmd + var wg sync.WaitGroup + var conn_ready_chan chan bool + var err error + + s = pts.S + req = ws.Request() + conn_ready_chan = make(chan bool, 3) + + wg.Add(1) + go func() { + var conn_ready bool + + defer wg.Done() + defer ws.Close() // dirty way to break the main loop + + conn_ready = <-conn_ready_chan + if conn_ready { // connected + var poll_fds []unix.PollFd; + var buf []byte + var n int + var err error + + + poll_fds = []unix.PollFd{ + unix.PollFd{Fd: int32(out.Fd()), Events: unix.POLLIN}, + } + + s.stats.pts_sessions.Add(1) + buf = make([]byte, 2048) + for { + n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely + if err != nil { + if errors.Is(err, unix.EINTR) { continue } + s.log.Write("", LOG_ERROR, "[%s] Failed to poll pts stdout - %s", req.RemoteAddr, err.Error()) + break + } + if n == 0 { // timed out + continue + } + + if (poll_fds[0].Revents & (unix.POLLERR | unix.POLLHUP | unix.POLLNVAL)) != 0 { + s.log.Write(pts.Id, LOG_DEBUG, "[%s] EOF detected on pts stdout", req.RemoteAddr) + break; + } + + if (poll_fds[0].Revents & unix.POLLIN) != 0 { + n, err = out.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) { + s.log.Write(pts.Id, LOG_ERROR, "[%s] Failed to read pts stdout - %s", req.RemoteAddr, err.Error()) + } + break + } + if n > 0 { + err = pts.send_ws_data(ws, "iov", string(buf[:n])) + if err != nil { + s.log.Write(pts.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) + break + } + } + } + } + s.stats.pts_sessions.Add(-1) + } + }() + +ws_recv_loop: + for { + var msg []byte + err = websocket.Message.Receive(ws, &msg) + if err != nil { goto done } + + if len(msg) > 0 { + var ev json_xterm_ws_event + err = json.Unmarshal(msg, &ev) + if err == nil { + switch ev.Type { + case "open": + if tty == nil && len(ev.Data) == 2 { + username = string(ev.Data[0]) + password = string(ev.Data[1]) + + wg.Add(1) + go func() { + var err error + + defer wg.Done() + cmd, tty, err = pts.connect_pts(username, password) + if err != nil { + s.log.Write(pts.Id, LOG_ERROR, "[%s] Failed to connect pts - %s", req.RemoteAddr, err.Error()) + pts.send_ws_data(ws, "error", err.Error()) + ws.Close() // dirty way to flag out the error + } else { + err = pts.send_ws_data(ws, "status", "opened") + if err != nil { + s.log.Write(pts.Id, LOG_ERROR, "[%s] Failed to write opened event to websocket - %s", req.RemoteAddr, err.Error()) + ws.Close() // dirty way to flag out the error + } else { + s.log.Write(pts.Id, LOG_DEBUG, "[%s] Opened pts session", req.RemoteAddr) + out = tty + in = tty + conn_ready_chan <- true + } + } + }() + } + + case "close": + if tty != nil { + tty.Close() + tty = nil + } + break ws_recv_loop + + case "iov": + if tty != nil { + var i int + for i, _ = range ev.Data { + in.Write([]byte(ev.Data[i])) + } + } + + case "size": + if tty != nil && len(ev.Data) == 2 { + var rows int + var cols int + rows, _ = strconv.Atoi(ev.Data[0]) + cols, _ = strconv.Atoi(ev.Data[1]) + pty.Setsize(tty, &pty.Winsize{Rows: uint16(rows), Cols: uint16(cols)}) + s.log.Write(pts.Id, LOG_DEBUG, "[%s] Resized terminal to %d,%d", req.RemoteAddr, rows, cols) + // ignore error + } + } + } + } + } + + if tty != nil { + err = pts.send_ws_data(ws, "status", "closed") + if err != nil { goto done } + } + +done: + conn_ready_chan <- false + ws.Close() + if cmd != nil { + // kill the child process underneath to close ptym(the master pty). + //cmd.Process.Signal(syscall.SIGTERM) + cmd.Process.Kill() + } + if tty != nil { tty.Close() } + if cmd != nil { cmd.Wait() } + wg.Wait() + s.log.Write(pts.Id, LOG_DEBUG, "[%s] Ended pts session", req.RemoteAddr) + + return http.StatusOK, err +} + + +// ------------------------------------------------------ + +func (pts *server_pts_xterm_file) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { + var s *Server + var status_code int + var err error + + s = pts.S + + switch pts.file { + case "xterm.js": + status_code = WriteJsRespHeader(w, http.StatusOK) + w.Write(xterm_js) + case "xterm-addon-fit.js": + status_code = WriteJsRespHeader(w, http.StatusOK) + w.Write(xterm_addon_fit_js) + case "xterm.css": + status_code = WriteCssRespHeader(w, http.StatusOK) + w.Write(xterm_css) + case "xterm.html": + var tmpl *template.Template + + tmpl = template.New("") + if s.xterm_html != "" { + _, err = tmpl.Parse(s.xterm_html) + } else { + _, err = tmpl.Parse(xterm_html) + } + if err != nil { + status_code = WriteEmptyRespHeader(w, http.StatusInternalServerError) + goto oops + } else { + status_code = WriteHtmlRespHeader(w, http.StatusOK) + tmpl.Execute(w, + &xterm_session_info{ + Mode: "pts", + ConnId: "-1", + RouteId: "-1", + }) + } + + case "_forbidden": + status_code = WriteEmptyRespHeader(w, http.StatusForbidden) + + case "_notfound": + status_code = WriteEmptyRespHeader(w, http.StatusNotFound) + + default: + status_code = WriteEmptyRespHeader(w, http.StatusNotFound) + } + +//done: + return status_code, nil + +oops: + return status_code, err +} diff --git a/server-pxy.go b/server-pxy.go index ac2c2c8..798a293 100644 --- a/server-pxy.go +++ b/server-pxy.go @@ -550,11 +550,6 @@ type server_pxy_ssh_ws struct { Id string } -type json_ssh_ws_event struct { - Type string `json:"type"` - Data []string `json:"data"` -} - func (pxy *server_pxy_ssh_ws) Identity() string { return pxy.Id } @@ -566,7 +561,7 @@ func (pxy *server_pxy_ssh_ws) send_ws_data(ws *websocket.Conn, type_val string, var msg []byte var err error - msg, err = json.Marshal(json_ssh_ws_event{Type: type_val, Data: []string{ data } }) + msg, err = json.Marshal(json_xterm_ws_event{Type: type_val, Data: []string{ data } }) if err == nil { err = websocket.Message.Send(ws, msg) } return err } @@ -723,7 +718,7 @@ ws_recv_loop: if err != nil { goto done } if len(msg) > 0 { - var ev json_ssh_ws_event + var ev json_xterm_ws_event err = json.Unmarshal(msg, &ev) if err == nil { switch ev.Type { diff --git a/server.go b/server.go index 9bd8883..dd18b35 100644 --- a/server.go +++ b/server.go @@ -154,11 +154,16 @@ type Server struct { routes atomic.Int64 peers atomic.Int64 ssh_proxy_sessions atomic.Int64 + pts_sessions atomic.Int64 } wpx_resp_tf ServerWpxResponseTransformer wpx_foreign_port_proxy_maker ServerWpxForeignPortProxyMaker - xterm_html string + + + pts_user string + pts_shell string + xterm_html string } // connection from client. @@ -1376,6 +1381,27 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.ctl_mux.Handle("/_ctl/events", s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_ctl_ws{ServerCtl{S: &s, Id: HS_ID_CTL}}))) + s.ctl_mux.Handle("/_pts/ws", + s.SafeWrapWebsocketHandler(s.WrapWebsocketHandler(&server_pts_ws{S: &s, Id: HS_ID_CTL}))) + + s.ctl_mux.Handle("/_pts/xterm.js", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.js"})) + s.ctl_mux.Handle("/_pts/xterm.js.map", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_notfound"})) + s.ctl_mux.Handle("/_pts/xterm-addon-fit.js", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm-addon-fit.js"})) + s.ctl_mux.Handle("/_pts/xterm-addon-fit.js.map", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_notfound"})) + s.ctl_mux.Handle("/_pts/xterm.css", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.css"})) + s.ctl_mux.Handle("/_pts/xterm.html", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "xterm.html"})) + s.ctl_mux.Handle("/_pts/", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) + s.ctl_mux.Handle("/_pts/favicon.ico", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) + s.ctl_mux.Handle("/_pts/favicon.ico/", + s.WrapHttpHandler(&server_pts_xterm_file{ServerCtl: ServerCtl{S: &s, Id: HS_ID_CTL}, file: "_forbidden"})) /* // this part is duplcate of pxy_mux. s.ctl_mux.Handle("/_ssh/ws/{conn_id}/{route_id}", @@ -1503,6 +1529,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.stats.routes.Store(0) s.stats.peers.Store(0) s.stats.ssh_proxy_sessions.Store(0) + s.stats.pts_sessions.Store(0) return &s, nil @@ -1538,6 +1565,22 @@ func (s *Server) GetXtermHtml() string { return s.xterm_html } +func (s *Server) SetPtsUser(user string) { + s.pts_user = user +} + +func (s *Server) GetPtsUser() string { + return s.pts_user +} + +func (s *Server) SetPtsShell(user string) { + s.pts_shell = user +} + +func (s *Server) GetPtsShell() string { + return s.pts_shell +} + func (s *Server) run_grpc_server(idx int, wg *sync.WaitGroup) error { var l *net.TCPListener var err error