From 6078a41504c6e4880e10c9ae0dd974ed2605c50d Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Wed, 20 Aug 2025 14:11:02 +0900 Subject: [PATCH] moved some goroutines to separate functions for improved readability --- client.go | 328 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 169 insertions(+), 159 deletions(-) diff --git a/client.go b/client.go index e576ae1..7e800a5 100644 --- a/client.go +++ b/client.go @@ -1675,16 +1675,175 @@ func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx { return crpx } +func (cts *ClientConn) server_pipe_to_ws_target(crpx* ClientRpx, conn net.Conn, wg *sync.WaitGroup) { + var buf [4096]byte + var n int + var err error + + defer wg.Done() + + for { + n, err = crpx.pr.Read(buf[:]) + if n > 0 { + var err2 error + _, err2 = conn.Write(buf[:n]) + if err2 != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error()) + break + } + } + if err != nil { + if errors.Is(err, io.EOF) { break } + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error()) + break + } + } +} + +func (cts *ClientConn) proxy_ws(crpx *ClientRpx, raw_req []byte, req *http.Request) error { + var l_wg sync.WaitGroup + var conn net.Conn + var resp *http.Response + var r *bufio.Reader + var buf [4096]byte + var n int + var err error + + if cts.C.rpx_target_tls != nil { + var dialer *tls.Dialer + dialer = &tls.Dialer{ + NetDialer: &net.Dialer{}, + Config: cts.C.rpx_target_tls, + } + conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding + } else { + var dialer *net.Dialer + dialer = &net.Dialer{} + conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding + } + if err != nil { + return fmt.Errorf("failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error()) + } + defer conn.Close() + + // TODO: make this atomic? + crpx.ws_conn = conn + + // write the raw request line and headers as sent by the server. + // for the upgrade request, i assume no payload. + _, err = conn.Write(raw_req) + if err != nil { + return fmt.Errorf("failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error()) + } + + r = bufio.NewReader(conn) + resp, err = http.ReadResponse(r, req) + if err != nil { + return fmt.Errorf("failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error()) + } + defer resp.Body.Close() + + err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp))) + if err != nil { + return fmt.Errorf("failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error()) + } + + if resp.StatusCode != http.StatusSwitchingProtocols { + // websock upgrade failed. let the code jump to the done + // label to skip reading from the pipe. the server side + // has the code to ensure no content-length. and the upgrade + // fails, the pipe below will be pending forever as the server + // side doesn't send data and there's no feeding to the pipe. + return fmt.Errorf("protocol switching failed for rpx(%d)", crpx.id) + } + + // unlike with the normal request, the actual pipe is not read + // until the initial switching protocol response is received. + + l_wg.Add(1) + go cts.server_pipe_to_ws_target(crpx, conn, &l_wg) + + for { + n, err = conn.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) + if err2 != nil { + crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops + return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error()) + } + } + if err != nil { + if errors.Is(err, io.EOF) { + cts.psc.Send(MakeRpxEofPacket(crpx.id)) + cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id) + break + } + + crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops + return fmt.Errorf("failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error()) + } + } + + // wait until the pipe reading(from the server side) goroutine is over + l_wg.Wait() + return nil +} + +func (cts *ClientConn) proxy_http(crpx *ClientRpx, req *http.Request) error { + var tr *http.Transport + var resp *http.Response + var buf [4096]byte + var n int + var err error + + tr = &http.Transport { + DisableKeepAlives: true, // this implementation can't support keepalive.. + } + if cts.C.rpx_target_tls != nil { + tr.TLSClientConfig = cts.C.rpx_target_tls + } + + resp, err = tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("failed to send rpx(%d) request - %s", crpx.id, err.Error()) + } + + defer resp.Body.Close() + + err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp))) + if err != nil { + return fmt.Errorf("failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error()) + } + + for { + n, err = resp.Body.Read(buf[:]) + if n > 0 { + var err2 error + err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) + if err2 != nil { + return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error()) + } + } + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("failed to read response body for rpx(%d) - %s", crpx.id, err.Error()) + } + } + + return nil +} + func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) { var sc *bufio.Scanner var line string var flds []string - var buf [4096]byte var req_meth string var req_path string //var req_proto string var req *http.Request - var n int var err error defer wg.Done() @@ -1721,7 +1880,6 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) k = strings.TrimSpace(flds[0]) v = strings.TrimSpace(flds[1]) req.Header.Add(k, v) -//fmt.Printf ("ADDING HEADER %s: %v\n", k, v) } } err = sc.Err() @@ -1732,163 +1890,14 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { // websocket - var done_chan chan struct{} - - var conn net.Conn - var resp *http.Response - var r *bufio.Reader - - if cts.C.rpx_target_tls != nil { - var dialer *tls.Dialer - dialer = &tls.Dialer{ - NetDialer: &net.Dialer{}, - Config: cts.C.rpx_target_tls, - } - conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding - } else { - var dialer *net.Dialer - dialer = &net.Dialer{} - conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding - } - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error()) - goto done - } - defer conn.Close() - - // TODO: make this atomic? - crpx.ws_conn = conn - - // write the raw request line and headers as sent by the server. - // for the upgrade request, i assume no payload. - _, err = conn.Write(data) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error()) - goto done - } - - r = bufio.NewReader(conn) - resp, err = http.ReadResponse(r, req) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error()) - goto done - } - defer resp.Body.Close() - - err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp))) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error()) - goto done - } - - if resp.StatusCode != http.StatusSwitchingProtocols { - // websock upgrade failed. let the code jump to the done - // label to skip reading from the pipe. the server side - // has the code to ensure no content-length. and the upgrade - // fails, the pipe below will be pending forever as the server - // side doesn't send data and there's no feeding to the pipe. - cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id) - goto done - } - - // unlike with the normal request, the actual pipe is not read - // until the initial switching protocol response is received. - - wg.Add(1) - done_chan = make(chan struct{}, 5) - go func() { - var buf [4096]byte - var n int - var err error - - defer wg.Done() - for { - n, err = crpx.pr.Read(buf[:]) - if n > 0 { - var err2 error - _, err2 = conn.Write(buf[:n]) - if err2 != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error()) - break - } - } - if err != nil { - if errors.Is(err, io.EOF) { break } - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error()) - break - } - } - done_chan <- struct{}{} - }() - - for { - n, err = conn.Read(buf[:]) - if n > 0 { - var err2 error - err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) - if err2 != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error()) - break - } - } - if err != nil { - if errors.Is(err, io.EOF) { - cts.psc.Send(MakeRpxEofPacket(crpx.id)) - cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id) - break - } - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error()) - break - } - } - - // wait until the pipe reading(from the server side) goroutine is over - <-done_chan + err = cts.proxy_ws(crpx, data, req) } else { - var tr *http.Transport - var resp *http.Response - - tr = &http.Transport { - DisableKeepAlives: true, // this implementation can't support keepalive.. - } - if cts.C.rpx_target_tls != nil { - tr.TLSClientConfig = cts.C.rpx_target_tls - } -//fmt.Printf("%+v\n", req) - resp, err = tr.RoundTrip(req) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error()) - goto done - } - - defer resp.Body.Close() - - err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp))) - if err != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error()) - goto done - } - - for { - n, err = resp.Body.Read(buf[:]) -//fmt.Printf ("READ RESPONSE [%s], %d, %v\n", string(buf[:n]), n, err) - if n > 0 { - var err2 error - err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) - if err2 != nil { - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error()) - break - } - } - if err != nil { - if errors.Is(err, io.EOF) { - break - } - cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error()) - break - } - } -//fmt.Printf ("READ RESPONSE LOOP IS OVER\n") + // normal http + err = cts.proxy_http(crpx, req) + } + if err != nil { + cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to proxy rpx(%d) - %s", crpx.id, err.Error()) + goto done } done: @@ -1899,6 +1908,7 @@ done: cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id) crpx.ReqStop() + cts.rpx_mtx.Lock() delete(cts.rpx_map, crpx.id) cts.rpx_mtx.Unlock()