added code for rpx handling

This commit is contained in:
2025-08-19 20:20:18 +09:00
parent 31a4223aab
commit 10c139e837
19 changed files with 1518 additions and 427 deletions

View File

@ -110,6 +110,7 @@ type json_out_client_stats struct {
ClientRoutes int64 `json:"client-routes"`
ClientPeers int64 `json:"client-peers"`
ClientPtySessions int64 `json:"client-pty-sessions"`
ClientRptySessions int64 `json:"client-rpty-sessions"`
}
// ------------------------------------
@ -1138,6 +1139,7 @@ func (ctl *client_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request)
stats.ClientRoutes = c.stats.routes.Load()
stats.ClientPeers = c.stats.peers.Load()
stats.ClientPtySessions = c.stats.pty_sessions.Load()
stats.ClientRptySessions = c.stats.rpty_sessions.Load()
status_code = WriteJsonRespHeader(w, http.StatusOK)
if err = je.Encode(stats); err != nil { goto oops }

View File

@ -11,6 +11,7 @@ type ClientCollector struct {
ClientRoutes *prometheus.Desc
ClientPeers *prometheus.Desc
PtySessions *prometheus.Desc
RptySessions *prometheus.Desc
}
// NewClientCollector returns a new ClientCollector with all prometheus.Desc initialized
@ -52,6 +53,11 @@ func NewClientCollector(client *Client) ClientCollector {
"Number of pty sessions",
nil, nil,
),
RptySessions: prometheus.NewDesc(
prefix + "rpty_sessions",
"Number of rpty sessions",
nil, nil,
),
}
}
@ -61,6 +67,7 @@ func (c ClientCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- c.ClientRoutes
ch <- c.ClientPeers
ch <- c.PtySessions
ch <- c.RptySessions
}
func (c ClientCollector) Collect(ch chan<- prometheus.Metric) {
@ -97,4 +104,10 @@ func (c ClientCollector) Collect(ch chan<- prometheus.Metric) {
prometheus.GaugeValue,
float64(c.client.stats.pty_sessions.Load()),
)
ch <- prometheus.MustNewConstMetric(
c.RptySessions,
prometheus.GaugeValue,
float64(c.client.stats.rpty_sessions.Load()),
)
}

View File

@ -30,6 +30,16 @@ func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error {
for {
n, err = cpc.conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = cpc.route.cts.psc.Send(MakePeerDataPacket(cpc.route.Id, cpc.conn_id, buf[0:n]))
if err2 != nil {
cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_ERROR,
"Failed to write peer(%d,%d,%s,%s) data to server - %s",
cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String(), err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { // i hate checking this condition with strings.Contains()
cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_INFO,
@ -42,14 +52,6 @@ func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error {
}
break
}
err = cpc.route.cts.psc.Send(MakePeerDataPacket(cpc.route.Id, cpc.conn_id, buf[0:n]))
if err != nil {
cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_ERROR,
"Failed to write peer(%d,%d,%s,%s) data to server - %s",
cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String(), err.Error())
break
}
}
cpc.route.cts.psc.Send(MakePeerStoppedPacket(cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String())) // nothing much to do upon failure. no error check here

View File

@ -59,7 +59,7 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
conn_ready = <-conn_ready_chan
if conn_ready { // connected
var poll_fds []unix.PollFd
var buf []byte
var buf [2048]byte
var n int
var err error
@ -69,7 +69,6 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
}
c.stats.pty_sessions.Add(1)
buf = make([]byte, 2048)
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
@ -87,20 +86,21 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
}
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
n, err = out.Read(buf)
n, err = out.Read(buf[:])
if n > 0 {
var err2 error
err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
if err2 != nil {
c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error())
}
break
}
if n > 0 {
err = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
if err != nil {
c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error())
break
}
}
}
}
c.stats.pty_sessions.Add(-1)

615
client.go
View File

