diff --git a/client-ctl.go b/client-ctl.go index 6d0310a..f4dcf47 100644 --- a/client-ctl.go +++ b/client-ctl.go @@ -110,6 +110,7 @@ type json_out_client_stats struct { ClientRoutes int64 `json:"client-routes"` ClientPeers int64 `json:"client-peers"` ClientPtySessions int64 `json:"client-pty-sessions"` + ClientRptySessions int64 `json:"client-rpty-sessions"` } // ------------------------------------ @@ -1138,6 +1139,7 @@ func (ctl *client_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request) stats.ClientRoutes = c.stats.routes.Load() stats.ClientPeers = c.stats.peers.Load() stats.ClientPtySessions = c.stats.pty_sessions.Load() + stats.ClientRptySessions = c.stats.rpty_sessions.Load() status_code = WriteJsonRespHeader(w, http.StatusOK) if err = je.Encode(stats); err != nil { goto oops } diff --git a/client-metrics.go b/client-metrics.go index 51a1cc7..8233214 100644 --- a/client-metrics.go +++ b/client-metrics.go @@ -11,6 +11,7 @@ type ClientCollector struct { ClientRoutes *prometheus.Desc ClientPeers *prometheus.Desc PtySessions *prometheus.Desc + RptySessions *prometheus.Desc } // NewClientCollector returns a new ClientCollector with all prometheus.Desc initialized @@ -52,6 +53,11 @@ func NewClientCollector(client *Client) ClientCollector { "Number of pty sessions", nil, nil, ), + RptySessions: prometheus.NewDesc( + prefix + "rpty_sessions", + "Number of rpty sessions", + nil, nil, + ), } } @@ -61,6 +67,7 @@ func (c ClientCollector) Describe(ch chan<- *prometheus.Desc) { ch <- c.ClientRoutes ch <- c.ClientPeers ch <- c.PtySessions + ch <- c.RptySessions } func (c ClientCollector) Collect(ch chan<- prometheus.Metric) { @@ -97,4 +104,10 @@ func (c ClientCollector) Collect(ch chan<- prometheus.Metric) { prometheus.GaugeValue, float64(c.client.stats.pty_sessions.Load()), ) + + ch <- prometheus.MustNewConstMetric( + c.RptySessions, + prometheus.GaugeValue, + float64(c.client.stats.rpty_sessions.Load()), + ) } diff --git a/client-peer.go b/client-peer.go index f436ed7..6cc5eae 100644 --- a/client-peer.go +++ b/client-peer.go @@ -30,6 +30,16 @@ func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error { for { n, err = cpc.conn.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cpc.route.cts.psc.Send(MakePeerDataPacket(cpc.route.Id, cpc.conn_id, buf[0:n])) + if err2 != nil { + cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_ERROR, + "Failed to write peer(%d,%d,%s,%s) data to server - %s", + cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String(), err2.Error()) + break + } + } if err != nil { if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { // i hate checking this condition with strings.Contains() cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_INFO, @@ -42,14 +52,6 @@ func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error { } break } - - err = cpc.route.cts.psc.Send(MakePeerDataPacket(cpc.route.Id, cpc.conn_id, buf[0:n])) - if err != nil { - cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_ERROR, - "Failed to write peer(%d,%d,%s,%s) data to server - %s", - cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String(), err.Error()) - break - } } cpc.route.cts.psc.Send(MakePeerStoppedPacket(cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String())) // nothing much to do upon failure. no error check here diff --git a/client-pty.go b/client-pty.go index 280604e..9aacdf0 100644 --- a/client-pty.go +++ b/client-pty.go @@ -59,7 +59,7 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { conn_ready = <-conn_ready_chan if conn_ready { // connected var poll_fds []unix.PollFd - var buf []byte + var buf [2048]byte var n int var err error @@ -69,7 +69,6 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { } c.stats.pty_sessions.Add(1) - buf = make([]byte, 2048) for { n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely if err != nil { @@ -87,20 +86,21 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { } if (poll_fds[0].Revents & unix.POLLIN) != 0 { - n, err = out.Read(buf) + n, err = out.Read(buf[:]) + if n > 0 { + var err2 error + err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) + if err2 != nil { + c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error()) + break + } + } if err != nil { if !errors.Is(err, io.EOF) { c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error()) } break } - if n > 0 { - err = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) - if err != nil { - c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) - break - } - } } } c.stats.pty_sessions.Add(-1) diff --git a/client.go b/client.go index dfdb6b9..69b4b59 100644 --- a/client.go +++ b/client.go @@ -1,5 +1,7 @@ package hodu +import "bufio" +import "bytes" import "container/list" import "context" import "crypto/tls" @@ -40,6 +42,7 @@ type ClientRouteMap map[RouteId]*ClientRoute type ClientPeerConnMap map[PeerId]*ClientPeerConn type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc type ClientRptyMap map[uint64]*ClientRpty +type ClientRpxMap map[uint64]*ClientRpx // -------------------------------------------------------------------- type ClientRouteConfig struct { @@ -84,6 +87,10 @@ type ClientConfig struct { PeerConnTmout time.Duration Token string // to send to the server for identification + + // default target for rpx + RpxTargetAddr string + RpxTargetTls *tls.Config } type ClientEventKind int @@ -122,6 +129,10 @@ type Client struct { ext_svcs []Service rpc_tls *tls.Config + rpx_target_addr string + rpx_target_url string + rpx_target_tls *tls.Config + ctl_tls *tls.Config ctl_addr []string ctl_prefix string @@ -155,6 +166,7 @@ type Client struct { routes atomic.Int64 peers atomic.Int64 pty_sessions atomic.Int64 + rpty_sessions atomic.Int64 } pty_user string @@ -176,6 +188,15 @@ type ClientRpty struct { tty *os.File } +type ClientRpx struct { + id uint64 + pr *io.PipeReader + pw *io.PipeWriter + ctx context.Context + cancel context.CancelFunc + ws_conn net.Conn +} + // client connection to server type ClientConn struct { C *Client @@ -209,6 +230,9 @@ type ClientConn struct { rpty_mtx sync.Mutex rpty_map ClientRptyMap + rpx_mtx sync.Mutex + rpx_map ClientRpxMap + stop_req atomic.Bool stop_chan chan bool @@ -294,6 +318,22 @@ func (g *GuardedPacketStreamClient) Context() context.Context { return g.psc.Context() }*/ +// -------------------------------------------------------------------- + +func (rpty *ClientRpty) ReqStop() { + rpty.tty.Close() + rpty.cmd.Process.Kill() +} + +func (rpx *ClientRpx) ReqStop() { + rpx.pr.Close() + rpx.pw.Close() + rpx.cancel() + if rpx.ws_conn != nil { + rpx.ws_conn.SetDeadline(time.Now()) // to make Read return immediately + } +} + // -------------------------------------------------------------------- func NewClientRoute(cts *ClientConn, id RouteId, static bool, client_peer_addr string, client_peer_name string, server_peer_svc_addr string, server_peer_svc_net string, server_peer_option RouteOption, lifetime time.Duration) *ClientRoute { var r ClientRoute @@ -397,7 +437,7 @@ func (r *ClientRoute) ExtendLifetime(lifetime time.Duration) error { r.lifetime_timer.Stop() r.Lifetime = r.Lifetime + lifetime expiry = r.LifetimeStart.Add(r.Lifetime) - r.lifetime_timer.Reset(expiry.Sub(time.Now())) + r.lifetime_timer.Reset(time.Until(expiry)) // expiry.Sub(time.Now()) if r.cts.C.route_persister != nil { r.cts.C.route_persister.Save(r.cts, r) } r.lifetime_mtx.Unlock() @@ -477,10 +517,8 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) { break waiting_loop } } else { - select { - case <-r.stop_chan: - break waiting_loop - } + <-r.stop_chan + break waiting_loop } } @@ -547,7 +585,7 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts defer wg.Done() tmout = time.Duration(r.cts.C.ptc_tmout) - if tmout <= 0 { tmout = 5 * time.Second} // TODO: make this configurable... + if tmout <= 0 { tmout = 5 * time.Second } // TODO: make this configurable... waitctx, cancel_wait = context.WithTimeout(r.cts.C.Ctx, tmout) r.ptc_mtx.Lock() r.ptc_cancel_map[pts_id] = cancel_wait @@ -571,8 +609,8 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts real_conn, ok = conn.(*net.TCPConn) if !ok { r.cts.C.log.Write(r.cts.Sid, LOG_ERROR, - "Failed to get connection information to %s for route(%d,%d,%s,%s) - %s", - r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr, err.Error()) + "Failed to get connection information to %s for route(%d,%d,%s,%s)", + r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr) goto peer_aborted } @@ -825,6 +863,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn { cts.stop_chan = make(chan bool, 8) cts.ptc_list = list.New() cts.rpty_map = make(ClientRptyMap) + cts.rpx_map = make(ClientRpxMap) for i, _ = range cts.cfg.Routes { // override it to static regardless of the value passed in @@ -833,7 +872,6 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn { // the actual connection to the server is established in the main task function // The cts.conn, cts.hdc, cts.psc fields are left unassigned here. - return &cts } @@ -1031,7 +1069,8 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error { func (cts *ClientConn) disconnect_from_server(logmsg bool) { if cts.conn != nil { var r *ClientRoute - var crp *ClientRpty + var rpty *ClientRpty + var rpx *ClientRpx cts.discon_mtx.Lock() @@ -1045,15 +1084,20 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) { // arrange to clean up all rpty objects cts.rpty_mtx.Lock() - for _, crp = range cts.rpty_map { - crp.tty.Close() - crp.cmd.Process.Kill() + for _, rpty = range cts.rpty_map { + rpty.ReqStop() // the loop in ReadRptyLoop() is supposed to be broken. // let's not inform the server of this connection. // the server should clean up itself upon connection error } cts.rpty_mtx.Unlock() + cts.rpx_mtx.Lock() + for _, rpx = range cts.rpx_map { + rpx.ReqStop() + } + cts.rpx_mtx.Unlock() + // don't care about double closes when this function is called from both RunTask() and ReqStop() cts.conn.Close() @@ -1217,7 +1261,7 @@ start_over: pkt, err = psc.Recv() if err != nil { - if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) { + if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { goto reconnect_to_server } else { cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr_p, err.Error()) @@ -1348,6 +1392,31 @@ start_over: cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) } + case PACKET_KIND_RPX_START: + fallthrough + case PACKET_KIND_RPX_STOP: + fallthrough + case PACKET_KIND_RPX_DATA: + fallthrough + case PACKET_KIND_RPX_EOF: + var x *Packet_RpxEvt + var ok bool + x, ok = pkt.U.(*Packet_RpxEvt) + if ok { + err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, + "Failed to handle %s event for rpx(%d) from %s - %s", + pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p, err.Error()) + } else { + cts.C.log.Write(cts.Sid, LOG_DEBUG, + "Handled %s event for rpx(%d) from %s", + pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p) + } + } else { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p) + } + default: // do nothing. ignore the rest } @@ -1390,7 +1459,7 @@ reconnect_to_server: goto wait_for_termination case <-slpctx.Done(): select { - case <- cts.C.Ctx.Done(): + case <-cts.C.Ctx.Done(): // non-blocking check if the parent context of the sleep context is // terminated too. if so, this is normal termination case. // this check seem redundant but the go-runtime doesn't seem to guarantee @@ -1421,20 +1490,34 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type } +// rpty +func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty { + var crp *ClientRpty + var ok bool + + cts.rpty_mtx.Lock() + crp, ok = cts.rpty_map[id] + cts.rpty_mtx.Unlock() + + if !ok { crp = nil } + return crp +} + func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { var poll_fds []unix.PollFd - var buf []byte + var buf [2048]byte var n int var err error defer wg.Done() + cts.C.stats.rpty_sessions.Add(1) + poll_fds = []unix.PollFd{ unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN}, } - buf = make([]byte, 2048) for { n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely if err != nil { @@ -1452,34 +1535,35 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) { } if (poll_fds[0].Revents & unix.POLLIN) != 0 { - n, err = crp.tty.Read(buf) + n, err = crp.tty.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n])) + if err2 != nil { + cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error()) + break + } + } if err != nil { if !errors.Is(err, io.EOF) { cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error()) } - break } - if n > 0 { - err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n])) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error()) - break - } - } } } cts.psc.Send(MakeRptyStopPacket(crp.id, "")) cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id) - crp.tty.Close() // don't care about multiple closes - crp.cmd.Process.Kill() + crp.ReqStop() crp.cmd.Wait() cts.rpty_mtx.Lock() delete(cts.rpty_map, crp.id) cts.rpty_mtx.Unlock() + + cts.C.stats.rpty_sessions.Add(-1) } func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error { @@ -1515,46 +1599,34 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error { func (cts *ClientConn) StopRpty(id uint64) error { var crp *ClientRpty - var ok bool - cts.rpty_mtx.Lock() - crp, ok = cts.rpty_map[id] - if !ok { - cts.rpty_mtx.Unlock() + crp = cts.FindClientRptyById(id) + if crp == nil { return fmt.Errorf("unknown rpty id %d", id) } - crp.tty.Close() // to break ReadRptyLoop() - crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop() - cts.rpty_mtx.Unlock() + crp.ReqStop() return nil } func (cts *ClientConn) WriteRpty(id uint64, data []byte) error { var crp *ClientRpty - var ok bool - cts.rpty_mtx.Lock() - crp, ok = cts.rpty_map[id] - if !ok { - cts.rpty_mtx.Unlock() + crp = cts.FindClientRptyById(id) + if crp == nil { return fmt.Errorf("unknown rpty id %d", id) } crp.tty.Write(data) - cts.rpty_mtx.Unlock() return nil } func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error { var crp *ClientRpty - var ok bool var flds []string - cts.rpty_mtx.Lock() - crp, ok = cts.rpty_map[id] - if !ok { - cts.rpty_mtx.Unlock() + crp = cts.FindClientRptyById(id) + if crp == nil { return fmt.Errorf("unknown rpty id %d", id) } @@ -1566,11 +1638,11 @@ func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error { cols, _ = strconv.Atoi(flds[1]) pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)}) } - cts.rpty_mtx.Unlock() return nil } func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error { + switch packet_type { case PACKET_KIND_RPTY_START: return cts.StartRpty(evt.Id, &cts.C.wg) @@ -1589,6 +1661,329 @@ func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) return nil } +// rpx +func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx { + var crpx *ClientRpx + var ok bool + + cts.rpx_mtx.Lock() + crpx, ok = cts.rpx_map[id] + cts.rpx_mtx.Unlock() + + if !ok { crpx = nil } + return crpx +} + +func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) { + var sc *bufio.Scanner + var line string + var flds []string + var buf [4096]byte + var req_meth string + var req_path string + //var req_proto string + var req *http.Request + var n int + var err error + + defer wg.Done() + + sc = bufio.NewScanner(bytes.NewReader(data)) + sc.Scan() + line = sc.Text() + + flds = strings.Fields(line) + if (len(flds) < 3) { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line) + goto done + } + +// TODO: handle trailers... + req_meth = flds[0] + req_path = flds[1] + //req_proto = flds[2] + + // create a request assuming it's a normal http request + + req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error()) + goto done + } + + for sc.Scan() { + line = sc.Text() + if line == "" { break } + flds = strings.SplitN(line, ":", 2) + if len(flds) == 2 { + req.Header.Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1])) + } + } + err = sc.Err() + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error()) + goto done + } + + if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + // websocket + var done_chan chan struct{} + + var conn net.Conn + var resp *http.Response + var r *bufio.Reader + + if cts.C.rpx_target_tls != nil { + var dialer *tls.Dialer + dialer = &tls.Dialer{ + NetDialer: &net.Dialer{}, + Config: cts.C.rpx_target_tls, + } + conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding + } else { + var dialer *net.Dialer + dialer = &net.Dialer{} + conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding + } + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error()) + goto done + } + defer conn.Close() + + // TODO: make this atomic? + crpx.ws_conn = conn + + // write the raw request line and headers as sent by the server. + // for the upgrade request, i assume no payload. + _, err = conn.Write(data) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error()) + goto done + } + + r = bufio.NewReader(conn) + resp, err = http.ReadResponse(r, req) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error()) + goto done + } + defer resp.Body.Close() + + err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp))) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error()) + goto done + } + + if resp.StatusCode != http.StatusSwitchingProtocols { + // websock upgrade failed. let the code jump to the done + // label to skip reading from the pipe. the server side + // has the code to ensure no content-length. and the upgrade + // fails, the pipe below will be pending forever as the server + // side doesn't send data and there's no feeding to the pipe. + cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id) + goto done + } + + // unlike with the normal request, the actual pipe is not read + // until the initial switching protocol response is received. + + wg.Add(1) + done_chan = make(chan struct{}, 5) + go func() { + var buf [4096]byte + var n int + var err error + + defer wg.Done() + for { + n, err = crpx.pr.Read(buf[:]) + if n > 0 { + var err2 error + _, err2 = conn.Write(buf[:n]) + if err2 != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error()) + break + } + } + if err != nil { + if errors.Is(err, io.EOF) { break } + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error()) + break + } + } + done_chan <- struct{}{} + }() + + for { + n, err = conn.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) + if err2 != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error()) + break + } + } + if err != nil { + if errors.Is(err, io.EOF) { + cts.psc.Send(MakeRpxEofPacket(crpx.id)) + cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id) + break + } + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error()) + break + } + } + + // wait until the pipe reading(from the server side) goroutine is over + <-done_chan + } else { + var tr *http.Transport + var resp *http.Response + + tr = &http.Transport { + TLSClientConfig: cts.C.rpx_target_tls, + } + + resp, err = tr.RoundTrip(req) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error()) + goto done + } + + defer resp.Body.Close() + + err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp))) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error()) + goto done + } + + for { + n, err = resp.Body.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) + if err2 != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error()) + break + } + } + if err != nil { + if errors.Is(err, io.EOF) { + break + } + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error()) + break + } + } + } + +done: + err = cts.psc.Send(MakeRpxStopPacket(crpx.id)) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error()) + } + cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id) + + crpx.ReqStop() + cts.rpx_mtx.Lock() + delete(cts.rpx_map, crpx.id) + cts.rpx_mtx.Unlock() +} + +func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error { + var crpx *ClientRpx + var ok bool + + cts.rpx_mtx.Lock() + _, ok = cts.rpx_map[id] + if ok { + cts.rpx_mtx.Unlock() + return fmt.Errorf("multiple start on rpx id %d", id) + } + crpx = &ClientRpx{ id: id } + cts.rpx_map[id] = crpx + + // i want the pipe to be created before the goroutine is started + // so that the WriteRpx() can write to the pipe. i protect pipe creation + // and context creation with a mutex + crpx.pr, crpx.pw = io.Pipe() + crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx) + + cts.rpx_mtx.Unlock() + + wg.Add(1) + go cts.RpxLoop(crpx, data, wg) + + return nil +} + +func (cts *ClientConn) StopRpx(id uint64) error { + var crpx *ClientRpx + + crpx = cts.FindClientRpxById(id) + if crpx == nil { + return fmt.Errorf("unknown rpx id %d", id) + } + + crpx.ReqStop() + return nil +} + +func (cts *ClientConn) WriteRpx(id uint64, data []byte) error { + var crpx *ClientRpx + var err error + + crpx = cts.FindClientRpxById(id) + if crpx == nil { + return fmt.Errorf("unknown rpx id %d", id) + } + +// TODO: may have to write it in a goroutine to avoid blocking? + _, err = crpx.pw.Write(data) + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error()) + return err + } + return nil +} + +func (cts *ClientConn) EofRpx(id uint64, data []byte) error { + var crpx *ClientRpx + + crpx = cts.FindClientRpxById(id) + if crpx == nil { + return fmt.Errorf("unknown rpx id %d", id) + } + + // close the writing end only. leave the reading end untouched + crpx.pw.Close() + + return nil +} + +func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error { + switch packet_type { + case PACKET_KIND_RPX_START: + return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg) + + case PACKET_KIND_RPX_STOP: + return cts.StopRpx(evt.Id) + + case PACKET_KIND_RPX_DATA: + return cts.WriteRpx(evt.Id, evt.Data) + + case PACKET_KIND_RPX_EOF: + return cts.EofRpx(evt.Id, evt.Data) + } + + // ignore other packet types + return nil +} + // -------------------------------------------------------------------- func (m ClientPeerConnMap) get_sorted_keys() []PeerId { @@ -1631,12 +2026,13 @@ func (m ClientConnMap) get_sorted_keys() []ConnId { type client_ctl_log_writer struct { cli *Client + depth int } func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) { // the standard http.Server always requires *log.Logger // use this iowriter to create a logger to pass it to the http server. - hlw.cli.log.Write("", LOG_INFO, string(p)) + hlw.cli.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p)) return len(p), nil } @@ -1696,13 +2092,13 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler { } // TODO: statistics by status_code and end point types. - time_taken = time.Now().Sub(start_time) + time_taken = time.Since(start_time) //time.Now().Sub(start_time) if status_code > 0 { if err != nil { - c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error()) } else { - c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds()) } } }) @@ -1714,7 +2110,7 @@ func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { var status_code int status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) - c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code) + c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code) return } handler.ServeHTTP(w, req) @@ -1728,21 +2124,19 @@ func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket. var start_time time.Time var time_taken time.Duration var req *http.Request - var raw_url_path string req = ws.Request() - raw_url_path = get_raw_url_path(req) - c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI) start_time = time.Now() status_code, err = handler.ServeWebsocket(ws) - time_taken = time.Now().Sub(start_time) + time_taken = time.Since(start_time) // time.Now().Sub(start_time) if status_code > 0 { if err != nil { - c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error()) } else { - c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds()) + c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds()) } } }) @@ -1770,6 +2164,14 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi c.bulletin = NewBulletin[*ClientEvent](&c, 1024) c.rpc_tls = cfg.RpcTls + c.rpx_target_addr = cfg.RpxTargetAddr + c.rpx_target_tls = cfg.RpxTargetTls + if c.rpx_target_tls != nil { + c.rpx_target_url = "https://" + c.rpx_target_addr + } else { + c.rpx_target_url = "http://" + c.rpx_target_addr + } + c.ctl_auth = cfg.CtlAuth c.ctl_tls = cfg.CtlTls c.ctl_prefix = cfg.CtlPrefix @@ -1843,13 +2245,13 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi c.ctl = make([]*http.Server, len(cfg.CtlAddrs)) copy(c.ctl_addr, cfg.CtlAddrs) - hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0) + hs_log = log.New(&client_ctl_log_writer{cli: &c, depth: 0}, "", 0) for i = 0; i < len(cfg.CtlAddrs); i++ { c.ctl[i] = &http.Server{ Addr: cfg.CtlAddrs[i], Handler: c.ctl_mux, - TLSConfig: c.ctl_tls, + TLSConfig: c.ctl_tls.Clone(), ErrorLog: hs_log, // TODO: more settings } @@ -1859,6 +2261,7 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi c.stats.routes.Store(0) c.stats.peers.Store(0) c.stats.pty_sessions.Store(0) + c.stats.rpty_sessions.Store(0) return &c } @@ -2171,8 +2574,52 @@ func (c *Client) GetPtyShell() string { return c.pty_shell } -func (c *Client) RunCtlTask(wg *sync.WaitGroup) { +func (c *Client) run_single_ctl_server(i int, cs *http.Server, wg *sync.WaitGroup) { + var l net.Listener var err error + + defer wg.Done() + + c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i]) + + // defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS() + // by creating the listener explicitly. + // err = cs.ListenAndServe() + // err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key + + //cs.shuttingDown(), as the name indicates, is not expoosed by the net/http + //so I have to use my own indicator to check if it's been shutdown.. + // + if c.stop_req.Load() == false { + // this guard has a flaw in that the stop request can be made + // between the check above and net.Listen() below. + l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) + if err == nil { + if c.stop_req.Load() == false { + // check it again to make the guard slightly more stable + // although it's still possible that the stop request is made + // after Listen() + if c.ctl_tls == nil { + err = cs.Serve(l) + } else { + err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key + } + } else { + err = fmt.Errorf("stop requested") + } + l.Close() + } + } else { + err = fmt.Errorf("stop requested") + } + if errors.Is(err, http.ErrServerClosed) { + c.log.Write("", LOG_INFO, "Control channel[%d] ended", i) + } else { + c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error()) + } +} + +func (c *Client) RunCtlTask(wg *sync.WaitGroup) { var ctl *http.Server var idx int var l_wg sync.WaitGroup @@ -2181,49 +2628,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) { for idx, ctl = range c.ctl { l_wg.Add(1) - go func(i int, cs *http.Server) { - var l net.Listener - - c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i]) - - // defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS() - // by creating the listener explicitly. - // err = cs.ListenAndServe() - // err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key - - //cs.shuttingDown(), as the name indicates, is not expoosed by the net/http - //so I have to use my own indicator to check if it's been shutdown.. - // - if c.stop_req.Load() == false { - // this guard has a flaw in that the stop request can be made - // between the check above and net.Listen() below. - l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) - if err == nil { - if c.stop_req.Load() == false { - // check it again to make the guard slightly more stable - // although it's still possible that the stop request is made - // after Listen() - if c.ctl_tls == nil { - err = cs.Serve(l) - } else { - err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key - } - } else { - err = fmt.Errorf("stop requested") - } - l.Close() - } - } else { - err = fmt.Errorf("stop requested") - } - if errors.Is(err, http.ErrServerClosed) { - c.log.Write("", LOG_INFO, "Control channel[%d] ended", i) - } else { - c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error()) - } - - l_wg.Done() - }(idx, ctl) + go c.run_single_ctl_server(idx, ctl, &l_wg) } l_wg.Wait() } diff --git a/cmd/config.go b/cmd/config.go index 90dab9b..f910396 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -73,6 +73,12 @@ type RPXServiceConfig struct { Addrs []string `yaml:"addresses"` } +type RPXClientTokenConfig struct { + AttrName string `yaml:"attr-name"` + Regex string `yaml:"regex"` + SubmatchIndex int `yaml:"submatch-index"` +} + type PXYServiceConfig struct { Addrs []string `yaml:"addresses"` } @@ -127,8 +133,9 @@ type ServerConfig struct { } `yaml:"ctl"` RPX struct { - Service RPXServiceConfig `yaml:"service"` - TLS ServerTLSConfig `yaml:"tls"` + Service RPXServiceConfig `yaml:"service"` + TLS ServerTLSConfig `yaml:"tls"` + ClientToken RPXClientTokenConfig `yaml:"client-token"` } `yaml:"rpx"` PXY struct { @@ -158,6 +165,12 @@ type ClientConfig struct { Endpoint RPCEndpointConfig `yaml:"endpoint"` TLS ClientTLSConfig `yaml:"tls"` } `yaml:"rpc"` + RPX struct { + Target struct { + Addr string `yaml:"address"` + TLS ClientTLSConfig `yaml:"tls"` + } `yaml:"target"` + } } func load_server_config_to(cfgfile string, cfg *ServerConfig) error { diff --git a/cmd/logger.go b/cmd/logger.go index 5813e24..826addc 100644 --- a/cmd/logger.go +++ b/cmd/logger.go @@ -121,7 +121,6 @@ main_loop: } } - func (l *AppLogger) Write(id string, level hodu.LogLevel, fmtstr string, args ...interface{}) { if l.mask & hodu.LogMask(level) == 0 { return } l.write(id, level, 1, fmtstr, args...) diff --git a/cmd/main.go b/cmd/main.go index 9459ff4..7354d4c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -10,6 +10,7 @@ import "net" import "os" import "os/signal" import "path/filepath" +import "regexp" import "strings" import "sync" import "syscall" @@ -131,6 +132,13 @@ func server_main(ctl_addrs []string, rpc_addrs []string, rpx_addrs[] string, pxy if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs } if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs } + config.RpxClientTokenAttrName = cfg.RPX.ClientToken.AttrName + if cfg.RPX.ClientToken.Regex != "" { + config.RpxClientTokenRegex, err = regexp.Compile(cfg.RPX.ClientToken.Regex) + if err != nil { return err } + } + config.RpxClientTokenSubmatchIndex = cfg.RPX.ClientToken.SubmatchIndex + config.CtlCors = cfg.CTL.Service.Cors config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth) if err != nil { return err } @@ -282,10 +290,13 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string, if err != nil { return err } config.RpcTls, err = make_tls_client_config(&cfg.RPC.TLS) if err != nil { return err } + config.RpxTargetTls, err = make_tls_client_config(&cfg.RPX.Target.TLS) + if err != nil { return err } if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs } if len(config.CtlAddrs) <= 0 { config.CtlAddrs = cfg.CTL.Service.Addrs } + config.RpxTargetAddr = cfg.RPX.Target.Addr config.CtlPrefix = cfg.CTL.Service.Prefix config.CtlCors = cfg.CTL.Service.Cors config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth) diff --git a/hodu.go b/hodu.go index c6783b1..12c2c95 100644 --- a/hodu.go +++ b/hodu.go @@ -1,5 +1,6 @@ package hodu +import "bytes" import "crypto/rsa" import _ "embed" import "encoding/base64" @@ -8,6 +9,7 @@ import "net" import "net/http" import "net/netip" import "os" +import "regexp" import "runtime" import "strings" import "sync" @@ -79,11 +81,6 @@ type JsonErrmsg struct { Text string `json:"error-text"` } -type json_in_cred struct { - Username string `json:"username"` - Password string `json:"password"` -} - type json_in_notice struct { Text string `json:"text"` } @@ -224,8 +221,9 @@ func (option RouteOption) String() string { func dump_call_frame_and_exit(log Logger, req *http.Request, err interface{}) { var buf []byte - buf = make([]byte, 65536); buf = buf[:min(65536, runtime.Stack(buf, false))] - log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, get_raw_url_path(req), err, string(buf)) + buf = make([]byte, 65536) + buf = buf[:min(65536, runtime.Stack(buf, false))] + log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, req.RequestURI, err, string(buf)) log.Close() os.Exit(99) // fatal error. treat panic() as a fatal runtime error } @@ -467,12 +465,75 @@ func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) { return http.StatusOK, "" } - // ------------------------------------ -func get_raw_url_path(req *http.Request) string { - var path string - path = req.URL.Path - if req.URL.RawQuery != "" { path += "?" + req.URL.RawQuery } - return path +func get_http_req_line_and_headers(r *http.Request, force_host bool) []byte { + var buf bytes.Buffer + var name string + var value string + var values []string + var host_found bool + + fmt.Fprintf(&buf, "%s %s %s\r\n", r.Method, r.RequestURI, r.Proto) + + for name, values = range r.Header { + if strings.EqualFold(name, "Accept-Encoding") { // TODO: make it generic. parameterize it?? + // skip Accept-Encoding as the go client side + // doesn't function properly when a certain enconding + // is specified. resp.Body.Read() returned EOF when + // not working + continue + } else if strings.EqualFold(name, "Host") { + host_found = true + } + for _, value = range values { + fmt.Fprintf(&buf, "%s: %s\r\n", name, value) + } + } + + if force_host && !host_found && r.Host != "" { + fmt.Fprintf(&buf, "Host: %s\r\n", r.Host) + } +// TODO: host and x-forwarded-for, x-forwarded-proto, etc??? + + buf.WriteString("\r\n") // End of headers + return buf.Bytes() } + +func get_http_resp_line_and_headers(r *http.Response) []byte { + var buf bytes.Buffer + var name string + var value string + var values []string + + fmt.Fprintf(&buf, "%s %s\r\n", r.Proto, r.Status) + + for name, values = range r.Header { + for _, value = range values { + fmt.Fprintf(&buf, "%s: %s\r\n", name, value) + } + } + + buf.WriteString("\r\n") // End of headers + return buf.Bytes() +} + +func get_regex_submatch(re *regexp.Regexp, str string, n int) string { + var idxs []int + var pos int + var start int + var end int + + idxs = re.FindStringSubmatchIndex(str) + if idxs == nil { return "" } + + pos = n * 2 + if pos + 1 >= len(idxs) { return "" } + + start, end = idxs[pos], idxs[pos + 1] + if start == -1 || end == -1 { + return "" + } + + return str[start:end] +} \ No newline at end of file diff --git a/hodu.pb.go b/hodu.pb.go index 9d6bae3..c75c216 100644 --- a/hodu.pb.go +++ b/hodu.pb.go @@ -101,7 +101,11 @@ const ( PACKET_KIND_RPTY_START PACKET_KIND = 14 PACKET_KIND_RPTY_STOP PACKET_KIND = 15 PACKET_KIND_RPTY_DATA PACKET_KIND = 16 - PACKET_KIND_RPTY_SIZE PACKET_KIND = 17 + PACKET_KIND_RPTY_SIZE PACKET_KIND = 17 // terminal size + PACKET_KIND_RPX_START PACKET_KIND = 18 + PACKET_KIND_RPX_STOP PACKET_KIND = 19 + PACKET_KIND_RPX_DATA PACKET_KIND = 20 + PACKET_KIND_RPX_EOF PACKET_KIND = 21 ) // Enum value maps for PACKET_KIND. @@ -124,6 +128,10 @@ var ( 15: "RPTY_STOP", 16: "RPTY_DATA", 17: "RPTY_SIZE", + 18: "RPX_START", + 19: "RPX_STOP", + 20: "RPX_DATA", + 21: "RPX_EOF", } PACKET_KIND_value = map[string]int32{ "RESERVED": 0, @@ -143,6 +151,10 @@ var ( "RPTY_STOP": 15, "RPTY_DATA": 16, "RPTY_SIZE": 17, + "RPX_START": 18, + "RPX_STOP": 19, + "RPX_DATA": 20, + "RPX_EOF": 21, } ) @@ -643,6 +655,58 @@ func (x *RptyEvent) GetData() []byte { return nil } +type RpxEvent struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=Id,proto3" json:"Id,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=Data,proto3" json:"Data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RpxEvent) Reset() { + *x = RpxEvent{} + mi := &file_hodu_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RpxEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RpxEvent) ProtoMessage() {} + +func (x *RpxEvent) ProtoReflect() protoreflect.Message { + mi := &file_hodu_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RpxEvent.ProtoReflect.Descriptor instead. +func (*RpxEvent) Descriptor() ([]byte, []int) { + return file_hodu_proto_rawDescGZIP(), []int{8} +} + +func (x *RpxEvent) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *RpxEvent) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + type Packet struct { state protoimpl.MessageState `protogen:"open.v1"` Kind PACKET_KIND `protobuf:"varint,1,opt,name=Kind,proto3,enum=PACKET_KIND" json:"Kind,omitempty"` @@ -655,6 +719,7 @@ type Packet struct { // *Packet_ConnErr // *Packet_ConnNoti // *Packet_RptyEvt + // *Packet_RpxEvt U isPacket_U `protobuf_oneof:"U"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -662,7 +727,7 @@ type Packet struct { func (x *Packet) Reset() { *x = Packet{} - mi := &file_hodu_proto_msgTypes[8] + mi := &file_hodu_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -674,7 +739,7 @@ func (x *Packet) String() string { func (*Packet) ProtoMessage() {} func (x *Packet) ProtoReflect() protoreflect.Message { - mi := &file_hodu_proto_msgTypes[8] + mi := &file_hodu_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -687,7 +752,7 @@ func (x *Packet) ProtoReflect() protoreflect.Message { // Deprecated: Use Packet.ProtoReflect.Descriptor instead. func (*Packet) Descriptor() ([]byte, []int) { - return file_hodu_proto_rawDescGZIP(), []int{8} + return file_hodu_proto_rawDescGZIP(), []int{9} } func (x *Packet) GetKind() PACKET_KIND { @@ -767,6 +832,15 @@ func (x *Packet) GetRptyEvt() *RptyEvent { return nil } +func (x *Packet) GetRpxEvt() *RpxEvent { + if x != nil { + if x, ok := x.U.(*Packet_RpxEvt); ok { + return x.RpxEvt + } + } + return nil +} + type isPacket_U interface { isPacket_U() } @@ -799,6 +873,10 @@ type Packet_RptyEvt struct { RptyEvt *RptyEvent `protobuf:"bytes,8,opt,name=RptyEvt,proto3,oneof"` } +type Packet_RpxEvt struct { + RpxEvt *RpxEvent `protobuf:"bytes,9,opt,name=RpxEvt,proto3,oneof"` +} + func (*Packet_Route) isPacket_U() {} func (*Packet_Peer) isPacket_U() {} @@ -813,6 +891,8 @@ func (*Packet_ConnNoti) isPacket_U() {} func (*Packet_RptyEvt) isPacket_U() {} +func (*Packet_RpxEvt) isPacket_U() {} + var File_hodu_proto protoreflect.FileDescriptor const file_hodu_proto_rawDesc = "" + @@ -850,7 +930,10 @@ const file_hodu_proto_rawDesc = "" + "\x04Text\x18\x01 \x01(\tR\x04Text\"/\n" + "\tRptyEvent\x12\x0e\n" + "\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" + - "\x04Data\x18\x02 \x01(\fR\x04Data\"\xb1\x02\n" + + "\x04Data\x18\x02 \x01(\fR\x04Data\".\n" + + "\bRpxEvent\x12\x0e\n" + + "\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" + + "\x04Data\x18\x02 \x01(\fR\x04Data\"\xd6\x02\n" + "\x06Packet\x12 \n" + "\x04Kind\x18\x01 \x01(\x0e2\f.PACKET_KINDR\x04Kind\x12\"\n" + "\x05Route\x18\x02 \x01(\v2\n" + @@ -862,7 +945,8 @@ const file_hodu_proto_rawDesc = "" + ".ConnErrorH\x00R\aConnErr\x12)\n" + "\bConnNoti\x18\a \x01(\v2\v.ConnNoticeH\x00R\bConnNoti\x12&\n" + "\aRptyEvt\x18\b \x01(\v2\n" + - ".RptyEventH\x00R\aRptyEvtB\x03\n" + + ".RptyEventH\x00R\aRptyEvt\x12#\n" + + "\x06RpxEvt\x18\t \x01(\v2\t.RpxEventH\x00R\x06RpxEvtB\x03\n" + "\x01U*U\n" + "\fROUTE_OPTION\x12\n" + "\n" + @@ -872,7 +956,7 @@ const file_hodu_proto_rawDesc = "" + "\x04TCP6\x10\x04\x12\b\n" + "\x04HTTP\x10\b\x12\t\n" + "\x05HTTPS\x10\x10\x12\a\n" + - "\x03SSH\x10 *\xa2\x02\n" + + "\x03SSH\x10 *\xda\x02\n" + "\vPACKET_KIND\x12\f\n" + "\bRESERVED\x10\x00\x12\x0f\n" + "\vROUTE_START\x10\x01\x12\x0e\n" + @@ -893,7 +977,11 @@ const file_hodu_proto_rawDesc = "" + "RPTY_START\x10\x0e\x12\r\n" + "\tRPTY_STOP\x10\x0f\x12\r\n" + "\tRPTY_DATA\x10\x10\x12\r\n" + - "\tRPTY_SIZE\x10\x112I\n" + + "\tRPTY_SIZE\x10\x11\x12\r\n" + + "\tRPX_START\x10\x12\x12\f\n" + + "\bRPX_STOP\x10\x13\x12\f\n" + + "\bRPX_DATA\x10\x14\x12\v\n" + + "\aRPX_EOF\x10\x152I\n" + "\x04Hodu\x12\x19\n" + "\aGetSeed\x12\x05.Seed\x1a\x05.Seed\"\x00\x12&\n" + "\fPacketStream\x12\a.Packet\x1a\a.Packet\"\x00(\x010\x01B\bZ\x06./hodub\x06proto3" @@ -911,7 +999,7 @@ func file_hodu_proto_rawDescGZIP() []byte { } var file_hodu_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_hodu_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_hodu_proto_msgTypes = make([]protoimpl.MessageInfo, 10) var file_hodu_proto_goTypes = []any{ (ROUTE_OPTION)(0), // 0: ROUTE_OPTION (PACKET_KIND)(0), // 1: PACKET_KIND @@ -923,7 +1011,8 @@ var file_hodu_proto_goTypes = []any{ (*ConnError)(nil), // 7: ConnError (*ConnNotice)(nil), // 8: ConnNotice (*RptyEvent)(nil), // 9: RptyEvent - (*Packet)(nil), // 10: Packet + (*RpxEvent)(nil), // 10: RpxEvent + (*Packet)(nil), // 11: Packet } var file_hodu_proto_depIdxs = []int32{ 1, // 0: Packet.Kind:type_name -> PACKET_KIND @@ -934,15 +1023,16 @@ var file_hodu_proto_depIdxs = []int32{ 7, // 5: Packet.ConnErr:type_name -> ConnError 8, // 6: Packet.ConnNoti:type_name -> ConnNotice 9, // 7: Packet.RptyEvt:type_name -> RptyEvent - 2, // 8: Hodu.GetSeed:input_type -> Seed - 10, // 9: Hodu.PacketStream:input_type -> Packet - 2, // 10: Hodu.GetSeed:output_type -> Seed - 10, // 11: Hodu.PacketStream:output_type -> Packet - 10, // [10:12] is the sub-list for method output_type - 8, // [8:10] is the sub-list for method input_type - 8, // [8:8] is the sub-list for extension type_name - 8, // [8:8] is the sub-list for extension extendee - 0, // [0:8] is the sub-list for field type_name + 10, // 8: Packet.RpxEvt:type_name -> RpxEvent + 2, // 9: Hodu.GetSeed:input_type -> Seed + 11, // 10: Hodu.PacketStream:input_type -> Packet + 2, // 11: Hodu.GetSeed:output_type -> Seed + 11, // 12: Hodu.PacketStream:output_type -> Packet + 11, // [11:13] is the sub-list for method output_type + 9, // [9:11] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_hodu_proto_init() } @@ -950,7 +1040,7 @@ func file_hodu_proto_init() { if File_hodu_proto != nil { return } - file_hodu_proto_msgTypes[8].OneofWrappers = []any{ + file_hodu_proto_msgTypes[9].OneofWrappers = []any{ (*Packet_Route)(nil), (*Packet_Peer)(nil), (*Packet_Data)(nil), @@ -958,6 +1048,7 @@ func file_hodu_proto_init() { (*Packet_ConnErr)(nil), (*Packet_ConnNoti)(nil), (*Packet_RptyEvt)(nil), + (*Packet_RpxEvt)(nil), } type x struct{} out := protoimpl.TypeBuilder{ @@ -965,7 +1056,7 @@ func file_hodu_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_hodu_proto_rawDesc), len(file_hodu_proto_rawDesc)), NumEnums: 2, - NumMessages: 9, + NumMessages: 10, NumExtensions: 0, NumServices: 1, }, diff --git a/hodu.proto b/hodu.proto index 29f644c..7a744fc 100644 --- a/hodu.proto +++ b/hodu.proto @@ -85,6 +85,11 @@ message RptyEvent { bytes Data = 2; }; +message RpxEvent { + uint64 Id = 1; + bytes Data = 2; +}; + enum PACKET_KIND { RESERVED = 0; // not used ROUTE_START = 1; @@ -103,7 +108,12 @@ enum PACKET_KIND { RPTY_START = 14; RPTY_STOP = 15; RPTY_DATA = 16; - RPTY_SIZE = 17; + RPTY_SIZE = 17; // terminal size + + RPX_START = 18; + RPX_STOP = 19; + RPX_DATA = 20; + RPX_EOF = 21; }; message Packet { @@ -117,5 +127,6 @@ message Packet { ConnError ConnErr = 6; ConnNotice ConnNoti = 7; RptyEvent RptyEvt = 8; + RpxEvent RpxEvt = 9; }; } diff --git a/packet.go b/packet.go index 85c41a3..de8f305 100644 --- a/packet.go +++ b/packet.go @@ -80,6 +80,7 @@ func MakeRptyStartPacket(id uint64) *Packet { } func MakeRptyStopPacket(id uint64, msg string) *Packet { + // the rpty stop conveys an error/info message return &Packet{Kind: PACKET_KIND_RPTY_STOP, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: []byte(msg)}}} } @@ -90,3 +91,21 @@ func MakeRptyDataPacket(id uint64, data []byte) *Packet { func MakeRptySizePacket(id uint64, data []byte) *Packet { return &Packet{Kind: PACKET_KIND_RPTY_SIZE, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}} } + +func MakeRpxStartPacket(id uint64, hdr_part []byte) *Packet { + // the rpx start conveys the data unlike other Start packets... + return &Packet{Kind: PACKET_KIND_RPX_START, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id, Data: hdr_part}}} +} + +func MakeRpxStopPacket(id uint64) *Packet { + // the rpx start conveys the data unlike other Start packets... + return &Packet{Kind: PACKET_KIND_RPX_STOP, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id}}} +} + +func MakeRpxDataPacket(id uint64, data_part []byte) *Packet { + return &Packet{Kind: PACKET_KIND_RPX_DATA, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id, Data: data_part}}} +} + +func MakeRpxEofPacket(id uint64) *Packet { + return &Packet{Kind: PACKET_KIND_RPX_EOF, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id}}} +} diff --git a/server-ctl.go b/server-ctl.go index ed4e73b..1f0e0ef 100644 --- a/server-ctl.go +++ b/server-ctl.go @@ -76,7 +76,8 @@ type json_out_server_stats struct { ServerPeers int64 `json:"server-peers"` SshProxySessions int64 `json:"pxy-ssh-sessions"` - ServerPtySessions int64 `json:"server-pty-sessions"` + ServerPtySessions int64 `json:"server-pty-sessions"` + ServerRptySessions int64 `json:"server-rpty-sessions"` } // this is a more specialized variant of json_in_notice @@ -921,6 +922,7 @@ func (ctl *server_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request) stats.ServerPeers = s.stats.peers.Load() stats.SshProxySessions = s.stats.ssh_proxy_sessions.Load() stats.ServerPtySessions = s.stats.pty_sessions.Load() + stats.ServerRptySessions = s.stats.rpty_sessions.Load() status_code = WriteJsonRespHeader(w, http.StatusOK) if err = je.Encode(stats); err != nil { goto oops } diff --git a/server-metrics.go b/server-metrics.go index a1e0c3f..16d9485 100644 --- a/server-metrics.go +++ b/server-metrics.go @@ -12,6 +12,7 @@ type ServerCollector struct { ServerPeers *prometheus.Desc SshProxySessions *prometheus.Desc PtySessions *prometheus.Desc + RptySessions *prometheus.Desc } // NewServerCollector returns a new ServerCollector with all prometheus.Desc initialized @@ -58,6 +59,11 @@ func NewServerCollector(server *Server) ServerCollector { "Number of pty session", nil, nil, ), + RptySessions: prometheus.NewDesc( + prefix + "rpty_sessions", + "Number of rpty session", + nil, nil, + ), } } @@ -68,6 +74,7 @@ func (c ServerCollector) Describe(ch chan<- *prometheus.Desc) { ch <- c.ServerPeers ch <- c.SshProxySessions ch <- c.PtySessions + ch <- c.RptySessions } func (c ServerCollector) Collect(ch chan<- prometheus.Metric) { @@ -110,4 +117,10 @@ func (c ServerCollector) Collect(ch chan<- prometheus.Metric) { prometheus.GaugeValue, float64(c.server.stats.pty_sessions.Load()), ) + + ch <- prometheus.MustNewConstMetric( + c.RptySessions, + prometheus.GaugeValue, + float64(c.server.stats.rpty_sessions.Load()), + ) } diff --git a/server-peer.go b/server-peer.go index cf38ffc..4f016ff 100644 --- a/server-peer.go +++ b/server-peer.go @@ -102,6 +102,16 @@ wait_for_started: for { n, err = spc.conn.Read(buf[:]) + if n > 0 { + var err2 error + err2 = pss.Send(MakePeerDataPacket(spc.route.Id, spc.conn_id, buf[:n])) + if err2 != nil { + spc.route.Cts.S.log.Write(spc.route.Cts.Sid, LOG_ERROR, + "Failed to send data from peer(%d,%d,%s,%s) to client - %s", + spc.route.Id, spc.conn_id, conn_raddr, conn_laddr, err2.Error()) + goto done + } + } if err != nil { if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { // i don't like this way to check this error. err = pss.Send(MakePeerEofPacket(spc.route.Id, spc.conn_id)) @@ -119,14 +129,6 @@ wait_for_started: goto done } } - - err = pss.Send(MakePeerDataPacket(spc.route.Id, spc.conn_id, buf[:n])) - if err != nil { - spc.route.Cts.S.log.Write(spc.route.Cts.Sid, LOG_ERROR, - "Failed to send data from peer(%d,%d,%s,%s) to client - %s", - spc.route.Id, spc.conn_id, conn_raddr, conn_laddr, err.Error()) - goto done - } } wait_for_stopped: diff --git a/server-pty.go b/server-pty.go index 9542161..dbe46dd 100644 --- a/server-pty.go +++ b/server-pty.go @@ -67,7 +67,7 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { conn_ready = <-conn_ready_chan if conn_ready { // connected var poll_fds []unix.PollFd - var buf []byte + var buf [2048]byte var n int var err error @@ -76,7 +76,6 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { } s.stats.pty_sessions.Add(1) - buf = make([]byte, 2048) for { n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely if err != nil { @@ -94,20 +93,21 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { } if (poll_fds[0].Revents & unix.POLLIN) != 0 { - n, err = out.Read(buf) + n, err = out.Read(buf[:]) + if n > 0 { + var err2 error + err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) + if err2 != nil { + s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error()) + break + } + } if err != nil { if !errors.Is(err, io.EOF) { s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error()) } break } - if n > 0 { - err = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) - if err != nil { - s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) - break - } - } } } s.stats.pty_sessions.Add(-1) diff --git a/server-pxy.go b/server-pxy.go index ccb0293..a08fb1b 100644 --- a/server-pxy.go +++ b/server-pxy.go @@ -375,7 +375,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ } proxy_url = pxy.req_to_proxy_url(req, pi) - s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, get_raw_url_path(req), proxy_url) + s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.RequestURI, proxy_url) proxy_req, err = http.NewRequestWithContext(s.Ctx, req.Method, proxy_url.String(), req.Body) if err != nil { @@ -401,7 +401,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ } else { status_code = resp.StatusCode if upgrade_required && resp.StatusCode == http.StatusSwitchingProtocols { - s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code) + s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.RequestURI, status_code) err = pxy.serve_upgraded(w, req, resp) if err != nil { goto oops } return 0, nil// print the log mesage before calling serve_upgraded() and exit here @@ -426,7 +426,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ _, err = io.Copy(w, resp_body) if err != nil { - s.log.Write(pxy.Id, LOG_WARN, "[%s] %s %s %s", req.RemoteAddr, req.Method, get_raw_url_path(req), err.Error()) + s.log.Write(pxy.Id, LOG_WARN, "[%s] %s %s %s", req.RemoteAddr, req.Method, req.RequestURI, err.Error()) } // TODO: handle trailers @@ -689,27 +689,27 @@ func (pxy *server_pxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) (int, error) { conn_ready = <-conn_ready_chan if conn_ready { // connected - var buf []byte + var buf [2048]byte var n int var err error s.stats.ssh_proxy_sessions.Add(1) - buf = make([]byte, 2048) for { - n, err = out.Read(buf) + n, err = out.Read(buf[:]) + if n > 0 { + var err2 error + err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) + if err2 != nil { + s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error()) + break + } + } if err != nil { if !errors.Is(err, io.EOF) { s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to read from SSH stdout - %s", req.RemoteAddr, err.Error()) } break } - if n > 0 { - err = send_ws_data_for_xterm(ws, "iov", string(buf[:n])) - if err != nil { - s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error()) - break - } - } } s.stats.ssh_proxy_sessions.Add(-1) } diff --git a/server-rpx.go b/server-rpx.go index 39d054e..2530f5d 100644 --- a/server-rpx.go +++ b/server-rpx.go @@ -1,6 +1,15 @@ package hodu +import "bufio" +import "bytes" +import "errors" +import "fmt" +import "io" +import "net" import "net/http" +import "strconv" +import "strings" +import "sync" type server_rpx struct { S *Server @@ -8,28 +17,313 @@ type server_rpx struct { } // ------------------------------------ -func (pxy *server_rpx) Identity() string { - return pxy.Id +func (rpx *server_rpx) Identity() string { + return rpx.Id } -func (pxy *server_rpx) Cors(req *http.Request) bool { +func (rpx *server_rpx) Cors(req *http.Request) bool { return false } -func (pxy *server_rpx) Authenticate(req *http.Request) (int, string) { +func (rpx *server_rpx) Authenticate(req *http.Request) (int, string) { return http.StatusOK, "" } -func (pxy *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { - var status_code int -// var err error +func (rpx *server_rpx) get_client_token(req *http.Request) string { + var val string - status_code = http.StatusOK + // TODO: enhance this client token extraction logic with some expression language? + val = req.Header.Get(rpx.S.Cfg.RpxClientTokenAttrName) + if val == "" { val = req.Host } -//done: - return status_code, nil + if rpx.S.Cfg.RpxClientTokenRegex != nil { + val = get_regex_submatch(rpx.S.Cfg.RpxClientTokenRegex, val, rpx.S.Cfg.RpxClientTokenSubmatchIndex) + } -//oops: -// return status_code, err + return val } +func (rpx* server_rpx) handle_header_data(rpx_id uint64, data []byte, w http.ResponseWriter) (int, error) { + var sc *bufio.Scanner + var line string + var flds []string + var status_code int + var err error + + sc = bufio.NewScanner(bytes.NewReader(data)) + sc.Scan() + line = sc.Text() + + flds = strings.Fields(line) + if (len(flds) < 2) { // i care about the status code.. + return http.StatusBadGateway, fmt.Errorf("invalid response status for rpx(%d) - %s", rpx_id, line) + } + status_code, err = strconv.Atoi(flds[1]) + if err != nil { + return http.StatusBadGateway, fmt.Errorf("invalid response code for rpx(%d) - %s", rpx_id, err.Error()) + } + + for sc.Scan() { + line = sc.Text() + if line == "" { break } + flds = strings.SplitN(line, ":", 2) + if len(flds) == 2 { + w.Header().Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1])) + } + } + err = sc.Err() + if err != nil { + return http.StatusBadGateway, fmt.Errorf("failed to parse response for rpx(%d) - %s", rpx_id, err.Error()) + } + + w.WriteHeader(status_code) + return status_code, nil +} + +func (rpx *server_rpx) handle_response(srpx *ServerRpx, req *http.Request, w http.ResponseWriter, ws_upgrade bool, wg *sync.WaitGroup) { + var start_resp []byte + var status_code int + var buf [4096]byte + var n int + var wr io.Writer + var wrote_br_chan bool + var err error + + defer wg.Done() + + select { + case start_resp = <- srpx.start_chan: + // received the header. ready to proceed to the body + // do nothing. just continue + status_code, err = rpx.handle_header_data(srpx.id, start_resp, w) + if err != nil { goto done } + + case <- srpx.done_chan: + err = fmt.Errorf("rpx(%d) terminated before receiving header", srpx.id) + status_code = http.StatusBadGateway + goto done + case <- req.Context().Done(): + err = fmt.Errorf("rpx(%d) terminated before receiving header - %s", srpx.id, req.Context().Err().Error()) + status_code = http.StatusBadGateway + goto done + + // no default. block + } + + if ws_upgrade && status_code == http.StatusSwitchingProtocols { + var hijk http.Hijacker + var conn net.Conn + var ok bool + + hijk, ok = w.(http.Hijacker) + if !ok { + err = fmt.Errorf("failed to upgrade rpx(%d) - not a hijacker", srpx.id) + status_code = http.StatusInternalServerError + goto done + } + + conn, _, err = hijk.Hijack() + if err != nil { + err = fmt.Errorf("failed to upgrade rpx(%d) - %s", srpx.id, err.Error()) + status_code = http.StatusInternalServerError + goto done + } + + // websocket upgrade is successful + srpx.br = conn + srpx.br_chan <- true // inform another goroutine that the protocol switching is completed. + wrote_br_chan = true + + wr = conn + } else { + if ws_upgrade { + srpx.br_chan <- false + wrote_br_chan = true + } // indicate upgrade failure + wr = w + } + + for { + n, err = srpx.pr.Read(buf[:]) + if n > 0 { + var err2 error + _, err2 = wr.Write(buf[:n]) + if err2 != nil { + err = err2 + status_code = http.StatusInternalServerError + break + } + } + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + } else { + status_code = http.StatusInternalServerError + } + break + } + } + +done: + // just send another in case the code got jump into this part for an error + // may not be consumed but the channel is large enough for redundant data + srpx.resp_status_code = status_code + srpx.resp_error = err + + if ws_upgrade && !wrote_br_chan { + srpx.br_chan <- false + } +} + +func (rpx *server_rpx) alloc_server_rpx(cts *ServerConn, req *http.Request) (*ServerRpx, error) { + var srpx *ServerRpx + var start_id uint64 + var assigned_id uint64 + var ok bool + + cts.rpx_mtx.Lock() + start_id = cts.rpx_next_id + for { + _, ok = cts.rpx_map[cts.rpx_next_id] + if !ok { + assigned_id = cts.rpx_next_id + cts.rpx_next_id++ + if cts.rpx_next_id == 0 { cts.rpx_next_id++ } + break + } + cts.rpx_next_id++ + if cts.rpx_next_id == 0 { cts.rpx_next_id++ } + if cts.rpx_next_id == start_id { + // unlikely to happen but it cycled through the whole range. + cts.rpx_mtx.Unlock() + return nil, fmt.Errorf("failed to assign id") + } + } + + srpx = &ServerRpx{ + id: assigned_id, + start_chan: make(chan []byte, 5), + done_chan: make(chan bool, 5), + br_chan: make(chan bool, 5), + } + srpx.br = req.Body + srpx.pr, srpx.pw = io.Pipe() + cts.rpx_map[assigned_id] = srpx + + cts.rpx_mtx.Unlock() + return srpx, nil +} + +func (rpx *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) { + var s *Server + var client_token string + var start_sent bool + var cts *ServerConn + var status_code int + var srpx *ServerRpx + var ws_upgrade bool + var buf [4096]byte + var wg sync.WaitGroup + var err error + + s = rpx.S + client_token = rpx.get_client_token(req) + cts = s.FindServerConnByClientToken(client_token) + if cts == nil { + status_code = WriteEmptyRespHeader(w, http.StatusNotFound) + err = fmt.Errorf("unknown client token - %s", client_token) + goto oops + } + + srpx, err = rpx.alloc_server_rpx(cts, req) + if err != nil { + status_code = WriteEmptyRespHeader(w, http.StatusServiceUnavailable) + err = fmt.Errorf("unable to allocate rpx - %s", err.Error()) + goto oops + } + + // arrange to clear the rpx_map entry when this function exits + defer func() { + cts.rpx_mtx.Lock() + delete(cts.rpx_map, srpx.id) + cts.rpx_mtx.Unlock() + }() + + ws_upgrade = strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"); + if ws_upgrade && req.ContentLength > 0 { + // while other webservers are ok with upgrade request with body payload, + // this program rejects such a request for impelementation limitation as + // it's not dealing with a raw byte but is using the standard web server handler. + status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) + err = fmt.Errorf("failed to assign id") + goto oops + } + + err = cts.pss.Send(MakeRpxStartPacket(srpx.id, get_http_req_line_and_headers(req, true))) + if err != nil { + status_code = WriteEmptyRespHeader(w, http.StatusBadGateway) + goto oops + } + start_sent = true + + wg.Add(1) + go rpx.handle_response(srpx, req, w, ws_upgrade, &wg) + + if ws_upgrade { + // wait until the protocol switching is done in rpx.handle_response() + var upgraded bool + upgraded = <- srpx.br_chan + if upgraded { + // arrange to close the hijacked connection inside rpx.handle_response() + defer srpx.br.Close() + } + } + + for { + var n int + + n, err = srpx.br.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cts.pss.Send(MakeRpxDataPacket(srpx.id, buf[:n])) + if err2 != nil { + status_code = WriteEmptyRespHeader(w, http.StatusBadGateway) + goto oops + } + } + if err != nil { + if errors.Is(err, io.EOF) { + err = cts.pss.Send(MakeRpxEofPacket(srpx.id)) + if err != nil { + status_code = WriteEmptyRespHeader(w, http.StatusBadGateway) + goto oops + } + break + } + status_code = WriteEmptyRespHeader(w, http.StatusInternalServerError) + goto oops + } + } + + wg.Wait() + if srpx.resp_error != nil { + status_code = WriteEmptyRespHeader(w, srpx.resp_status_code) + err = srpx.resp_error + goto oops + } + + select { + case <- srpx.done_chan: + // anything to do? + case <- req.Context().Done(): + // anything to do? + // no default. block + } + + cts.pss.Send(MakeRpxStopPacket(srpx.id)) + return srpx.resp_status_code, nil + +oops: + if srpx != nil && start_sent { cts.pss.Send(MakeRpxStopPacket(srpx.id)) } + return status_code, err +} diff --git a/server.go b/server.go index f26b6ed..0de1f35 100644 --- a/server.go +++ b/server.go @@ -10,6 +10,7 @@ import "log" import "net" import "net/http" import "net/netip" +import "regexp" import "slices" import "strconv" import "strings" @@ -32,8 +33,8 @@ const CTS_LIMIT int = 16384 type PortId uint16 const PORT_ID_MARKER string = "_" const HS_ID_CTL string = "ctl" -const HS_ID_RPX string = "pxy" -const HS_ID_PXY string = "rpx" +const HS_ID_RPX string = "rpx" +const HS_ID_PXY string = "pxy" const HS_ID_WPX string = "wpx" type ServerConnMapByAddr map[net.Addr]*ServerConn @@ -45,6 +46,7 @@ type ServerSvcPortMap map[PortId]ConnRouteId type ServerRptyMap map[uint64]*ServerRpty type ServerRptyMapByWs map[*websocket.Conn]*ServerRpty +type ServerRpxMap map[uint64]*ServerRpx type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error) @@ -67,6 +69,9 @@ type ServerConfig struct { RpxAddrs []string RpxTls *tls.Config + RpxClientTokenAttrName string + RpxClientTokenRegex *regexp.Regexp + RpxClientTokenSubmatchIndex int PxyAddrs []string PxyTls *tls.Config @@ -166,12 +171,12 @@ type Server struct { peers atomic.Int64 ssh_proxy_sessions atomic.Int64 pty_sessions atomic.Int64 + rpty_sessions atomic.Int64 } wpx_resp_tf ServerWpxResponseTransformer wpx_foreign_port_proxy_maker ServerWpxForeignPortProxyMaker - pty_user string pty_shell string xterm_html string @@ -202,6 +207,10 @@ type ServerConn struct { rpty_map ServerRptyMap rpty_map_by_ws ServerRptyMapByWs + rpx_next_id uint64 + rpx_mtx sync.Mutex + rpx_map ServerRpxMap + wg sync.WaitGroup stop_req atomic.Bool stop_chan chan bool @@ -236,6 +245,20 @@ type ServerRpty struct { ws *websocket.Conn } +type ServerRpx struct { + id uint64 + pr *io.PipeReader + pw *io.PipeWriter + br io.ReadCloser // body reader + start_chan chan []byte + done_chan chan bool + br_chan chan bool + + resp_status_code int + resp_error error + resp_done_chan chan bool +} + type GuardedPacketStreamServer struct { mtx sync.Mutex //pss Hodu_PacketStreamServer @@ -265,6 +288,16 @@ func (g *GuardedPacketStreamServer) Context() context.Context { // ------------------------------------ +func (rpty *ServerRpty) ReqStop() { + rpty.ws.Close() +} + +func (rpx *ServerRpx) ReqStop() { + rpx.done_chan <- true + rpx.pw.Close() +} +// ------------------------------------ + func NewServerRoute(cts *ServerConn, id RouteId, option RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) { var r ServerRoute var l *net.TCPListener @@ -658,6 +691,7 @@ func (cts *ServerConn) ReqStopAllServerRoutes() { cts.route_mtx.Unlock() } +// Rpty func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) { var ok bool var start_id uint64 @@ -707,11 +741,11 @@ func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) { return nil , err } + cts.S.stats.rpty_sessions.Add(1) return rpty, nil } func (cts *ServerConn) StopRpty(ws *websocket.Conn) error { - // called by the websocket handler. var rpty *ServerRpty var id uint64 @@ -721,7 +755,9 @@ func (cts *ServerConn) StopRpty(ws *websocket.Conn) error { cts.rpty_mtx.Lock() rpty, ok = cts.rpty_map_by_ws[ws] if !ok { - return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr()) + cts.rpty_mtx.Unlock() + cts.S.log.Write(cts.Sid, LOG_ERROR, "Unknown websocket connection for rpty - websocket %v", ws.RemoteAddr()) + return fmt.Errorf("unknown websocket connection for rpty - %v", ws.RemoteAddr()) } id = rpty.id @@ -730,14 +766,24 @@ func (cts *ServerConn) StopRpty(ws *websocket.Conn) error { // send the stop request to the client side err = cts.pss.Send(MakeRptyStopPacket(id, "")) if err != nil { - return fmt.Errorf("unable to send stop rpty request to client - %s", err.Error()) + cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to send RPTY_STOP(%d) for server %s websocket %v - %s", id, cts.RemoteAddr, ws.RemoteAddr(), err.Error()) + // carry on } + // delete the rpty entry from the maps as the websocket + // handler is ending + cts.rpty_mtx.Lock() + delete(cts.rpty_map, id) + delete(cts.rpty_map_by_ws, ws) + cts.rpty_mtx.Unlock() + cts.S.stats.rpty_sessions.Add(-1) + + cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for server %s websocket %vs", id, cts.RemoteAddr, ws.RemoteAddr()) return nil } func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error { - // called this when the stop requested comes from the client + // call this when the stop requested comes from the client. // abort the websocket side. var rpty *ServerRpty @@ -746,11 +792,12 @@ func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error { cts.rpty_mtx.Lock() rpty, ok = cts.rpty_map[id] if !ok { + cts.rpty_mtx.Unlock() return fmt.Errorf("unknown rpty id %d", id) } - rpty.ws.Close() cts.rpty_mtx.Unlock() + rpty.ReqStop() cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for %s - %s", id, cts.RemoteAddr, msg) return nil } @@ -764,6 +811,7 @@ func (cts *ServerConn) WriteRpty(ws *websocket.Conn, data []byte) error { cts.rpty_mtx.Lock() rpty, ok = cts.rpty_map_by_ws[ws] if !ok { + cts.rpty_mtx.Unlock() return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr()) } @@ -787,6 +835,7 @@ func (cts *ServerConn) WriteRptySize(ws *websocket.Conn, data []byte) error { cts.rpty_mtx.Lock() rpty, ok = cts.rpty_map_by_ws[ws] if !ok { + cts.rpty_mtx.Unlock() return fmt.Errorf("unknown ws connection for rpty size - %v", ws.RemoteAddr()) } @@ -812,33 +861,16 @@ func (cts *ServerConn) ReadRptyAndWriteWs(id uint64, data []byte) error { cts.rpty_mtx.Unlock() return fmt.Errorf("unknown rpty id - %d", id) } + cts.rpty_mtx.Unlock() err = send_ws_data_for_xterm(rpty.ws, "iov", string(data)) if err != nil { - cts.rpty_mtx.Unlock() return fmt.Errorf("failed to write rpty data(%d) to ws - %s", id, err.Error()) } - cts.rpty_mtx.Unlock() return nil } - -func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error { - var r *ServerRoute - var ok bool - - cts.route_mtx.Lock() - r, ok = cts.route_map[route_id] - if !ok { - cts.route_mtx.Unlock() - return fmt.Errorf("non-existent route id - %d", route_id) - } - cts.route_mtx.Unlock() - - return r.ReportPacket(pts_id, packet_type, event_data) -} - func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error { switch packet_type { case PACKET_KIND_RPTY_STOP: @@ -853,6 +885,80 @@ func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) return nil } +// Rpx +func (cts *ServerConn) StartRpxWebById(srpx* ServerRpx, id uint64, data []byte) error { + // pass the initial response to code in server-rpx.go + srpx.start_chan <- data + return nil +} + +func (cts *ServerConn) StopRpxWebById(srpx* ServerRpx, id uint64) error { + srpx.ReqStop() + return nil +} + +func (cts *ServerConn) WroteRpxWebById(srpx* ServerRpx, id uint64, data []byte) error { + var err error + _, err = srpx.pw.Write(data) + if err != nil { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx data(%d) to rpx pipe - %s", id, err.Error()) + srpx.ReqStop() + } + return err +} + +func (cts *ServerConn) EofRpxWebById(srpx* ServerRpx, id uint64) error { + srpx.ReqStop() + return nil +} + +func (cts *ServerConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error { + var ok bool + var rpx* ServerRpx + + cts.rpx_mtx.Lock() + rpx, ok = cts.rpx_map[evt.Id] + if !ok { + cts.rpx_mtx.Unlock() + return fmt.Errorf("unknown rpx id - %v", evt.Id) + } + cts.rpx_mtx.Unlock() + + switch packet_type { + case PACKET_KIND_RPX_START: + return cts.StartRpxWebById(rpx, evt.Id, evt.Data) + + case PACKET_KIND_RPX_STOP: + // stop requested from the server + return cts.StopRpxWebById(rpx, evt.Id) + + case PACKET_KIND_RPX_EOF: + return cts.EofRpxWebById(rpx, evt.Id) + + case PACKET_KIND_RPX_DATA: + return cts.WroteRpxWebById(rpx, evt.Id, evt.Data) + } + + // ignore other packet types + return nil +} + +// Rpx +func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error { + var r *ServerRoute + var ok bool + + cts.route_mtx.Lock() + r, ok = cts.route_map[route_id] + if !ok { + cts.route_mtx.Unlock() + return fmt.Errorf("non-existent route id - %d", route_id) + } + cts.route_mtx.Unlock() + + return r.ReportPacket(pts_id, packet_type, event_data) +} + func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { var pkt *Packet var err error @@ -1055,10 +1161,30 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt) if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpty(%d) from %s - %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr, err.Error()) - } else { + } else { cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpty(%d) from %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr) - } + } + } else { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) + } + case PACKET_KIND_RPX_START: // the client sends the response header using START + fallthrough + case PACKET_KIND_RPX_STOP: + fallthrough + case PACKET_KIND_RPX_EOF: + fallthrough + case PACKET_KIND_RPX_DATA: + var x *Packet_RpxEvt + var ok bool + x, ok = pkt.U.(*Packet_RpxEvt) + if ok { + err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt) + if err != nil { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpx(%d) from %s - %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr, err.Error()) + } else { + cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpx(%d) from %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr) + } } else { cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) } @@ -1066,17 +1192,26 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { } done: - // arrange to break all rpty resources cts.rpty_mtx.Lock() if len(cts.rpty_map) > 0 { var rpty *ServerRpty for _, rpty = range cts.rpty_map { - rpty.ws.Close() + rpty.ReqStop() } } cts.rpty_mtx.Unlock() + // arrange to break all rpx resources + cts.rpx_mtx.Lock() + if len(cts.rpx_map) > 0 { + var rpx *ServerRpx + for _, rpx = range cts.rpx_map { + rpx.ReqStop() + } + } + cts.rpx_mtx.Unlock() + cts.S.log.Write(cts.Sid, LOG_INFO, "RPC stream receiver ended") } @@ -1109,7 +1244,7 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { cts.route_wg.Add(1) // start the loop inside a goroutine so that route_wg counter - // is likely to be greater than 1 what Wait() is called. + // is likely to be greater than 1 when Wait() is called. go func() { waiting_loop: for { @@ -1145,11 +1280,22 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) { func (cts *ServerConn) ReqStop() { if cts.stop_req.CompareAndSwap(false, true) { var r *ServerRoute + var rpty *ServerRpty + var srpx *ServerRpx cts.route_mtx.Lock() for _, r = range cts.route_map { r.ReqStop() } cts.route_mtx.Unlock() + cts.rpty_mtx.Lock() + for _, rpty = range cts.rpty_map { rpty.ReqStop() } + cts.rpty_mtx.Unlock() + + cts.rpx_mtx.Lock() + for _, srpx = range cts.rpx_map { srpx.ReqStop() } + cts.rpx_mtx.Unlock() + + // there is no good way to break a specific connection client to // the grpc server. while the global grpc server is closed in // ReqStop() for Server, the individuation connection is closed @@ -1359,13 +1505,14 @@ func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ type server_http_log_writer struct { svr *Server + depth int } func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) { // the standard http.Server always requires *log.Logger // use this iowriter to create a logger to pass it to the http server. // since this is another log write wrapper, give adjustment value - hlw.svr.log.WriteWithCallDepth("", LOG_INFO, +1, string(p)) + hlw.svr.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p)) return len(p), nil } @@ -1422,13 +1569,13 @@ func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler { WriteEmptyRespHeader(w, status_code) } } - time_taken = time.Now().Sub(start_time) + time_taken = time.Since(start_time) // time.Now().Sub(start_time) if status_code > 0 { if err != nil { - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error()) + s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error()) } else { - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds()) + s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds()) } } }) @@ -1440,7 +1587,7 @@ func (s *Server) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { var status_code int status_code = WriteEmptyRespHeader(w, http.StatusBadRequest) - s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code) + s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code) return } handler.ServeHTTP(w, req) @@ -1454,21 +1601,19 @@ func (s *Server) WrapWebsocketHandler(handler ServerWebsocketHandler) websocket. var start_time time.Time var time_taken time.Duration var req *http.Request - var raw_url_path string req = ws.Request() - raw_url_path = get_raw_url_path(req) - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path) + s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI) start_time = time.Now() status_code, err = handler.ServeWebsocket(ws) - time_taken = time.Now().Sub(start_time) + time_taken = time.Since(start_time) // time.Now().Sub(start_time) if status_code > 0 { if err != nil { - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error()) + s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error()) } else { - s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds()) + s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds()) } } }) @@ -1479,7 +1624,6 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi var l *net.TCPListener var rpcaddr *net.TCPAddr var addr string - var gl *net.TCPListener var i int var hs_log *log.Logger var opts []grpc.ServerOption @@ -1536,7 +1680,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi // --------------------------------------------------------- - hs_log = log.New(&server_http_log_writer{svr: &s}, "", 0) + hs_log = log.New(&server_http_log_writer{svr: &s, depth: +2}, "", 0) // --------------------------------------------------------- @@ -1626,7 +1770,8 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.ctl[i] = &http.Server{ Addr: cfg.CtlAddrs[i], Handler: s.ctl_mux, - TLSConfig: s.Cfg.CtlTls, + // race condition issues without cloning. the http package modifies some fields in the configuration object + TLSConfig: cfg.CtlTls.Clone(), ErrorLog: hs_log, // TODO: more settings } @@ -1638,12 +1783,11 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.rpx_mux.Handle("/", s.WrapHttpHandler(&server_rpx{ S: &s, Id: HS_ID_RPX })) s.rpx = make([]*http.Server, len(cfg.RpxAddrs)) - for i = 0; i < len(cfg.RpxAddrs); i++ { s.rpx[i] = &http.Server{ Addr: cfg.RpxAddrs[i], Handler: s.rpx_mux, - TLSConfig: cfg.RpxTls, + TLSConfig: cfg.RpxTls.Clone(), ErrorLog: hs_log, // TODO: more settings } @@ -1696,7 +1840,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.pxy[i] = &http.Server{ Addr: cfg.PxyAddrs[i], Handler: s.pxy_mux, - TLSConfig: cfg.PxyTls, + TLSConfig: cfg.PxyTls.Clone(), ErrorLog: hs_log, // TODO: more settings } @@ -1746,7 +1890,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.wpx[i] = &http.Server{ Addr: cfg.WpxAddrs[i], Handler: s.wpx_mux, - TLSConfig: cfg.WpxTls, + TLSConfig: cfg.WpxTls.Clone(), ErrorLog: hs_log, } } @@ -1757,11 +1901,11 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi s.stats.peers.Store(0) s.stats.ssh_proxy_sessions.Store(0) s.stats.pty_sessions.Store(0) + s.stats.rpty_sessions.Store(0) return &s, nil oops: - if gl != nil { gl.Close() } for _, l = range s.rpc { l.Close() } s.rpc = make([]*net.TCPListener, 0) return nil, err @@ -1840,16 +1984,10 @@ func (s *Server) RunTask(wg *sync.WaitGroup) { go s.run_grpc_server(idx, &s.rpc_wg) } - // most work is done by in separate goroutines (s.run_grp_server) - // this loop serves as a placeholder to prevent the logic flow from + // most work is done in separate goroutines (s.run_grp_server) + // this read on the stop channel serves as a placeholder to prevent the logic flow from // descening down to s.ReqStop() -task_loop: - for { - select { - case <-s.stop_chan: - break task_loop - } - } + <-s.stop_chan s.ReqStop() @@ -1863,8 +2001,52 @@ task_loop: s.rpc_svr.Stop() } -func (s *Server) RunCtlTask(wg *sync.WaitGroup) { +func (s* Server) run_single_ctl_server(i int, cs *http.Server, wg* sync.WaitGroup) { + var l net.Listener var err error + + defer wg.Done() + + s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, cs.Addr) + + if s.stop_req.Load() == false { + // defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS() + // err = cs.ListenAndServe() + // err = cs.ListenAndServeTLS("", "") + l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) + if err == nil { + if s.stop_req.Load() == false { + var node *list.Element + + s.ctl_addrs_mtx.Lock() + node = s.ctl_addrs.PushBack(l.Addr().(*net.TCPAddr)) + s.ctl_addrs_mtx.Unlock() + + if s.Cfg.CtlTls == nil { + err = cs.Serve(l) + } else { + err = cs.ServeTLS(l, "", "") // s.Cfg.CtlTls must provide a certificate and a key + } + + s.ctl_addrs_mtx.Lock() + s.ctl_addrs.Remove(node) + s.ctl_addrs_mtx.Unlock() + } else { + err = fmt.Errorf("stop requested") + } + l.Close() + } + } else { + err = fmt.Errorf("stop requested") + } + if errors.Is(err, http.ErrServerClosed) { + s.log.Write("", LOG_INFO, "Control channel[%d] ended", i) + } else { + s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error()) + } +} + +func (s *Server) RunCtlTask(wg *sync.WaitGroup) { var ctl *http.Server var idx int var l_wg sync.WaitGroup @@ -1873,54 +2055,55 @@ func (s *Server) RunCtlTask(wg *sync.WaitGroup) { for idx, ctl = range s.ctl { l_wg.Add(1) - go func(i int, cs *http.Server) { - var l net.Listener - - s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, cs.Addr) - - if s.stop_req.Load() == false { - // defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS() - // err = cs.ListenAndServe() - // err = cs.ListenAndServeTLS("", "") - l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) - if err == nil { - if s.stop_req.Load() == false { - var node *list.Element - - s.ctl_addrs_mtx.Lock() - node = s.ctl_addrs.PushBack(l.Addr().(*net.TCPAddr)) - s.ctl_addrs_mtx.Unlock() - - if s.Cfg.CtlTls == nil { - err = cs.Serve(l) - } else { - err = cs.ServeTLS(l, "", "") // s.Cfg.CtlTls must provide a certificate and a key - } - - s.ctl_addrs_mtx.Lock() - s.ctl_addrs.Remove(node) - s.ctl_addrs_mtx.Unlock() - } else { - err = fmt.Errorf("stop requested") - } - l.Close() - } - } else { - err = fmt.Errorf("stop requested") - } - if errors.Is(err, http.ErrServerClosed) { - s.log.Write("", LOG_INFO, "Control channel[%d] ended", i) - } else { - s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error()) - } - l_wg.Done() - }(idx, ctl) + go s.run_single_ctl_server(idx, ctl, &l_wg); } l_wg.Wait() } -func (s *Server) RunRpxTask(wg *sync.WaitGroup) { +func (s *Server) run_single_rpx_server(i int, cs *http.Server, wg* sync.WaitGroup) { + var l net.Listener var err error + + defer wg.Done() + + s.log.Write("", LOG_INFO, "RPX channel[%d] started on %s", i, s.Cfg.RpxAddrs[i]) + + if s.stop_req.Load() == false { + l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) + if err == nil { + if s.stop_req.Load() == false { + var node *list.Element + + s.rpx_addrs_mtx.Lock() + node = s.rpx_addrs.PushBack(l.Addr().(*net.TCPAddr)) + s.rpx_addrs_mtx.Unlock() + + if s.Cfg.RpxTls == nil { // TODO: change this + err = cs.Serve(l) + } else { + err = cs.ServeTLS(l, "", "") // s.Cfg.RpxTls must provide a certificate and a key + } + + s.rpx_addrs_mtx.Lock() + s.rpx_addrs.Remove(node) + s.rpx_addrs_mtx.Unlock() + } else { + err = fmt.Errorf("stop requested") + } + l.Close() + } + } else { + err = fmt.Errorf("stop requested") + } + if errors.Is(err, http.ErrServerClosed) { + s.log.Write("", LOG_INFO, "RPX channel[%d] ended", i) + } else { + s.log.Write("", LOG_ERROR, "RPX channel[%d] error - %s", i, err.Error()) + } + +} + +func (s *Server) RunRpxTask(wg *sync.WaitGroup) { var rpx *http.Server var idx int var l_wg sync.WaitGroup @@ -1929,51 +2112,54 @@ func (s *Server) RunRpxTask(wg *sync.WaitGroup) { for idx, rpx = range s.rpx { l_wg.Add(1) - go func(i int, cs *http.Server) { - var l net.Listener - - s.log.Write("", LOG_INFO, "RPX channel[%d] started on %s", i, s.Cfg.RpxAddrs[i]) - - if s.stop_req.Load() == false { - l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) - if err == nil { - if s.stop_req.Load() == false { - var node *list.Element - - s.rpx_addrs_mtx.Lock() - node = s.rpx_addrs.PushBack(l.Addr().(*net.TCPAddr)) - s.rpx_addrs_mtx.Unlock() - - if s.Cfg.RpxTls == nil { // TODO: change this - err = cs.Serve(l) - } else { - err = cs.ServeTLS(l, "", "") // s.Cfg.RpxTls must provide a certificate and a key - } - - s.rpx_addrs_mtx.Lock() - s.rpx_addrs.Remove(node) - s.rpx_addrs_mtx.Unlock() - } else { - err = fmt.Errorf("stop requested") - } - l.Close() - } - } else { - err = fmt.Errorf("stop requested") - } - if errors.Is(err, http.ErrServerClosed) { - s.log.Write("", LOG_INFO, "RPX channel[%d] ended", i) - } else { - s.log.Write("", LOG_ERROR, "RPX channel[%d] error - %s", i, err.Error()) - } - l_wg.Done() - }(idx, rpx) + go s.run_single_rpx_server(idx, rpx, &l_wg) } l_wg.Wait() } -func (s *Server) RunPxyTask(wg *sync.WaitGroup) { +func (s *Server) run_single_pxy_server(i int, cs *http.Server, wg* sync.WaitGroup) { + var l net.Listener var err error + + defer wg.Done() + + s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.Cfg.PxyAddrs[i]) + + if s.stop_req.Load() == false { + l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) + if err == nil { + if s.stop_req.Load() == false { + var node *list.Element + + s.pxy_addrs_mtx.Lock() + node = s.pxy_addrs.PushBack(l.Addr().(*net.TCPAddr)) + s.pxy_addrs_mtx.Unlock() + + if s.Cfg.PxyTls == nil { // TODO: change this + err = cs.Serve(l) + } else { + err = cs.ServeTLS(l, "", "") // s.Cfg.PxyTls must provide a certificate and a key + } + + s.pxy_addrs_mtx.Lock() + s.pxy_addrs.Remove(node) + s.pxy_addrs_mtx.Unlock() + } else { + err = fmt.Errorf("stop requested") + } + l.Close() + } + } else { + err = fmt.Errorf("stop requested") + } + if errors.Is(err, http.ErrServerClosed) { + s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i) + } else { + s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error()) + } +} + +func (s *Server) RunPxyTask(wg *sync.WaitGroup) { var pxy *http.Server var idx int var l_wg sync.WaitGroup @@ -1982,51 +2168,55 @@ func (s *Server) RunPxyTask(wg *sync.WaitGroup) { for idx, pxy = range s.pxy { l_wg.Add(1) - go func(i int, cs *http.Server) { - var l net.Listener - - s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.Cfg.PxyAddrs[i]) - - if s.stop_req.Load() == false { - l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) - if err == nil { - if s.stop_req.Load() == false { - var node *list.Element - - s.pxy_addrs_mtx.Lock() - node = s.pxy_addrs.PushBack(l.Addr().(*net.TCPAddr)) - s.pxy_addrs_mtx.Unlock() - - if s.Cfg.PxyTls == nil { // TODO: change this - err = cs.Serve(l) - } else { - err = cs.ServeTLS(l, "", "") // s.Cfg.PxyTls must provide a certificate and a key - } - - s.pxy_addrs_mtx.Lock() - s.pxy_addrs.Remove(node) - s.pxy_addrs_mtx.Unlock() - } else { - err = fmt.Errorf("stop requested") - } - l.Close() - } - } else { - err = fmt.Errorf("stop requested") - } - if errors.Is(err, http.ErrServerClosed) { - s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i) - } else { - s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error()) - } - l_wg.Done() - }(idx, pxy) + go s.run_single_pxy_server(idx, pxy, &l_wg); } l_wg.Wait() } -func (s *Server) RunWpxTask(wg *sync.WaitGroup) { +func (s *Server) run_single_wpx_server(i int, cs *http.Server, wg* sync.WaitGroup) { + var l net.Listener var err error + + defer wg.Done() + + s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.Cfg.WpxAddrs[i]) + + if s.stop_req.Load() == false { + l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) + if err == nil { + if s.stop_req.Load() == false { + var node *list.Element + + s.wpx_addrs_mtx.Lock() + node = s.wpx_addrs.PushBack(l.Addr().(*net.TCPAddr)) + s.wpx_addrs_mtx.Unlock() + + if s.Cfg.WpxTls == nil { + err = cs.Serve(l) + } else { + err = cs.ServeTLS(l, "", "") // s.Cfg.WpxTls must provide a certificate and a key + } + + s.wpx_addrs_mtx.Lock() + s.wpx_addrs.Remove(node) + s.wpx_addrs_mtx.Unlock() + } else { + err = fmt.Errorf("stop requested") + } + l.Close() + } + } else { + err = fmt.Errorf("stop requested") + } + if errors.Is(err, http.ErrServerClosed) { + s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i) + } else { + s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error()) + } + +} + +func (s *Server) RunWpxTask(wg *sync.WaitGroup) { var wpx *http.Server var idx int var l_wg sync.WaitGroup @@ -2035,45 +2225,7 @@ func (s *Server) RunWpxTask(wg *sync.WaitGroup) { for idx, wpx = range s.wpx { l_wg.Add(1) - go func(i int, cs *http.Server) { - var l net.Listener - - s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.Cfg.WpxAddrs[i]) - - if s.stop_req.Load() == false { - l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr) - if err == nil { - if s.stop_req.Load() == false { - var node *list.Element - - s.wpx_addrs_mtx.Lock() - node = s.wpx_addrs.PushBack(l.Addr().(*net.TCPAddr)) - s.wpx_addrs_mtx.Unlock() - - if s.Cfg.WpxTls == nil { - err = cs.Serve(l) - } else { - err = cs.ServeTLS(l, "", "") // s.Cfg.WpxTls must provide a certificate and a key - } - - s.wpx_addrs_mtx.Lock() - s.wpx_addrs.Remove(node) - s.wpx_addrs_mtx.Unlock() - } else { - err = fmt.Errorf("stop requested") - } - l.Close() - } - } else { - err = fmt.Errorf("stop requested") - } - if errors.Is(err, http.ErrServerClosed) { - s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i) - } else { - s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error()) - } - l_wg.Done() - }(idx, wpx) + go s.run_single_wpx_server(idx, wpx, &l_wg) } l_wg.Wait() } @@ -2141,6 +2293,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p cts.rpty_map = make(ServerRptyMap) cts.rpty_map_by_ws = make(ServerRptyMapByWs) + cts.rpx_map = make(ServerRpxMap) s.cts_mtx.Lock()