added code for rpx handling

This commit is contained in:
2025-08-19 20:20:18 +09:00
parent 31a4223aab
commit 10c139e837
19 changed files with 1518 additions and 427 deletions

615
client.go
View File

@@ -1,5 +1,7 @@
package hodu
import "bufio"
import "bytes"
import "container/list"
import "context"
import "crypto/tls"
@@ -40,6 +42,7 @@ type ClientRouteMap map[RouteId]*ClientRoute
type ClientPeerConnMap map[PeerId]*ClientPeerConn
type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc
type ClientRptyMap map[uint64]*ClientRpty
type ClientRpxMap map[uint64]*ClientRpx
// --------------------------------------------------------------------
type ClientRouteConfig struct {
@@ -84,6 +87,10 @@ type ClientConfig struct {
PeerConnTmout time.Duration
Token string // to send to the server for identification
// default target for rpx
RpxTargetAddr string
RpxTargetTls *tls.Config
}
type ClientEventKind int
@@ -122,6 +129,10 @@ type Client struct {
ext_svcs []Service
rpc_tls *tls.Config
rpx_target_addr string
rpx_target_url string
rpx_target_tls *tls.Config
ctl_tls *tls.Config
ctl_addr []string
ctl_prefix string
@@ -155,6 +166,7 @@ type Client struct {
routes atomic.Int64
peers atomic.Int64
pty_sessions atomic.Int64
rpty_sessions atomic.Int64
}
pty_user string
@@ -176,6 +188,15 @@ type ClientRpty struct {
tty *os.File
}
type ClientRpx struct {
id uint64
pr *io.PipeReader
pw *io.PipeWriter
ctx context.Context
cancel context.CancelFunc
ws_conn net.Conn
}
// client connection to server
type ClientConn struct {
C *Client
@@ -209,6 +230,9 @@ type ClientConn struct {
rpty_mtx sync.Mutex
rpty_map ClientRptyMap
rpx_mtx sync.Mutex
rpx_map ClientRpxMap
stop_req atomic.Bool
stop_chan chan bool
@@ -294,6 +318,22 @@ func (g *GuardedPacketStreamClient) Context() context.Context {
return g.psc.Context()
}*/
// --------------------------------------------------------------------
func (rpty *ClientRpty) ReqStop() {
rpty.tty.Close()
rpty.cmd.Process.Kill()
}
func (rpx *ClientRpx) ReqStop() {
rpx.pr.Close()
rpx.pw.Close()
rpx.cancel()
if rpx.ws_conn != nil {
rpx.ws_conn.SetDeadline(time.Now()) // to make Read return immediately
}
}
// --------------------------------------------------------------------
func NewClientRoute(cts *ClientConn, id RouteId, static bool, client_peer_addr string, client_peer_name string, server_peer_svc_addr string, server_peer_svc_net string, server_peer_option RouteOption, lifetime time.Duration) *ClientRoute {
var r ClientRoute
@@ -397,7 +437,7 @@ func (r *ClientRoute) ExtendLifetime(lifetime time.Duration) error {
r.lifetime_timer.Stop()
r.Lifetime = r.Lifetime + lifetime
expiry = r.LifetimeStart.Add(r.Lifetime)
r.lifetime_timer.Reset(expiry.Sub(time.Now()))
r.lifetime_timer.Reset(time.Until(expiry)) // expiry.Sub(time.Now())
if r.cts.C.route_persister != nil { r.cts.C.route_persister.Save(r.cts, r) }
r.lifetime_mtx.Unlock()
@@ -477,10 +517,8 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) {
break waiting_loop
}
} else {
select {
case <-r.stop_chan:
break waiting_loop
}
<-r.stop_chan
break waiting_loop
}
}
@@ -547,7 +585,7 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
defer wg.Done()
tmout = time.Duration(r.cts.C.ptc_tmout)
if tmout <= 0 { tmout = 5 * time.Second} // TODO: make this configurable...
if tmout <= 0 { tmout = 5 * time.Second } // TODO: make this configurable...
waitctx, cancel_wait = context.WithTimeout(r.cts.C.Ctx, tmout)
r.ptc_mtx.Lock()
r.ptc_cancel_map[pts_id] = cancel_wait
@@ -571,8 +609,8 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
real_conn, ok = conn.(*net.TCPConn)
if !ok {
r.cts.C.log.Write(r.cts.Sid, LOG_ERROR,
"Failed to get connection information to %s for route(%d,%d,%s,%s) - %s",
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr, err.Error())
"Failed to get connection information to %s for route(%d,%d,%s,%s)",
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr)
goto peer_aborted
}
@@ -825,6 +863,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
cts.stop_chan = make(chan bool, 8)
cts.ptc_list = list.New()
cts.rpty_map = make(ClientRptyMap)
cts.rpx_map = make(ClientRpxMap)
for i, _ = range cts.cfg.Routes {
// override it to static regardless of the value passed in
@@ -833,7 +872,6 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
// the actual connection to the server is established in the main task function
// The cts.conn, cts.hdc, cts.psc fields are left unassigned here.
return &cts
}
@@ -1031,7 +1069,8 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error {
func (cts *ClientConn) disconnect_from_server(logmsg bool) {
if cts.conn != nil {
var r *ClientRoute
var crp *ClientRpty
var rpty *ClientRpty
var rpx *ClientRpx
cts.discon_mtx.Lock()
@@ -1045,15 +1084,20 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
// arrange to clean up all rpty objects
cts.rpty_mtx.Lock()
for _, crp = range cts.rpty_map {
crp.tty.Close()
crp.cmd.Process.Kill()
for _, rpty = range cts.rpty_map {
rpty.ReqStop()
// the loop in ReadRptyLoop() is supposed to be broken.
// let's not inform the server of this connection.
// the server should clean up itself upon connection error
}
cts.rpty_mtx.Unlock()
cts.rpx_mtx.Lock()
for _, rpx = range cts.rpx_map {
rpx.ReqStop()
}
cts.rpx_mtx.Unlock()
// don't care about double closes when this function is called from both RunTask() and ReqStop()
cts.conn.Close()
@@ -1217,7 +1261,7 @@ start_over:
pkt, err = psc.Recv()
if err != nil {
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) {
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
goto reconnect_to_server
} else {
cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr_p, err.Error())
@@ -1348,6 +1392,31 @@ start_over:
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
}
case PACKET_KIND_RPX_START:
fallthrough
case PACKET_KIND_RPX_STOP:
fallthrough
case PACKET_KIND_RPX_DATA:
fallthrough
case PACKET_KIND_RPX_EOF:
var x *Packet_RpxEvt
var ok bool
x, ok = pkt.U.(*Packet_RpxEvt)
if ok {
err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR,
"Failed to handle %s event for rpx(%d) from %s - %s",
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p, err.Error())
} else {
cts.C.log.Write(cts.Sid, LOG_DEBUG,
"Handled %s event for rpx(%d) from %s",
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p)
}
} else {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
}
default:
// do nothing. ignore the rest
}
@@ -1390,7 +1459,7 @@ reconnect_to_server:
goto wait_for_termination
case <-slpctx.Done():
select {
case <- cts.C.Ctx.Done():
case <-cts.C.Ctx.Done():
// non-blocking check if the parent context of the sleep context is
// terminated too. if so, this is normal termination case.
// this check seem redundant but the go-runtime doesn't seem to guarantee
@@ -1421,20 +1490,34 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type
}
// rpty
func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
cts.rpty_mtx.Unlock()
if !ok { crp = nil }
return crp
}
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
var poll_fds []unix.PollFd
var buf []byte
var buf [2048]byte
var n int
var err error
defer wg.Done()
cts.C.stats.rpty_sessions.Add(1)
poll_fds = []unix.PollFd{
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
}
buf = make([]byte, 2048)
for {
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
if err != nil {
@@ -1452,34 +1535,35 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
}
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
n, err = crp.tty.Read(buf)
n, err = crp.tty.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
break
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
}
break
}
if n > 0 {
err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error())
break
}
}
}
}
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
crp.tty.Close() // don't care about multiple closes
crp.cmd.Process.Kill()
crp.ReqStop()
crp.cmd.Wait()
cts.rpty_mtx.Lock()
delete(cts.rpty_map, crp.id)
cts.rpty_mtx.Unlock()
cts.C.stats.rpty_sessions.Add(-1)
}
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
@@ -1515,46 +1599,34 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
func (cts *ClientConn) StopRpty(id uint64) error {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Close() // to break ReadRptyLoop()
crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop()
cts.rpty_mtx.Unlock()
crp.ReqStop()
return nil
}
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
var crp *ClientRpty
var ok bool
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
crp.tty.Write(data)
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
var crp *ClientRpty
var ok bool
var flds []string
cts.rpty_mtx.Lock()
crp, ok = cts.rpty_map[id]
if !ok {
cts.rpty_mtx.Unlock()
crp = cts.FindClientRptyById(id)
if crp == nil {
return fmt.Errorf("unknown rpty id %d", id)
}
@@ -1566,11 +1638,11 @@ func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
cols, _ = strconv.Atoi(flds[1])
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
}
cts.rpty_mtx.Unlock()
return nil
}
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
switch packet_type {
case PACKET_KIND_RPTY_START:
return cts.StartRpty(evt.Id, &cts.C.wg)
@@ -1589,6 +1661,329 @@ func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent)
return nil
}
// rpx
func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
var crpx *ClientRpx
var ok bool
cts.rpx_mtx.Lock()
crpx, ok = cts.rpx_map[id]
cts.rpx_mtx.Unlock()
if !ok { crpx = nil }
return crpx
}
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
var sc *bufio.Scanner
var line string
var flds []string
var buf [4096]byte
var req_meth string
var req_path string
//var req_proto string
var req *http.Request
var n int
var err error
defer wg.Done()
sc = bufio.NewScanner(bytes.NewReader(data))
sc.Scan()
line = sc.Text()
flds = strings.Fields(line)
if (len(flds) < 3) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line)
goto done
}
// TODO: handle trailers...
req_meth = flds[0]
req_path = flds[1]
//req_proto = flds[2]
// create a request assuming it's a normal http request
req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
for sc.Scan() {
line = sc.Text()
if line == "" { break }
flds = strings.SplitN(line, ":", 2)
if len(flds) == 2 {
req.Header.Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
}
}
err = sc.Err()
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
// websocket
var done_chan chan struct{}
var conn net.Conn
var resp *http.Response
var r *bufio.Reader
if cts.C.rpx_target_tls != nil {
var dialer *tls.Dialer
dialer = &tls.Dialer{
NetDialer: &net.Dialer{},
Config: cts.C.rpx_target_tls,
}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
} else {
var dialer *net.Dialer
dialer = &net.Dialer{}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
}
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
defer conn.Close()
// TODO: make this atomic?
crpx.ws_conn = conn
// write the raw request line and headers as sent by the server.
// for the upgrade request, i assume no payload.
_, err = conn.Write(data)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
r = bufio.NewReader(conn)
resp, err = http.ReadResponse(r, req)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
goto done
}
if resp.StatusCode != http.StatusSwitchingProtocols {
// websock upgrade failed. let the code jump to the done
// label to skip reading from the pipe. the server side
// has the code to ensure no content-length. and the upgrade
// fails, the pipe below will be pending forever as the server
// side doesn't send data and there's no feeding to the pipe.
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
goto done
}
// unlike with the normal request, the actual pipe is not read
// until the initial switching protocol response is received.
wg.Add(1)
done_chan = make(chan struct{}, 5)
go func() {
var buf [4096]byte
var n int
var err error
defer wg.Done()
for {
n, err = crpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = conn.Write(buf[:n])
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) { break }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
done_chan <- struct{}{}
}()
for {
n, err = conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
cts.psc.Send(MakeRpxEofPacket(crpx.id))
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
break
}
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
break
}
}
// wait until the pipe reading(from the server side) goroutine is over
<-done_chan
} else {
var tr *http.Transport
var resp *http.Response
tr = &http.Transport {
TLSClientConfig: cts.C.rpx_target_tls,
}
resp, err = tr.RoundTrip(req)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
goto done
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
goto done
}
for {
n, err = resp.Body.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
break
}
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
}
done:
err = cts.psc.Send(MakeRpxStopPacket(crpx.id))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error())
}
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
crpx.ReqStop()
cts.rpx_mtx.Lock()
delete(cts.rpx_map, crpx.id)
cts.rpx_mtx.Unlock()
}
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
var crpx *ClientRpx
var ok bool
cts.rpx_mtx.Lock()
_, ok = cts.rpx_map[id]
if ok {
cts.rpx_mtx.Unlock()
return fmt.Errorf("multiple start on rpx id %d", id)
}
crpx = &ClientRpx{ id: id }
cts.rpx_map[id] = crpx
// i want the pipe to be created before the goroutine is started
// so that the WriteRpx() can write to the pipe. i protect pipe creation
// and context creation with a mutex
crpx.pr, crpx.pw = io.Pipe()
crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx)
cts.rpx_mtx.Unlock()
wg.Add(1)
go cts.RpxLoop(crpx, data, wg)
return nil
}
func (cts *ClientConn) StopRpx(id uint64) error {
var crpx *ClientRpx
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
crpx.ReqStop()
return nil
}
func (cts *ClientConn) WriteRpx(id uint64, data []byte) error {
var crpx *ClientRpx
var err error
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
// TODO: may have to write it in a goroutine to avoid blocking?
_, err = crpx.pw.Write(data)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error())
return err
}
return nil
}
func (cts *ClientConn) EofRpx(id uint64, data []byte) error {
var crpx *ClientRpx
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
// close the writing end only. leave the reading end untouched
crpx.pw.Close()
return nil
}
func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
switch packet_type {
case PACKET_KIND_RPX_START:
return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg)
case PACKET_KIND_RPX_STOP:
return cts.StopRpx(evt.Id)
case PACKET_KIND_RPX_DATA:
return cts.WriteRpx(evt.Id, evt.Data)
case PACKET_KIND_RPX_EOF:
return cts.EofRpx(evt.Id, evt.Data)
}
// ignore other packet types
return nil
}
// --------------------------------------------------------------------
func (m ClientPeerConnMap) get_sorted_keys() []PeerId {
@@ -1631,12 +2026,13 @@ func (m ClientConnMap) get_sorted_keys() []ConnId {
type client_ctl_log_writer struct {
cli *Client
depth int
}
func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
// the standard http.Server always requires *log.Logger
// use this iowriter to create a logger to pass it to the http server.
hlw.cli.log.Write("", LOG_INFO, string(p))
hlw.cli.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p))
return len(p), nil
}
@@ -1696,13 +2092,13 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler {
}
// TODO: statistics by status_code and end point types.
time_taken = time.Now().Sub(start_time)
time_taken = time.Since(start_time) //time.Now().Sub(start_time)
if status_code > 0 {
if err != nil {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
} else {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
}
}
})
@@ -1714,7 +2110,7 @@ func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
var status_code int
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code)
return
}
handler.ServeHTTP(w, req)
@@ -1728,21 +2124,19 @@ func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
var start_time time.Time
var time_taken time.Duration
var req *http.Request
var raw_url_path string
req = ws.Request()
raw_url_path = get_raw_url_path(req)
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path)
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI)
start_time = time.Now()
status_code, err = handler.ServeWebsocket(ws)
time_taken = time.Now().Sub(start_time)
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
if status_code > 0 {
if err != nil {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
} else {
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds())
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
}
}
})
@@ -1770,6 +2164,14 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
c.bulletin = NewBulletin[*ClientEvent](&c, 1024)
c.rpc_tls = cfg.RpcTls
c.rpx_target_addr = cfg.RpxTargetAddr
c.rpx_target_tls = cfg.RpxTargetTls
if c.rpx_target_tls != nil {
c.rpx_target_url = "https://" + c.rpx_target_addr
} else {
c.rpx_target_url = "http://" + c.rpx_target_addr
}
c.ctl_auth = cfg.CtlAuth
c.ctl_tls = cfg.CtlTls
c.ctl_prefix = cfg.CtlPrefix
@@ -1843,13 +2245,13 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
c.ctl = make([]*http.Server, len(cfg.CtlAddrs))
copy(c.ctl_addr, cfg.CtlAddrs)
hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0)
hs_log = log.New(&client_ctl_log_writer{cli: &c, depth: 0}, "", 0)
for i = 0; i < len(cfg.CtlAddrs); i++ {
c.ctl[i] = &http.Server{
Addr: cfg.CtlAddrs[i],
Handler: c.ctl_mux,
TLSConfig: c.ctl_tls,
TLSConfig: c.ctl_tls.Clone(),
ErrorLog: hs_log,
// TODO: more settings
}
@@ -1859,6 +2261,7 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
c.stats.routes.Store(0)
c.stats.peers.Store(0)
c.stats.pty_sessions.Store(0)
c.stats.rpty_sessions.Store(0)
return &c
}
@@ -2171,8 +2574,52 @@ func (c *Client) GetPtyShell() string {
return c.pty_shell
}
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
func (c *Client) run_single_ctl_server(i int, cs *http.Server, wg *sync.WaitGroup) {
var l net.Listener
var err error
defer wg.Done()
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
// by creating the listener explicitly.
// err = cs.ListenAndServe()
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
//so I have to use my own indicator to check if it's been shutdown..
//
if c.stop_req.Load() == false {
// this guard has a flaw in that the stop request can be made
// between the check above and net.Listen() below.
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if c.stop_req.Load() == false {
// check it again to make the guard slightly more stable
// although it's still possible that the stop request is made
// after Listen()
if c.ctl_tls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
}
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
} else {
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
}
}
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
var ctl *http.Server
var idx int
var l_wg sync.WaitGroup
@@ -2181,49 +2628,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
for idx, ctl = range c.ctl {
l_wg.Add(1)
go func(i int, cs *http.Server) {
var l net.Listener
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
// by creating the listener explicitly.
// err = cs.ListenAndServe()
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
//so I have to use my own indicator to check if it's been shutdown..
//
if c.stop_req.Load() == false {
// this guard has a flaw in that the stop request can be made
// between the check above and net.Listen() below.
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
if err == nil {
if c.stop_req.Load() == false {
// check it again to make the guard slightly more stable
// although it's still possible that the stop request is made
// after Listen()
if c.ctl_tls == nil {
err = cs.Serve(l)
} else {
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
}
} else {
err = fmt.Errorf("stop requested")
}
l.Close()
}
} else {
err = fmt.Errorf("stop requested")
}
if errors.Is(err, http.ErrServerClosed) {
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
} else {
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
}
l_wg.Done()
}(idx, ctl)
go c.run_single_ctl_server(idx, ctl, &l_wg)
}
l_wg.Wait()
}