@ -1,5 +1,7 @@
package hodu
import "bufio"
import "bytes"
import "container/list"
import "context"
import "crypto/tls"
@ -40,6 +42,7 @@ type ClientRouteMap map[RouteId]*ClientRoute
type ClientPeerConnMap map[PeerId]*ClientPeerConn
type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc
type ClientRptyMap map[uint64]*ClientRpty
type ClientRpxMap map[uint64]*ClientRpx
// --------------------------------------------------------------------
type ClientRouteConfig struct {
@ -84,6 +87,10 @@ type ClientConfig struct {
PeerConnTmout time.Duration
Token string // to send to the server for identification
// default target for rpx
RpxTargetAddr string
RpxTargetTls *tls.Config
}
type ClientEventKind int
@ -122,6 +129,10 @@ type Client struct {
ext_svcs []Service
rpc_tls *tls.Config
rpx_target_addr string
rpx_target_url string
rpx_target_tls *tls.Config
ctl_tls *tls.Config
ctl_addr []string
ctl_prefix string
@ -155,6 +166,7 @@ type Client struct {
routes atomic.Int64
peers atomic.Int64
pty_sessions atomic.Int64
rpty_sessions atomic.Int64
}
pty_user string
@ -176,6 +188,15 @@ type ClientRpty struct {
tty *os.File
}
type ClientRpx struct {
id uint64
pr *io.PipeReader
pw *io.PipeWriter
ctx context.Context
cancel context.CancelFunc
ws_conn net.Conn
}
// client connection to server
type ClientConn struct {
C *Client
@ -209,6 +230,9 @@ type ClientConn struct {
rpty_mtx sync.Mutex
rpty_map ClientRptyMap
rpx_mtx sync.Mutex
rpx_map ClientRpxMap
stop_req atomic.Bool
stop_chan chan bool
@ -294,6 +318,22 @@ func (g *GuardedPacketStreamClient) Context() context.Context {
return g.psc.Context()
}*/
// --------------------------------------------------------------------
func (rpty *ClientRpty) ReqStop() {
rpty.tty.Close()
rpty.cmd.Process.Kill()
}
func (rpx *ClientRpx) ReqStop() {
rpx.pr.Close()
rpx.pw.Close()
rpx.cancel()
if rpx.ws_conn != nil {
rpx.ws_conn.SetDeadline(time.Now()) // to make Read return immediately
}
}
// --------------------------------------------------------------------
func NewClientRoute(cts *ClientConn, id RouteId, static bool, client_peer_addr string, client_peer_name string, server_peer_svc_addr string, server_peer_svc_net string, server_peer_option RouteOption, lifetime time.Duration) *ClientRoute {
var r ClientRoute
@ -397,7 +437,7 @@ func (r *ClientRoute) ExtendLifetime(lifetime time.Duration) error {
r.lifetime_timer.Stop()
r.Lifetime = r.Lifetime + lifetime
expiry = r.LifetimeStart.Add(r.Lifetime)
r.lifetime_timer.Reset(expiry.Sub(time.Now()))
r.lifetime_timer.Reset(time.Until(expiry)) // expiry.Sub(time.Now())
if r.cts.C.route_persister != nil { r.cts.C.route_persister.Save(r.cts, r) }
r.lifetime_mtx.Unlock()
@ -477,10 +517,8 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) {
break waiting_loop
}
} else {
select {
case <-r.stop_chan:
break waiting_loop
}
<-r.stop_chan
break waiting_loop
}
}
@ -547,7 +585,7 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
defer wg.Done()
tmout = time.Duration(r.cts.C.ptc_tmout)
if tmout <= 0 { tmout = 5 * time.Second} // TODO: make this configurable...
if tmout <= 0 { tmout = 5 * time.Second } // TODO: make this configurable...
waitctx, cancel_wait = context.WithTimeout(r.cts.C.Ctx, tmout)
r.ptc_mtx.Lock()
r.ptc_cancel_map[pts_id] = cancel_wait
@ -571,8 +609,8 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
real_conn, ok = conn.(*net.TCPConn)
if !ok {
r.cts.C.log.Write(r.cts.Sid, LOG_ERROR,
"Failed to get connection information to %s for route(%d,%d,%s,%s) - %s",
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr, err.Error())
"Failed to get connection information to %s for route(%d,%d,%s,%s)",
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr)
goto peer_aborted
}
@ -825,6 +863,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
cts.stop_chan = make(chan bool, 8)
cts.ptc_list = list.New()
cts.rpty_map = make(ClientRptyMap)
cts.rpx_map = make(ClientRpxMap)
for i, _ = range cts.cfg.Routes {
// override it to static regardless of the value passed in
@ -833,7 +872,6 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
// the actual connection to the server is established in the main task function
// The cts.conn, cts.hdc, cts.psc fields are left unassigned here.
return &cts
}
@ -1031,7 +1069,8 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error {
func (cts *ClientConn) disconnect_from_server(logmsg bool) {
if cts.conn != nil {
var r *ClientRoute
var crp *ClientRpty
var rpty *ClientRpty
var rpx *ClientRpx
cts.discon_mtx.Lock()
@ -1045,15 +1084,20 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
// arrange to clean up all rpty objects
cts.rpty_mtx.Lock()
for _, crp = range cts.rpty_map {
crp.tty.Close()
crp.cmd.Process.Kill()
for _, rpty = range cts.rpty_map {
rpty.ReqStop()
// the loop in ReadRptyLoop() is supposed to be broken.
// let's not inform the server of this connection.
// the server should clean up itself upon connection error
}
cts.rpty_mtx.Unlock()
cts.rpx_mtx.Lock()
for _, rpx = range cts.rpx_map {
rpx.ReqStop()
}
cts.rpx_mtx.Unlock()
// don't care about double closes when this function is called from both RunTask() and ReqStop()
cts.conn.Close()
@ -1217,7 +1261,7 @@ start_over:
pkt, err = psc.Recv()
if err != nil {
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) {
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
goto reconnect_to_server
} else {
cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr_p, err.Error())
@ -1348,6 +1392,31 @@ start_over:
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
}
case PACKET_KIND_RPX_START:
fallthrough
case PACKET_KIND_RPX_STOP:
fallthrough
case PACKET_KIND_RPX_DATA:
fallthrough
case PACKET_KIND_RPX_EOF:
var x *Packet_RpxEvt
var ok bool
x, ok = pkt.U.(*Packet_RpxEvt)
if ok {
err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR,
"Failed to handle %s event for rpx(%d) from %s - %s",
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p, err.Error())
} else {
cts.C.log.Write(cts.Sid, LOG_DEBUG,
"Handled %s event for rpx(%d) from %s",
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p)
}
} else {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
}
default:
// do nothing. ignore the rest
}
@ -1390,7 +1459,7 @@ reconnect_to_server:
goto wait_for_termination
case <-slpctx.Done():
select {
case <- cts.C.Ctx.Done():
case <-cts.C.Ctx.Done():
// non-blocking check if the parent context of the sleep context is
// terminated too. if so, this is normal termination case.
// this check seem redundant but the go-runtime doesn't seem to guarantee
@ -1421,20 +1490,34 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type
}
// rpty
func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
cts.rpty_mtx.Unlock()
if !ok { crp = nil }
return crp
}
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
var poll_fds []unix.PollFd
var buf []byte
var buf [2048]byte
var n int
var err error
defer wg.Done()
cts.C.stats.rpty_sessions.Add(1)
poll_fds = []unix.PollFd{
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
}
buf = make([]byte, 2048)
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
@ -1452,34 +1535,35 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
}
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
n, err = crp.tty.Read(buf)
n, err = crp.tty.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
}
break
}
if n > 0 {
err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error())
break
}
}
}
}
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
crp.tty.Close() // don't care about multiple closes
crp.cmd.Process.Kill()
crp.ReqStop()
crp.cmd.Wait()
cts.rpty_mtx.Lock()
delete(cts.rpty_map, crp.id)
cts.rpty_mtx.Unlock()
cts.C.stats.rpty_sessions.Add(-1)
}
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
@ -1515,46 +1599,34 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
func (cts *ClientConn) StopRpty(id uint64) error {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Close() // to break ReadRptyLoop()
crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop()
cts.rpty_mtx.Unlock()
crp.ReqStop()
return nil
}
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Write(data)
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
var crp *ClientRpty
var ok bool
var flds []string
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
@ -1566,11 +1638,11 @@ func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
cols, _ = strconv.Atoi(flds[1])
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
}
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
switch packet_type {
case PACKET_KIND_RPTY_START:
return cts.StartRpty(evt.Id, &cts.C.wg)
@ -1589,6 +1661,329 @@ func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent)
return nil
}
// rpx
func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
var crpx *ClientRpx
var ok bool
cts.rpx_mtx.Lock()
crpx, ok = cts.rpx_map[id]
cts.rpx_mtx.Unlock()
if !ok { crpx = nil }
return crpx
}
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
var sc *bufio.Scanner
var line string
var flds []string
var buf [4096]byte
var req_meth string
var req_path string
//var req_proto string
var req *http.Request
var n int
var err error
defer wg.Done()
sc = bufio.NewScanner(bytes.NewReader(data))
sc.Scan()
line = sc.Text()
flds = strings.Fields(line)
if (len(flds) < 3) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line)
goto done
}
// TODO: handle trailers...
req_meth = flds[0]
req_path = flds[1]
//req_proto = flds[2]
// create a request assuming it's a normal http request
req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
for sc.Scan() {
line = sc.Text()
if line == "" { break }
flds = strings.SplitN(line, ":", 2)
if len(flds) == 2 {
req.Header.Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
}
}
err = sc.Err()
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
// websocket
var done_chan chan struct{}
var conn net.Conn
var resp *http.Response
var r *bufio.Reader
if cts.C.rpx_target_tls != nil {
var dialer *tls.Dialer
dialer = &tls.Dialer{
NetDialer: &net.Dialer{},
Config: cts.C.rpx_target_tls,
}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
} else {
var dialer *net.Dialer
dialer = &net.Dialer{}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
}
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
defer conn.Close()
// TODO: make this atomic?
crpx.ws_conn = conn
// write the raw request line and headers as sent by the server.
// for the upgrade request, i assume no payload.
_, err = conn.Write(data)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
r = bufio.NewReader(conn)
resp, err = http.ReadResponse(r, req)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
goto done
}
if resp.StatusCode != http.StatusSwitchingProtocols {
// websock upgrade failed. let the code jump to the done
// label to skip reading from the pipe. the server side
// has the code to ensure no content-length. and the upgrade
// fails, the pipe below will be pending forever as the server
// side doesn't send data and there's no feeding to the pipe.
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
goto done
}
// unlike with the normal request, the actual pipe is not read
// until the initial switching protocol response is received.
wg.Add(1)
done_chan = make(chan struct{}, 5)
go func() {
var buf [4096]byte
var n int
var err error
defer wg.Done()
for {
n, err = crpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = conn.Write(buf[:n])
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) { break }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
done_chan <- struct{}{}
}()
for {
n, err = conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
cts.psc.Send(MakeRpxEofPacket(crpx.id))
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
break
}
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
break
}
}
// wait until the pipe reading(from the server side) goroutine is over
<-done_chan
} else {
var tr *http.Transport
var resp *http.Response
tr = &http.Transport {
TLSClientConfig: cts.C.rpx_target_tls,
}
resp, err = tr.RoundTrip(req)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
goto done
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
goto done
}
for {
n, err = resp.Body.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
break
}
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
}
done:
err = cts.psc.Send(MakeRpxStopPacket(crpx.id))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error())
}
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
crpx.ReqStop()
cts.rpx_mtx.Lock()
delete(cts.rpx_map, crpx.id)
cts.rpx_mtx.Unlock()
}
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
var crpx *ClientRpx
var ok bool
cts.rpx_mtx.Lock()
_, ok = cts.rpx_map[id]
if ok {
cts.rpx_mtx.Unlock()
return fmt.Errorf("multiple start on rpx id %d", id)
}
crpx = &ClientRpx{ id: id }
cts.rpx_map[id] = crpx
// i want the pipe to be created before the goroutine is started
// so that the WriteRpx() can write to the pipe. i protect pipe creation
// and context creation with a mutex
crpx.pr, crpx.pw = io.Pipe()
crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx)
cts.rpx_mtx.Unlock()
wg.Add(1)
go cts.RpxLoop(crpx, data, wg)
return nil
}
func (cts *ClientConn) StopRpx(id uint64) error {
var crpx *ClientRpx
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
crpx.ReqStop()
return nil
}
func (cts *ClientConn) WriteRpx(id uint64, data []byte) error {
var crpx *ClientRpx
var err error
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
// TODO: may have to write it in a goroutine to avoid blocking?
_, err = crpx.pw.Write(data)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error())
return err
}
return nil
}
func (cts *ClientConn) EofRpx(id uint64, data []byte) error {
var crpx *ClientRpx
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
// close the writing end only. leave the reading end untouched
crpx.pw.Close()
return nil
}
func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
switch packet_type {
case PACKET_KIND_RPX_START:
return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg)
case PACKET_KIND_RPX_STOP:
return cts.StopRpx(evt.Id)
case PACKET_KIND_RPX_DATA:
return cts.WriteRpx(evt.Id, evt.Data)
case PACKET_KIND_RPX_EOF:
return cts.EofRpx(evt.Id, evt.Data)
}
// ignore other packet types
return nil
}
// --------------------------------------------------------------------
func (m ClientPeerConnMap) get_sorted_keys() []PeerId {
@ -1631,12 +2026,13 @@ func (m ClientConnMap) get_sorted_keys() []ConnId {
type client_ctl_log_writer struct {
cli *Client
depth int
}
func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
// the standard http.Server always requires *log.Logger
// use this iowriter to create a logger to pass it to the http server.
hlw.cli.log.Write("", LOG_INFO, string(p))
hlw.cli.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p))
return len(p), nil
}
@ -1696,13 +2092,13 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler {
}
// TODO: statistics by status_code and end point types.
time_taken = time.Now().Sub(start_time)
time_taken = time.Since(start_time) //time.Now().Sub(start_time)
if status_code > 0 {
if err != nil {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
} else {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
}
}
})
@ -1714,7 +2110,7 @@ func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
var status_code int
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code)
return
}
handler.ServeHTTP(w, req)
@ -1728,21 +2124,19 @@ func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
var start_time time.Time
var time_taken time.Duration
var req *http.Request
var raw_url_path string
req = ws.Request()
raw_url_path = get_raw_url_path(req)
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path)
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI)
start_time = time.Now()
status_code, err = handler.ServeWebsocket(ws)
time_taken = time.Now().Sub(start_time)
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
if status_code > 0 {
if err != nil {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
} else {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
}
}
})
@ -1770,6 +2164,14 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
c.bulletin = NewBulletin[*ClientEvent](&c, 1024)
c.rpc_tls = cfg.RpcTls
c.rpx_target_addr = cfg.RpxTargetAddr
c.rpx_target_tls = cfg.RpxTargetTls
if c.rpx_target_tls != nil {
c.rpx_target_url = "https://" + c.rpx_target_addr
} else {
c.rpx_target_url = "http://" + c.rpx_target_addr
}
c.ctl_auth = cfg.CtlAuth
c.ctl_tls = cfg.CtlTls
c.ctl_prefix = cfg.CtlPrefix
@ -1843,13 +2245,13 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
c.ctl = make([]*http.Server, len(cfg.CtlAddrs))
copy(c.ctl_addr, cfg.CtlAddrs)
hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0)
hs_log = log.New(&client_ctl_log_writer{cli: &c, depth: 0}, "", 0)
for i = 0; i < len(cfg.CtlAddrs); i++ {
c.ctl[i] = &http.Server{
Addr: cfg.CtlAddrs[i],
Handler: c.ctl_mux,
TLSConfig: c.ctl_tls,
TLSConfig: c.ctl_tls.Clone(),
ErrorLog: hs_log,
// TODO: more settings
}
@ -1859,6 +2261,7 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
c.stats.routes.Store(0)
c.stats.peers.Store(0)
c.stats.pty_sessions.Store(0)
c.stats.rpty_sessions.Store(0)
return &c
}
@ -2171,8 +2574,52 @@ func (c *Client) GetPtyShell() string {
return c.pty_shell
}
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
func (c *Client) run_single_ctl_server(i int, cs *http.Server, wg *sync.WaitGroup) {
var l net.Listener
var err error
defer wg.Done()
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
// by creating the listener explicitly.
// err = cs.ListenAndServe()
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
//so I have to use my own indicator to check if it's been shutdown..
//
if c.stop_req.Load() == false {
// this guard has a flaw in that the stop request can be made
// between the check above and net.Listen() below.
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if c.stop_req.Load() == false {
// check it again to make the guard slightly more stable
// although it's still possible that the stop request is made
// after Listen()
if c.ctl_tls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
}
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
} else {
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
}
}
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
var ctl *http.Server
var idx int
var l_wg sync.WaitGroup
@ -2181,49 +2628,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
for idx, ctl = range c.ctl {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
// by creating the listener explicitly.
// err = cs.ListenAndServe()
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
//so I have to use my own indicator to check if it's been shutdown..
//
if c.stop_req.Load() == false {
// this guard has a flaw in that the stop request can be made
// between the check above and net.Listen() below.
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if c.stop_req.Load() == false {
// check it again to make the guard slightly more stable
// although it's still possible that the stop request is made
// after Listen()
if c.ctl_tls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
}
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
} else {
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, ctl)
go c.run_single_ctl_server(idx, ctl, &l_wg)
}
l_wg.Wait()
}

