moved some goroutines to separate functions for improved readability
This commit is contained in:
328
client.go
328
client.go
@ -1675,16 +1675,175 @@ func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
|
||||
return crpx
|
||||
}
|
||||
|
||||
func (cts *ClientConn) server_pipe_to_ws_target(crpx* ClientRpx, conn net.Conn, wg *sync.WaitGroup) {
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
for {
|
||||
n, err = crpx.pr.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
_, err2 = conn.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) { break }
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cts *ClientConn) proxy_ws(crpx *ClientRpx, raw_req []byte, req *http.Request) error {
|
||||
var l_wg sync.WaitGroup
|
||||
var conn net.Conn
|
||||
var resp *http.Response
|
||||
var r *bufio.Reader
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
if cts.C.rpx_target_tls != nil {
|
||||
var dialer *tls.Dialer
|
||||
dialer = &tls.Dialer{
|
||||
NetDialer: &net.Dialer{},
|
||||
Config: cts.C.rpx_target_tls,
|
||||
}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
} else {
|
||||
var dialer *net.Dialer
|
||||
dialer = &net.Dialer{}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// TODO: make this atomic?
|
||||
crpx.ws_conn = conn
|
||||
|
||||
// write the raw request line and headers as sent by the server.
|
||||
// for the upgrade request, i assume no payload.
|
||||
_, err = conn.Write(raw_req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
}
|
||||
|
||||
r = bufio.NewReader(conn)
|
||||
resp, err = http.ReadResponse(r, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
// websock upgrade failed. let the code jump to the done
|
||||
// label to skip reading from the pipe. the server side
|
||||
// has the code to ensure no content-length. and the upgrade
|
||||
// fails, the pipe below will be pending forever as the server
|
||||
// side doesn't send data and there's no feeding to the pipe.
|
||||
return fmt.Errorf("protocol switching failed for rpx(%d)", crpx.id)
|
||||
}
|
||||
|
||||
// unlike with the normal request, the actual pipe is not read
|
||||
// until the initial switching protocol response is received.
|
||||
|
||||
l_wg.Add(1)
|
||||
go cts.server_pipe_to_ws_target(crpx, conn, &l_wg)
|
||||
|
||||
for {
|
||||
n, err = conn.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
|
||||
return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
||||
break
|
||||
}
|
||||
|
||||
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
|
||||
return fmt.Errorf("failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// wait until the pipe reading(from the server side) goroutine is over
|
||||
l_wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) proxy_http(crpx *ClientRpx, req *http.Request) error {
|
||||
var tr *http.Transport
|
||||
var resp *http.Response
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
tr = &http.Transport {
|
||||
DisableKeepAlives: true, // this implementation can't support keepalive..
|
||||
}
|
||||
if cts.C.rpx_target_tls != nil {
|
||||
tr.TLSClientConfig = cts.C.rpx_target_tls
|
||||
}
|
||||
|
||||
resp, err = tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
||||
}
|
||||
|
||||
for {
|
||||
n, err = resp.Body.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
|
||||
var sc *bufio.Scanner
|
||||
var line string
|
||||
var flds []string
|
||||
var buf [4096]byte
|
||||
var req_meth string
|
||||
var req_path string
|
||||
//var req_proto string
|
||||
var req *http.Request
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
@ -1721,7 +1880,6 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
|
||||
k = strings.TrimSpace(flds[0])
|
||||
v = strings.TrimSpace(flds[1])
|
||||
req.Header.Add(k, v)
|
||||
//fmt.Printf ("ADDING HEADER %s: %v\n", k, v)
|
||||
}
|
||||
}
|
||||
err = sc.Err()
|
||||
@ -1732,163 +1890,14 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
|
||||
|
||||
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
// websocket
|
||||
var done_chan chan struct{}
|
||||
|
||||
var conn net.Conn
|
||||
var resp *http.Response
|
||||
var r *bufio.Reader
|
||||
|
||||
if cts.C.rpx_target_tls != nil {
|
||||
var dialer *tls.Dialer
|
||||
dialer = &tls.Dialer{
|
||||
NetDialer: &net.Dialer{},
|
||||
Config: cts.C.rpx_target_tls,
|
||||
}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
} else {
|
||||
var dialer *net.Dialer
|
||||
dialer = &net.Dialer{}
|
||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||
}
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// TODO: make this atomic?
|
||||
crpx.ws_conn = conn
|
||||
|
||||
// write the raw request line and headers as sent by the server.
|
||||
// for the upgrade request, i assume no payload.
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
r = bufio.NewReader(conn)
|
||||
resp, err = http.ReadResponse(r, req)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
// websock upgrade failed. let the code jump to the done
|
||||
// label to skip reading from the pipe. the server side
|
||||
// has the code to ensure no content-length. and the upgrade
|
||||
// fails, the pipe below will be pending forever as the server
|
||||
// side doesn't send data and there's no feeding to the pipe.
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
|
||||
goto done
|
||||
}
|
||||
|
||||
// unlike with the normal request, the actual pipe is not read
|
||||
// until the initial switching protocol response is received.
|
||||
|
||||
wg.Add(1)
|
||||
done_chan = make(chan struct{}, 5)
|
||||
go func() {
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
for {
|
||||
n, err = crpx.pr.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
_, err2 = conn.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) { break }
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
done_chan <- struct{}{}
|
||||
}()
|
||||
|
||||
for {
|
||||
n, err = conn.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
||||
break
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// wait until the pipe reading(from the server side) goroutine is over
|
||||
<-done_chan
|
||||
err = cts.proxy_ws(crpx, data, req)
|
||||
} else {
|
||||
var tr *http.Transport
|
||||
var resp *http.Response
|
||||
|
||||
tr = &http.Transport {
|
||||
DisableKeepAlives: true, // this implementation can't support keepalive..
|
||||
}
|
||||
if cts.C.rpx_target_tls != nil {
|
||||
tr.TLSClientConfig = cts.C.rpx_target_tls
|
||||
}
|
||||
//fmt.Printf("%+v\n", req)
|
||||
resp, err = tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
for {
|
||||
n, err = resp.Body.Read(buf[:])
|
||||
//fmt.Printf ("READ RESPONSE [%s], %d, %v\n", string(buf[:n]), n, err)
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
//fmt.Printf ("READ RESPONSE LOOP IS OVER\n")
|
||||
// normal http
|
||||
err = cts.proxy_http(crpx, req)
|
||||
}
|
||||
if err != nil {
|
||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to proxy rpx(%d) - %s", crpx.id, err.Error())
|
||||
goto done
|
||||
}
|
||||
|
||||
done:
|
||||
@ -1899,6 +1908,7 @@ done:
|
||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
|
||||
|
||||
crpx.ReqStop()
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
delete(cts.rpx_map, crpx.id)
|
||||
cts.rpx_mtx.Unlock()
|
||||
|
Reference in New Issue
Block a user