moved some goroutines to separate functions for improved readability

This commit is contained in:
2025-08-20 14:11:02 +09:00
parent 0696f4f560
commit 6078a41504

328
client.go
View File

@ -1675,16 +1675,175 @@ func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
return crpx
}
func (cts *ClientConn) server_pipe_to_ws_target(crpx* ClientRpx, conn net.Conn, wg *sync.WaitGroup) {
var buf [4096]byte
var n int
var err error
defer wg.Done()
for {
n, err = crpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = conn.Write(buf[:n])
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) { break }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
}
func (cts *ClientConn) proxy_ws(crpx *ClientRpx, raw_req []byte, req *http.Request) error {
var l_wg sync.WaitGroup
var conn net.Conn
var resp *http.Response
var r *bufio.Reader
var buf [4096]byte
var n int
var err error
if cts.C.rpx_target_tls != nil {
var dialer *tls.Dialer
dialer = &tls.Dialer{
NetDialer: &net.Dialer{},
Config: cts.C.rpx_target_tls,
}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
} else {
var dialer *net.Dialer
dialer = &net.Dialer{}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
}
if err != nil {
return fmt.Errorf("failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
}
defer conn.Close()
// TODO: make this atomic?
crpx.ws_conn = conn
// write the raw request line and headers as sent by the server.
// for the upgrade request, i assume no payload.
_, err = conn.Write(raw_req)
if err != nil {
return fmt.Errorf("failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
}
r = bufio.NewReader(conn)
resp, err = http.ReadResponse(r, req)
if err != nil {
return fmt.Errorf("failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
return fmt.Errorf("failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
}
if resp.StatusCode != http.StatusSwitchingProtocols {
// websock upgrade failed. let the code jump to the done
// label to skip reading from the pipe. the server side
// has the code to ensure no content-length. and the upgrade
// fails, the pipe below will be pending forever as the server
// side doesn't send data and there's no feeding to the pipe.
return fmt.Errorf("protocol switching failed for rpx(%d)", crpx.id)
}
// unlike with the normal request, the actual pipe is not read
// until the initial switching protocol response is received.
l_wg.Add(1)
go cts.server_pipe_to_ws_target(crpx, conn, &l_wg)
for {
n, err = conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
}
}
if err != nil {
if errors.Is(err, io.EOF) {
cts.psc.Send(MakeRpxEofPacket(crpx.id))
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
break
}
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
return fmt.Errorf("failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
}
}
// wait until the pipe reading(from the server side) goroutine is over
l_wg.Wait()
return nil
}
func (cts *ClientConn) proxy_http(crpx *ClientRpx, req *http.Request) error {
var tr *http.Transport
var resp *http.Response
var buf [4096]byte
var n int
var err error
tr = &http.Transport {
DisableKeepAlives: true, // this implementation can't support keepalive..
}
if cts.C.rpx_target_tls != nil {
tr.TLSClientConfig = cts.C.rpx_target_tls
}
resp, err = tr.RoundTrip(req)
if err != nil {
return fmt.Errorf("failed to send rpx(%d) request - %s", crpx.id, err.Error())
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
return fmt.Errorf("failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
}
for {
n, err = resp.Body.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
}
}
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return fmt.Errorf("failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
}
}
return nil
}
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
var sc *bufio.Scanner
var line string
var flds []string
var buf [4096]byte
var req_meth string
var req_path string
//var req_proto string
var req *http.Request
var n int
var err error
defer wg.Done()
@ -1721,7 +1880,6 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
k = strings.TrimSpace(flds[0])
v = strings.TrimSpace(flds[1])
req.Header.Add(k, v)
//fmt.Printf ("ADDING HEADER %s: %v\n", k, v)
}
}
err = sc.Err()
@ -1732,163 +1890,14 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
// websocket
var done_chan chan struct{}
var conn net.Conn
var resp *http.Response
var r *bufio.Reader
if cts.C.rpx_target_tls != nil {
var dialer *tls.Dialer
dialer = &tls.Dialer{
NetDialer: &net.Dialer{},
Config: cts.C.rpx_target_tls,
}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
} else {
var dialer *net.Dialer
dialer = &net.Dialer{}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
}
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
defer conn.Close()
// TODO: make this atomic?
crpx.ws_conn = conn
// write the raw request line and headers as sent by the server.
// for the upgrade request, i assume no payload.
_, err = conn.Write(data)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
r = bufio.NewReader(conn)
resp, err = http.ReadResponse(r, req)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
goto done
}
if resp.StatusCode != http.StatusSwitchingProtocols {
// websock upgrade failed. let the code jump to the done
// label to skip reading from the pipe. the server side
// has the code to ensure no content-length. and the upgrade
// fails, the pipe below will be pending forever as the server
// side doesn't send data and there's no feeding to the pipe.
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
goto done
}
// unlike with the normal request, the actual pipe is not read
// until the initial switching protocol response is received.
wg.Add(1)
done_chan = make(chan struct{}, 5)
go func() {
var buf [4096]byte
var n int
var err error
defer wg.Done()
for {
n, err = crpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = conn.Write(buf[:n])
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) { break }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
done_chan <- struct{}{}
}()
for {
n, err = conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
cts.psc.Send(MakeRpxEofPacket(crpx.id))
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
break
}
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
break
}
}
// wait until the pipe reading(from the server side) goroutine is over
<-done_chan
err = cts.proxy_ws(crpx, data, req)
} else {
var tr *http.Transport
var resp *http.Response
tr = &http.Transport {
DisableKeepAlives: true, // this implementation can't support keepalive..
}
if cts.C.rpx_target_tls != nil {
tr.TLSClientConfig = cts.C.rpx_target_tls
}
//fmt.Printf("%+v\n", req)
resp, err = tr.RoundTrip(req)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
goto done
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
goto done
}
for {
n, err = resp.Body.Read(buf[:])
//fmt.Printf ("READ RESPONSE [%s], %d, %v\n", string(buf[:n]), n, err)
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
break
}
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
//fmt.Printf ("READ RESPONSE LOOP IS OVER\n")
// normal http
err = cts.proxy_http(crpx, req)
}
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to proxy rpx(%d) - %s", crpx.id, err.Error())
goto done
}
done:
@ -1899,6 +1908,7 @@ done:
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
crpx.ReqStop()
cts.rpx_mtx.Lock()
delete(cts.rpx_map, crpx.id)
cts.rpx_mtx.Unlock()