View File

@ -73,6 +73,12 @@ type RPXServiceConfig struct {
Addrs []string `yaml:"addresses"`
}
type RPXClientTokenConfig struct {
AttrName string `yaml:"attr-name"`
Regex string `yaml:"regex"`
SubmatchIndex int `yaml:"submatch-index"`
}
type PXYServiceConfig struct {
Addrs []string `yaml:"addresses"`
}
@ -127,8 +133,9 @@ type ServerConfig struct {
} `yaml:"ctl"`
RPX struct {
Service RPXServiceConfig `yaml:"service"`
TLS ServerTLSConfig `yaml:"tls"`
Service RPXServiceConfig `yaml:"service"`
TLS ServerTLSConfig `yaml:"tls"`
ClientToken RPXClientTokenConfig `yaml:"client-token"`
} `yaml:"rpx"`
PXY struct {
@ -158,6 +165,12 @@ type ClientConfig struct {
Endpoint RPCEndpointConfig `yaml:"endpoint"`
TLS ClientTLSConfig `yaml:"tls"`
} `yaml:"rpc"`
RPX struct {
Target struct {
Addr string `yaml:"address"`
TLS ClientTLSConfig `yaml:"tls"`
} `yaml:"target"`
}
}
func load_server_config_to(cfgfile string, cfg *ServerConfig) error {

View File

@ -121,7 +121,6 @@ main_loop:
}
}
func (l *AppLogger) Write(id string, level hodu.LogLevel, fmtstr string, args ...interface{}) {
if l.mask & hodu.LogMask(level) == 0 { return }
l.write(id, level, 1, fmtstr, args...)

View File

@ -10,6 +10,7 @@ import "net"
import "os"
import "os/signal"
import "path/filepath"
import "regexp"
import "strings"
import "sync"
import "syscall"
@ -131,6 +132,13 @@ func server_main(ctl_addrs []string, rpc_addrs []string, rpx_addrs[] string, pxy
if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs }
if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs }
config.RpxClientTokenAttrName = cfg.RPX.ClientToken.AttrName
if cfg.RPX.ClientToken.Regex != "" {
config.RpxClientTokenRegex, err = regexp.Compile(cfg.RPX.ClientToken.Regex)
if err != nil { return err }
}
config.RpxClientTokenSubmatchIndex = cfg.RPX.ClientToken.SubmatchIndex
config.CtlCors = cfg.CTL.Service.Cors
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
if err != nil { return err }
@ -282,10 +290,13 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string,
if err != nil { return err }
config.RpcTls, err = make_tls_client_config(&cfg.RPC.TLS)
if err != nil { return err }
config.RpxTargetTls, err = make_tls_client_config(&cfg.RPX.Target.TLS)
if err != nil { return err }
if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs }
if len(config.CtlAddrs) <= 0 { config.CtlAddrs = cfg.CTL.Service.Addrs }
config.RpxTargetAddr = cfg.RPX.Target.Addr
config.CtlPrefix = cfg.CTL.Service.Prefix
config.CtlCors = cfg.CTL.Service.Cors
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)

87
hodu.go
View File

