moved some goroutines to separate functions for improved readability
This commit is contained in:
328
client.go
328
client.go
@ -1675,16 +1675,175 @@ func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
|
|||||||
return crpx
|
return crpx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cts *ClientConn) server_pipe_to_ws_target(crpx* ClientRpx, conn net.Conn, wg *sync.WaitGroup) {
|
||||||
|
var buf [4096]byte
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err = crpx.pr.Read(buf[:])
|
||||||
|
if n > 0 {
|
||||||
|
var err2 error
|
||||||
|
_, err2 = conn.Write(buf[:n])
|
||||||
|
if err2 != nil {
|
||||||
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) { break }
|
||||||
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cts *ClientConn) proxy_ws(crpx *ClientRpx, raw_req []byte, req *http.Request) error {
|
||||||
|
var l_wg sync.WaitGroup
|
||||||
|
var conn net.Conn
|
||||||
|
var resp *http.Response
|
||||||
|
var r *bufio.Reader
|
||||||
|
var buf [4096]byte
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if cts.C.rpx_target_tls != nil {
|
||||||
|
var dialer *tls.Dialer
|
||||||
|
dialer = &tls.Dialer{
|
||||||
|
NetDialer: &net.Dialer{},
|
||||||
|
Config: cts.C.rpx_target_tls,
|
||||||
|
}
|
||||||
|
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||||
|
} else {
|
||||||
|
var dialer *net.Dialer
|
||||||
|
dialer = &net.Dialer{}
|
||||||
|
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// TODO: make this atomic?
|
||||||
|
crpx.ws_conn = conn
|
||||||
|
|
||||||
|
// write the raw request line and headers as sent by the server.
|
||||||
|
// for the upgrade request, i assume no payload.
|
||||||
|
_, err = conn.Write(raw_req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
r = bufio.NewReader(conn)
|
||||||
|
resp, err = http.ReadResponse(r, req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
|
// websock upgrade failed. let the code jump to the done
|
||||||
|
// label to skip reading from the pipe. the server side
|
||||||
|
// has the code to ensure no content-length. and the upgrade
|
||||||
|
// fails, the pipe below will be pending forever as the server
|
||||||
|
// side doesn't send data and there's no feeding to the pipe.
|
||||||
|
return fmt.Errorf("protocol switching failed for rpx(%d)", crpx.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unlike with the normal request, the actual pipe is not read
|
||||||
|
// until the initial switching protocol response is received.
|
||||||
|
|
||||||
|
l_wg.Add(1)
|
||||||
|
go cts.server_pipe_to_ws_target(crpx, conn, &l_wg)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err = conn.Read(buf[:])
|
||||||
|
if n > 0 {
|
||||||
|
var err2 error
|
||||||
|
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||||
|
if err2 != nil {
|
||||||
|
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
|
||||||
|
return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
||||||
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
|
||||||
|
return fmt.Errorf("failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until the pipe reading(from the server side) goroutine is over
|
||||||
|
l_wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cts *ClientConn) proxy_http(crpx *ClientRpx, req *http.Request) error {
|
||||||
|
var tr *http.Transport
|
||||||
|
var resp *http.Response
|
||||||
|
var buf [4096]byte
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
|
||||||
|
tr = &http.Transport {
|
||||||
|
DisableKeepAlives: true, // this implementation can't support keepalive..
|
||||||
|
}
|
||||||
|
if cts.C.rpx_target_tls != nil {
|
||||||
|
tr.TLSClientConfig = cts.C.rpx_target_tls
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = tr.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err = resp.Body.Read(buf[:])
|
||||||
|
if n > 0 {
|
||||||
|
var err2 error
|
||||||
|
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
||||||
|
if err2 != nil {
|
||||||
|
return fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
|
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
|
||||||
var sc *bufio.Scanner
|
var sc *bufio.Scanner
|
||||||
var line string
|
var line string
|
||||||
var flds []string
|
var flds []string
|
||||||
var buf [4096]byte
|
|
||||||
var req_meth string
|
var req_meth string
|
||||||
var req_path string
|
var req_path string
|
||||||
//var req_proto string
|
//var req_proto string
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
var n int
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -1721,7 +1880,6 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
|
|||||||
k = strings.TrimSpace(flds[0])
|
k = strings.TrimSpace(flds[0])
|
||||||
v = strings.TrimSpace(flds[1])
|
v = strings.TrimSpace(flds[1])
|
||||||
req.Header.Add(k, v)
|
req.Header.Add(k, v)
|
||||||
//fmt.Printf ("ADDING HEADER %s: %v\n", k, v)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sc.Err()
|
err = sc.Err()
|
||||||
@ -1732,163 +1890,14 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup)
|
|||||||
|
|
||||||
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||||
// websocket
|
// websocket
|
||||||
var done_chan chan struct{}
|
err = cts.proxy_ws(crpx, data, req)
|
||||||
|
|
||||||
var conn net.Conn
|
|
||||||
var resp *http.Response
|
|
||||||
var r *bufio.Reader
|
|
||||||
|
|
||||||
if cts.C.rpx_target_tls != nil {
|
|
||||||
var dialer *tls.Dialer
|
|
||||||
dialer = &tls.Dialer{
|
|
||||||
NetDialer: &net.Dialer{},
|
|
||||||
Config: cts.C.rpx_target_tls,
|
|
||||||
}
|
|
||||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
|
||||||
} else {
|
|
||||||
var dialer *net.Dialer
|
|
||||||
dialer = &net.Dialer{}
|
|
||||||
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// TODO: make this atomic?
|
|
||||||
crpx.ws_conn = conn
|
|
||||||
|
|
||||||
// write the raw request line and headers as sent by the server.
|
|
||||||
// for the upgrade request, i assume no payload.
|
|
||||||
_, err = conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
|
|
||||||
r = bufio.NewReader(conn)
|
|
||||||
resp, err = http.ReadResponse(r, req)
|
|
||||||
if err != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
|
||||||
if err != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
||||||
// websock upgrade failed. let the code jump to the done
|
|
||||||
// label to skip reading from the pipe. the server side
|
|
||||||
// has the code to ensure no content-length. and the upgrade
|
|
||||||
// fails, the pipe below will be pending forever as the server
|
|
||||||
// side doesn't send data and there's no feeding to the pipe.
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Protocol switching failed for rpx(%d)", crpx.id)
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
|
|
||||||
// unlike with the normal request, the actual pipe is not read
|
|
||||||
// until the initial switching protocol response is received.
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
done_chan = make(chan struct{}, 5)
|
|
||||||
go func() {
|
|
||||||
var buf [4096]byte
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
|
|
||||||
defer wg.Done()
|
|
||||||
for {
|
|
||||||
n, err = crpx.pr.Read(buf[:])
|
|
||||||
if n > 0 {
|
|
||||||
var err2 error
|
|
||||||
_, err2 = conn.Write(buf[:n])
|
|
||||||
if err2 != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) { break }
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
done_chan <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err = conn.Read(buf[:])
|
|
||||||
if n > 0 {
|
|
||||||
var err2 error
|
|
||||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
|
||||||
if err2 != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until the pipe reading(from the server side) goroutine is over
|
|
||||||
<-done_chan
|
|
||||||
} else {
|
} else {
|
||||||
var tr *http.Transport
|
// normal http
|
||||||
var resp *http.Response
|
err = cts.proxy_http(crpx, req)
|
||||||
|
}
|
||||||
tr = &http.Transport {
|
if err != nil {
|
||||||
DisableKeepAlives: true, // this implementation can't support keepalive..
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to proxy rpx(%d) - %s", crpx.id, err.Error())
|
||||||
}
|
goto done
|
||||||
if cts.C.rpx_target_tls != nil {
|
|
||||||
tr.TLSClientConfig = cts.C.rpx_target_tls
|
|
||||||
}
|
|
||||||
//fmt.Printf("%+v\n", req)
|
|
||||||
resp, err = tr.RoundTrip(req)
|
|
||||||
if err != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
|
||||||
if err != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err = resp.Body.Read(buf[:])
|
|
||||||
//fmt.Printf ("READ RESPONSE [%s], %d, %v\n", string(buf[:n]), n, err)
|
|
||||||
if n > 0 {
|
|
||||||
var err2 error
|
|
||||||
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
|
||||||
if err2 != nil {
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//fmt.Printf ("READ RESPONSE LOOP IS OVER\n")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
done:
|
done:
|
||||||
@ -1899,6 +1908,7 @@ done:
|
|||||||
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
|
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) read loop", crpx.id)
|
||||||
|
|
||||||
crpx.ReqStop()
|
crpx.ReqStop()
|
||||||
|
|
||||||
cts.rpx_mtx.Lock()
|
cts.rpx_mtx.Lock()
|
||||||
delete(cts.rpx_map, crpx.id)
|
delete(cts.rpx_map, crpx.id)
|
||||||
cts.rpx_mtx.Unlock()
|
cts.rpx_mtx.Unlock()
|
||||||
|
Reference in New Issue
Block a user