added code for rpx handling
This commit is contained in:
@ -110,6 +110,7 @@ type json_out_client_stats struct {
|
||||
ClientRoutes int64 `json:"client-routes"`
|
||||
ClientPeers int64 `json:"client-peers"`
|
||||
ClientPtySessions int64 `json:"client-pty-sessions"`
|
||||
ClientRptySessions int64 `json:"client-rpty-sessions"`
|
||||
}
|
||||
// ------------------------------------
|
||||
|
||||
@ -1138,6 +1139,7 @@ func (ctl *client_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
stats.ClientRoutes = c.stats.routes.Load()
|
||||
stats.ClientPeers = c.stats.peers.Load()
|
||||
stats.ClientPtySessions = c.stats.pty_sessions.Load()
|
||||
stats.ClientRptySessions = c.stats.rpty_sessions.Load()
|
||||
status_code = WriteJsonRespHeader(w, http.StatusOK)
|
||||
if err = je.Encode(stats); err != nil { goto oops }
|
||||
|
||||
|
@ -11,6 +11,7 @@ type ClientCollector struct {
|
||||
ClientRoutes *prometheus.Desc
|
||||
ClientPeers *prometheus.Desc
|
||||
PtySessions *prometheus.Desc
|
||||
RptySessions *prometheus.Desc
|
||||
}
|
||||
|
||||
// NewClientCollector returns a new ClientCollector with all prometheus.Desc initialized
|
||||
@ -52,6 +53,11 @@ func NewClientCollector(client *Client) ClientCollector {
|
||||
"Number of pty sessions",
|
||||
nil, nil,
|
||||
),
|
||||
RptySessions: prometheus.NewDesc(
|
||||
prefix + "rpty_sessions",
|
||||
"Number of rpty sessions",
|
||||
nil, nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,6 +67,7 @@ func (c ClientCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- c.ClientRoutes
|
||||
ch <- c.ClientPeers
|
||||
ch <- c.PtySessions
|
||||
ch <- c.RptySessions
|
||||
}
|
||||
|
||||
func (c ClientCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
@ -97,4 +104,10 @@ func (c ClientCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
prometheus.GaugeValue,
|
||||
float64(c.client.stats.pty_sessions.Load()),
|
||||
)
|
||||
|
||||
ch <- prometheus.MustNewConstMetric(
|
||||
c.RptySessions,
|
||||
prometheus.GaugeValue,
|
||||
float64(c.client.stats.rpty_sessions.Load()),
|
||||
)
|
||||
}
|
||||
|
@ -30,6 +30,16 @@ func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error {
|
||||
|
||||
for {
|
||||
n, err = cpc.conn.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cpc.route.cts.psc.Send(MakePeerDataPacket(cpc.route.Id, cpc.conn_id, buf[0:n]))
|
||||
if err2 != nil {
|
||||
cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_ERROR,
|
||||
"Failed to write peer(%d,%d,%s,%s) data to server - %s",
|
||||
cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String(), err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { // i hate checking this condition with strings.Contains()
|
||||
cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_INFO,
|
||||
@ -42,14 +52,6 @@ func (cpc *ClientPeerConn) RunTask(wg *sync.WaitGroup) error {
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
err = cpc.route.cts.psc.Send(MakePeerDataPacket(cpc.route.Id, cpc.conn_id, buf[0:n]))
|
||||
if err != nil {
|
||||
cpc.route.cts.C.log.Write(cpc.route.cts.Sid, LOG_ERROR,
|
||||
"Failed to write peer(%d,%d,%s,%s) data to server - %s",
|
||||
cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String(), err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cpc.route.cts.psc.Send(MakePeerStoppedPacket(cpc.route.Id, cpc.conn_id, cpc.conn.RemoteAddr().String(), cpc.conn.LocalAddr().String())) // nothing much to do upon failure. no error check here
|
||||
|
@ -59,7 +59,7 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
conn_ready = <-conn_ready_chan
|
||||
if conn_ready { // connected
|
||||
var poll_fds []unix.PollFd
|
||||
var buf []byte
|
||||
var buf [2048]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
@ -69,7 +69,6 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
}
|
||||
|
||||
c.stats.pty_sessions.Add(1)
|
||||
buf = make([]byte, 2048)
|
||||
for {
|
||||
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
|
||||
if err != nil {
|
||||
@ -87,20 +86,21 @@ func (pty *client_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
|
||||
n, err = out.Read(buf)
|
||||
n, err = out.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
|
||||
if err2 != nil {
|
||||
c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
err = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
|
||||
if err != nil {
|
||||
c.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.stats.pty_sessions.Add(-1)
|
||||
|
615
client.go
615
client.go
@ -1,5 +1,7 @@
|
||||
package hodu
|
||||
|
||||
import "bufio"
|
||||
import "bytes"
|
||||
import "container/list"
|
||||
import "context"
|
||||
import "crypto/tls"
|
||||
@ -40,6 +42,7 @@ type ClientRouteMap map[RouteId]*ClientRoute
|
||||
type ClientPeerConnMap map[PeerId]*ClientPeerConn
|
||||
type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc
|
||||
type ClientRptyMap map[uint64]*ClientRpty
|
||||
type ClientRpxMap map[uint64]*ClientRpx
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
type ClientRouteConfig struct {
|
||||
@ -84,6 +87,10 @@ type ClientConfig struct {
|
||||
PeerConnTmout time.Duration
|
||||
|
||||
Token string // to send to the server for identification
|
||||
|
||||
// default target for rpx
|
||||
RpxTargetAddr string
|
||||
RpxTargetTls *tls.Config
|
||||
}
|
||||
|
||||
type ClientEventKind int
|
||||
@ -122,6 +129,10 @@ type Client struct {
|
||||
ext_svcs []Service
|
||||
|
||||
rpc_tls *tls.Config
|
||||
rpx_target_addr string
|
||||
rpx_target_url string
|
||||
rpx_target_tls *tls.Config
|
||||
|
||||
ctl_tls *tls.Config
|
||||
ctl_addr []string
|
||||
ctl_prefix string
|
||||
@ -155,6 +166,7 @@ type Client struct {
|
||||
routes atomic.Int64
|
||||
peers atomic.Int64
|
||||
pty_sessions atomic.Int64
|
||||
rpty_sessions atomic.Int64
|
||||
}
|
||||
|
||||
pty_user string
|
||||
@ -176,6 +188,15 @@ type ClientRpty struct {
|
||||
tty *os.File
|
||||
}
|
||||
|
||||
type ClientRpx struct {
|
||||
id uint64
|
||||
pr *io.PipeReader
|
||||
pw *io.PipeWriter
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ws_conn net.Conn
|
||||
}
|
||||
|
||||
// client connection to server
|
||||
type ClientConn struct {
|
||||
C *Client
|
||||
@ -209,6 +230,9 @@ type ClientConn struct {
|
||||
rpty_mtx sync.Mutex
|
||||
rpty_map ClientRptyMap
|
||||
|
||||
rpx_mtx sync.Mutex
|
||||
rpx_map ClientRpxMap
|
||||
|
||||
stop_req atomic.Bool
|
||||
stop_chan chan bool
|
||||
|
||||
@ -294,6 +318,22 @@ func (g *GuardedPacketStreamClient) Context() context.Context {
|
||||
return g.psc.Context()
|
||||
}*/
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func (rpty *ClientRpty) ReqStop() {
|
||||
rpty.tty.Close()
|
||||
rpty.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (rpx *ClientRpx) ReqStop() {
|
||||
rpx.pr.Close()
|
||||
rpx.pw.Close()
|
||||
rpx.cancel()
|
||||
if rpx.ws_conn != nil {
|
||||
rpx.ws_conn.SetDeadline(time.Now()) // to make Read return immediately
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
func NewClientRoute(cts *ClientConn, id RouteId, static bool, client_peer_addr string, client_peer_name string, server_peer_svc_addr string, server_peer_svc_net string, server_peer_option RouteOption, lifetime time.Duration) *ClientRoute {
|
||||
var r ClientRoute
|
||||
@ -397,7 +437,7 @@ func (r *ClientRoute) ExtendLifetime(lifetime time.Duration) error {
|
||||
r.lifetime_timer.Stop()
|
||||
r.Lifetime = r.Lifetime + lifetime
|
||||
expiry = r.LifetimeStart.Add(r.Lifetime)
|
||||
r.lifetime_timer.Reset(expiry.Sub(time.Now()))
|
||||
r.lifetime_timer.Reset(time.Until(expiry)) // expiry.Sub(time.Now())
|
||||
if r.cts.C.route_persister != nil { r.cts.C.route_persister.Save(r.cts, r) }
|
||||
r.lifetime_mtx.Unlock()
|
||||
|
||||
@ -477,10 +517,8 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) {
|
||||
break waiting_loop
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-r.stop_chan:
|
||||
break waiting_loop
|
||||
}
|
||||
<-r.stop_chan
|
||||
break waiting_loop
|
||||
}
|
||||
}
|
||||
|
||||
@ -547,7 +585,7 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
|
||||
defer wg.Done()
|
||||
|
||||
tmout = time.Duration(r.cts.C.ptc_tmout)
|
||||
if tmout <= 0 { tmout = 5 * time.Second} // TODO: make this configurable...
|
||||
if tmout <= 0 { tmout = 5 * time.Second } // TODO: make this configurable...
|
||||
waitctx, cancel_wait = context.WithTimeout(r.cts.C.Ctx, tmout)
|
||||
r.ptc_mtx.Lock()
|
||||
r.ptc_cancel_map[pts_id] = cancel_wait
|
||||
@ -571,8 +609,8 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
|
||||
real_conn, ok = conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_ERROR,
|
||||
"Failed to get connection information to %s for route(%d,%d,%s,%s) - %s",
|
||||
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr, err.Error())
|
||||
"Failed to get connection information to %s for route(%d,%d,%s,%s)",
|
||||
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr)
|
||||
goto peer_aborted
|
||||
}
|
||||
|
||||
@ -825,6 +863,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
|
||||
cts.stop_chan = make(chan bool, 8)
|
||||
cts.ptc_list = list.New()
|
||||
cts.rpty_map = make(ClientRptyMap)
|
||||
cts.rpx_map = make(ClientRpxMap)
|
||||
|
||||
for i, _ = range cts.cfg.Routes {
|
||||
// override it to static regardless of the value passed in
|
||||
@ -833,7 +872,6 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
|
||||
|
||||
// the actual connection to the server is established in the main task function
|
||||
// The cts.conn, cts.hdc, cts.psc fields are left unassigned here.
|
||||
|
||||
return &cts
|
||||
}
|
||||
|
||||
@ -1031,7 +1069,8 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error {
|
||||
func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
if cts.conn != nil {
|
||||
var r *ClientRoute
|
||||
var crp *ClientRpty
|
||||
var rpty *ClientRpty
|
||||
var rpx *ClientRpx
|
||||
|
||||
cts.discon_mtx.Lock()
|
||||
|
||||
@ -1045,15 +1084,20 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
|
||||
// arrange to clean up all rpty objects
|
||||
cts.rpty_mtx.Lock()
|
||||
for _, crp = range cts.rpty_map {
|
||||
crp.tty.Close()
|
||||
crp.cmd.Process.Kill()
|
||||
for _, rpty = range cts.rpty_map {
|
||||
rpty.ReqStop()
|
||||
// the loop in ReadRptyLoop() is supposed to be broken.
|
||||
// let's not inform the server of this connection.
|
||||
// the server should clean up itself upon connection error
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
for _, rpx = range cts.rpx_map {
|
||||
rpx.ReqStop()
|
||||
}
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
// don't care about double closes when this function is called from both RunTask() and ReqStop()
|
||||
cts.conn.Close()
|
||||
|
||||
@ -1217,7 +1261,7 @@ start_over:
|
||||
|
||||
pkt, err = psc.Recv()
|
||||
if err != nil {
|
||||
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) {
|
||||
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
||||
goto reconnect_to_server
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr_p, err.Error())
|
||||
@ -1348,6 +1392,31 @@ start_over:
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
|
||||
}
|
||||
|
||||
case PACKET_KIND_RPX_START:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_STOP:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_DATA:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_EOF:
|
||||
var x *Packet_RpxEvt
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_RpxEvt)
|
||||
if ok {
|
||||
err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR,
|
||||
"Failed to handle %s event for rpx(%d) from %s - %s",
|
||||
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p, err.Error())
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG,
|
||||
"Handled %s event for rpx(%d) from %s",
|
||||
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p)
|
||||
}
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
|
||||
}
|
||||
|
||||
default:
|
||||
// do nothing. ignore the rest
|
||||
}
|
||||
@ -1390,7 +1459,7 @@ reconnect_to_server:
|
||||
goto wait_for_termination
|
||||
case <-slpctx.Done():
|
||||
select {
|
||||
case <- cts.C.Ctx.Done():
|
||||
case <-cts.C.Ctx.Done():
|
||||
// non-blocking check if the parent context of the sleep context is
|
||||
// terminated too. if so, this is normal termination case.
|
||||
// this check seem redundant but the go-runtime doesn't seem to guarantee
|
||||
@ -1421,20 +1490,34 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type
|
||||
}
|
||||
|
||||
|
||||
// rpty
|
||||
func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
if !ok { crp = nil }
|
||||
return crp
|
||||
}
|
||||
|
||||
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
|
||||
var poll_fds []unix.PollFd
|
||||
var buf []byte
|
||||
var buf [2048]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
cts.C.stats.rpty_sessions.Add(1)
|
||||
|
||||
poll_fds = []unix.PollFd{
|
||||
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
buf = make([]byte, 2048)
|
||||
for {
|
||||
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
|
||||
if err != nil {
|
||||
@ -1452,34 +1535,35 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
|
||||
n, err = crp.tty.Read(buf)
|
||||
n, err = crp.tty.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
|
||||
|
||||
crp.tty.Close() // don't care about multiple closes
|
||||
crp.cmd.Process.Kill()
|
||||
crp.ReqStop()
|
||||
crp.cmd.Wait()
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
delete(cts.rpty_map, crp.id)
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.C.stats.rpty_sessions.Add(-1)
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
@ -1515,46 +1599,34 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
|
||||
func (cts *ClientConn) StopRpty(id uint64) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp = cts.FindClientRptyById(id)
|
||||
if crp == nil {
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
crp.tty.Close() // to break ReadRptyLoop()
|
||||
crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop()
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp.ReqStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp = cts.FindClientRptyById(id)
|
||||
if crp == nil {
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
crp.tty.Write(data)
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
var flds []string
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp = cts.FindClientRptyById(id)
|
||||
if crp == nil {
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
@ -1566,11 +1638,11 @@ func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
|
||||
cols, _ = strconv.Atoi(flds[1])
|
||||
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
|
||||
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPTY_START:
|
||||
return cts.StartRpty(evt.Id, &cts.C.wg)
|
||||
@ -1589,6 +1661,329 @@ func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
// rpx
|
||||
func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
|
||||
var crpx *ClientRpx
|
||||
var ok bool
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
crpx, ok = cts.rpx_map[id]
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
if !ok { crpx = nil }
|
||||
return crpx
|
||||
}
|
||||
|
||||
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
|
||||
var sc *bufio.Scanner
|
||||
var line string
|
||||
var flds []string
|
||||
var buf [4096]byte
|
||||
var req_meth string
|
||||
var req_path string
|
||||
//var req_proto string
|
||||
var req *http.Request
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
sc = bufio.NewScanner(bytes.NewReader(data))
|
||||
sc.Scan()
|
||||
line = sc.Text()
|
||||
|
||||
flds = strings.Fields(line)
|
||||
if (len(flds) < 3) {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line)
|
||||
goto done
|
||||
}
|
||||
|
||||
// TODO: handle trailers...
|
||||
req_meth = flds[0]
|
||||
req_path = flds[1]
|
||||
//req_proto = flds[2]
|
||||
|
||||
// create a request assuming it's a normal http request
|
||||
|
||||
req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
for sc.Scan() {
|
||||
line = sc.Text()
|
||||
if line == "" { break }
|
||||
flds = strings.SplitN(line, ":", 2)
|
||||
if len(flds) == 2 {
|
||||
req.Header.Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
|
||||
}
|
||||
}
|
||||
err = sc.Err()
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
// websocket
|
||||
var done_chan chan struct{}
|
||||
|
||||
var conn net.Conn
|
||||
var resp *http.Response
|
||||
var r *bufio.Reader
|
||||
|
||||
if cts.C.rpx_target_tls != nil {
|
||||
var dialer *tls.Dialer
|
||||
dialer = &tls.Dialer{
|
||||
NetDialer: &net.Dialer{},
|
||||
Config: cts.C.rpx_target_tls,
|
||||
}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
} else {
|
||||
var dialer *net.Dialer
|
||||
dialer = &net.Dialer{}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
}
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// TODO: make this atomic?
|
||||
crpx.ws_conn = conn
|
||||
|
||||
// write the raw request line and headers as sent by the server.
|
||||
// for the upgrade request, i assume no payload.
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
r = bufio.NewReader(conn)
|
||||
resp, err = http.ReadResponse(r, req)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
// websock upgrade failed. let the code jump to the done
|
||||
// label to skip reading from the pipe. the server side
|
||||
// has the code to ensure no content-length. and the upgrade
|
||||
// fails, the pipe below will be pending forever as the server
|
||||
// side doesn't send data and there's no feeding to the pipe.
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
|
||||
goto done
|
||||
}
|
||||
|
||||
// unlike with the normal request, the actual pipe is not read
|
||||
// until the initial switching protocol response is received.
|
||||
|
||||
wg.Add(1)
|
||||
done_chan = make(chan struct{}, 5)
|
||||
go func() {
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
for {
|
||||
n, err = crpx.pr.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
_, err2 = conn.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) { break }
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
done_chan <- struct{}{}
|
||||
}()
|
||||
|
||||
for {
|
||||
n, err = conn.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
||||
break
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// wait until the pipe reading(from the server side) goroutine is over
|
||||
<-done_chan
|
||||
} else {
|
||||
var tr *http.Transport
|
||||
var resp *http.Response
|
||||
|
||||
tr = &http.Transport {
|
||||
TLSClientConfig: cts.C.rpx_target_tls,
|
||||
}
|
||||
|
||||
resp, err = tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
for {
|
||||
n, err = resp.Body.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
err = cts.psc.Send(MakeRpxStopPacket(crpx.id))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error())
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
|
||||
|
||||
crpx.ReqStop()
|
||||
cts.rpx_mtx.Lock()
|
||||
delete(cts.rpx_map, crpx.id)
|
||||
cts.rpx_mtx.Unlock()
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
|
||||
var crpx *ClientRpx
|
||||
var ok bool
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
_, ok = cts.rpx_map[id]
|
||||
if ok {
|
||||
cts.rpx_mtx.Unlock()
|
||||
return fmt.Errorf("multiple start on rpx id %d", id)
|
||||
}
|
||||
crpx = &ClientRpx{ id: id }
|
||||
cts.rpx_map[id] = crpx
|
||||
|
||||
// i want the pipe to be created before the goroutine is started
|
||||
// so that the WriteRpx() can write to the pipe. i protect pipe creation
|
||||
// and context creation with a mutex
|
||||
crpx.pr, crpx.pw = io.Pipe()
|
||||
crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx)
|
||||
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
wg.Add(1)
|
||||
go cts.RpxLoop(crpx, data, wg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StopRpx(id uint64) error {
|
||||
var crpx *ClientRpx
|
||||
|
||||
crpx = cts.FindClientRpxById(id)
|
||||
if crpx == nil {
|
||||
return fmt.Errorf("unknown rpx id %d", id)
|
||||
}
|
||||
|
||||
crpx.ReqStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRpx(id uint64, data []byte) error {
|
||||
var crpx *ClientRpx
|
||||
var err error
|
||||
|
||||
crpx = cts.FindClientRpxById(id)
|
||||
if crpx == nil {
|
||||
return fmt.Errorf("unknown rpx id %d", id)
|
||||
}
|
||||
|
||||
// TODO: may have to write it in a goroutine to avoid blocking?
|
||||
_, err = crpx.pw.Write(data)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) EofRpx(id uint64, data []byte) error {
|
||||
var crpx *ClientRpx
|
||||
|
||||
crpx = cts.FindClientRpxById(id)
|
||||
if crpx == nil {
|
||||
return fmt.Errorf("unknown rpx id %d", id)
|
||||
}
|
||||
|
||||
// close the writing end only. leave the reading end untouched
|
||||
crpx.pw.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPX_START:
|
||||
return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg)
|
||||
|
||||
case PACKET_KIND_RPX_STOP:
|
||||
return cts.StopRpx(evt.Id)
|
||||
|
||||
case PACKET_KIND_RPX_DATA:
|
||||
return cts.WriteRpx(evt.Id, evt.Data)
|
||||
|
||||
case PACKET_KIND_RPX_EOF:
|
||||
return cts.EofRpx(evt.Id, evt.Data)
|
||||
}
|
||||
|
||||
// ignore other packet types
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func (m ClientPeerConnMap) get_sorted_keys() []PeerId {
|
||||
@ -1631,12 +2026,13 @@ func (m ClientConnMap) get_sorted_keys() []ConnId {
|
||||
|
||||
type client_ctl_log_writer struct {
|
||||
cli *Client
|
||||
depth int
|
||||
}
|
||||
|
||||
func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
|
||||
// the standard http.Server always requires *log.Logger
|
||||
// use this iowriter to create a logger to pass it to the http server.
|
||||
hlw.cli.log.Write("", LOG_INFO, string(p))
|
||||
hlw.cli.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@ -1696,13 +2092,13 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler {
|
||||
}
|
||||
|
||||
// TODO: statistics by status_code and end point types.
|
||||
time_taken = time.Now().Sub(start_time)
|
||||
time_taken = time.Since(start_time) //time.Now().Sub(start_time)
|
||||
|
||||
if status_code > 0 {
|
||||
if err != nil {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
|
||||
} else {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -1714,7 +2110,7 @@ func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle
|
||||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
var status_code int
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
|
||||
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
|
||||
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code)
|
||||
return
|
||||
}
|
||||
handler.ServeHTTP(w, req)
|
||||
@ -1728,21 +2124,19 @@ func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
|
||||
var start_time time.Time
|
||||
var time_taken time.Duration
|
||||
var req *http.Request
|
||||
var raw_url_path string
|
||||
|
||||
req = ws.Request()
|
||||
raw_url_path = get_raw_url_path(req)
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path)
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI)
|
||||
|
||||
start_time = time.Now()
|
||||
status_code, err = handler.ServeWebsocket(ws)
|
||||
time_taken = time.Now().Sub(start_time)
|
||||
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
|
||||
|
||||
if status_code > 0 {
|
||||
if err != nil {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
|
||||
} else {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -1770,6 +2164,14 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
|
||||
c.bulletin = NewBulletin[*ClientEvent](&c, 1024)
|
||||
|
||||
c.rpc_tls = cfg.RpcTls
|
||||
c.rpx_target_addr = cfg.RpxTargetAddr
|
||||
c.rpx_target_tls = cfg.RpxTargetTls
|
||||
if c.rpx_target_tls != nil {
|
||||
c.rpx_target_url = "https://" + c.rpx_target_addr
|
||||
} else {
|
||||
c.rpx_target_url = "http://" + c.rpx_target_addr
|
||||
}
|
||||
|
||||
c.ctl_auth = cfg.CtlAuth
|
||||
c.ctl_tls = cfg.CtlTls
|
||||
c.ctl_prefix = cfg.CtlPrefix
|
||||
@ -1843,13 +2245,13 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
|
||||
c.ctl = make([]*http.Server, len(cfg.CtlAddrs))
|
||||
copy(c.ctl_addr, cfg.CtlAddrs)
|
||||
|
||||
hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0)
|
||||
hs_log = log.New(&client_ctl_log_writer{cli: &c, depth: 0}, "", 0)
|
||||
|
||||
for i = 0; i < len(cfg.CtlAddrs); i++ {
|
||||
c.ctl[i] = &http.Server{
|
||||
Addr: cfg.CtlAddrs[i],
|
||||
Handler: c.ctl_mux,
|
||||
TLSConfig: c.ctl_tls,
|
||||
TLSConfig: c.ctl_tls.Clone(),
|
||||
ErrorLog: hs_log,
|
||||
// TODO: more settings
|
||||
}
|
||||
@ -1859,6 +2261,7 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
|
||||
c.stats.routes.Store(0)
|
||||
c.stats.peers.Store(0)
|
||||
c.stats.pty_sessions.Store(0)
|
||||
c.stats.rpty_sessions.Store(0)
|
||||
|
||||
return &c
|
||||
}
|
||||
@ -2171,8 +2574,52 @@ func (c *Client) GetPtyShell() string {
|
||||
return c.pty_shell
|
||||
}
|
||||
|
||||
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
|
||||
func (c *Client) run_single_ctl_server(i int, cs *http.Server, wg *sync.WaitGroup) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
|
||||
|
||||
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
||||
// by creating the listener explicitly.
|
||||
// err = cs.ListenAndServe()
|
||||
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
|
||||
|
||||
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
|
||||
//so I have to use my own indicator to check if it's been shutdown..
|
||||
//
|
||||
if c.stop_req.Load() == false {
|
||||
// this guard has a flaw in that the stop request can be made
|
||||
// between the check above and net.Listen() below.
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if c.stop_req.Load() == false {
|
||||
// check it again to make the guard slightly more stable
|
||||
// although it's still possible that the stop request is made
|
||||
// after Listen()
|
||||
if c.ctl_tls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
||||
} else {
|
||||
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
|
||||
var ctl *http.Server
|
||||
var idx int
|
||||
var l_wg sync.WaitGroup
|
||||
@ -2181,49 +2628,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
|
||||
|
||||
for idx, ctl = range c.ctl {
|
||||
l_wg.Add(1)
|
||||
go func(i int, cs *http.Server) {
|
||||
var l net.Listener
|
||||
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
|
||||
|
||||
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
||||
// by creating the listener explicitly.
|
||||
// err = cs.ListenAndServe()
|
||||
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
|
||||
|
||||
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
|
||||
//so I have to use my own indicator to check if it's been shutdown..
|
||||
//
|
||||
if c.stop_req.Load() == false {
|
||||
// this guard has a flaw in that the stop request can be made
|
||||
// between the check above and net.Listen() below.
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if c.stop_req.Load() == false {
|
||||
// check it again to make the guard slightly more stable
|
||||
// although it's still possible that the stop request is made
|
||||
// after Listen()
|
||||
if c.ctl_tls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
||||
} else {
|
||||
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
|
||||
l_wg.Done()
|
||||
}(idx, ctl)
|
||||
go c.run_single_ctl_server(idx, ctl, &l_wg)
|
||||
}
|
||||
l_wg.Wait()
|
||||
}
|
||||
|
@ -73,6 +73,12 @@ type RPXServiceConfig struct {
|
||||
Addrs []string `yaml:"addresses"`
|
||||
}
|
||||
|
||||
type RPXClientTokenConfig struct {
|
||||
AttrName string `yaml:"attr-name"`
|
||||
Regex string `yaml:"regex"`
|
||||
SubmatchIndex int `yaml:"submatch-index"`
|
||||
}
|
||||
|
||||
type PXYServiceConfig struct {
|
||||
Addrs []string `yaml:"addresses"`
|
||||
}
|
||||
@ -127,8 +133,9 @@ type ServerConfig struct {
|
||||
} `yaml:"ctl"`
|
||||
|
||||
RPX struct {
|
||||
Service RPXServiceConfig `yaml:"service"`
|
||||
TLS ServerTLSConfig `yaml:"tls"`
|
||||
Service RPXServiceConfig `yaml:"service"`
|
||||
TLS ServerTLSConfig `yaml:"tls"`
|
||||
ClientToken RPXClientTokenConfig `yaml:"client-token"`
|
||||
} `yaml:"rpx"`
|
||||
|
||||
PXY struct {
|
||||
@ -158,6 +165,12 @@ type ClientConfig struct {
|
||||
Endpoint RPCEndpointConfig `yaml:"endpoint"`
|
||||
TLS ClientTLSConfig `yaml:"tls"`
|
||||
} `yaml:"rpc"`
|
||||
RPX struct {
|
||||
Target struct {
|
||||
Addr string `yaml:"address"`
|
||||
TLS ClientTLSConfig `yaml:"tls"`
|
||||
} `yaml:"target"`
|
||||
}
|
||||
}
|
||||
|
||||
func load_server_config_to(cfgfile string, cfg *ServerConfig) error {
|
||||
|
@ -121,7 +121,6 @@ main_loop:
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func (l *AppLogger) Write(id string, level hodu.LogLevel, fmtstr string, args ...interface{}) {
|
||||
if l.mask & hodu.LogMask(level) == 0 { return }
|
||||
l.write(id, level, 1, fmtstr, args...)
|
||||
|
11
cmd/main.go
11
cmd/main.go
@ -10,6 +10,7 @@ import "net"
|
||||
import "os"
|
||||
import "os/signal"
|
||||
import "path/filepath"
|
||||
import "regexp"
|
||||
import "strings"
|
||||
import "sync"
|
||||
import "syscall"
|
||||
@ -131,6 +132,13 @@ func server_main(ctl_addrs []string, rpc_addrs []string, rpx_addrs[] string, pxy
|
||||
if len(config.PxyAddrs) <= 0 { config.PxyAddrs = cfg.PXY.Service.Addrs }
|
||||
if len(config.WpxAddrs) <= 0 { config.WpxAddrs = cfg.WPX.Service.Addrs }
|
||||
|
||||
config.RpxClientTokenAttrName = cfg.RPX.ClientToken.AttrName
|
||||
if cfg.RPX.ClientToken.Regex != "" {
|
||||
config.RpxClientTokenRegex, err = regexp.Compile(cfg.RPX.ClientToken.Regex)
|
||||
if err != nil { return err }
|
||||
}
|
||||
config.RpxClientTokenSubmatchIndex = cfg.RPX.ClientToken.SubmatchIndex
|
||||
|
||||
config.CtlCors = cfg.CTL.Service.Cors
|
||||
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
|
||||
if err != nil { return err }
|
||||
@ -282,10 +290,13 @@ func client_main(ctl_addrs []string, rpc_addrs []string, route_configs []string,
|
||||
if err != nil { return err }
|
||||
config.RpcTls, err = make_tls_client_config(&cfg.RPC.TLS)
|
||||
if err != nil { return err }
|
||||
config.RpxTargetTls, err = make_tls_client_config(&cfg.RPX.Target.TLS)
|
||||
if err != nil { return err }
|
||||
|
||||
if len(rpc_addrs) <= 0 { rpc_addrs = cfg.RPC.Endpoint.Addrs }
|
||||
if len(config.CtlAddrs) <= 0 { config.CtlAddrs = cfg.CTL.Service.Addrs }
|
||||
|
||||
config.RpxTargetAddr = cfg.RPX.Target.Addr
|
||||
config.CtlPrefix = cfg.CTL.Service.Prefix
|
||||
config.CtlCors = cfg.CTL.Service.Cors
|
||||
config.CtlAuth, err = make_http_auth_config(&cfg.CTL.Service.Auth)
|
||||
|
87
hodu.go
87
hodu.go
@ -1,5 +1,6 @@
|
||||
package hodu
|
||||
|
||||
import "bytes"
|
||||
import "crypto/rsa"
|
||||
import _ "embed"
|
||||
import "encoding/base64"
|
||||
@ -8,6 +9,7 @@ import "net"
|
||||
import "net/http"
|
||||
import "net/netip"
|
||||
import "os"
|
||||
import "regexp"
|
||||
import "runtime"
|
||||
import "strings"
|
||||
import "sync"
|
||||
@ -79,11 +81,6 @@ type JsonErrmsg struct {
|
||||
Text string `json:"error-text"`
|
||||
}
|
||||
|
||||
type json_in_cred struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type json_in_notice struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
@ -224,8 +221,9 @@ func (option RouteOption) String() string {
|
||||
|
||||
func dump_call_frame_and_exit(log Logger, req *http.Request, err interface{}) {
|
||||
var buf []byte
|
||||
buf = make([]byte, 65536); buf = buf[:min(65536, runtime.Stack(buf, false))]
|
||||
log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, get_raw_url_path(req), err, string(buf))
|
||||
buf = make([]byte, 65536)
|
||||
buf = buf[:min(65536, runtime.Stack(buf, false))]
|
||||
log.Write("", LOG_ERROR, "[%s] %s %s - %v\n%s", req.RemoteAddr, req.Method, req.RequestURI, err, string(buf))
|
||||
log.Close()
|
||||
os.Exit(99) // fatal error. treat panic() as a fatal runtime error
|
||||
}
|
||||
@ -467,12 +465,75 @@ func (auth *HttpAuthConfig) Authenticate(req *http.Request) (int, string) {
|
||||
return http.StatusOK, ""
|
||||
}
|
||||
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
func get_raw_url_path(req *http.Request) string {
|
||||
var path string
|
||||
path = req.URL.Path
|
||||
if req.URL.RawQuery != "" { path += "?" + req.URL.RawQuery }
|
||||
return path
|
||||
func get_http_req_line_and_headers(r *http.Request, force_host bool) []byte {
|
||||
var buf bytes.Buffer
|
||||
var name string
|
||||
var value string
|
||||
var values []string
|
||||
var host_found bool
|
||||
|
||||
fmt.Fprintf(&buf, "%s %s %s\r\n", r.Method, r.RequestURI, r.Proto)
|
||||
|
||||
for name, values = range r.Header {
|
||||
if strings.EqualFold(name, "Accept-Encoding") { // TODO: make it generic. parameterize it??
|
||||
// skip Accept-Encoding as the go client side
|
||||
// doesn't function properly when a certain enconding
|
||||
// is specified. resp.Body.Read() returned EOF when
|
||||
// not working
|
||||
continue
|
||||
} else if strings.EqualFold(name, "Host") {
|
||||
host_found = true
|
||||
}
|
||||
for _, value = range values {
|
||||
fmt.Fprintf(&buf, "%s: %s\r\n", name, value)
|
||||
}
|
||||
}
|
||||
|
||||
if force_host && !host_found && r.Host != "" {
|
||||
fmt.Fprintf(&buf, "Host: %s\r\n", r.Host)
|
||||
}
|
||||
// TODO: host and x-forwarded-for, x-forwarded-proto, etc???
|
||||
|
||||
buf.WriteString("\r\n") // End of headers
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func get_http_resp_line_and_headers(r *http.Response) []byte {
|
||||
var buf bytes.Buffer
|
||||
var name string
|
||||
var value string
|
||||
var values []string
|
||||
|
||||
fmt.Fprintf(&buf, "%s %s\r\n", r.Proto, r.Status)
|
||||
|
||||
for name, values = range r.Header {
|
||||
for _, value = range values {
|
||||
fmt.Fprintf(&buf, "%s: %s\r\n", name, value)
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteString("\r\n") // End of headers
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func get_regex_submatch(re *regexp.Regexp, str string, n int) string {
|
||||
var idxs []int
|
||||
var pos int
|
||||
var start int
|
||||
var end int
|
||||
|
||||
idxs = re.FindStringSubmatchIndex(str)
|
||||
if idxs == nil { return "" }
|
||||
|
||||
pos = n * 2
|
||||
if pos + 1 >= len(idxs) { return "" }
|
||||
|
||||
start, end = idxs[pos], idxs[pos + 1]
|
||||
if start == -1 || end == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return str[start:end]
|
||||
}
|
133
hodu.pb.go
133
hodu.pb.go
@ -101,7 +101,11 @@ const (
|
||||
PACKET_KIND_RPTY_START PACKET_KIND = 14
|
||||
PACKET_KIND_RPTY_STOP PACKET_KIND = 15
|
||||
PACKET_KIND_RPTY_DATA PACKET_KIND = 16
|
||||
PACKET_KIND_RPTY_SIZE PACKET_KIND = 17
|
||||
PACKET_KIND_RPTY_SIZE PACKET_KIND = 17 // terminal size
|
||||
PACKET_KIND_RPX_START PACKET_KIND = 18
|
||||
PACKET_KIND_RPX_STOP PACKET_KIND = 19
|
||||
PACKET_KIND_RPX_DATA PACKET_KIND = 20
|
||||
PACKET_KIND_RPX_EOF PACKET_KIND = 21
|
||||
)
|
||||
|
||||
// Enum value maps for PACKET_KIND.
|
||||
@ -124,6 +128,10 @@ var (
|
||||
15: "RPTY_STOP",
|
||||
16: "RPTY_DATA",
|
||||
17: "RPTY_SIZE",
|
||||
18: "RPX_START",
|
||||
19: "RPX_STOP",
|
||||
20: "RPX_DATA",
|
||||
21: "RPX_EOF",
|
||||
}
|
||||
PACKET_KIND_value = map[string]int32{
|
||||
"RESERVED": 0,
|
||||
@ -143,6 +151,10 @@ var (
|
||||
"RPTY_STOP": 15,
|
||||
"RPTY_DATA": 16,
|
||||
"RPTY_SIZE": 17,
|
||||
"RPX_START": 18,
|
||||
"RPX_STOP": 19,
|
||||
"RPX_DATA": 20,
|
||||
"RPX_EOF": 21,
|
||||
}
|
||||
)
|
||||
|
||||
@ -643,6 +655,58 @@ func (x *RptyEvent) GetData() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
type RpxEvent struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Id uint64 `protobuf:"varint,1,opt,name=Id,proto3" json:"Id,omitempty"`
|
||||
Data []byte `protobuf:"bytes,2,opt,name=Data,proto3" json:"Data,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RpxEvent) Reset() {
|
||||
*x = RpxEvent{}
|
||||
mi := &file_hodu_proto_msgTypes[8]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RpxEvent) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RpxEvent) ProtoMessage() {}
|
||||
|
||||
func (x *RpxEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_hodu_proto_msgTypes[8]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RpxEvent.ProtoReflect.Descriptor instead.
|
||||
func (*RpxEvent) Descriptor() ([]byte, []int) {
|
||||
return file_hodu_proto_rawDescGZIP(), []int{8}
|
||||
}
|
||||
|
||||
func (x *RpxEvent) GetId() uint64 {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *RpxEvent) GetData() []byte {
|
||||
if x != nil {
|
||||
return x.Data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Packet struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Kind PACKET_KIND `protobuf:"varint,1,opt,name=Kind,proto3,enum=PACKET_KIND" json:"Kind,omitempty"`
|
||||
@ -655,6 +719,7 @@ type Packet struct {
|
||||
// *Packet_ConnErr
|
||||
// *Packet_ConnNoti
|
||||
// *Packet_RptyEvt
|
||||
// *Packet_RpxEvt
|
||||
U isPacket_U `protobuf_oneof:"U"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
@ -662,7 +727,7 @@ type Packet struct {
|
||||
|
||||
func (x *Packet) Reset() {
|
||||
*x = Packet{}
|
||||
mi := &file_hodu_proto_msgTypes[8]
|
||||
mi := &file_hodu_proto_msgTypes[9]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@ -674,7 +739,7 @@ func (x *Packet) String() string {
|
||||
func (*Packet) ProtoMessage() {}
|
||||
|
||||
func (x *Packet) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_hodu_proto_msgTypes[8]
|
||||
mi := &file_hodu_proto_msgTypes[9]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@ -687,7 +752,7 @@ func (x *Packet) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use Packet.ProtoReflect.Descriptor instead.
|
||||
func (*Packet) Descriptor() ([]byte, []int) {
|
||||
return file_hodu_proto_rawDescGZIP(), []int{8}
|
||||
return file_hodu_proto_rawDescGZIP(), []int{9}
|
||||
}
|
||||
|
||||
func (x *Packet) GetKind() PACKET_KIND {
|
||||
@ -767,6 +832,15 @@ func (x *Packet) GetRptyEvt() *RptyEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Packet) GetRpxEvt() *RpxEvent {
|
||||
if x != nil {
|
||||
if x, ok := x.U.(*Packet_RpxEvt); ok {
|
||||
return x.RpxEvt
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isPacket_U interface {
|
||||
isPacket_U()
|
||||
}
|
||||
@ -799,6 +873,10 @@ type Packet_RptyEvt struct {
|
||||
RptyEvt *RptyEvent `protobuf:"bytes,8,opt,name=RptyEvt,proto3,oneof"`
|
||||
}
|
||||
|
||||
type Packet_RpxEvt struct {
|
||||
RpxEvt *RpxEvent `protobuf:"bytes,9,opt,name=RpxEvt,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*Packet_Route) isPacket_U() {}
|
||||
|
||||
func (*Packet_Peer) isPacket_U() {}
|
||||
@ -813,6 +891,8 @@ func (*Packet_ConnNoti) isPacket_U() {}
|
||||
|
||||
func (*Packet_RptyEvt) isPacket_U() {}
|
||||
|
||||
func (*Packet_RpxEvt) isPacket_U() {}
|
||||
|
||||
var File_hodu_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_hodu_proto_rawDesc = "" +
|
||||
@ -850,7 +930,10 @@ const file_hodu_proto_rawDesc = "" +
|
||||
"\x04Text\x18\x01 \x01(\tR\x04Text\"/\n" +
|
||||
"\tRptyEvent\x12\x0e\n" +
|
||||
"\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" +
|
||||
"\x04Data\x18\x02 \x01(\fR\x04Data\"\xb1\x02\n" +
|
||||
"\x04Data\x18\x02 \x01(\fR\x04Data\".\n" +
|
||||
"\bRpxEvent\x12\x0e\n" +
|
||||
"\x02Id\x18\x01 \x01(\x04R\x02Id\x12\x12\n" +
|
||||
"\x04Data\x18\x02 \x01(\fR\x04Data\"\xd6\x02\n" +
|
||||
"\x06Packet\x12 \n" +
|
||||
"\x04Kind\x18\x01 \x01(\x0e2\f.PACKET_KINDR\x04Kind\x12\"\n" +
|
||||
"\x05Route\x18\x02 \x01(\v2\n" +
|
||||
@ -862,7 +945,8 @@ const file_hodu_proto_rawDesc = "" +
|
||||
".ConnErrorH\x00R\aConnErr\x12)\n" +
|
||||
"\bConnNoti\x18\a \x01(\v2\v.ConnNoticeH\x00R\bConnNoti\x12&\n" +
|
||||
"\aRptyEvt\x18\b \x01(\v2\n" +
|
||||
".RptyEventH\x00R\aRptyEvtB\x03\n" +
|
||||
".RptyEventH\x00R\aRptyEvt\x12#\n" +
|
||||
"\x06RpxEvt\x18\t \x01(\v2\t.RpxEventH\x00R\x06RpxEvtB\x03\n" +
|
||||
"\x01U*U\n" +
|
||||
"\fROUTE_OPTION\x12\n" +
|
||||
"\n" +
|
||||
@ -872,7 +956,7 @@ const file_hodu_proto_rawDesc = "" +
|
||||
"\x04TCP6\x10\x04\x12\b\n" +
|
||||
"\x04HTTP\x10\b\x12\t\n" +
|
||||
"\x05HTTPS\x10\x10\x12\a\n" +
|
||||
"\x03SSH\x10 *\xa2\x02\n" +
|
||||
"\x03SSH\x10 *\xda\x02\n" +
|
||||
"\vPACKET_KIND\x12\f\n" +
|
||||
"\bRESERVED\x10\x00\x12\x0f\n" +
|
||||
"\vROUTE_START\x10\x01\x12\x0e\n" +
|
||||
@ -893,7 +977,11 @@ const file_hodu_proto_rawDesc = "" +
|
||||
"RPTY_START\x10\x0e\x12\r\n" +
|
||||
"\tRPTY_STOP\x10\x0f\x12\r\n" +
|
||||
"\tRPTY_DATA\x10\x10\x12\r\n" +
|
||||
"\tRPTY_SIZE\x10\x112I\n" +
|
||||
"\tRPTY_SIZE\x10\x11\x12\r\n" +
|
||||
"\tRPX_START\x10\x12\x12\f\n" +
|
||||
"\bRPX_STOP\x10\x13\x12\f\n" +
|
||||
"\bRPX_DATA\x10\x14\x12\v\n" +
|
||||
"\aRPX_EOF\x10\x152I\n" +
|
||||
"\x04Hodu\x12\x19\n" +
|
||||
"\aGetSeed\x12\x05.Seed\x1a\x05.Seed\"\x00\x12&\n" +
|
||||
"\fPacketStream\x12\a.Packet\x1a\a.Packet\"\x00(\x010\x01B\bZ\x06./hodub\x06proto3"
|
||||
@ -911,7 +999,7 @@ func file_hodu_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_hodu_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||
var file_hodu_proto_msgTypes = make([]protoimpl.MessageInfo, 9)
|
||||
var file_hodu_proto_msgTypes = make([]protoimpl.MessageInfo, 10)
|
||||
var file_hodu_proto_goTypes = []any{
|
||||
(ROUTE_OPTION)(0), // 0: ROUTE_OPTION
|
||||
(PACKET_KIND)(0), // 1: PACKET_KIND
|
||||
@ -923,7 +1011,8 @@ var file_hodu_proto_goTypes = []any{
|
||||
(*ConnError)(nil), // 7: ConnError
|
||||
(*ConnNotice)(nil), // 8: ConnNotice
|
||||
(*RptyEvent)(nil), // 9: RptyEvent
|
||||
(*Packet)(nil), // 10: Packet
|
||||
(*RpxEvent)(nil), // 10: RpxEvent
|
||||
(*Packet)(nil), // 11: Packet
|
||||
}
|
||||
var file_hodu_proto_depIdxs = []int32{
|
||||
1, // 0: Packet.Kind:type_name -> PACKET_KIND
|
||||
@ -934,15 +1023,16 @@ var file_hodu_proto_depIdxs = []int32{
|
||||
7, // 5: Packet.ConnErr:type_name -> ConnError
|
||||
8, // 6: Packet.ConnNoti:type_name -> ConnNotice
|
||||
9, // 7: Packet.RptyEvt:type_name -> RptyEvent
|
||||
2, // 8: Hodu.GetSeed:input_type -> Seed
|
||||
10, // 9: Hodu.PacketStream:input_type -> Packet
|
||||
2, // 10: Hodu.GetSeed:output_type -> Seed
|
||||
10, // 11: Hodu.PacketStream:output_type -> Packet
|
||||
10, // [10:12] is the sub-list for method output_type
|
||||
8, // [8:10] is the sub-list for method input_type
|
||||
8, // [8:8] is the sub-list for extension type_name
|
||||
8, // [8:8] is the sub-list for extension extendee
|
||||
0, // [0:8] is the sub-list for field type_name
|
||||
10, // 8: Packet.RpxEvt:type_name -> RpxEvent
|
||||
2, // 9: Hodu.GetSeed:input_type -> Seed
|
||||
11, // 10: Hodu.PacketStream:input_type -> Packet
|
||||
2, // 11: Hodu.GetSeed:output_type -> Seed
|
||||
11, // 12: Hodu.PacketStream:output_type -> Packet
|
||||
11, // [11:13] is the sub-list for method output_type
|
||||
9, // [9:11] is the sub-list for method input_type
|
||||
9, // [9:9] is the sub-list for extension type_name
|
||||
9, // [9:9] is the sub-list for extension extendee
|
||||
0, // [0:9] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_hodu_proto_init() }
|
||||
@ -950,7 +1040,7 @@ func file_hodu_proto_init() {
|
||||
if File_hodu_proto != nil {
|
||||
return
|
||||
}
|
||||
file_hodu_proto_msgTypes[8].OneofWrappers = []any{
|
||||
file_hodu_proto_msgTypes[9].OneofWrappers = []any{
|
||||
(*Packet_Route)(nil),
|
||||
(*Packet_Peer)(nil),
|
||||
(*Packet_Data)(nil),
|
||||
@ -958,6 +1048,7 @@ func file_hodu_proto_init() {
|
||||
(*Packet_ConnErr)(nil),
|
||||
(*Packet_ConnNoti)(nil),
|
||||
(*Packet_RptyEvt)(nil),
|
||||
(*Packet_RpxEvt)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
@ -965,7 +1056,7 @@ func file_hodu_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_hodu_proto_rawDesc), len(file_hodu_proto_rawDesc)),
|
||||
NumEnums: 2,
|
||||
NumMessages: 9,
|
||||
NumMessages: 10,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
13
hodu.proto
13
hodu.proto
@ -85,6 +85,11 @@ message RptyEvent {
|
||||
bytes Data = 2;
|
||||
};
|
||||
|
||||
message RpxEvent {
|
||||
uint64 Id = 1;
|
||||
bytes Data = 2;
|
||||
};
|
||||
|
||||
enum PACKET_KIND {
|
||||
RESERVED = 0; // not used
|
||||
ROUTE_START = 1;
|
||||
@ -103,7 +108,12 @@ enum PACKET_KIND {
|
||||
RPTY_START = 14;
|
||||
RPTY_STOP = 15;
|
||||
RPTY_DATA = 16;
|
||||
RPTY_SIZE = 17;
|
||||
RPTY_SIZE = 17; // terminal size
|
||||
|
||||
RPX_START = 18;
|
||||
RPX_STOP = 19;
|
||||
RPX_DATA = 20;
|
||||
RPX_EOF = 21;
|
||||
};
|
||||
|
||||
message Packet {
|
||||
@ -117,5 +127,6 @@ message Packet {
|
||||
ConnError ConnErr = 6;
|
||||
ConnNotice ConnNoti = 7;
|
||||
RptyEvent RptyEvt = 8;
|
||||
RpxEvent RpxEvt = 9;
|
||||
};
|
||||
}
|
||||
|
19
packet.go
19
packet.go
@ -80,6 +80,7 @@ func MakeRptyStartPacket(id uint64) *Packet {
|
||||
}
|
||||
|
||||
func MakeRptyStopPacket(id uint64, msg string) *Packet {
|
||||
// the rpty stop conveys an error/info message
|
||||
return &Packet{Kind: PACKET_KIND_RPTY_STOP, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: []byte(msg)}}}
|
||||
}
|
||||
|
||||
@ -90,3 +91,21 @@ func MakeRptyDataPacket(id uint64, data []byte) *Packet {
|
||||
func MakeRptySizePacket(id uint64, data []byte) *Packet {
|
||||
return &Packet{Kind: PACKET_KIND_RPTY_SIZE, U: &Packet_RptyEvt{RptyEvt: &RptyEvent{Id: id, Data: data}}}
|
||||
}
|
||||
|
||||
func MakeRpxStartPacket(id uint64, hdr_part []byte) *Packet {
|
||||
// the rpx start conveys the data unlike other Start packets...
|
||||
return &Packet{Kind: PACKET_KIND_RPX_START, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id, Data: hdr_part}}}
|
||||
}
|
||||
|
||||
func MakeRpxStopPacket(id uint64) *Packet {
|
||||
// the rpx start conveys the data unlike other Start packets...
|
||||
return &Packet{Kind: PACKET_KIND_RPX_STOP, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id}}}
|
||||
}
|
||||
|
||||
func MakeRpxDataPacket(id uint64, data_part []byte) *Packet {
|
||||
return &Packet{Kind: PACKET_KIND_RPX_DATA, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id, Data: data_part}}}
|
||||
}
|
||||
|
||||
func MakeRpxEofPacket(id uint64) *Packet {
|
||||
return &Packet{Kind: PACKET_KIND_RPX_EOF, U: &Packet_RpxEvt{RpxEvt: &RpxEvent{Id: id}}}
|
||||
}
|
||||
|
@ -76,7 +76,8 @@ type json_out_server_stats struct {
|
||||
ServerPeers int64 `json:"server-peers"`
|
||||
|
||||
SshProxySessions int64 `json:"pxy-ssh-sessions"`
|
||||
ServerPtySessions int64 `json:"server-pty-sessions"`
|
||||
ServerPtySessions int64 `json:"server-pty-sessions"`
|
||||
ServerRptySessions int64 `json:"server-rpty-sessions"`
|
||||
}
|
||||
|
||||
// this is a more specialized variant of json_in_notice
|
||||
@ -921,6 +922,7 @@ func (ctl *server_ctl_stats) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
stats.ServerPeers = s.stats.peers.Load()
|
||||
stats.SshProxySessions = s.stats.ssh_proxy_sessions.Load()
|
||||
stats.ServerPtySessions = s.stats.pty_sessions.Load()
|
||||
stats.ServerRptySessions = s.stats.rpty_sessions.Load()
|
||||
status_code = WriteJsonRespHeader(w, http.StatusOK)
|
||||
if err = je.Encode(stats); err != nil { goto oops }
|
||||
|
||||
|
@ -12,6 +12,7 @@ type ServerCollector struct {
|
||||
ServerPeers *prometheus.Desc
|
||||
SshProxySessions *prometheus.Desc
|
||||
PtySessions *prometheus.Desc
|
||||
RptySessions *prometheus.Desc
|
||||
}
|
||||
|
||||
// NewServerCollector returns a new ServerCollector with all prometheus.Desc initialized
|
||||
@ -58,6 +59,11 @@ func NewServerCollector(server *Server) ServerCollector {
|
||||
"Number of pty session",
|
||||
nil, nil,
|
||||
),
|
||||
RptySessions: prometheus.NewDesc(
|
||||
prefix + "rpty_sessions",
|
||||
"Number of rpty session",
|
||||
nil, nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,6 +74,7 @@ func (c ServerCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- c.ServerPeers
|
||||
ch <- c.SshProxySessions
|
||||
ch <- c.PtySessions
|
||||
ch <- c.RptySessions
|
||||
}
|
||||
|
||||
func (c ServerCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
@ -110,4 +117,10 @@ func (c ServerCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
prometheus.GaugeValue,
|
||||
float64(c.server.stats.pty_sessions.Load()),
|
||||
)
|
||||
|
||||
ch <- prometheus.MustNewConstMetric(
|
||||
c.RptySessions,
|
||||
prometheus.GaugeValue,
|
||||
float64(c.server.stats.rpty_sessions.Load()),
|
||||
)
|
||||
}
|
||||
|
@ -102,6 +102,16 @@ wait_for_started:
|
||||
|
||||
for {
|
||||
n, err = spc.conn.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = pss.Send(MakePeerDataPacket(spc.route.Id, spc.conn_id, buf[:n]))
|
||||
if err2 != nil {
|
||||
spc.route.Cts.S.log.Write(spc.route.Cts.Sid, LOG_ERROR,
|
||||
"Failed to send data from peer(%d,%d,%s,%s) to client - %s",
|
||||
spc.route.Id, spc.conn_id, conn_raddr, conn_laddr, err2.Error())
|
||||
goto done
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "use of closed network connection") { // i don't like this way to check this error.
|
||||
err = pss.Send(MakePeerEofPacket(spc.route.Id, spc.conn_id))
|
||||
@ -119,14 +129,6 @@ wait_for_started:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
err = pss.Send(MakePeerDataPacket(spc.route.Id, spc.conn_id, buf[:n]))
|
||||
if err != nil {
|
||||
spc.route.Cts.S.log.Write(spc.route.Cts.Sid, LOG_ERROR,
|
||||
"Failed to send data from peer(%d,%d,%s,%s) to client - %s",
|
||||
spc.route.Id, spc.conn_id, conn_raddr, conn_laddr, err.Error())
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
wait_for_stopped:
|
||||
|
@ -67,7 +67,7 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
conn_ready = <-conn_ready_chan
|
||||
if conn_ready { // connected
|
||||
var poll_fds []unix.PollFd
|
||||
var buf []byte
|
||||
var buf [2048]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
@ -76,7 +76,6 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
}
|
||||
|
||||
s.stats.pty_sessions.Add(1)
|
||||
buf = make([]byte, 2048)
|
||||
for {
|
||||
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
|
||||
if err != nil {
|
||||
@ -94,20 +93,21 @@ func (pty *server_pty_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
|
||||
n, err = out.Read(buf)
|
||||
n, err = out.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
|
||||
if err2 != nil {
|
||||
s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to read pty stdout - %s", req.RemoteAddr, err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
err = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
|
||||
if err != nil {
|
||||
s.log.Write(pty.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.stats.pty_sessions.Add(-1)
|
||||
|
@ -375,7 +375,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ
|
||||
}
|
||||
proxy_url = pxy.req_to_proxy_url(req, pi)
|
||||
|
||||
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, get_raw_url_path(req), proxy_url)
|
||||
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s -> %+v", req.RemoteAddr, req.Method, req.RequestURI, proxy_url)
|
||||
|
||||
proxy_req, err = http.NewRequestWithContext(s.Ctx, req.Method, proxy_url.String(), req.Body)
|
||||
if err != nil {
|
||||
@ -401,7 +401,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ
|
||||
} else {
|
||||
status_code = resp.StatusCode
|
||||
if upgrade_required && resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
|
||||
s.log.Write(pxy.Id, LOG_INFO, "[%s] %s %s %d", req.RemoteAddr, req.Method, req.RequestURI, status_code)
|
||||
err = pxy.serve_upgraded(w, req, resp)
|
||||
if err != nil { goto oops }
|
||||
return 0, nil// print the log mesage before calling serve_upgraded() and exit here
|
||||
@ -426,7 +426,7 @@ func (pxy *server_pxy_http_main) ServeHTTP(w http.ResponseWriter, req *http.Requ
|
||||
|
||||
_, err = io.Copy(w, resp_body)
|
||||
if err != nil {
|
||||
s.log.Write(pxy.Id, LOG_WARN, "[%s] %s %s %s", req.RemoteAddr, req.Method, get_raw_url_path(req), err.Error())
|
||||
s.log.Write(pxy.Id, LOG_WARN, "[%s] %s %s %s", req.RemoteAddr, req.Method, req.RequestURI, err.Error())
|
||||
}
|
||||
|
||||
// TODO: handle trailers
|
||||
@ -689,27 +689,27 @@ func (pxy *server_pxy_ssh_ws) ServeWebsocket(ws *websocket.Conn) (int, error) {
|
||||
|
||||
conn_ready = <-conn_ready_chan
|
||||
if conn_ready { // connected
|
||||
var buf []byte
|
||||
var buf [2048]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
s.stats.ssh_proxy_sessions.Add(1)
|
||||
buf = make([]byte, 2048)
|
||||
for {
|
||||
n, err = out.Read(buf)
|
||||
n, err = out.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
|
||||
if err2 != nil {
|
||||
s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to read from SSH stdout - %s", req.RemoteAddr, err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
err = send_ws_data_for_xterm(ws, "iov", string(buf[:n]))
|
||||
if err != nil {
|
||||
s.log.Write(pxy.Id, LOG_ERROR, "[%s] Failed to send to websocket - %s", req.RemoteAddr, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.stats.ssh_proxy_sessions.Add(-1)
|
||||
}
|
||||
|
318
server-rpx.go
318
server-rpx.go
@ -1,6 +1,15 @@
|
||||
package hodu
|
||||
|
||||
import "bufio"
|
||||
import "bytes"
|
||||
import "errors"
|
||||
import "fmt"
|
||||
import "io"
|
||||
import "net"
|
||||
import "net/http"
|
||||
import "strconv"
|
||||
import "strings"
|
||||
import "sync"
|
||||
|
||||
type server_rpx struct {
|
||||
S *Server
|
||||
@ -8,28 +17,313 @@ type server_rpx struct {
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
func (pxy *server_rpx) Identity() string {
|
||||
return pxy.Id
|
||||
func (rpx *server_rpx) Identity() string {
|
||||
return rpx.Id
|
||||
}
|
||||
|
||||
func (pxy *server_rpx) Cors(req *http.Request) bool {
|
||||
func (rpx *server_rpx) Cors(req *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (pxy *server_rpx) Authenticate(req *http.Request) (int, string) {
|
||||
func (rpx *server_rpx) Authenticate(req *http.Request) (int, string) {
|
||||
return http.StatusOK, ""
|
||||
}
|
||||
|
||||
func (pxy *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
|
||||
var status_code int
|
||||
// var err error
|
||||
func (rpx *server_rpx) get_client_token(req *http.Request) string {
|
||||
var val string
|
||||
|
||||
status_code = http.StatusOK
|
||||
// TODO: enhance this client token extraction logic with some expression language?
|
||||
val = req.Header.Get(rpx.S.Cfg.RpxClientTokenAttrName)
|
||||
if val == "" { val = req.Host }
|
||||
|
||||
//done:
|
||||
return status_code, nil
|
||||
if rpx.S.Cfg.RpxClientTokenRegex != nil {
|
||||
val = get_regex_submatch(rpx.S.Cfg.RpxClientTokenRegex, val, rpx.S.Cfg.RpxClientTokenSubmatchIndex)
|
||||
}
|
||||
|
||||
//oops:
|
||||
// return status_code, err
|
||||
return val
|
||||
}
|
||||
|
||||
func (rpx* server_rpx) handle_header_data(rpx_id uint64, data []byte, w http.ResponseWriter) (int, error) {
|
||||
var sc *bufio.Scanner
|
||||
var line string
|
||||
var flds []string
|
||||
var status_code int
|
||||
var err error
|
||||
|
||||
sc = bufio.NewScanner(bytes.NewReader(data))
|
||||
sc.Scan()
|
||||
line = sc.Text()
|
||||
|
||||
flds = strings.Fields(line)
|
||||
if (len(flds) < 2) { // i care about the status code..
|
||||
return http.StatusBadGateway, fmt.Errorf("invalid response status for rpx(%d) - %s", rpx_id, line)
|
||||
}
|
||||
status_code, err = strconv.Atoi(flds[1])
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, fmt.Errorf("invalid response code for rpx(%d) - %s", rpx_id, err.Error())
|
||||
}
|
||||
|
||||
for sc.Scan() {
|
||||
line = sc.Text()
|
||||
if line == "" { break }
|
||||
flds = strings.SplitN(line, ":", 2)
|
||||
if len(flds) == 2 {
|
||||
w.Header().Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
|
||||
}
|
||||
}
|
||||
err = sc.Err()
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, fmt.Errorf("failed to parse response for rpx(%d) - %s", rpx_id, err.Error())
|
||||
}
|
||||
|
||||
w.WriteHeader(status_code)
|
||||
return status_code, nil
|
||||
}
|
||||
|
||||
func (rpx *server_rpx) handle_response(srpx *ServerRpx, req *http.Request, w http.ResponseWriter, ws_upgrade bool, wg *sync.WaitGroup) {
|
||||
var start_resp []byte
|
||||
var status_code int
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var wr io.Writer
|
||||
var wrote_br_chan bool
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
select {
|
||||
case start_resp = <- srpx.start_chan:
|
||||
// received the header. ready to proceed to the body
|
||||
// do nothing. just continue
|
||||
status_code, err = rpx.handle_header_data(srpx.id, start_resp, w)
|
||||
if err != nil { goto done }
|
||||
|
||||
case <- srpx.done_chan:
|
||||
err = fmt.Errorf("rpx(%d) terminated before receiving header", srpx.id)
|
||||
status_code = http.StatusBadGateway
|
||||
goto done
|
||||
case <- req.Context().Done():
|
||||
err = fmt.Errorf("rpx(%d) terminated before receiving header - %s", srpx.id, req.Context().Err().Error())
|
||||
status_code = http.StatusBadGateway
|
||||
goto done
|
||||
|
||||
// no default. block
|
||||
}
|
||||
|
||||
if ws_upgrade && status_code == http.StatusSwitchingProtocols {
|
||||
var hijk http.Hijacker
|
||||
var conn net.Conn
|
||||
var ok bool
|
||||
|
||||
hijk, ok = w.(http.Hijacker)
|
||||
if !ok {
|
||||
err = fmt.Errorf("failed to upgrade rpx(%d) - not a hijacker", srpx.id)
|
||||
status_code = http.StatusInternalServerError
|
||||
goto done
|
||||
}
|
||||
|
||||
conn, _, err = hijk.Hijack()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to upgrade rpx(%d) - %s", srpx.id, err.Error())
|
||||
status_code = http.StatusInternalServerError
|
||||
goto done
|
||||
}
|
||||
|
||||
// websocket upgrade is successful
|
||||
srpx.br = conn
|
||||
srpx.br_chan <- true // inform another goroutine that the protocol switching is completed.
|
||||
wrote_br_chan = true
|
||||
|
||||
wr = conn
|
||||
} else {
|
||||
if ws_upgrade {
|
||||
srpx.br_chan <- false
|
||||
wrote_br_chan = true
|
||||
} // indicate upgrade failure
|
||||
wr = w
|
||||
}
|
||||
|
||||
for {
|
||||
n, err = srpx.pr.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
_, err2 = wr.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
status_code = http.StatusInternalServerError
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
} else {
|
||||
status_code = http.StatusInternalServerError
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
// just send another in case the code got jump into this part for an error
|
||||
// may not be consumed but the channel is large enough for redundant data
|
||||
srpx.resp_status_code = status_code
|
||||
srpx.resp_error = err
|
||||
|
||||
if ws_upgrade && !wrote_br_chan {
|
||||
srpx.br_chan <- false
|
||||
}
|
||||
}
|
||||
|
||||
func (rpx *server_rpx) alloc_server_rpx(cts *ServerConn, req *http.Request) (*ServerRpx, error) {
|
||||
var srpx *ServerRpx
|
||||
var start_id uint64
|
||||
var assigned_id uint64
|
||||
var ok bool
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
start_id = cts.rpx_next_id
|
||||
for {
|
||||
_, ok = cts.rpx_map[cts.rpx_next_id]
|
||||
if !ok {
|
||||
assigned_id = cts.rpx_next_id
|
||||
cts.rpx_next_id++
|
||||
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
|
||||
break
|
||||
}
|
||||
cts.rpx_next_id++
|
||||
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
|
||||
if cts.rpx_next_id == start_id {
|
||||
// unlikely to happen but it cycled through the whole range.
|
||||
cts.rpx_mtx.Unlock()
|
||||
return nil, fmt.Errorf("failed to assign id")
|
||||
}
|
||||
}
|
||||
|
||||
srpx = &ServerRpx{
|
||||
id: assigned_id,
|
||||
start_chan: make(chan []byte, 5),
|
||||
done_chan: make(chan bool, 5),
|
||||
br_chan: make(chan bool, 5),
|
||||
}
|
||||
srpx.br = req.Body
|
||||
srpx.pr, srpx.pw = io.Pipe()
|
||||
cts.rpx_map[assigned_id] = srpx
|
||||
|
||||
cts.rpx_mtx.Unlock()
|
||||
return srpx, nil
|
||||
}
|
||||
|
||||
func (rpx *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
|
||||
var s *Server
|
||||
var client_token string
|
||||
var start_sent bool
|
||||
var cts *ServerConn
|
||||
var status_code int
|
||||
var srpx *ServerRpx
|
||||
var ws_upgrade bool
|
||||
var buf [4096]byte
|
||||
var wg sync.WaitGroup
|
||||
var err error
|
||||
|
||||
s = rpx.S
|
||||
client_token = rpx.get_client_token(req)
|
||||
cts = s.FindServerConnByClientToken(client_token)
|
||||
if cts == nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusNotFound)
|
||||
err = fmt.Errorf("unknown client token - %s", client_token)
|
||||
goto oops
|
||||
}
|
||||
|
||||
srpx, err = rpx.alloc_server_rpx(cts, req)
|
||||
if err != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusServiceUnavailable)
|
||||
err = fmt.Errorf("unable to allocate rpx - %s", err.Error())
|
||||
goto oops
|
||||
}
|
||||
|
||||
// arrange to clear the rpx_map entry when this function exits
|
||||
defer func() {
|
||||
cts.rpx_mtx.Lock()
|
||||
delete(cts.rpx_map, srpx.id)
|
||||
cts.rpx_mtx.Unlock()
|
||||
}()
|
||||
|
||||
ws_upgrade = strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade");
|
||||
if ws_upgrade && req.ContentLength > 0 {
|
||||
// while other webservers are ok with upgrade request with body payload,
|
||||
// this program rejects such a request for impelementation limitation as
|
||||
// it's not dealing with a raw byte but is using the standard web server handler.
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
|
||||
err = fmt.Errorf("failed to assign id")
|
||||
goto oops
|
||||
}
|
||||
|
||||
err = cts.pss.Send(MakeRpxStartPacket(srpx.id, get_http_req_line_and_headers(req, true)))
|
||||
if err != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
|
||||
goto oops
|
||||
}
|
||||
start_sent = true
|
||||
|
||||
wg.Add(1)
|
||||
go rpx.handle_response(srpx, req, w, ws_upgrade, &wg)
|
||||
|
||||
if ws_upgrade {
|
||||
// wait until the protocol switching is done in rpx.handle_response()
|
||||
var upgraded bool
|
||||
upgraded = <- srpx.br_chan
|
||||
if upgraded {
|
||||
// arrange to close the hijacked connection inside rpx.handle_response()
|
||||
defer srpx.br.Close()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
var n int
|
||||
|
||||
n, err = srpx.br.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.pss.Send(MakeRpxDataPacket(srpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
|
||||
goto oops
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = cts.pss.Send(MakeRpxEofPacket(srpx.id))
|
||||
if err != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
|
||||
goto oops
|
||||
}
|
||||
break
|
||||
}
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusInternalServerError)
|
||||
goto oops
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if srpx.resp_error != nil {
|
||||
status_code = WriteEmptyRespHeader(w, srpx.resp_status_code)
|
||||
err = srpx.resp_error
|
||||
goto oops
|
||||
}
|
||||
|
||||
select {
|
||||
case <- srpx.done_chan:
|
||||
// anything to do?
|
||||
case <- req.Context().Done():
|
||||
// anything to do?
|
||||
// no default. block
|
||||
}
|
||||
|
||||
cts.pss.Send(MakeRpxStopPacket(srpx.id))
|
||||
return srpx.resp_status_code, nil
|
||||
|
||||
oops:
|
||||
if srpx != nil && start_sent { cts.pss.Send(MakeRpxStopPacket(srpx.id)) }
|
||||
return status_code, err
|
||||
}
|
||||
|
597
server.go
597
server.go
@ -10,6 +10,7 @@ import "log"
|
||||
import "net"
|
||||
import "net/http"
|
||||
import "net/netip"
|
||||
import "regexp"
|
||||
import "slices"
|
||||
import "strconv"
|
||||
import "strings"
|
||||
@ -32,8 +33,8 @@ const CTS_LIMIT int = 16384
|
||||
type PortId uint16
|
||||
const PORT_ID_MARKER string = "_"
|
||||
const HS_ID_CTL string = "ctl"
|
||||
const HS_ID_RPX string = "pxy"
|
||||
const HS_ID_PXY string = "rpx"
|
||||
const HS_ID_RPX string = "rpx"
|
||||
const HS_ID_PXY string = "pxy"
|
||||
const HS_ID_WPX string = "wpx"
|
||||
|
||||
type ServerConnMapByAddr map[net.Addr]*ServerConn
|
||||
@ -45,6 +46,7 @@ type ServerSvcPortMap map[PortId]ConnRouteId
|
||||
|
||||
type ServerRptyMap map[uint64]*ServerRpty
|
||||
type ServerRptyMapByWs map[*websocket.Conn]*ServerRpty
|
||||
type ServerRpxMap map[uint64]*ServerRpx
|
||||
|
||||
type ServerWpxResponseTransformer func(r *ServerRouteProxyInfo, resp *http.Response) io.Reader
|
||||
type ServerWpxForeignPortProxyMaker func(wpx_type string, port_id string) (*ServerRouteProxyInfo, error)
|
||||
@ -67,6 +69,9 @@ type ServerConfig struct {
|
||||
|
||||
RpxAddrs []string
|
||||
RpxTls *tls.Config
|
||||
RpxClientTokenAttrName string
|
||||
RpxClientTokenRegex *regexp.Regexp
|
||||
RpxClientTokenSubmatchIndex int
|
||||
|
||||
PxyAddrs []string
|
||||
PxyTls *tls.Config
|
||||
@ -166,12 +171,12 @@ type Server struct {
|
||||
peers atomic.Int64
|
||||
ssh_proxy_sessions atomic.Int64
|
||||
pty_sessions atomic.Int64
|
||||
rpty_sessions atomic.Int64
|
||||
}
|
||||
|
||||
wpx_resp_tf ServerWpxResponseTransformer
|
||||
wpx_foreign_port_proxy_maker ServerWpxForeignPortProxyMaker
|
||||
|
||||
|
||||
pty_user string
|
||||
pty_shell string
|
||||
xterm_html string
|
||||
@ -202,6 +207,10 @@ type ServerConn struct {
|
||||
rpty_map ServerRptyMap
|
||||
rpty_map_by_ws ServerRptyMapByWs
|
||||
|
||||
rpx_next_id uint64
|
||||
rpx_mtx sync.Mutex
|
||||
rpx_map ServerRpxMap
|
||||
|
||||
wg sync.WaitGroup
|
||||
stop_req atomic.Bool
|
||||
stop_chan chan bool
|
||||
@ -236,6 +245,20 @@ type ServerRpty struct {
|
||||
ws *websocket.Conn
|
||||
}
|
||||
|
||||
type ServerRpx struct {
|
||||
id uint64
|
||||
pr *io.PipeReader
|
||||
pw *io.PipeWriter
|
||||
br io.ReadCloser // body reader
|
||||
start_chan chan []byte
|
||||
done_chan chan bool
|
||||
br_chan chan bool
|
||||
|
||||
resp_status_code int
|
||||
resp_error error
|
||||
resp_done_chan chan bool
|
||||
}
|
||||
|
||||
type GuardedPacketStreamServer struct {
|
||||
mtx sync.Mutex
|
||||
//pss Hodu_PacketStreamServer
|
||||
@ -265,6 +288,16 @@ func (g *GuardedPacketStreamServer) Context() context.Context {
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
func (rpty *ServerRpty) ReqStop() {
|
||||
rpty.ws.Close()
|
||||
}
|
||||
|
||||
func (rpx *ServerRpx) ReqStop() {
|
||||
rpx.done_chan <- true
|
||||
rpx.pw.Close()
|
||||
}
|
||||
// ------------------------------------
|
||||
|
||||
func NewServerRoute(cts *ServerConn, id RouteId, option RouteOption, ptc_addr string, ptc_name string, svc_requested_addr string, svc_permitted_net string) (*ServerRoute, error) {
|
||||
var r ServerRoute
|
||||
var l *net.TCPListener
|
||||
@ -658,6 +691,7 @@ func (cts *ServerConn) ReqStopAllServerRoutes() {
|
||||
cts.route_mtx.Unlock()
|
||||
}
|
||||
|
||||
// Rpty
|
||||
func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) {
|
||||
var ok bool
|
||||
var start_id uint64
|
||||
@ -707,11 +741,11 @@ func (cts *ServerConn) StartRpty(ws *websocket.Conn) (*ServerRpty, error) {
|
||||
return nil , err
|
||||
}
|
||||
|
||||
cts.S.stats.rpty_sessions.Add(1)
|
||||
return rpty, nil
|
||||
}
|
||||
|
||||
func (cts *ServerConn) StopRpty(ws *websocket.Conn) error {
|
||||
|
||||
// called by the websocket handler.
|
||||
var rpty *ServerRpty
|
||||
var id uint64
|
||||
@ -721,7 +755,9 @@ func (cts *ServerConn) StopRpty(ws *websocket.Conn) error {
|
||||
cts.rpty_mtx.Lock()
|
||||
rpty, ok = cts.rpty_map_by_ws[ws]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr())
|
||||
cts.rpty_mtx.Unlock()
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Unknown websocket connection for rpty - websocket %v", ws.RemoteAddr())
|
||||
return fmt.Errorf("unknown websocket connection for rpty - %v", ws.RemoteAddr())
|
||||
}
|
||||
|
||||
id = rpty.id
|
||||
@ -730,14 +766,24 @@ func (cts *ServerConn) StopRpty(ws *websocket.Conn) error {
|
||||
// send the stop request to the client side
|
||||
err = cts.pss.Send(MakeRptyStopPacket(id, ""))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to send stop rpty request to client - %s", err.Error())
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to send RPTY_STOP(%d) for server %s websocket %v - %s", id, cts.RemoteAddr, ws.RemoteAddr(), err.Error())
|
||||
// carry on
|
||||
}
|
||||
|
||||
// delete the rpty entry from the maps as the websocket
|
||||
// handler is ending
|
||||
cts.rpty_mtx.Lock()
|
||||
delete(cts.rpty_map, id)
|
||||
delete(cts.rpty_map_by_ws, ws)
|
||||
cts.rpty_mtx.Unlock()
|
||||
cts.S.stats.rpty_sessions.Add(-1)
|
||||
|
||||
cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for server %s websocket %vs", id, cts.RemoteAddr, ws.RemoteAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error {
|
||||
// called this when the stop requested comes from the client
|
||||
// call this when the stop requested comes from the client.
|
||||
// abort the websocket side.
|
||||
|
||||
var rpty *ServerRpty
|
||||
@ -746,11 +792,12 @@ func (cts *ServerConn) StopRptyWsById(id uint64, msg string) error {
|
||||
cts.rpty_mtx.Lock()
|
||||
rpty, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
rpty.ws.Close()
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
rpty.ReqStop()
|
||||
cts.S.log.Write(cts.Sid, LOG_INFO, "Stopped rpty(%d) for %s - %s", id, cts.RemoteAddr, msg)
|
||||
return nil
|
||||
}
|
||||
@ -764,6 +811,7 @@ func (cts *ServerConn) WriteRpty(ws *websocket.Conn, data []byte) error {
|
||||
cts.rpty_mtx.Lock()
|
||||
rpty, ok = cts.rpty_map_by_ws[ws]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown ws connection for rpty - %v", ws.RemoteAddr())
|
||||
}
|
||||
|
||||
@ -787,6 +835,7 @@ func (cts *ServerConn) WriteRptySize(ws *websocket.Conn, data []byte) error {
|
||||
cts.rpty_mtx.Lock()
|
||||
rpty, ok = cts.rpty_map_by_ws[ws]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown ws connection for rpty size - %v", ws.RemoteAddr())
|
||||
}
|
||||
|
||||
@ -812,33 +861,16 @@ func (cts *ServerConn) ReadRptyAndWriteWs(id uint64, data []byte) error {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("unknown rpty id - %d", id)
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
err = send_ws_data_for_xterm(rpty.ws, "iov", string(data))
|
||||
if err != nil {
|
||||
cts.rpty_mtx.Unlock()
|
||||
return fmt.Errorf("failed to write rpty data(%d) to ws - %s", id, err.Error())
|
||||
}
|
||||
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
r, ok = cts.route_map[route_id]
|
||||
if !ok {
|
||||
cts.route_mtx.Unlock()
|
||||
return fmt.Errorf("non-existent route id - %d", route_id)
|
||||
}
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
return r.ReportPacket(pts_id, packet_type, event_data)
|
||||
}
|
||||
|
||||
func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPTY_STOP:
|
||||
@ -853,6 +885,80 @@ func (cts *ServerConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rpx
|
||||
func (cts *ServerConn) StartRpxWebById(srpx* ServerRpx, id uint64, data []byte) error {
|
||||
// pass the initial response to code in server-rpx.go
|
||||
srpx.start_chan <- data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ServerConn) StopRpxWebById(srpx* ServerRpx, id uint64) error {
|
||||
srpx.ReqStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ServerConn) WroteRpxWebById(srpx* ServerRpx, id uint64, data []byte) error {
|
||||
var err error
|
||||
_, err = srpx.pw.Write(data)
|
||||
if err != nil {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx data(%d) to rpx pipe - %s", id, err.Error())
|
||||
srpx.ReqStop()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (cts *ServerConn) EofRpxWebById(srpx* ServerRpx, id uint64) error {
|
||||
srpx.ReqStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ServerConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
|
||||
var ok bool
|
||||
var rpx* ServerRpx
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
rpx, ok = cts.rpx_map[evt.Id]
|
||||
if !ok {
|
||||
cts.rpx_mtx.Unlock()
|
||||
return fmt.Errorf("unknown rpx id - %v", evt.Id)
|
||||
}
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPX_START:
|
||||
return cts.StartRpxWebById(rpx, evt.Id, evt.Data)
|
||||
|
||||
case PACKET_KIND_RPX_STOP:
|
||||
// stop requested from the server
|
||||
return cts.StopRpxWebById(rpx, evt.Id)
|
||||
|
||||
case PACKET_KIND_RPX_EOF:
|
||||
return cts.EofRpxWebById(rpx, evt.Id)
|
||||
|
||||
case PACKET_KIND_RPX_DATA:
|
||||
return cts.WroteRpxWebById(rpx, evt.Id, evt.Data)
|
||||
}
|
||||
|
||||
// ignore other packet types
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rpx
|
||||
func (cts *ServerConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type PACKET_KIND, event_data interface{}) error {
|
||||
var r *ServerRoute
|
||||
var ok bool
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
r, ok = cts.route_map[route_id]
|
||||
if !ok {
|
||||
cts.route_mtx.Unlock()
|
||||
return fmt.Errorf("non-existent route id - %d", route_id)
|
||||
}
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
return r.ReportPacket(pts_id, packet_type, event_data)
|
||||
}
|
||||
|
||||
func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
||||
var pkt *Packet
|
||||
var err error
|
||||
@ -1055,10 +1161,30 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
||||
err = cts.HandleRptyEvent(pkt.Kind, x.RptyEvt)
|
||||
if err != nil {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpty(%d) from %s - %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr, err.Error())
|
||||
} else {
|
||||
} else {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpty(%d) from %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr)
|
||||
}
|
||||
|
||||
case PACKET_KIND_RPX_START: // the client sends the response header using START
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_STOP:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_EOF:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_DATA:
|
||||
var x *Packet_RpxEvt
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_RpxEvt)
|
||||
if ok {
|
||||
err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt)
|
||||
if err != nil {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpx(%d) from %s - %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr, err.Error())
|
||||
} else {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpx(%d) from %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr)
|
||||
}
|
||||
} else {
|
||||
cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr)
|
||||
}
|
||||
@ -1066,17 +1192,26 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) {
|
||||
}
|
||||
|
||||
done:
|
||||
|
||||
// arrange to break all rpty resources
|
||||
cts.rpty_mtx.Lock()
|
||||
if len(cts.rpty_map) > 0 {
|
||||
var rpty *ServerRpty
|
||||
for _, rpty = range cts.rpty_map {
|
||||
rpty.ws.Close()
|
||||
rpty.ReqStop()
|
||||
}
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
// arrange to break all rpx resources
|
||||
cts.rpx_mtx.Lock()
|
||||
if len(cts.rpx_map) > 0 {
|
||||
var rpx *ServerRpx
|
||||
for _, rpx = range cts.rpx_map {
|
||||
rpx.ReqStop()
|
||||
}
|
||||
}
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
cts.S.log.Write(cts.Sid, LOG_INFO, "RPC stream receiver ended")
|
||||
}
|
||||
|
||||
@ -1109,7 +1244,7 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) {
|
||||
cts.route_wg.Add(1)
|
||||
|
||||
// start the loop inside a goroutine so that route_wg counter
|
||||
// is likely to be greater than 1 what Wait() is called.
|
||||
// is likely to be greater than 1 when Wait() is called.
|
||||
go func() {
|
||||
waiting_loop:
|
||||
for {
|
||||
@ -1145,11 +1280,22 @@ func (cts *ServerConn) RunTask(wg *sync.WaitGroup) {
|
||||
func (cts *ServerConn) ReqStop() {
|
||||
if cts.stop_req.CompareAndSwap(false, true) {
|
||||
var r *ServerRoute
|
||||
var rpty *ServerRpty
|
||||
var srpx *ServerRpx
|
||||
|
||||
cts.route_mtx.Lock()
|
||||
for _, r = range cts.route_map { r.ReqStop() }
|
||||
cts.route_mtx.Unlock()
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
for _, rpty = range cts.rpty_map { rpty.ReqStop() }
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
for _, srpx = range cts.rpx_map { srpx.ReqStop() }
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
|
||||
// there is no good way to break a specific connection client to
|
||||
// the grpc server. while the global grpc server is closed in
|
||||
// ReqStop() for Server, the individuation connection is closed
|
||||
@ -1359,13 +1505,14 @@ func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ
|
||||
|
||||
type server_http_log_writer struct {
|
||||
svr *Server
|
||||
depth int
|
||||
}
|
||||
|
||||
func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
|
||||
// the standard http.Server always requires *log.Logger
|
||||
// use this iowriter to create a logger to pass it to the http server.
|
||||
// since this is another log write wrapper, give adjustment value
|
||||
hlw.svr.log.WriteWithCallDepth("", LOG_INFO, +1, string(p))
|
||||
hlw.svr.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@ -1422,13 +1569,13 @@ func (s *Server) WrapHttpHandler(handler ServerHttpHandler) http.Handler {
|
||||
WriteEmptyRespHeader(w, status_code)
|
||||
}
|
||||
}
|
||||
time_taken = time.Now().Sub(start_time)
|
||||
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
|
||||
|
||||
if status_code > 0 {
|
||||
if err != nil {
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error())
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
|
||||
} else {
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds())
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -1440,7 +1587,7 @@ func (s *Server) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle
|
||||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
var status_code int
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
|
||||
s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
|
||||
s.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code)
|
||||
return
|
||||
}
|
||||
handler.ServeHTTP(w, req)
|
||||
@ -1454,21 +1601,19 @@ func (s *Server) WrapWebsocketHandler(handler ServerWebsocketHandler) websocket.
|
||||
var start_time time.Time
|
||||
var time_taken time.Duration
|
||||
var req *http.Request
|
||||
var raw_url_path string
|
||||
|
||||
req = ws.Request()
|
||||
raw_url_path = get_raw_url_path(req)
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path)
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI)
|
||||
|
||||
start_time = time.Now()
|
||||
status_code, err = handler.ServeWebsocket(ws)
|
||||
time_taken = time.Now().Sub(start_time)
|
||||
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
|
||||
|
||||
if status_code > 0 {
|
||||
if err != nil {
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error())
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
|
||||
} else {
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds())
|
||||
s.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -1479,7 +1624,6 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
var l *net.TCPListener
|
||||
var rpcaddr *net.TCPAddr
|
||||
var addr string
|
||||
var gl *net.TCPListener
|
||||
var i int
|
||||
var hs_log *log.Logger
|
||||
var opts []grpc.ServerOption
|
||||
@ -1536,7 +1680,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
|
||||
// ---------------------------------------------------------
|
||||
|
||||
hs_log = log.New(&server_http_log_writer{svr: &s}, "", 0)
|
||||
hs_log = log.New(&server_http_log_writer{svr: &s, depth: +2}, "", 0)
|
||||
|
||||
// ---------------------------------------------------------
|
||||
|
||||
@ -1626,7 +1770,8 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
s.ctl[i] = &http.Server{
|
||||
Addr: cfg.CtlAddrs[i],
|
||||
Handler: s.ctl_mux,
|
||||
TLSConfig: s.Cfg.CtlTls,
|
||||
// race condition issues without cloning. the http package modifies some fields in the configuration object
|
||||
TLSConfig: cfg.CtlTls.Clone(),
|
||||
ErrorLog: hs_log,
|
||||
// TODO: more settings
|
||||
}
|
||||
@ -1638,12 +1783,11 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
s.rpx_mux.Handle("/", s.WrapHttpHandler(&server_rpx{ S: &s, Id: HS_ID_RPX }))
|
||||
|
||||
s.rpx = make([]*http.Server, len(cfg.RpxAddrs))
|
||||
|
||||
for i = 0; i < len(cfg.RpxAddrs); i++ {
|
||||
s.rpx[i] = &http.Server{
|
||||
Addr: cfg.RpxAddrs[i],
|
||||
Handler: s.rpx_mux,
|
||||
TLSConfig: cfg.RpxTls,
|
||||
TLSConfig: cfg.RpxTls.Clone(),
|
||||
ErrorLog: hs_log,
|
||||
// TODO: more settings
|
||||
}
|
||||
@ -1696,7 +1840,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
s.pxy[i] = &http.Server{
|
||||
Addr: cfg.PxyAddrs[i],
|
||||
Handler: s.pxy_mux,
|
||||
TLSConfig: cfg.PxyTls,
|
||||
TLSConfig: cfg.PxyTls.Clone(),
|
||||
ErrorLog: hs_log,
|
||||
// TODO: more settings
|
||||
}
|
||||
@ -1746,7 +1890,7 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
s.wpx[i] = &http.Server{
|
||||
Addr: cfg.WpxAddrs[i],
|
||||
Handler: s.wpx_mux,
|
||||
TLSConfig: cfg.WpxTls,
|
||||
TLSConfig: cfg.WpxTls.Clone(),
|
||||
ErrorLog: hs_log,
|
||||
}
|
||||
}
|
||||
@ -1757,11 +1901,11 @@ func NewServer(ctx context.Context, name string, logger Logger, cfg *ServerConfi
|
||||
s.stats.peers.Store(0)
|
||||
s.stats.ssh_proxy_sessions.Store(0)
|
||||
s.stats.pty_sessions.Store(0)
|
||||
s.stats.rpty_sessions.Store(0)
|
||||
|
||||
return &s, nil
|
||||
|
||||
oops:
|
||||
if gl != nil { gl.Close() }
|
||||
for _, l = range s.rpc { l.Close() }
|
||||
s.rpc = make([]*net.TCPListener, 0)
|
||||
return nil, err
|
||||
@ -1840,16 +1984,10 @@ func (s *Server) RunTask(wg *sync.WaitGroup) {
|
||||
go s.run_grpc_server(idx, &s.rpc_wg)
|
||||
}
|
||||
|
||||
// most work is done by in separate goroutines (s.run_grp_server)
|
||||
// this loop serves as a placeholder to prevent the logic flow from
|
||||
// most work is done in separate goroutines (s.run_grp_server)
|
||||
// this read on the stop channel serves as a placeholder to prevent the logic flow from
|
||||
// descening down to s.ReqStop()
|
||||
task_loop:
|
||||
for {
|
||||
select {
|
||||
case <-s.stop_chan:
|
||||
break task_loop
|
||||
}
|
||||
}
|
||||
<-s.stop_chan
|
||||
|
||||
s.ReqStop()
|
||||
|
||||
@ -1863,8 +2001,52 @@ task_loop:
|
||||
s.rpc_svr.Stop()
|
||||
}
|
||||
|
||||
func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
|
||||
func (s* Server) run_single_ctl_server(i int, cs *http.Server, wg* sync.WaitGroup) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, cs.Addr)
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
||||
// err = cs.ListenAndServe()
|
||||
// err = cs.ListenAndServeTLS("", "")
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.ctl_addrs_mtx.Lock()
|
||||
node = s.ctl_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.ctl_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.CtlTls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.CtlTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.ctl_addrs_mtx.Lock()
|
||||
s.ctl_addrs.Remove(node)
|
||||
s.ctl_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
|
||||
var ctl *http.Server
|
||||
var idx int
|
||||
var l_wg sync.WaitGroup
|
||||
@ -1873,54 +2055,55 @@ func (s *Server) RunCtlTask(wg *sync.WaitGroup) {
|
||||
|
||||
for idx, ctl = range s.ctl {
|
||||
l_wg.Add(1)
|
||||
go func(i int, cs *http.Server) {
|
||||
var l net.Listener
|
||||
|
||||
s.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, cs.Addr)
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
||||
// err = cs.ListenAndServe()
|
||||
// err = cs.ListenAndServeTLS("", "")
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.ctl_addrs_mtx.Lock()
|
||||
node = s.ctl_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.ctl_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.CtlTls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.CtlTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.ctl_addrs_mtx.Lock()
|
||||
s.ctl_addrs.Remove(node)
|
||||
s.ctl_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
l_wg.Done()
|
||||
}(idx, ctl)
|
||||
go s.run_single_ctl_server(idx, ctl, &l_wg);
|
||||
}
|
||||
l_wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) RunRpxTask(wg *sync.WaitGroup) {
|
||||
func (s *Server) run_single_rpx_server(i int, cs *http.Server, wg* sync.WaitGroup) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
s.log.Write("", LOG_INFO, "RPX channel[%d] started on %s", i, s.Cfg.RpxAddrs[i])
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.rpx_addrs_mtx.Lock()
|
||||
node = s.rpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.rpx_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.RpxTls == nil { // TODO: change this
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.RpxTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.rpx_addrs_mtx.Lock()
|
||||
s.rpx_addrs.Remove(node)
|
||||
s.rpx_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "RPX channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "RPX channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) RunRpxTask(wg *sync.WaitGroup) {
|
||||
var rpx *http.Server
|
||||
var idx int
|
||||
var l_wg sync.WaitGroup
|
||||
@ -1929,51 +2112,54 @@ func (s *Server) RunRpxTask(wg *sync.WaitGroup) {
|
||||
|
||||
for idx, rpx = range s.rpx {
|
||||
l_wg.Add(1)
|
||||
go func(i int, cs *http.Server) {
|
||||
var l net.Listener
|
||||
|
||||
s.log.Write("", LOG_INFO, "RPX channel[%d] started on %s", i, s.Cfg.RpxAddrs[i])
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.rpx_addrs_mtx.Lock()
|
||||
node = s.rpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.rpx_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.RpxTls == nil { // TODO: change this
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.RpxTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.rpx_addrs_mtx.Lock()
|
||||
s.rpx_addrs.Remove(node)
|
||||
s.rpx_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "RPX channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "RPX channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
l_wg.Done()
|
||||
}(idx, rpx)
|
||||
go s.run_single_rpx_server(idx, rpx, &l_wg)
|
||||
}
|
||||
l_wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
|
||||
func (s *Server) run_single_pxy_server(i int, cs *http.Server, wg* sync.WaitGroup) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.Cfg.PxyAddrs[i])
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.pxy_addrs_mtx.Lock()
|
||||
node = s.pxy_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.pxy_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.PxyTls == nil { // TODO: change this
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.PxyTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.pxy_addrs_mtx.Lock()
|
||||
s.pxy_addrs.Remove(node)
|
||||
s.pxy_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
|
||||
var pxy *http.Server
|
||||
var idx int
|
||||
var l_wg sync.WaitGroup
|
||||
@ -1982,51 +2168,55 @@ func (s *Server) RunPxyTask(wg *sync.WaitGroup) {
|
||||
|
||||
for idx, pxy = range s.pxy {
|
||||
l_wg.Add(1)
|
||||
go func(i int, cs *http.Server) {
|
||||
var l net.Listener
|
||||
|
||||
s.log.Write("", LOG_INFO, "Proxy channel[%d] started on %s", i, s.Cfg.PxyAddrs[i])
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.pxy_addrs_mtx.Lock()
|
||||
node = s.pxy_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.pxy_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.PxyTls == nil { // TODO: change this
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.PxyTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.pxy_addrs_mtx.Lock()
|
||||
s.pxy_addrs.Remove(node)
|
||||
s.pxy_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "Proxy channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "Proxy channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
l_wg.Done()
|
||||
}(idx, pxy)
|
||||
go s.run_single_pxy_server(idx, pxy, &l_wg);
|
||||
}
|
||||
l_wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
|
||||
func (s *Server) run_single_wpx_server(i int, cs *http.Server, wg* sync.WaitGroup) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.Cfg.WpxAddrs[i])
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.wpx_addrs_mtx.Lock()
|
||||
node = s.wpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.wpx_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.WpxTls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.WpxTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.wpx_addrs_mtx.Lock()
|
||||
s.wpx_addrs.Remove(node)
|
||||
s.wpx_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
|
||||
var wpx *http.Server
|
||||
var idx int
|
||||
var l_wg sync.WaitGroup
|
||||
@ -2035,45 +2225,7 @@ func (s *Server) RunWpxTask(wg *sync.WaitGroup) {
|
||||
|
||||
for idx, wpx = range s.wpx {
|
||||
l_wg.Add(1)
|
||||
go func(i int, cs *http.Server) {
|
||||
var l net.Listener
|
||||
|
||||
s.log.Write("", LOG_INFO, "Wpx channel[%d] started on %s", i, s.Cfg.WpxAddrs[i])
|
||||
|
||||
if s.stop_req.Load() == false {
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if s.stop_req.Load() == false {
|
||||
var node *list.Element
|
||||
|
||||
s.wpx_addrs_mtx.Lock()
|
||||
node = s.wpx_addrs.PushBack(l.Addr().(*net.TCPAddr))
|
||||
s.wpx_addrs_mtx.Unlock()
|
||||
|
||||
if s.Cfg.WpxTls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // s.Cfg.WpxTls must provide a certificate and a key
|
||||
}
|
||||
|
||||
s.wpx_addrs_mtx.Lock()
|
||||
s.wpx_addrs.Remove(node)
|
||||
s.wpx_addrs_mtx.Unlock()
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
s.log.Write("", LOG_INFO, "Wpx channel[%d] ended", i)
|
||||
} else {
|
||||
s.log.Write("", LOG_ERROR, "Wpx channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
l_wg.Done()
|
||||
}(idx, wpx)
|
||||
go s.run_single_wpx_server(idx, wpx, &l_wg)
|
||||
}
|
||||
l_wg.Wait()
|
||||
}
|
||||
@ -2141,6 +2293,7 @@ func (s *Server) AddNewServerConn(remote_addr *net.Addr, local_addr *net.Addr, p
|
||||
|
||||
cts.rpty_map = make(ServerRptyMap)
|
||||
cts.rpty_map_by_ws = make(ServerRptyMapByWs)
|
||||
cts.rpx_map = make(ServerRpxMap)
|
||||
|
||||
s.cts_mtx.Lock()
|
||||
|
||||
|
Reference in New Issue
Block a user