@ -1,5 +1,6 @@
package hodu
import "bytes"
import "crypto/rsa"
import _ "embed"
import "encoding/base64"
@ -8,6 +9,7 @@ import "net"
import "net/http"
import "net/netip"
import "os"
import "regexp"
import "runtime"
import "strings"
import "sync"
@ -79,11 +81,6 @@ type JsonErrmsg struct {
Text string `json:"error-text"`
}
type json_in_cred struct {
Username string `json:"username"`
Password string `json:"password"`
}
type json_in_notice struct {
Text string `json:"text"`
}
@ -224,8 +221,9 @@ func (option RouteOption) String() string {
func dump_call_frame_and_exit(log Logger, req *http.Request, err interface{}) {
var buf []byte
buf = make([]byte, 65536); buf = buf[:min(65536, runtime.Stack(buf, false))]
log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, get_raw_url_path(req), err, string(buf))
buf = make([]byte, 65536)
buf = buf[:min(65536, runtime.Stack(buf, false))]
log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, req.RequestURI, err, string(buf))
log.Close()
os.Exit(99) // fatal error. treat panic() as a fatal runtime error
}
@ -467,12 +465,75 @@ func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) {
return http.StatusOK, ""
}
// ------------------------------------
func get_raw_url_path(req *http.Request) string {
var path string
path = req.URL.Path
if req.URL.RawQuery != "" { path += "?" + req.URL.RawQuery }
return path
func get_http_req_line_and_headers(r *http.Request, force_host bool) []byte {
var buf bytes.Buffer
var name string
var value string
var values []string
var host_found bool
fmt.Fprintf(&buf, "%s %s %s\r\n", r.Method, r.RequestURI, r.Proto)
for name, values = range r.Header {
if strings.EqualFold(name, "Accept-Encoding") { // TODO: make it generic. parameterize it??
// skip Accept-Encoding as the go client side
// doesn't function properly when a certain enconding
// is specified. resp.Body.Read() returned EOF when
// not working
continue
} else if strings.EqualFold(name, "Host") {
host_found = true
}
for _, value = range values {
fmt.Fprintf(&buf, "%s: %s\r\n", name, value)
}
}
if force_host && !host_found && r.Host != "" {
fmt.Fprintf(&buf, "Host: %s\r\n", r.Host)
}
// TODO: host and x-forwarded-for, x-forwarded-proto, etc???
buf.WriteString("\r\n") // End of headers
return buf.Bytes()
}
func get_http_resp_line_and_headers(r *http.Response) []byte {
var buf bytes.Buffer
var name string
var value string
var values []string
fmt.Fprintf(&buf, "%s %s\r\n", r.Proto, r.Status)
for name, values = range r.Header {
for _, value = range values {
fmt.Fprintf(&buf, "%s: %s\r\n", name, value)
}
}
buf.WriteString("\r\n") // End of headers
return buf.Bytes()
}
func get_regex_submatch(re *regexp.Regexp, str string, n int) string {
var idxs []int
var pos int
var start int
var end int
idxs = re.FindStringSubmatchIndex(str)
if idxs == nil { return "" }
pos = n * 2
if pos + 1 >= len(idxs) { return "" }
start, end = idxs[pos], idxs[pos + 1]
if start == -1 || end == -1 {
return ""
}
return str[start:end]
}

View File

@ -101,7 +101,11 @@ const (
PACKET_KIND_RPTY_START PACKET_KIND = 14
PACKET_KIND_RPTY_STOP PACKET_KIND = 15
PACKET_KIND_RPTY_DATA PACKET_KIND = 16
PACKET_KIND_RPTY_SIZE PACKET_KIND = 17
PACKET_KIND_RPTY_SIZE PACKET_KIND = 17 // terminal size
PACKET_KIND_RPX_START PACKET_KIND = 18
PACKET_KIND_RPX_STOP PACKET_KIND = 19
PACKET_KIND_RPX_DATA PACKET_KIND = 20
PACKET_KIND_RPX_EOF PACKET_KIND = 21
)
// Enum value maps for PACKET_KIND.
@ -124,6 +128,10 @@ var (
15: "RPTY_STOP",
16: "RPTY_DATA",
17: "RPTY_SIZE",
18: "RPX_START",
19: "RPX_STOP",
20: "RPX_DATA",
21: "RPX_EOF",
}
PACKET_KIND_value = map[string]int32{
"RESERVED": 0,
@ -143,6 +151,10 @@ var (
"RPTY_STOP": 15,
"RPTY_DATA": 16,
"RPTY_SIZE": 17,
"RPX_START": 18,
"RPX_STOP": 19,
"RPX_DATA": 20,
"RPX_EOF": 21,
}
)
@ -643,6 +655,58 @@ func (x *RptyEvent) GetData() []byte {
return nil
}
type RpxEvent struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id uint64 `protobuf:"varint,1,opt,name=Id,proto3" json:"Id,omitempty"`
Data []byte `protobuf:"bytes,2,opt,name=Data,proto3" json:"Data,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RpxEvent) Reset() {
*x = RpxEvent{}
mi := &file_hodu_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RpxEvent) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RpxEvent) ProtoMessage() {}
func (x *RpxEvent) ProtoReflect() protoreflect.Message {
mi := &file_hodu_proto_msgTypes[8]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RpxEvent.ProtoReflect.Descriptor instead.
func (*RpxEvent) Descriptor() ([]byte, []int) {
return file_hodu_proto_rawDescGZIP(), []int{8}
}
func (x *RpxEvent) GetId() uint64 {
if x != nil {
return x.Id
}
return 0
}
func (x *RpxEvent) GetData() []byte {
if x != nil {
return x.Data
}
return nil
}
type Packet struct {
state protoimpl.MessageState `protogen:"open.v1"`
Kind PACKET_KIND `protobuf:"varint,1,opt,name=Kind,proto3,enum=PACKET_KIND" json:"Kind,omitempty"`
@ -655,6 +719,7 @@ type Packet struct {
// *Packet_ConnErr
// *Packet_ConnNoti
// *Packet_RptyEvt
// *Packet_RpxEvt
U isPacket_U `protobuf_oneof:"U"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
@ -662,7 +727,7 @@ type Packet struct {
func (x *Packet) Reset() {
*x = Packet{}
mi := &file_hodu_proto_msgTypes[8]
mi := &file_hodu_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -674,7 +739,7 @@ func (x *Packet) String() string {
func (*Packet) ProtoMessage() {}
func (x *Packet) ProtoReflect() protoreflect.Message {
mi := &file_hodu_proto_msgTypes[8]
mi := &file_hodu_proto_msgTypes[9]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -687,7 +752,7 @@ func (x *Packet) ProtoReflect() protoreflect.Message {
// Deprecated: Use Packet.ProtoReflect.Descriptor instead.
func (*Packet) Descriptor() ([]byte, []int) {
return file_hodu_proto_rawDescGZIP(), []int{8}
return file_hodu_proto_rawDescGZIP(), []int{9}
}
func (x *Packet) GetKind() PACKET_KIND {
@ -767,6 +832,15 @@ func (x *Packet) GetRptyEvt() *RptyEvent {
return nil
}
func (x *Packet) GetRpxEvt() *RpxEvent {
if x != nil {
if x, ok := x.U.(*Packet_RpxEvt); ok {
return x.RpxEvt
}
}
return nil
}
type isPacket_U interface {
isPacket_U()
}
@ -799,6 +873,10 @@ type Packet_RptyEvt struct {
RptyEvt *RptyEvent `protobuf:"bytes,8,opt,name=RptyEvt,proto3,oneof"`
}
type Packet_RpxEvt struct {
RpxEvt *RpxEvent `protobuf:"bytes,9,opt,name=RpxEvt,proto3,oneof"`
}
func (*Packet_Route) isPacket_U() {}
func (*Packet_Peer) isPacket_U() {}
@ -813,6 +891,8 @@ func (*Packet_ConnNoti) isPacket_U() {}
func (*Packet_RptyEvt) isPacket_U() {}
func (*Packet_RpxEvt) isPacket_U() {}
var File_hodu_proto protoreflect.FileDescriptor
const file_hodu_proto_rawDesc = "" +
@ -850,7 +930,10 @@ const file_hodu_proto_rawDesc = "" +
"\x04Text\x18\x01 \x01(\tR\x04Text\"/\n" +
"\tRptyEvent\x12\x0e\n" +
"\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" +
"\x04Data\x18\x02 \x01(\fR\x04Data\"\xb1\x02\n" +
"\x04Data\x18\x02 \x01(\fR\x04Data\".\n" +
"\bRpxEvent\x12\x0e\n" +
"\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" +
"\x04Data\x18\x02 \x01(\fR\x04Data\"\xd6\x02\n" +
"\x06Packet\x12 \n" +
"\x04Kind\x18\x01 \x01(\x0e2\f.PACKET_KINDR\x04Kind\x12\"\n" +
"\x05Route\x18\x02 \x01(\v2\n" +
@ -862,7 +945,8 @@ const file_hodu_proto_rawDesc = "" +
".ConnErrorH\x00R\aConnErr\x12)\n" +
"\bConnNoti\x18\a \x01(\v2\v.ConnNoticeH\x00R\bConnNoti\x12&\n" +
"\aRptyEvt\x18\b \x01(\v2\n" +
".RptyEventH\x00R\aRptyEvtB\x03\n" +
".RptyEventH\x00R\aRptyEvt\x12#\n" +
"\x06RpxEvt\x18\t \x01(\v2\t.RpxEventH\x00R\x06RpxEvtB\x03\n" +
"\x01U*U\n" +
"\fROUTE_OPTION\x12\n" +
"\n" +
@ -872,7 +956,7 @@ const file_hodu_proto_rawDesc = "" +
"\x04TCP6\x10\x04\x12\b\n" +
"\x04HTTP\x10\b\x12\t\n" +
"\x05HTTPS\x10\x10\x12\a\n" +
"\x03SSH\x10 *\xa2\x02\n" +
"\x03SSH\x10 *\xda\x02\n" +
"\vPACKET_KIND\x12\f\n" +
"\bRESERVED\x10\x00\x12\x0f\n" +
"\vROUTE_START\x10\x01\x12\x0e\n" +
@ -893,7 +977,11 @@ const file_hodu_proto_rawDesc = "" +
"RPTY_START\x10\x0e\x12\r\n" +
"\tRPTY_STOP\x10\x0f\x12\r\n" +
"\tRPTY_DATA\x10\x10\x12\r\n" +
"\tRPTY_SIZE\x10\x112I\n" +
"\tRPTY_SIZE\x10\x11\x12\r\n" +
"\tRPX_START\x10\x12\x12\f\n" +
"\bRPX_STOP\x10\x13\x12\f\n" +
"\bRPX_DATA\x10\x14\x12\v\n" +
"\aRPX_EOF\x10\x152I\n" +
"\x04Hodu\x12\x19\n" +
"\aGetSeed\x12\x05.Seed\x1a\x05.Seed\"\x00\x12&\n" +
"\fPacketStream\x12\a.Packet\x1a\a.Packet\"\x00(\x010\x01B\bZ\x06./hodub\x06proto3"
@ -911,7 +999,7 @@ func file_hodu_proto_rawDescGZIP() []byte {
}
var file_hodu_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_hodu_proto_msgTypes = make([]protoimpl.MessageInfo, 9)
var file_hodu_proto_msgTypes = make([]protoimpl.MessageInfo, 10)
var file_hodu_proto_goTypes = []any{
(ROUTE_OPTION)(0), // 0: ROUTE_OPTION
(PACKET_KIND)(0), // 1: PACKET_KIND
@ -923,7 +1011,8 @@ var file_hodu_proto_goTypes = []any{
(*ConnError)(nil), // 7: ConnError
(*ConnNotice)(nil), // 8: ConnNotice
(*RptyEvent)(nil), // 9: RptyEvent
(*Packet)(nil), // 10: Packet
(*RpxEvent)(nil), // 10: RpxEvent
(*Packet)(nil), // 11: Packet
}
var file_hodu_proto_depIdxs = []int32{
1, // 0: Packet.Kind:type_name -> PACKET_KIND
@ -934,15 +1023,16 @@ var file_hodu_proto_depIdxs = []int32{
7, // 5: Packet.ConnErr:type_name -> ConnError
8, // 6: Packet.ConnNoti:type_name -> ConnNotice
9, // 7: Packet.RptyEvt:type_name -> RptyEvent
2, // 8: Hodu.GetSeed:input_type -> Seed
10, // 9: Hodu.PacketStream:input_type -> Packet
2, // 10: Hodu.GetSeed:output_type -> Seed
10, // 11: Hodu.PacketStream:output_type -> Packet
10, // [10:12] is the sub-list for method output_type
8, // [8:10] is the sub-list for method input_type
8, // [8:8] is the sub-list for extension type_name
8, // [8:8] is the sub-list for extension extendee
0, // [0:8] is the sub-list for field type_name
10, // 8: Packet.RpxEvt:type_name -> RpxEvent
2, // 9: Hodu.GetSeed:input_type -> Seed
11, // 10: Hodu.PacketStream:input_type -> Packet
2, // 11: Hodu.GetSeed:output_type -> Seed
11, // 12: Hodu.PacketStream:output_type -> Packet
11, // [11:13] is the sub-list for method output_type
9, // [9:11] is the sub-list for method input_type
9, // [9:9] is the sub-list for extension type_name
9, // [9:9] is the sub-list for extension extendee
0, // [0:9] is the sub-list for field type_name
}
func init() { file_hodu_proto_init() }
@ -950,7 +1040,7 @@ func file_hodu_proto_init() {
if File_hodu_proto != nil {
return
}
file_hodu_proto_msgTypes[8].OneofWrappers = []any{
file_hodu_proto_msgTypes[9].OneofWrappers = []any{
(*Packet_Route)(nil),
(*Packet_Peer)(nil),
(*Packet_Data)(nil),
@ -958,6 +1048,7 @@ func file_hodu_proto_init() {
(*Packet_ConnErr)(nil),
(*Packet_ConnNoti)(nil),
(*Packet_RptyEvt)(nil),
(*Packet_RpxEvt)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -965,7 +1056,7 @@ func file_hodu_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_hodu_proto_rawDesc), len(file_hodu_proto_rawDesc)),
NumEnums: 2,
NumMessages: 9,
NumMessages: 10,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -85,6 +85,11 @@ message RptyEvent {
bytes Data = 2;
};
message RpxEvent {
uint64 Id = 1;
bytes Data = 2;
};
enum PACKET_KIND {
RESERVED = 0; // not used
ROUTE_START = 1;
@ -103,7 +108,12 @@ enum PACKET_KIND {
RPTY_START = 14;
RPTY_STOP = 15;
RPTY_DATA = 16;
RPTY_SIZE = 17;
RPTY_SIZE = 17; // terminal size
RPX_START = 18;
RPX_STOP = 19;
RPX_DATA = 20;
RPX_EOF = 21;
};
message Packet {
@ -117,5 +127,6 @@ message Packet {
ConnError ConnErr = 6;
ConnNotice ConnNoti = 7;
RptyEvent RptyEvt = 8;
RpxEvent RpxEvt = 9;
};
}

View File

@ -80,6 +80,7 @@ func MakeRptyStartPacket(id uint64) *Packet {
}
func MakeRptyStopPacket(id uint64, msg string) *Packet {
// the rpty stop conveys an error/info message
return &Packet{Kind: PACKET_KIND_RPTY_STOP, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: []byte(msg)}}}
}
@ -90,3 +91,21 @@ func MakeRptyDataPacket(id uint64, data []byte) *Packet {
func MakeRptySizePacket(id uint64, data []byte) *Packet {
return &Packet{Kind: PACKET_KIND_RPTY_SIZE, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}}
}
func MakeRpxStartPacket(id uint64, hdr_part []byte) *Packet {
// the rpx start conveys the data unlike other Start packets...
return &Packet{Kind: PACKET_KIND_RPX_START, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id, Data: hdr_part}}}
}
func MakeRpxStopPacket(id uint64) *Packet {
// the rpx start conveys the data unlike other Start packets...
return &Packet{Kind: PACKET_KIND_RPX_STOP, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id}}}
}
func MakeRpxDataPacket(id uint64, data_part []byte) *Packet {
return &Packet{Kind: PACKET_KIND_RPX_DATA, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id, Data: data_part}}}
}
func MakeRpxEofPacket(id uint64) *Packet {
return &Packet{Kind: PACKET_KIND_RPX_EOF, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id}}}
}

View File

@ -76,7 +76,8 @@ type json_out_server_stats struct {
ServerPeers int64 `json:"server-peers"`
SshProxySessions int64 `json:"pxy-ssh-sessions"`
ServerPtySessions int64 `json:"server-pty-sessions"`
ServerPtySessions int64 `json:"server-pty-sessions"`
ServerRptySessions int64 `json:"server-rpty-sessions"`
}
// this is a more specialized variant of json_in_notice
@ -921,6 +922,7 @@ func (ctl *server_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request)
stats.ServerPeers = s.stats.peers.Load()
stats.SshProxySessions = s.stats.ssh_proxy_sessions.Load()
stats.ServerPtySessions = s.stats.pty_sessions.Load()
stats.ServerRptySessions = s.stats.rpty_sessions.Load()
status_code = WriteJsonRespHeader(w, http.StatusOK)
if err = je.Encode(stats); err != nil { goto oops }

View File

@ -12,6 +12,7 @@ type ServerCollector struct {
ServerPeers *prometheus.Desc
SshProxySessions *prometheus.Desc
PtySessions *prometheus.Desc
RptySessions *prometheus.Desc
}
// NewServerCollector returns a new ServerCollector with all prometheus.Desc initialized
@ -58,6 +59,11 @@ func NewServerCollector(server *Server) ServerCollector {
"Number of pty session",
nil, nil,
),
RptySessions: prometheus.NewDesc(
prefix + "rpty_sessions",
"Number of rpty session",
nil, nil,
),
}
}
@ -68,6 +74,7 @@ func (c ServerCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- c.ServerPeers
ch <- c.SshProxySessions
ch <- c.PtySessions
ch <- c.RptySessions
}
func (c ServerCollector) Collect(ch chan<- prometheus.Metric) {
@ -110,4 +117,10 @@ func (c ServerCollector) Collect(ch chan<- prometheus.Metric) {
prometheus.GaugeValue,
float64(c.server.stats.pty_sessions.Load()),
)
ch <- prometheus.MustNewConstMetric(
c.RptySessions,
prometheus.GaugeValue,
float64(c.server.stats.rpty_sessions.Load()),
)
}

