Files
hodu/client-cts-rpx.go
hyung-hwan ae89fd6248 added a new client configuration item rxc-user
fixed multiplexer check order - POLLIN before POLLHUP.
2026-03-18 15:57:31 +09:00

400 lines
11 KiB
Go

package hodu
import "bufio"
import "bytes"
import "context"
import "crypto/tls"
import "errors"
import "fmt"
import "io"
import "net"
import "net/http"
import "sync"
import "strings"
import "time"
// rpx
func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
var crpx *ClientRpx
var ok bool
cts.rpx_mtx.Lock()
crpx, ok = cts.rpx_map[id]
cts.rpx_mtx.Unlock()
if !ok { crpx = nil }
return crpx
}
func (cts *ClientConn) server_pipe_to_ws_target(crpx* ClientRpx, conn net.Conn, wg *sync.WaitGroup) {
var buf [4096]byte
var n int
var err error
defer wg.Done()
for {
n, err = crpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = conn.Write(buf[:n])
if err2 != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
break
}
}
if err != nil {
if errors.Is(err, io.EOF) { break }
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
break
}
}
}
func (cts *ClientConn) proxy_ws(crpx *ClientRpx, raw_req []byte, req *http.Request) (int, error) {
var l_wg sync.WaitGroup
var conn net.Conn
var resp *http.Response
var r *bufio.Reader
var buf [4096]byte
var n int
var err error
if cts.C.rpx_target_tls != nil {
var dialer *tls.Dialer
dialer = &tls.Dialer{
NetDialer: &net.Dialer{},
Config: cts.C.rpx_target_tls,
}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
} else {
var dialer *net.Dialer
dialer = &net.Dialer{}
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
}
if err != nil {
return http.StatusInternalServerError, fmt.Errorf("failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
}
defer conn.Close()
// TODO: make this atomic?
crpx.ws_conn = conn
// write the raw request line and headers as sent by the server.
// for the upgrade request, i assume no payload.
_, err = conn.Write(raw_req)
if err != nil {
return http.StatusInternalServerError, fmt.Errorf("failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
}
r = bufio.NewReader(conn)
resp, err = http.ReadResponse(r, req)
if err != nil {
return http.StatusInternalServerError, fmt.Errorf("failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
return http.StatusInternalServerError, fmt.Errorf("failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
}
if resp.StatusCode != http.StatusSwitchingProtocols {
// websock upgrade failed. let the code jump to the done
// label to skip reading from the pipe. the server side
// has the code to ensure no content-length. and the upgrade
// fails, the pipe below will be pending forever as the server
// side doesn't send data and there's no feeding to the pipe.
return resp.StatusCode, fmt.Errorf("protocol switching failed for rpx(%d)", crpx.id)
}
// unlike with the normal request, the actual pipe is not read
// until the initial switching protocol response is received.
l_wg.Add(1)
go cts.server_pipe_to_ws_target(crpx, conn, &l_wg)
for {
n, err = conn.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
return resp.StatusCode, fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
}
}
if err != nil {
if errors.Is(err, io.EOF) {
cts.psc.Send(MakeRpxEofPacket(crpx.id))
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
break
}
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
return resp.StatusCode, fmt.Errorf("failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
}
}
// wait until the pipe reading(from the server side) goroutine is over
l_wg.Wait()
return resp.StatusCode, nil
}
func (cts *ClientConn) proxy_http(crpx *ClientRpx, req *http.Request) (int, error) {
var tr *http.Transport
var resp *http.Response
var buf [4096]byte
var n int
var err error
tr = &http.Transport {
DisableKeepAlives: true, // this implementation can't support keepalive..
}
if cts.C.rpx_target_tls != nil {
tr.TLSClientConfig = cts.C.rpx_target_tls
}
resp, err = tr.RoundTrip(req)
if err != nil {
return http.StatusInternalServerError, fmt.Errorf("failed to send rpx(%d) request - %s", crpx.id, err.Error())
}
defer resp.Body.Close()
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
if err != nil {
return resp.StatusCode, fmt.Errorf("failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
}
for {
n, err = resp.Body.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
if err2 != nil {
return resp.StatusCode, fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
}
}
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return resp.StatusCode, fmt.Errorf("failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
}
}
return resp.StatusCode, nil
}
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
var start_time time.Time
var time_taken time.Duration
var r *bufio.Reader
var line string
var flds []string
var req_meth string
var req_path string
//var req_proto string
var x_forwarded_host string
var raw_req bytes.Buffer
var status_code int
var req *http.Request
var err error
defer wg.Done()
cts.C.log.Write(cts.Sid, LOG_INFO, "Starting rpx(%d) loop", crpx.id)
start_time = time.Now()
const rpx_header_line_max = 65535 // TODO: make this configurable
r = bufio.NewReader(bytes.NewReader(data))
line, err = read_line_limited(r, rpx_header_line_max)
if err != nil && !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
line = strings.TrimRight(line, "\r\n")
flds = strings.Fields(line)
if len(flds) < 3 {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line)
goto done
}
// TODO: handle trailers...
req_meth = flds[0]
req_path = flds[1]
//req_proto = flds[2]
raw_req.WriteString(line)
raw_req.WriteString("\r\n")
// create a request assuming it's a normal http request
req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
for {
line, err = read_line_limited(r, rpx_header_line_max)
if err != nil && !errors.Is(err, io.EOF) {
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
goto done
}
line = strings.TrimRight(line, "\r\n")
if line == "" { break }
flds = strings.SplitN(line, ":", 2)
if len(flds) == 2 {
var k string
var v string
k = strings.TrimSpace(flds[0])
v = strings.TrimSpace(flds[1])
req.Header.Add(k, v)
if strings.EqualFold(k, "Host") {
// a normal http client would set HOst to be the target address.
// the raw header is coming from the server. so it's different
// from the host it's supposed to be. correct it to the right value.
fmt.Fprintf(&raw_req, "%s: %s\r\n", k, req.Host)
} else {
raw_req.WriteString(line)
raw_req.WriteString("\r\n")
if strings.EqualFold(k, "X-Forwarded-Host") {
x_forwarded_host = v
}
}
}
if errors.Is(err, io.EOF) { break }
}
raw_req.WriteString("\r\n")
if x_forwarded_host == "" {
x_forwarded_host = req.Host
}
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
// websocket
status_code, err = cts.proxy_ws(crpx, raw_req.Bytes(), req)
} else {
// normal http
status_code, err = cts.proxy_http(crpx, req)
}
time_taken = time.Since(start_time)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "rpx(%d) %s - %s %s %d %.9f - failed to proxy - %s", crpx.id, x_forwarded_host, req_meth, req_path, status_code, time_taken.Seconds(), err.Error())
goto done
} else {
cts.C.log.Write(cts.Sid, LOG_INFO, "rpx(%d) %s - %s %s %d %.9f", crpx.id, x_forwarded_host, req_meth, req_path, status_code, time_taken.Seconds())
}
done:
err = cts.psc.Send(MakeRpxStopPacket(crpx.id))
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "rpx(%d) Failed to send %s to server - %s", crpx.id, PACKET_KIND_RPX_STOP.String(), err.Error())
}
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) loop", crpx.id)
crpx.ReqStop()
cts.rpx_mtx.Lock()
delete(cts.rpx_map, crpx.id)
cts.rpx_mtx.Unlock()
cts.C.stats.rpx_sessions.Add(-1)
cts.C.log.Write(cts.Sid, LOG_INFO, "Ended rpx(%d) loop", crpx.id)
}
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
var crpx *ClientRpx
var ok bool
cts.rpx_mtx.Lock()
_, ok = cts.rpx_map[id]
if ok {
cts.rpx_mtx.Unlock()
return fmt.Errorf("multiple start on rpx id %d", id)
}
crpx = &ClientRpx{ id: id }
cts.rpx_map[id] = crpx
// i want the pipe to be created before the goroutine is started
// so that the WriteRpx() can write to the pipe. i protect pipe creation
// and context creation with a mutex
crpx.pr, crpx.pw = io.Pipe()
crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx)
cts.rpx_mtx.Unlock()
cts.C.stats.rpx_sessions.Add(1)
wg.Add(1)
go cts.RpxLoop(crpx, data, wg)
return nil
}
func (cts *ClientConn) StopRpx(id uint64) error {
var crpx *ClientRpx
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
crpx.ReqStop()
return nil
}
func (cts *ClientConn) WriteRpx(id uint64, data []byte) error {
var crpx *ClientRpx
var err error
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
// TODO: may have to write it in a goroutine to avoid blocking?
_, err = crpx.pw.Write(data)
if err != nil {
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error())
return err
}
return nil
}
func (cts *ClientConn) EofRpx(id uint64, data []byte) error {
var crpx *ClientRpx
crpx = cts.FindClientRpxById(id)
if crpx == nil {
return fmt.Errorf("unknown rpx id %d", id)
}
// close the writing end only. leave the reading end untouched
crpx.pw.Close()
return nil
}
func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
switch packet_type {
case PACKET_KIND_RPX_START:
return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg)
case PACKET_KIND_RPX_STOP:
return cts.StopRpx(evt.Id)
case PACKET_KIND_RPX_DATA:
return cts.WriteRpx(evt.Id, evt.Data)
case PACKET_KIND_RPX_EOF:
return cts.EofRpx(evt.Id, evt.Data)
}
// ignore other packet types
return nil
}