added code for rpx handling
This commit is contained in:
615
client.go
615
client.go
@@ -1,5 +1,7 @@
|
||||
package hodu
|
||||
|
||||
import "bufio"
|
||||
import "bytes"
|
||||
import "container/list"
|
||||
import "context"
|
||||
import "crypto/tls"
|
||||
@@ -40,6 +42,7 @@ type ClientRouteMap map[RouteId]*ClientRoute
|
||||
type ClientPeerConnMap map[PeerId]*ClientPeerConn
|
||||
type ClientPeerCancelFuncMap map[PeerId]context.CancelFunc
|
||||
type ClientRptyMap map[uint64]*ClientRpty
|
||||
type ClientRpxMap map[uint64]*ClientRpx
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
type ClientRouteConfig struct {
|
||||
@@ -84,6 +87,10 @@ type ClientConfig struct {
|
||||
PeerConnTmout time.Duration
|
||||
|
||||
Token string // to send to the server for identification
|
||||
|
||||
// default target for rpx
|
||||
RpxTargetAddr string
|
||||
RpxTargetTls *tls.Config
|
||||
}
|
||||
|
||||
type ClientEventKind int
|
||||
@@ -122,6 +129,10 @@ type Client struct {
|
||||
ext_svcs []Service
|
||||
|
||||
rpc_tls *tls.Config
|
||||
rpx_target_addr string
|
||||
rpx_target_url string
|
||||
rpx_target_tls *tls.Config
|
||||
|
||||
ctl_tls *tls.Config
|
||||
ctl_addr []string
|
||||
ctl_prefix string
|
||||
@@ -155,6 +166,7 @@ type Client struct {
|
||||
routes atomic.Int64
|
||||
peers atomic.Int64
|
||||
pty_sessions atomic.Int64
|
||||
rpty_sessions atomic.Int64
|
||||
}
|
||||
|
||||
pty_user string
|
||||
@@ -176,6 +188,15 @@ type ClientRpty struct {
|
||||
tty *os.File
|
||||
}
|
||||
|
||||
type ClientRpx struct {
|
||||
id uint64
|
||||
pr *io.PipeReader
|
||||
pw *io.PipeWriter
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ws_conn net.Conn
|
||||
}
|
||||
|
||||
// client connection to server
|
||||
type ClientConn struct {
|
||||
C *Client
|
||||
@@ -209,6 +230,9 @@ type ClientConn struct {
|
||||
rpty_mtx sync.Mutex
|
||||
rpty_map ClientRptyMap
|
||||
|
||||
rpx_mtx sync.Mutex
|
||||
rpx_map ClientRpxMap
|
||||
|
||||
stop_req atomic.Bool
|
||||
stop_chan chan bool
|
||||
|
||||
@@ -294,6 +318,22 @@ func (g *GuardedPacketStreamClient) Context() context.Context {
|
||||
return g.psc.Context()
|
||||
}*/
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func (rpty *ClientRpty) ReqStop() {
|
||||
rpty.tty.Close()
|
||||
rpty.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (rpx *ClientRpx) ReqStop() {
|
||||
rpx.pr.Close()
|
||||
rpx.pw.Close()
|
||||
rpx.cancel()
|
||||
if rpx.ws_conn != nil {
|
||||
rpx.ws_conn.SetDeadline(time.Now()) // to make Read return immediately
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
func NewClientRoute(cts *ClientConn, id RouteId, static bool, client_peer_addr string, client_peer_name string, server_peer_svc_addr string, server_peer_svc_net string, server_peer_option RouteOption, lifetime time.Duration) *ClientRoute {
|
||||
var r ClientRoute
|
||||
@@ -397,7 +437,7 @@ func (r *ClientRoute) ExtendLifetime(lifetime time.Duration) error {
|
||||
r.lifetime_timer.Stop()
|
||||
r.Lifetime = r.Lifetime + lifetime
|
||||
expiry = r.LifetimeStart.Add(r.Lifetime)
|
||||
r.lifetime_timer.Reset(expiry.Sub(time.Now()))
|
||||
r.lifetime_timer.Reset(time.Until(expiry)) // expiry.Sub(time.Now())
|
||||
if r.cts.C.route_persister != nil { r.cts.C.route_persister.Save(r.cts, r) }
|
||||
r.lifetime_mtx.Unlock()
|
||||
|
||||
@@ -477,10 +517,8 @@ func (r *ClientRoute) RunTask(wg *sync.WaitGroup) {
|
||||
break waiting_loop
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-r.stop_chan:
|
||||
break waiting_loop
|
||||
}
|
||||
<-r.stop_chan
|
||||
break waiting_loop
|
||||
}
|
||||
}
|
||||
|
||||
@@ -547,7 +585,7 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
|
||||
defer wg.Done()
|
||||
|
||||
tmout = time.Duration(r.cts.C.ptc_tmout)
|
||||
if tmout <= 0 { tmout = 5 * time.Second} // TODO: make this configurable...
|
||||
if tmout <= 0 { tmout = 5 * time.Second } // TODO: make this configurable...
|
||||
waitctx, cancel_wait = context.WithTimeout(r.cts.C.Ctx, tmout)
|
||||
r.ptc_mtx.Lock()
|
||||
r.ptc_cancel_map[pts_id] = cancel_wait
|
||||
@@ -571,8 +609,8 @@ func (r *ClientRoute) ConnectToPeer(pts_id PeerId, route_option RouteOption, pts
|
||||
real_conn, ok = conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
r.cts.C.log.Write(r.cts.Sid, LOG_ERROR,
|
||||
"Failed to get connection information to %s for route(%d,%d,%s,%s) - %s",
|
||||
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr, err.Error())
|
||||
"Failed to get connection information to %s for route(%d,%d,%s,%s)",
|
||||
r.PeerAddr, r.Id, pts_id, pts_raddr, pts_laddr)
|
||||
goto peer_aborted
|
||||
}
|
||||
|
||||
@@ -825,6 +863,7 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
|
||||
cts.stop_chan = make(chan bool, 8)
|
||||
cts.ptc_list = list.New()
|
||||
cts.rpty_map = make(ClientRptyMap)
|
||||
cts.rpx_map = make(ClientRpxMap)
|
||||
|
||||
for i, _ = range cts.cfg.Routes {
|
||||
// override it to static regardless of the value passed in
|
||||
@@ -833,7 +872,6 @@ func NewClientConn(c *Client, cfg *ClientConnConfig) *ClientConn {
|
||||
|
||||
// the actual connection to the server is established in the main task function
|
||||
// The cts.conn, cts.hdc, cts.psc fields are left unassigned here.
|
||||
|
||||
return &cts
|
||||
}
|
||||
|
||||
@@ -1031,7 +1069,8 @@ func (cts *ClientConn) add_client_routes(routes []ClientRouteConfig) error {
|
||||
func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
if cts.conn != nil {
|
||||
var r *ClientRoute
|
||||
var crp *ClientRpty
|
||||
var rpty *ClientRpty
|
||||
var rpx *ClientRpx
|
||||
|
||||
cts.discon_mtx.Lock()
|
||||
|
||||
@@ -1045,15 +1084,20 @@ func (cts *ClientConn) disconnect_from_server(logmsg bool) {
|
||||
|
||||
// arrange to clean up all rpty objects
|
||||
cts.rpty_mtx.Lock()
|
||||
for _, crp = range cts.rpty_map {
|
||||
crp.tty.Close()
|
||||
crp.cmd.Process.Kill()
|
||||
for _, rpty = range cts.rpty_map {
|
||||
rpty.ReqStop()
|
||||
// the loop in ReadRptyLoop() is supposed to be broken.
|
||||
// let's not inform the server of this connection.
|
||||
// the server should clean up itself upon connection error
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
for _, rpx = range cts.rpx_map {
|
||||
rpx.ReqStop()
|
||||
}
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
// don't care about double closes when this function is called from both RunTask() and ReqStop()
|
||||
cts.conn.Close()
|
||||
|
||||
@@ -1217,7 +1261,7 @@ start_over:
|
||||
|
||||
pkt, err = psc.Recv()
|
||||
if err != nil {
|
||||
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) {
|
||||
if status.Code(err) == codes.Canceled || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
||||
goto reconnect_to_server
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Failed to receive packet from %s - %s", cts.remote_addr_p, err.Error())
|
||||
@@ -1348,6 +1392,31 @@ start_over:
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
|
||||
}
|
||||
|
||||
case PACKET_KIND_RPX_START:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_STOP:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_DATA:
|
||||
fallthrough
|
||||
case PACKET_KIND_RPX_EOF:
|
||||
var x *Packet_RpxEvt
|
||||
var ok bool
|
||||
x, ok = pkt.U.(*Packet_RpxEvt)
|
||||
if ok {
|
||||
err = cts.HandleRpxEvent(pkt.Kind, x.RpxEvt)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR,
|
||||
"Failed to handle %s event for rpx(%d) from %s - %s",
|
||||
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p, err.Error())
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG,
|
||||
"Handled %s event for rpx(%d) from %s",
|
||||
pkt.Kind.String(), x.RpxEvt.Id, cts.remote_addr_p)
|
||||
}
|
||||
} else {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid %s event from %s", pkt.Kind.String(), cts.remote_addr_p)
|
||||
}
|
||||
|
||||
default:
|
||||
// do nothing. ignore the rest
|
||||
}
|
||||
@@ -1390,7 +1459,7 @@ reconnect_to_server:
|
||||
goto wait_for_termination
|
||||
case <-slpctx.Done():
|
||||
select {
|
||||
case <- cts.C.Ctx.Done():
|
||||
case <-cts.C.Ctx.Done():
|
||||
// non-blocking check if the parent context of the sleep context is
|
||||
// terminated too. if so, this is normal termination case.
|
||||
// this check seem redundant but the go-runtime doesn't seem to guarantee
|
||||
@@ -1421,20 +1490,34 @@ func (cts *ClientConn) ReportPacket(route_id RouteId, pts_id PeerId, packet_type
|
||||
}
|
||||
|
||||
|
||||
// rpty
|
||||
func (cts *ClientConn) FindClientRptyById(id uint64) *ClientRpty {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
if !ok { crp = nil }
|
||||
return crp
|
||||
}
|
||||
|
||||
func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
|
||||
var poll_fds []unix.PollFd
|
||||
var buf []byte
|
||||
var buf [2048]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
cts.C.stats.rpty_sessions.Add(1)
|
||||
|
||||
poll_fds = []unix.PollFd{
|
||||
unix.PollFd{Fd: int32(crp.tty.Fd()), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
buf = make([]byte, 2048)
|
||||
for {
|
||||
n, err = unix.Poll(poll_fds, -1) // -1 means wait indefinitely
|
||||
if err != nil {
|
||||
@@ -1452,34 +1535,35 @@ func (cts *ClientConn) ReadRptyLoop(crp *ClientRpty, wg *sync.WaitGroup) {
|
||||
}
|
||||
|
||||
if (poll_fds[0].Revents & unix.POLLIN) != 0 {
|
||||
n, err = crp.tty.Read(buf)
|
||||
n, err = crp.tty.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to read rpty(%d) stdout - %s", crp.id, err.Error())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
err = cts.psc.Send(MakeRptyDataPacket(crp.id, buf[:n]))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "Failed to send rpty(%d) stdout to server - %s", crp.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cts.psc.Send(MakeRptyStopPacket(crp.id, ""))
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpty(%d) read loop", crp.id)
|
||||
|
||||
crp.tty.Close() // don't care about multiple closes
|
||||
crp.cmd.Process.Kill()
|
||||
crp.ReqStop()
|
||||
crp.cmd.Wait()
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
delete(cts.rpty_map, crp.id)
|
||||
cts.rpty_mtx.Unlock()
|
||||
|
||||
cts.C.stats.rpty_sessions.Add(-1)
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
@@ -1515,46 +1599,34 @@ func (cts *ClientConn) StartRpty(id uint64, wg *sync.WaitGroup) error {
|
||||
|
||||
func (cts *ClientConn) StopRpty(id uint64) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp = cts.FindClientRptyById(id)
|
||||
if crp == nil {
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
crp.tty.Close() // to break ReadRptyLoop()
|
||||
crp.cmd.Process.Kill() // to process wait to be done by ReadRptyLoop()
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp.ReqStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRpty(id uint64, data []byte) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp = cts.FindClientRptyById(id)
|
||||
if crp == nil {
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
crp.tty.Write(data)
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
|
||||
var crp *ClientRpty
|
||||
var ok bool
|
||||
var flds []string
|
||||
|
||||
cts.rpty_mtx.Lock()
|
||||
crp, ok = cts.rpty_map[id]
|
||||
if !ok {
|
||||
cts.rpty_mtx.Unlock()
|
||||
crp = cts.FindClientRptyById(id)
|
||||
if crp == nil {
|
||||
return fmt.Errorf("unknown rpty id %d", id)
|
||||
}
|
||||
|
||||
@@ -1566,11 +1638,11 @@ func (cts *ClientConn) WriteRptySize(id uint64, data []byte) error {
|
||||
cols, _ = strconv.Atoi(flds[1])
|
||||
pts.Setsize(crp.tty, &pts.Winsize{Rows: uint16(rows), Cols: uint16(cols)})
|
||||
}
|
||||
cts.rpty_mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent) error {
|
||||
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPTY_START:
|
||||
return cts.StartRpty(evt.Id, &cts.C.wg)
|
||||
@@ -1589,6 +1661,329 @@ func (cts *ClientConn) HandleRptyEvent(packet_type PACKET_KIND, evt *RptyEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
// rpx
|
||||
func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
|
||||
var crpx *ClientRpx
|
||||
var ok bool
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
crpx, ok = cts.rpx_map[id]
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
if !ok { crpx = nil }
|
||||
return crpx
|
||||
}
|
||||
|
||||
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
|
||||
var sc *bufio.Scanner
|
||||
var line string
|
||||
var flds []string
|
||||
var buf [4096]byte
|
||||
var req_meth string
|
||||
var req_path string
|
||||
//var req_proto string
|
||||
var req *http.Request
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
sc = bufio.NewScanner(bytes.NewReader(data))
|
||||
sc.Scan()
|
||||
line = sc.Text()
|
||||
|
||||
flds = strings.Fields(line)
|
||||
if (len(flds) < 3) {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line)
|
||||
goto done
|
||||
}
|
||||
|
||||
// TODO: handle trailers...
|
||||
req_meth = flds[0]
|
||||
req_path = flds[1]
|
||||
//req_proto = flds[2]
|
||||
|
||||
// create a request assuming it's a normal http request
|
||||
|
||||
req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
for sc.Scan() {
|
||||
line = sc.Text()
|
||||
if line == "" { break }
|
||||
flds = strings.SplitN(line, ":", 2)
|
||||
if len(flds) == 2 {
|
||||
req.Header.Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
|
||||
}
|
||||
}
|
||||
err = sc.Err()
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
// websocket
|
||||
var done_chan chan struct{}
|
||||
|
||||
var conn net.Conn
|
||||
var resp *http.Response
|
||||
var r *bufio.Reader
|
||||
|
||||
if cts.C.rpx_target_tls != nil {
|
||||
var dialer *tls.Dialer
|
||||
dialer = &tls.Dialer{
|
||||
NetDialer: &net.Dialer{},
|
||||
Config: cts.C.rpx_target_tls,
|
||||
}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
} else {
|
||||
var dialer *net.Dialer
|
||||
dialer = &net.Dialer{}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
}
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// TODO: make this atomic?
|
||||
crpx.ws_conn = conn
|
||||
|
||||
// write the raw request line and headers as sent by the server.
|
||||
// for the upgrade request, i assume no payload.
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
r = bufio.NewReader(conn)
|
||||
resp, err = http.ReadResponse(r, req)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
// websock upgrade failed. let the code jump to the done
|
||||
// label to skip reading from the pipe. the server side
|
||||
// has the code to ensure no content-length. and the upgrade
|
||||
// fails, the pipe below will be pending forever as the server
|
||||
// side doesn't send data and there's no feeding to the pipe.
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
|
||||
goto done
|
||||
}
|
||||
|
||||
// unlike with the normal request, the actual pipe is not read
|
||||
// until the initial switching protocol response is received.
|
||||
|
||||
wg.Add(1)
|
||||
done_chan = make(chan struct{}, 5)
|
||||
go func() {
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
for {
|
||||
n, err = crpx.pr.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
_, err2 = conn.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) { break }
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
done_chan <- struct{}{}
|
||||
}()
|
||||
|
||||
for {
|
||||
n, err = conn.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
||||
break
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// wait until the pipe reading(from the server side) goroutine is over
|
||||
<-done_chan
|
||||
} else {
|
||||
var tr *http.Transport
|
||||
var resp *http.Response
|
||||
|
||||
tr = &http.Transport {
|
||||
TLSClientConfig: cts.C.rpx_target_tls,
|
||||
}
|
||||
|
||||
resp, err = tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
for {
|
||||
n, err = resp.Body.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
err = cts.psc.Send(MakeRpxStopPacket(crpx.id))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) stp to server - %s", crpx.id, err.Error())
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
|
||||
|
||||
crpx.ReqStop()
|
||||
cts.rpx_mtx.Lock()
|
||||
delete(cts.rpx_map, crpx.id)
|
||||
cts.rpx_mtx.Unlock()
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
|
||||
var crpx *ClientRpx
|
||||
var ok bool
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
_, ok = cts.rpx_map[id]
|
||||
if ok {
|
||||
cts.rpx_mtx.Unlock()
|
||||
return fmt.Errorf("multiple start on rpx id %d", id)
|
||||
}
|
||||
crpx = &ClientRpx{ id: id }
|
||||
cts.rpx_map[id] = crpx
|
||||
|
||||
// i want the pipe to be created before the goroutine is started
|
||||
// so that the WriteRpx() can write to the pipe. i protect pipe creation
|
||||
// and context creation with a mutex
|
||||
crpx.pr, crpx.pw = io.Pipe()
|
||||
crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx)
|
||||
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
||||
wg.Add(1)
|
||||
go cts.RpxLoop(crpx, data, wg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) StopRpx(id uint64) error {
|
||||
var crpx *ClientRpx
|
||||
|
||||
crpx = cts.FindClientRpxById(id)
|
||||
if crpx == nil {
|
||||
return fmt.Errorf("unknown rpx id %d", id)
|
||||
}
|
||||
|
||||
crpx.ReqStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) WriteRpx(id uint64, data []byte) error {
|
||||
var crpx *ClientRpx
|
||||
var err error
|
||||
|
||||
crpx = cts.FindClientRpxById(id)
|
||||
if crpx == nil {
|
||||
return fmt.Errorf("unknown rpx id %d", id)
|
||||
}
|
||||
|
||||
// TODO: may have to write it in a goroutine to avoid blocking?
|
||||
_, err = crpx.pw.Write(data)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) EofRpx(id uint64, data []byte) error {
|
||||
var crpx *ClientRpx
|
||||
|
||||
crpx = cts.FindClientRpxById(id)
|
||||
if crpx == nil {
|
||||
return fmt.Errorf("unknown rpx id %d", id)
|
||||
}
|
||||
|
||||
// close the writing end only. leave the reading end untouched
|
||||
crpx.pw.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
|
||||
switch packet_type {
|
||||
case PACKET_KIND_RPX_START:
|
||||
return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg)
|
||||
|
||||
case PACKET_KIND_RPX_STOP:
|
||||
return cts.StopRpx(evt.Id)
|
||||
|
||||
case PACKET_KIND_RPX_DATA:
|
||||
return cts.WriteRpx(evt.Id, evt.Data)
|
||||
|
||||
case PACKET_KIND_RPX_EOF:
|
||||
return cts.EofRpx(evt.Id, evt.Data)
|
||||
}
|
||||
|
||||
// ignore other packet types
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
func (m ClientPeerConnMap) get_sorted_keys() []PeerId {
|
||||
@@ -1631,12 +2026,13 @@ func (m ClientConnMap) get_sorted_keys() []ConnId {
|
||||
|
||||
type client_ctl_log_writer struct {
|
||||
cli *Client
|
||||
depth int
|
||||
}
|
||||
|
||||
func (hlw *client_ctl_log_writer) Write(p []byte) (n int, err error) {
|
||||
// the standard http.Server always requires *log.Logger
|
||||
// use this iowriter to create a logger to pass it to the http server.
|
||||
hlw.cli.log.Write("", LOG_INFO, string(p))
|
||||
hlw.cli.log.WriteWithCallDepth("", LOG_INFO, hlw.depth, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -1696,13 +2092,13 @@ func (c *Client) WrapHttpHandler(handler ClientHttpHandler) http.Handler {
|
||||
}
|
||||
|
||||
// TODO: statistics by status_code and end point types.
|
||||
time_taken = time.Now().Sub(start_time)
|
||||
time_taken = time.Since(start_time) //time.Now().Sub(start_time)
|
||||
|
||||
if status_code > 0 {
|
||||
if err != nil {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds(), err.Error())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
|
||||
} else {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code, time_taken.Seconds())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1714,7 +2110,7 @@ func (c *Client) SafeWrapWebsocketHandler(handler websocket.Handler) http.Handle
|
||||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
var status_code int
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
|
||||
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, get_raw_url_path(req), status_code)
|
||||
c.log.Write("", LOG_INFO, "[%s] %s %s %d[non-websocket]", req.RemoteAddr, req.Method, req.RequestURI, status_code)
|
||||
return
|
||||
}
|
||||
handler.ServeHTTP(w, req)
|
||||
@@ -1728,21 +2124,19 @@ func (c *Client) WrapWebsocketHandler(handler ClientWebsocketHandler) websocket.
|
||||
var start_time time.Time
|
||||
var time_taken time.Duration
|
||||
var req *http.Request
|
||||
var raw_url_path string
|
||||
|
||||
req = ws.Request()
|
||||
raw_url_path = get_raw_url_path(req)
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, raw_url_path)
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws]", req.RemoteAddr, req.Method, req.RequestURI)
|
||||
|
||||
start_time = time.Now()
|
||||
status_code, err = handler.ServeWebsocket(ws)
|
||||
time_taken = time.Now().Sub(start_time)
|
||||
time_taken = time.Since(start_time) // time.Now().Sub(start_time)
|
||||
|
||||
if status_code > 0 {
|
||||
if err != nil {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds(), err.Error())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f - %s", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds(), err.Error())
|
||||
} else {
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, raw_url_path, status_code, time_taken.Seconds())
|
||||
c.log.Write(handler.Identity(), LOG_INFO, "[%s] %s %s [ws] %d %.9f", req.RemoteAddr, req.Method, req.RequestURI, status_code, time_taken.Seconds())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1770,6 +2164,14 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
|
||||
c.bulletin = NewBulletin[*ClientEvent](&c, 1024)
|
||||
|
||||
c.rpc_tls = cfg.RpcTls
|
||||
c.rpx_target_addr = cfg.RpxTargetAddr
|
||||
c.rpx_target_tls = cfg.RpxTargetTls
|
||||
if c.rpx_target_tls != nil {
|
||||
c.rpx_target_url = "https://" + c.rpx_target_addr
|
||||
} else {
|
||||
c.rpx_target_url = "http://" + c.rpx_target_addr
|
||||
}
|
||||
|
||||
c.ctl_auth = cfg.CtlAuth
|
||||
c.ctl_tls = cfg.CtlTls
|
||||
c.ctl_prefix = cfg.CtlPrefix
|
||||
@@ -1843,13 +2245,13 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
|
||||
c.ctl = make([]*http.Server, len(cfg.CtlAddrs))
|
||||
copy(c.ctl_addr, cfg.CtlAddrs)
|
||||
|
||||
hs_log = log.New(&client_ctl_log_writer{cli: &c}, "", 0)
|
||||
hs_log = log.New(&client_ctl_log_writer{cli: &c, depth: 0}, "", 0)
|
||||
|
||||
for i = 0; i < len(cfg.CtlAddrs); i++ {
|
||||
c.ctl[i] = &http.Server{
|
||||
Addr: cfg.CtlAddrs[i],
|
||||
Handler: c.ctl_mux,
|
||||
TLSConfig: c.ctl_tls,
|
||||
TLSConfig: c.ctl_tls.Clone(),
|
||||
ErrorLog: hs_log,
|
||||
// TODO: more settings
|
||||
}
|
||||
@@ -1859,6 +2261,7 @@ func NewClient(ctx context.Context, name string, logger Logger, cfg *ClientConfi
|
||||
c.stats.routes.Store(0)
|
||||
c.stats.peers.Store(0)
|
||||
c.stats.pty_sessions.Store(0)
|
||||
c.stats.rpty_sessions.Store(0)
|
||||
|
||||
return &c
|
||||
}
|
||||
@@ -2171,8 +2574,52 @@ func (c *Client) GetPtyShell() string {
|
||||
return c.pty_shell
|
||||
}
|
||||
|
||||
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
|
||||
func (c *Client) run_single_ctl_server(i int, cs *http.Server, wg *sync.WaitGroup) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
|
||||
|
||||
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
||||
// by creating the listener explicitly.
|
||||
// err = cs.ListenAndServe()
|
||||
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
|
||||
|
||||
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
|
||||
//so I have to use my own indicator to check if it's been shutdown..
|
||||
//
|
||||
if c.stop_req.Load() == false {
|
||||
// this guard has a flaw in that the stop request can be made
|
||||
// between the check above and net.Listen() below.
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if c.stop_req.Load() == false {
|
||||
// check it again to make the guard slightly more stable
|
||||
// although it's still possible that the stop request is made
|
||||
// after Listen()
|
||||
if c.ctl_tls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
||||
} else {
|
||||
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
|
||||
var ctl *http.Server
|
||||
var idx int
|
||||
var l_wg sync.WaitGroup
|
||||
@@ -2181,49 +2628,7 @@ func (c *Client) RunCtlTask(wg *sync.WaitGroup) {
|
||||
|
||||
for idx, ctl = range c.ctl {
|
||||
l_wg.Add(1)
|
||||
go func(i int, cs *http.Server) {
|
||||
var l net.Listener
|
||||
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] started on %s", i, c.ctl_addr[i])
|
||||
|
||||
// defeat hard-coded "tcp" in ListenAndServe() and ListenAndServeTLS()
|
||||
// by creating the listener explicitly.
|
||||
// err = cs.ListenAndServe()
|
||||
// err = cs.ListenAndServeTLS("", "") // c.tlscfg must provide a certificate and a key
|
||||
|
||||
//cs.shuttingDown(), as the name indicates, is not expoosed by the net/http
|
||||
//so I have to use my own indicator to check if it's been shutdown..
|
||||
//
|
||||
if c.stop_req.Load() == false {
|
||||
// this guard has a flaw in that the stop request can be made
|
||||
// between the check above and net.Listen() below.
|
||||
l, err = net.Listen(TcpAddrStrClass(cs.Addr), cs.Addr)
|
||||
if err == nil {
|
||||
if c.stop_req.Load() == false {
|
||||
// check it again to make the guard slightly more stable
|
||||
// although it's still possible that the stop request is made
|
||||
// after Listen()
|
||||
if c.ctl_tls == nil {
|
||||
err = cs.Serve(l)
|
||||
} else {
|
||||
err = cs.ServeTLS(l, "", "") // c.ctl_tls must provide a certificate and a key
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("stop requested")
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
c.log.Write("", LOG_INFO, "Control channel[%d] ended", i)
|
||||
} else {
|
||||
c.log.Write("", LOG_ERROR, "Control channel[%d] error - %s", i, err.Error())
|
||||
}
|
||||
|
||||
l_wg.Done()
|
||||
}(idx, ctl)
|
||||
go c.run_single_ctl_server(idx, ctl, &l_wg)
|
||||
}
|
||||
l_wg.Wait()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user