View File

@ -102,6 +102,16 @@ wait_for_started:
for {
n, err = spc.conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = pss.Send(MakePeerDataPacket(spc.route.Id, spc.conn_id, buf[:n]))
if err2 != nil {
spc.route.Cts.S.log.Write(spc.route.Cts.Sid, LOG_ERROR,
"Failed to send data from peer(%d,%d,%s,%s) to client - %s",
spc.route.Id, spc.conn_id, conn_raddr, conn_laddr, err2.Error())
goto done
}
}
if err != nil {
if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { // i don't like this way to check this error.
err = pss.Send(MakePeerEofPacket(spc.route.Id, spc.conn_id))
@ -119,14 +129,6 @@ wait_for_started:
goto done
}
}
err = pss.Send(MakePeerDataPacket(spc.route.Id, spc.conn_id, buf[:n]))
if err != nil {
spc.route.Cts.S.log.Write(spc.route.Cts.Sid, LOG_ERROR,
"Failed to send data from peer(%d,%d,%s,%s) to client - %s",
spc.route.Id, spc.conn_id, conn_raddr, conn_laddr, err.Error())
goto done
}
}
wait_for_stopped:

View File

@ -67,7 +67,7 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
conn_ready = <-conn_ready_chan
if conn_ready { // connected
var poll_fds []unix.PollFd
var buf []byte
var buf [2048]byte
var n int
var err error
@ -76,7 +76,6 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
}
s.stats.pty_sessions.Add(1)
buf = make([]byte, 2048)
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
@ -94,20 +93,21 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
}
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
n, err = out.Read(buf)
n, err = out.Read(buf[:])
if n > 0 {
var err2 error
err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
if err2 != nil {
s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error())
}
break
}
if n > 0 {
err = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
if err != nil {
s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error())
break
}
}
}
}
s.stats.pty_sessions.Add(-1)

View File

@ -375,7 +375,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ
}
proxy_url = pxy.req_to_proxy_url(req, pi)
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, get_raw_url_path(req), proxy_url)
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.RequestURI, proxy_url)
proxy_req, err = http.NewRequestWithContext(s.Ctx, req.Method, proxy_url.String(), req.Body)
if err != nil {
@ -401,7 +401,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ
} else {
status_code = resp.StatusCode
if upgrade_required && resp.StatusCode == http.StatusSwitchingProtocols {
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.RequestURI, status_code)
err = pxy.serve_upgraded(w, req, resp)
if err != nil { goto oops }
return 0, nil// print the log mesage before calling serve_upgraded() and exit here
@ -426,7 +426,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ
_, err = io.Copy(w, resp_body)
if err != nil {
s.log.Write(pxy.Id, LOG_WARN, "[%s] %s %s %s", req.RemoteAddr, req.Method, get_raw_url_path(req), err.Error())
s.log.Write(pxy.Id, LOG_WARN, "[%s] %s %s %s", req.RemoteAddr, req.Method, req.RequestURI, err.Error())
}
// TODO: handle trailers
@ -689,27 +689,27 @@ func (pxy *server_pxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
conn_ready = <-conn_ready_chan
if conn_ready { // connected
var buf []byte
var buf [2048]byte
var n int
var err error
s.stats.ssh_proxy_sessions.Add(1)
buf = make([]byte, 2048)
for {
n, err = out.Read(buf)
n, err = out.Read(buf[:])
if n > 0 {
var err2 error
err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
if err2 != nil {
s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to read from SSH stdout - %s", req.RemoteAddr, err.Error())
}
break
}
if n > 0 {
err = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
if err != nil {
s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error())
break
}
}
}
s.stats.ssh_proxy_sessions.Add(-1)
}

View File

@ -1,6 +1,15 @@
package hodu
import "bufio"
import "bytes"
import "errors"
import "fmt"
import "io"
import "net"
import "net/http"
import "strconv"
import "strings"
import "sync"
type server_rpx struct {
S *Server
@ -8,28 +17,313 @@ type server_rpx struct {
}
// ------------------------------------
func (pxy *server_rpx) Identity() string {
return pxy.Id
func (rpx *server_rpx) Identity() string {
return rpx.Id
}
func (pxy *server_rpx) Cors(req *http.Request) bool {
func (rpx *server_rpx) Cors(req *http.Request) bool {
return false
}
func (pxy *server_rpx) Authenticate(req *http.Request) (int, string) {
func (rpx *server_rpx) Authenticate(req *http.Request) (int, string) {
return http.StatusOK, ""
}
func (pxy *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
var status_code int
// var err error
func (rpx *server_rpx) get_client_token(req *http.Request) string {
var val string
status_code = http.StatusOK
// TODO: enhance this client token extraction logic with some expression language?
val = req.Header.Get(rpx.S.Cfg.RpxClientTokenAttrName)
if val == "" { val = req.Host }
//done:
return status_code, nil
if rpx.S.Cfg.RpxClientTokenRegex != nil {
val = get_regex_submatch(rpx.S.Cfg.RpxClientTokenRegex, val, rpx.S.Cfg.RpxClientTokenSubmatchIndex)
}
//oops:
// return status_code, err
return val
}
func (rpx* server_rpx) handle_header_data(rpx_id uint64, data []byte, w http.ResponseWriter) (int, error) {
var sc *bufio.Scanner
var line string
var flds []string
var status_code int
var err error
sc = bufio.NewScanner(bytes.NewReader(data))
sc.Scan()
line = sc.Text()
flds = strings.Fields(line)
if (len(flds) < 2) { // i care about the status code..
return http.StatusBadGateway, fmt.Errorf("invalid response status for rpx(%d) - %s", rpx_id, line)
}
status_code, err = strconv.Atoi(flds[1])
if err != nil {
return http.StatusBadGateway, fmt.Errorf("invalid response code for rpx(%d) - %s", rpx_id, err.Error())
}
for sc.Scan() {
line = sc.Text()
if line == "" { break }
flds = strings.SplitN(line, ":", 2)
if len(flds) == 2 {
w.Header().Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
}
}
err = sc.Err()
if err != nil {
return http.StatusBadGateway, fmt.Errorf("failed to parse response for rpx(%d) - %s", rpx_id, err.Error())
}
w.WriteHeader(status_code)
return status_code, nil
}
func (rpx *server_rpx) handle_response(srpx *ServerRpx, req *http.Request, w http.ResponseWriter, ws_upgrade bool, wg *sync.WaitGroup) {
var start_resp []byte
var status_code int
var buf [4096]byte
var n int
var wr io.Writer
var wrote_br_chan bool
var err error
defer wg.Done()
select {
case start_resp = <- srpx.start_chan:
// received the header. ready to proceed to the body
// do nothing. just continue
status_code, err = rpx.handle_header_data(srpx.id, start_resp, w)
if err != nil { goto done }
case <- srpx.done_chan:
err = fmt.Errorf("rpx(%d) terminated before receiving header", srpx.id)
status_code = http.StatusBadGateway
goto done
case <- req.Context().Done():
err = fmt.Errorf("rpx(%d) terminated before receiving header - %s", srpx.id, req.Context().Err().Error())
status_code = http.StatusBadGateway
goto done
// no default. block
}
if ws_upgrade && status_code == http.StatusSwitchingProtocols {
var hijk http.Hijacker
var conn net.Conn
var ok bool
hijk, ok = w.(http.Hijacker)
if !ok {
err = fmt.Errorf("failed to upgrade rpx(%d) - not a hijacker", srpx.id)
status_code = http.StatusInternalServerError
goto done
}
conn, _, err = hijk.Hijack()
if err != nil {
err = fmt.Errorf("failed to upgrade rpx(%d) - %s", srpx.id, err.Error())
status_code = http.StatusInternalServerError
goto done
}
// websocket upgrade is successful
srpx.br = conn
srpx.br_chan <- true // inform another goroutine that the protocol switching is completed.
wrote_br_chan = true
wr = conn
} else {
if ws_upgrade {
srpx.br_chan <- false
wrote_br_chan = true
} // indicate upgrade failure
wr = w
}
for {
n, err = srpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = wr.Write(buf[:n])
if err2 != nil {
err = err2
status_code = http.StatusInternalServerError
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
} else {
status_code = http.StatusInternalServerError
}
break
}
}
done:
// just send another in case the code got jump into this part for an error
// may not be consumed but the channel is large enough for redundant data
srpx.resp_status_code = status_code
srpx.resp_error = err
if ws_upgrade && !wrote_br_chan {
srpx.br_chan <- false
}
}
func (rpx *server_rpx) alloc_server_rpx(cts *ServerConn, req *http.Request) (*ServerRpx, error) {
var srpx *ServerRpx
var start_id uint64
var assigned_id uint64
var ok bool
cts.rpx_mtx.Lock()
start_id = cts.rpx_next_id
for {
_, ok = cts.rpx_map[cts.rpx_next_id]
if !ok {
assigned_id = cts.rpx_next_id
cts.rpx_next_id++
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
break
}
cts.rpx_next_id++
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
if cts.rpx_next_id == start_id {
// unlikely to happen but it cycled through the whole range.
cts.rpx_mtx.Unlock()
return nil, fmt.Errorf("failed to assign id")
}
}
srpx = &ServerRpx{
id: assigned_id,
start_chan: make(chan []byte, 5),
done_chan: make(chan bool, 5),
br_chan: make(chan bool, 5),
}
srpx.br = req.Body
srpx.pr, srpx.pw = io.Pipe()
cts.rpx_map[assigned_id] = srpx
cts.rpx_mtx.Unlock()
return srpx, nil
}
func (rpx *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
var s *Server
var client_token string
var start_sent bool
var cts *ServerConn
var status_code int
var srpx *ServerRpx
var ws_upgrade bool
var buf [4096]byte
var wg sync.WaitGroup
var err error
s = rpx.S
client_token = rpx.get_client_token(req)
cts = s.FindServerConnByClientToken(client_token)
if cts == nil {
status_code = WriteEmptyRespHeader(w, http.StatusNotFound)
err = fmt.Errorf("unknown client token - %s", client_token)
goto oops
}
srpx, err = rpx.alloc_server_rpx(cts, req)
if err != nil {
status_code = WriteEmptyRespHeader(w, http.StatusServiceUnavailable)
err = fmt.Errorf("unable to allocate rpx - %s", err.Error())
goto oops
}
// arrange to clear the rpx_map entry when this function exits
defer func() {
cts.rpx_mtx.Lock()
delete(cts.rpx_map, srpx.id)
cts.rpx_mtx.Unlock()
}()
ws_upgrade = strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade");
if ws_upgrade && req.ContentLength > 0 {
// while other webservers are ok with upgrade request with body payload,
// this program rejects such a request for impelementation limitation as
// it's not dealing with a raw byte but is using the standard web server handler.
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
err = fmt.Errorf("failed to assign id")
goto oops
}
err = cts.pss.Send(MakeRpxStartPacket(srpx.id, get_http_req_line_and_headers(req, true)))
if err != nil {
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
goto oops
}
start_sent = true
wg.Add(1)
go rpx.handle_response(srpx, req, w, ws_upgrade, &wg)
if ws_upgrade {
// wait until the protocol switching is done in rpx.handle_response()
var upgraded bool
upgraded = <- srpx.br_chan
if upgraded {
// arrange to close the hijacked connection inside rpx.handle_response()
defer srpx.br.Close()
}
}
for {
var n int
n, err = srpx.br.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.pss.Send(MakeRpxDataPacket(srpx.id, buf[:n]))
if err2 != nil {
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
goto oops
}
}
if err != nil {
if errors.Is(err, io.EOF) {
err = cts.pss.Send(MakeRpxEofPacket(srpx.id))
if err != nil {
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
goto oops
}
break
}
status_code = WriteEmptyRespHeader(w, http.StatusInternalServerError)
goto oops
}
}
wg.Wait()
if srpx.resp_error != nil {
status_code = WriteEmptyRespHeader(w, srpx.resp_status_code)
err = srpx.resp_error
goto oops
}
select {
case <- srpx.done_chan:
// anything to do?
case <- req.Context().Done():
// anything to do?
// no default. block
}
cts.pss.Send(MakeRpxStopPacket(srpx.id))
return srpx.resp_status_code, nil
oops:
if srpx != nil && start_sent { cts.pss.Send(MakeRpxStopPacket(srpx.id)) }
return status_code, err
}

597
server.go
View File

@ -10,6 +10,7 @@ import "log"
import "net"
import "net/http"
import "net/netip"
import "regexp"
import "slices"
import "strconv"
import "strings"
@ -32,8 +33,8 @@ const CTS_LIMIT int = 16384
type PortId uint16
const PORT_ID_MARKER string = "_"
const HS_ID_CTL string = "ctl"
const HS_ID_RPX string = "pxy"
const HS_ID_PXY string = "rpx"
const HS_ID_RPX string = "rpx"
const HS_ID_PXY string = "pxy"
const HS_ID_WPX string = "wpx"
type ServerConnMapByAddr map[net.Addr]*ServerConn
@ -45,6 +46,7 @@ type ServerSvcPortMap map[PortId]ConnRouteId
type ServerRptyMap map[uint64]*ServerRpty
type ServerRptyMapByWs map[*websocket.Conn]*ServerRpty
type ServerRpxMap map[uint64]*ServerRpx
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
@ -67,6 +69,9 @@ type ServerConfig struct {
RpxAddrs []string
RpxTls *tls.Config
RpxClientTokenAttrName string
RpxClientTokenRegex *regexp.Regexp
RpxClientTokenSubmatchIndex int
PxyAddrs []string
PxyTls *tls.Config
@ -166,12 +171,12 @@ type Server struct {
peers atomic.Int64
ssh_proxy_sessions atomic.Int64
pty_sessions atomic.Int64
rpty_sessions atomic.Int64
}
wpx_resp_tf ServerWpxResponseTransformer
wpx_foreign_port_proxy_maker ServerWpxForeignPortProxyMaker
pty_user string
pty_shell string
xterm_html string
@ -202,6 +207,10 @@ type ServerConn struct {
rpty_map ServerRptyMap
rpty_map_by_ws ServerRptyMapByWs
rpx_next_id uint64
rpx_mtx sync.Mutex
rpx_map ServerRpxMap
wg sync.WaitGroup
stop_req atomic.Bool
stop_chan chan bool
@ -236,6 +245,20 @@ type ServerRpty struct {
ws *websocket.Conn
}
type ServerRpx struct {
id uint64
pr *io.PipeReader
pw *io.PipeWriter
br io.ReadCloser // body reader
start_chan chan []byte
done_chan chan bool
br_chan chan bool
resp_status_code int
resp_error error
resp_done_chan chan bool
}
type GuardedPacketStreamServer struct {
mtx sync.Mutex
//pss Hodu_PacketStreamServer
@ -265,6 +288,16 @@ func (g *GuardedPacketStreamServer) Context() context.Context {
// ------------------------------------
func (rpty *ServerRpty) ReqStop() {
rpty.ws.Close()
}
func (rpx *ServerRpx) ReqStop() {
rpx.done_chan <- true
rpx.pw.Close()
}
// ------------------------------------
func NewServerRoute(cts *ServerConn, id RouteId, option RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
var r ServerRoute
var l *net.TCPListener
@ -658,6 +691,7 @@ func (cts *ServerConn) ReqStopAllServerRoutes() {
cts.route_mtx.Unlock()
}
// Rpty
func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) {
var ok bool
var start_id uint64
@ -707,11 +741,11 @@ func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) {
return nil , err
}
cts.S.stats.rpty_sessions.Add(1)
return rpty, nil
}
func (cts *ServerConn) StopRpty(ws *websocket.Conn) error {
// called by the websocket handler.
var rpty *ServerRpty
var id uint64
@ -721,7 +755,9 @@ func (cts *ServerConn) StopRpty(ws *websocket.Conn) error {
cts.rpty_mtx.Lock()
rpty, ok = cts.rpty_map_by_ws[ws]
if !ok {
return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr())
cts.rpty_mtx.Unlock()
cts.S.log.Write(cts.Sid, LOG_ERROR, "Unknown websocket connection for rpty - websocket %v", ws.RemoteAddr())
return fmt.Errorf("unknown websocket connection for rpty - %v", ws.RemoteAddr())
}
id = rpty.id
@ -730,14 +766,24 @@ func (cts *ServerConn) StopRpty(ws *websocket.Conn) error {
// send the stop request to the client side
err = cts.pss.Send(MakeRptyStopPacket(id, ""))
if err != nil {
return fmt.Errorf("unable to send stop rpty request to client - %s", err.Error())
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to send RPTY_STOP(%d) for server %s websocket %v - %s", id, cts.RemoteAddr, ws.RemoteAddr(), err.Error())
// carry on
}
// delete the rpty entry from the maps as the websocket
// handler is ending
cts.rpty_mtx.Lock()
delete(cts.rpty_map, id)
delete(cts.rpty_map_by_ws, ws)
cts.rpty_mtx.Unlock()
cts.S.stats.rpty_sessions.Add(-1)
cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for server %s websocket %vs", id, cts.RemoteAddr, ws.RemoteAddr())
return nil
}
func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error {
// called this when the stop requested comes from the client
// call this when the stop requested comes from the client.
// abort the websocket side.
var rpty *ServerRpty
@ -746,11 +792,12 @@ func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error {
cts.rpty_mtx.Lock()
rpty, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown rpty id %d", id)
}
rpty.ws.Close()
cts.rpty_mtx.Unlock()
rpty.ReqStop()
cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for %s - %s", id, cts.RemoteAddr, msg)
return nil
}
@ -764,6 +811,7 @@ func (cts *ServerConn) WriteRpty(ws *websocket.Conn, data []byte) error {
cts.rpty_mtx.Lock()
rpty, ok = cts.rpty_map_by_ws[ws]
if !ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr())
}
@ -787,6 +835,7 @@ func (cts *ServerConn) WriteRptySize(ws *websocket.Conn, data []byte) error {
cts.rpty_mtx.Lock()
rpty, ok = cts.rpty_map_by_ws[ws]
if !ok {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown ws connection for rpty size - %v", ws.RemoteAddr())
}
@ -812,33 +861,16 @@ func (cts *ServerConn) ReadRptyAndWriteWs(id uint64, data []byte) error {
cts.rpty_mtx.Unlock()
return fmt.Errorf("unknown rpty id - %d", id)
}
cts.rpty_mtx.Unlock()
err = send_ws_data_for_xterm(rpty.ws, "iov", string(data))
if err != nil {
cts.rpty_mtx.Unlock()
return fmt.Errorf("failed to write rpty data(%d) to ws - %s", id, err.Error())
}
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error {
var r *ServerRoute
var ok bool
cts.route_mtx.Lock()
r, ok = cts.route_map[route_id]
if !ok {
cts.route_mtx.Unlock()
return fmt.Errorf("non-existent route id - %d", route_id)
}
cts.route_mtx.Unlock()
return r.ReportPacket(pts_id, packet_type, event_data)
}
func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
switch packet_type {
case PACKET_KIND_RPTY_STOP:
@ -853,6 +885,80 @@ func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent)
return nil
}
// Rpx
func (cts *ServerConn) StartRpxWebById(srpx* ServerRpx, id uint64, data []byte) error {
// pass the initial response to code in server-rpx.go
srpx.start_chan <- data
return nil
}
func (cts *ServerConn) StopRpxWebById(srpx* ServerRpx, id uint64) error {
srpx.ReqStop()
return nil
}
func (cts *ServerConn) WroteRpxWebById(srpx* ServerRpx, id uint64, data []byte) error {
var err error
_, err = srpx.pw.Write(data)
if err != nil {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx data(%d) to rpx pipe - %s", id, err.Error())
srpx.ReqStop()
}
return err
}
func (cts *ServerConn) EofRpxWebById(srpx* ServerRpx, id uint64) error {
srpx.ReqStop()
return nil
}
func (cts *ServerConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
var ok bool
var rpx* ServerRpx
cts.rpx_mtx.Lock()
rpx, ok = cts.rpx_map[evt.Id]
if !ok {
cts.rpx_mtx.Unlock()
return fmt.Errorf("unknown rpx id - %v", evt.Id)
}
cts.rpx_mtx.Unlock()
switch packet_type {
case PACKET_KIND_RPX_START:
return cts.StartRpxWebById(rpx, evt.Id, evt.Data)
case PACKET_KIND_RPX_STOP:
// stop requested from the server
return cts.StopRpxWebById(rpx, evt.Id)
case PACKET_KIND_RPX_EOF:
return cts.EofRpxWebById(rpx, evt.Id)
case PACKET_KIND_RPX_DATA:
return cts.WroteRpxWebById(rpx, evt.Id, evt.Data)
}
// ignore other packet types
return nil
}
// Rpx
func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error {
var r *ServerRoute
var ok bool
cts.route_mtx.Lock()
r, ok = cts.route_map[route_id]
if !ok {
cts.route_mtx.Unlock()
return fmt.Errorf("non-existent route id - %d", route_id)
}
cts.route_mtx.Unlock()
return r.ReportPacket(pts_id, packet_type, event_data)
}
func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
var pkt *Packet
var err error
@ -1055,10 +1161,30 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt)
if err != nil {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpty(%d) from %s - %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr, err.Error())
} else {
} else {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpty(%d) from %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr)
}
}
} else {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr)
}
case PACKET_KIND_RPX_START: // the client sends the response header using START
fallthrough
case PACKET_KIND_RPX_STOP:
fallthrough
case PACKET_KIND_RPX_EOF:
fallthrough
case PACKET_KIND_RPX_DATA:
var x *Packet_RpxEvt
var ok bool
x, ok = pkt.U.(*Packet_RpxEvt)
if ok {
err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt)
if err != nil {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpx(%d) from %s - %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr, err.Error())
} else {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpx(%d) from %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr)
}
} else {
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr)
}
@ -1066,17 +1192,26 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
}
done:
// arrange to break all rpty resources
cts.rpty_mtx.Lock()
if len(cts.rpty_map) > 0 {
var rpty *ServerRpty
for _, rpty = range cts.rpty_map {
rpty.ws.Close()
rpty.ReqStop()
}
}
cts.rpty_mtx.Unlock()
// arrange to break all rpx resources
cts.rpx_mtx.Lock()
if len(cts.rpx_map) > 0 {
var rpx *ServerRpx
for _, rpx = range cts.rpx_map {
rpx.ReqStop()
}
}
cts.rpx_mtx.Unlock()
cts.S.log.Write(cts.Sid, LOG_INFO, "RPC stream receiver ended")
}
@ -1109,7 +1244,7 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) {
cts.route_wg.Add(1)
// start the loop inside a goroutine so that route_wg counter
// is likely to be greater than 1 what Wait() is called.
// is likely to be greater than 1 when Wait() is called.
go func() {
waiting_loop:
for {
@ -1145,11 +1280,22 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) {
func (cts *ServerConn) ReqStop() {
if cts.stop_req.CompareAndSwap(false, true) {
var r *ServerRoute
var rpty *ServerRpty
var srpx *ServerRpx
cts.route_mtx.Lock()
for _, r = range cts.route_map { r.ReqStop() }
cts.route_mtx.Unlock()
cts.rpty_mtx.Lock()
for _, rpty = range cts.rpty_map { rpty.ReqStop() }
cts.rpty_mtx.Unlock()
cts.rpx_mtx.Lock()
for _, srpx = range cts.rpx_map { srpx.ReqStop() }
cts.rpx_mtx.Unlock()
// there is no good way to break a specific connection client to
// the grpc server. while the global grpc server is closed in
// ReqStop() for Server, the individuation connection is closed
@ -1359,13 +1505,14 @@ func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ
type server_http_log_writer struct {
svr *Server
depth int
}
func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
// the standard http.Server always requires *log.Logger
// use this iowriter to create a logger to pass it to the http server.
// since this is another log write wrapper, give adjustment value
hlw.svr.log.WriteWithCallDepth("", LOG_INFO, +1, string(p))
hlw.svr.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p))
return len(p), nil
}
@ -1422,13 +1569,13 @@ func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler {
WriteEmptyRespHeader(w, status_code)
}
}
time_taken = time.Now().Sub(start_time)
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
if status_code > 0 {
if err != nil {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error())
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
} else {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds())
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
}
}
})
@ -1440,7 +1587,7 @@ func (s *Server) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
var status_code int
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code)
return
}
handler.ServeHTTP(w, req)
@ -1454,21 +1601,19 @@ func (s *Server) WrapWebsocketHandler(handler ServerWebsocketHandler) websocket.
var start_time time.Time
var time_taken time.Duration
var req *http.Request
var raw_url_path string
req = ws.Request()
raw_url_path = get_raw_url_path(req)
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path)
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI)
start_time = time.Now()
status_code, err = handler.ServeWebsocket(ws)
time_taken = time.Now().Sub(start_time)
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
if status_code > 0 {
if err != nil {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error())
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
} else {
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds())
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
}
}
})
@ -1479,7 +1624,6 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
var l *net.TCPListener
var rpcaddr *net.TCPAddr
var addr string
var gl *net.TCPListener
var i int
var hs_log *log.Logger
var opts []grpc.ServerOption
@ -1536,7 +1680,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
// ---------------------------------------------------------
hs_log = log.New(&server_http_log_writer{svr: &s}, "", 0)
hs_log = log.New(&server_http_log_writer{svr: &s, depth: +2}, "", 0)
// ---------------------------------------------------------
@ -1626,7 +1770,8 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.ctl[i] = &http.Server{
Addr: cfg.CtlAddrs[i],
Handler: s.ctl_mux,
TLSConfig: s.Cfg.CtlTls,
// race condition issues without cloning. the http package modifies some fields in the configuration object
TLSConfig: cfg.CtlTls.Clone(),
ErrorLog: hs_log,
// TODO: more settings
}
@ -1638,12 +1783,11 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.rpx_mux.Handle("/", s.WrapHttpHandler(&server_rpx{ S: &s, Id: HS_ID_RPX }))
s.rpx = make([]*http.Server, len(cfg.RpxAddrs))
for i = 0; i < len(cfg.RpxAddrs); i++ {
s.rpx[i] = &http.Server{
Addr: cfg.RpxAddrs[i],
Handler: s.rpx_mux,
TLSConfig: cfg.RpxTls,
TLSConfig: cfg.RpxTls.Clone(),
ErrorLog: hs_log,
// TODO: more settings
}
@ -1696,7 +1840,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.pxy[i] = &http.Server{
Addr: cfg.PxyAddrs[i],
Handler: s.pxy_mux,
TLSConfig: cfg.PxyTls,
TLSConfig: cfg.PxyTls.Clone(),
ErrorLog: hs_log,
// TODO: more settings
}
@ -1746,7 +1890,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.wpx[i] = &http.Server{
Addr: cfg.WpxAddrs[i],
Handler: s.wpx_mux,
TLSConfig: cfg.WpxTls,
TLSConfig: cfg.WpxTls.Clone(),
ErrorLog: hs_log,
}
}
@ -1757,11 +1901,11 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
s.stats.peers.Store(0)
s.stats.ssh_proxy_sessions.Store(0)
s.stats.pty_sessions.Store(0)
s.stats.rpty_sessions.Store(0)
return &s, nil
oops:
if gl != nil { gl.Close() }
for _, l = range s.rpc { l.Close() }
s.rpc = make([]*net.TCPListener, 0)
return nil, err
@ -1840,16 +1984,10 @@ func (s *Server) RunTask(wg *sync.WaitGroup) {
go s.run_grpc_server(idx, &s.rpc_wg)
}
// most work is done by in separate goroutines (s.run_grp_server)
// this loop serves as a placeholder to prevent the logic flow from
// most work is done in separate goroutines (s.run_grp_server)
// this read on the stop channel serves as a placeholder to prevent the logic flow from
// descening down to s.ReqStop()
task_loop:
for {
select {
case <-s.stop_chan:
break task_loop
}
}
<-s.stop_chan
s.ReqStop()
@ -1863,8 +2001,52 @@ task_loop:
s.rpc_svr.Stop()
}
func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
func (s* Server) run_single_ctl_server(i int, cs *http.Server, wg* sync.WaitGroup) {
var l net.Listener
var err error
defer wg.Done()
s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, cs.Addr)
if s.stop_req.Load() == false {
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
// err = cs.ListenAndServe()
// err = cs.ListenAndServeTLS("", "")
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.ctl_addrs_mtx.Lock()
node = s.ctl_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.ctl_addrs_mtx.Unlock()
if s.Cfg.CtlTls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.CtlTls must provide a certificate and a key
}
s.ctl_addrs_mtx.Lock()
s.ctl_addrs.Remove(node)
s.ctl_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
}
}
func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
var ctl *http.Server
var idx int
var l_wg sync.WaitGroup
@ -1873,54 +2055,55 @@ func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
for idx, ctl = range s.ctl {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, cs.Addr)
if s.stop_req.Load() == false {
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
// err = cs.ListenAndServe()
// err = cs.ListenAndServeTLS("", "")
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.ctl_addrs_mtx.Lock()
node = s.ctl_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.ctl_addrs_mtx.Unlock()
if s.Cfg.CtlTls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.CtlTls must provide a certificate and a key
}
s.ctl_addrs_mtx.Lock()
s.ctl_addrs.Remove(node)
s.ctl_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, ctl)
go s.run_single_ctl_server(idx, ctl, &l_wg);
}
l_wg.Wait()
}
func (s *Server) RunRpxTask(wg *sync.WaitGroup) {
func (s *Server) run_single_rpx_server(i int, cs *http.Server, wg* sync.WaitGroup) {
var l net.Listener
var err error
defer wg.Done()
s.log.Write("", LOG_INFO, "RPX channel[%d] started on %s", i, s.Cfg.RpxAddrs[i])
if s.stop_req.Load() == false {
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.rpx_addrs_mtx.Lock()
node = s.rpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.rpx_addrs_mtx.Unlock()
if s.Cfg.RpxTls == nil { // TODO: change this
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.RpxTls must provide a certificate and a key
}
s.rpx_addrs_mtx.Lock()
s.rpx_addrs.Remove(node)
s.rpx_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "RPX channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "RPX channel[%d] error - %s", i, err.Error())
}
}
func (s *Server) RunRpxTask(wg *sync.WaitGroup) {
var rpx *http.Server
var idx int
var l_wg sync.WaitGroup
@ -1929,51 +2112,54 @@ func (s *Server) RunRpxTask(wg *sync.WaitGroup) {
for idx, rpx = range s.rpx {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
s.log.Write("", LOG_INFO, "RPX channel[%d] started on %s", i, s.Cfg.RpxAddrs[i])
if s.stop_req.Load() == false {
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.rpx_addrs_mtx.Lock()
node = s.rpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.rpx_addrs_mtx.Unlock()
if s.Cfg.RpxTls == nil { // TODO: change this
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.RpxTls must provide a certificate and a key
}
s.rpx_addrs_mtx.Lock()
s.rpx_addrs.Remove(node)
s.rpx_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "RPX channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "RPX channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, rpx)
go s.run_single_rpx_server(idx, rpx, &l_wg)
}
l_wg.Wait()
}
func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
func (s *Server) run_single_pxy_server(i int, cs *http.Server, wg* sync.WaitGroup) {
var l net.Listener
var err error
defer wg.Done()
s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.Cfg.PxyAddrs[i])
if s.stop_req.Load() == false {
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.pxy_addrs_mtx.Lock()
node = s.pxy_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.pxy_addrs_mtx.Unlock()
if s.Cfg.PxyTls == nil { // TODO: change this
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.PxyTls must provide a certificate and a key
}
s.pxy_addrs_mtx.Lock()
s.pxy_addrs.Remove(node)
s.pxy_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error())
}
}
func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
var pxy *http.Server
var idx int
var l_wg sync.WaitGroup
@ -1982,51 +2168,55 @@ func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
for idx, pxy = range s.pxy {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.Cfg.PxyAddrs[i])
if s.stop_req.Load() == false {
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.pxy_addrs_mtx.Lock()
node = s.pxy_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.pxy_addrs_mtx.Unlock()
if s.Cfg.PxyTls == nil { // TODO: change this
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.PxyTls must provide a certificate and a key
}
s.pxy_addrs_mtx.Lock()
s.pxy_addrs.Remove(node)
s.pxy_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, pxy)
go s.run_single_pxy_server(idx, pxy, &l_wg);
}
l_wg.Wait()
}
func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
func (s *Server) run_single_wpx_server(i int, cs *http.Server, wg* sync.WaitGroup) {
var l net.Listener
var err error
defer wg.Done()
s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.Cfg.WpxAddrs[i])
if s.stop_req.Load() == false {
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.wpx_addrs_mtx.Lock()
node = s.wpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.wpx_addrs_mtx.Unlock()
if s.Cfg.WpxTls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.WpxTls must provide a certificate and a key
}
s.wpx_addrs_mtx.Lock()
s.wpx_addrs.Remove(node)
s.wpx_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error())
}
}
func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
var wpx *http.Server
var idx int
var l_wg sync.WaitGroup
@ -2035,45 +2225,7 @@ func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
for idx, wpx = range s.wpx {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.Cfg.WpxAddrs[i])
if s.stop_req.Load() == false {
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if s.stop_req.Load() == false {
var node *list.Element
s.wpx_addrs_mtx.Lock()
node = s.wpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
s.wpx_addrs_mtx.Unlock()
if s.Cfg.WpxTls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // s.Cfg.WpxTls must provide a certificate and a key
}
s.wpx_addrs_mtx.Lock()
s.wpx_addrs.Remove(node)
s.wpx_addrs_mtx.Unlock()
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i)
} else {
s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, wpx)
go s.run_single_wpx_server(idx, wpx, &l_wg)
}
l_wg.Wait()
}
@ -2141,6 +2293,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
cts.rpty_map = make(ServerRptyMap)
cts.rpty_map_by_ws = make(ServerRptyMapByWs)
cts.rpx_map = make(ServerRpxMap)
s.cts_mtx.Lock()