400 lines
11 KiB
Go
400 lines
11 KiB
Go
package hodu
|
|
|
|
import "bufio"
|
|
import "bytes"
|
|
import "context"
|
|
import "crypto/tls"
|
|
import "errors"
|
|
import "fmt"
|
|
import "io"
|
|
import "net"
|
|
import "net/http"
|
|
import "sync"
|
|
import "strings"
|
|
import "time"
|
|
|
|
// rpx
|
|
func (cts *ClientConn) FindClientRpxById(id uint64) *ClientRpx {
|
|
var crpx *ClientRpx
|
|
var ok bool
|
|
|
|
cts.rpx_mtx.Lock()
|
|
crpx, ok = cts.rpx_map[id]
|
|
cts.rpx_mtx.Unlock()
|
|
|
|
if !ok { crpx = nil }
|
|
return crpx
|
|
}
|
|
|
|
func (cts *ClientConn) server_pipe_to_ws_target(crpx* ClientRpx, conn net.Conn, wg *sync.WaitGroup) {
|
|
var buf [4096]byte
|
|
var n int
|
|
var err error
|
|
|
|
defer wg.Done()
|
|
|
|
for {
|
|
n, err = crpx.pr.Read(buf[:])
|
|
if n > 0 {
|
|
var err2 error
|
|
_, err2 = conn.Write(buf[:n])
|
|
if err2 != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write websocket for rpx(%d) - %s", crpx.id, err2.Error())
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) { break }
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to read pipe for rpx(%d) - %s", crpx.id, err.Error())
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (cts *ClientConn) proxy_ws(crpx *ClientRpx, raw_req []byte, req *http.Request) (int, error) {
|
|
var l_wg sync.WaitGroup
|
|
var conn net.Conn
|
|
var resp *http.Response
|
|
var r *bufio.Reader
|
|
var buf [4096]byte
|
|
var n int
|
|
var err error
|
|
|
|
if cts.C.rpx_target_tls != nil {
|
|
var dialer *tls.Dialer
|
|
dialer = &tls.Dialer{
|
|
NetDialer: &net.Dialer{},
|
|
Config: cts.C.rpx_target_tls,
|
|
}
|
|
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
|
} else {
|
|
var dialer *net.Dialer
|
|
dialer = &net.Dialer{}
|
|
conn, err = dialer.DialContext(crpx.ctx, "tcp", cts.C.rpx_target_addr) // TODO: no hard coding
|
|
}
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Errorf("failed to dial websocket for rpx(%d) - %s", crpx.id, err.Error())
|
|
}
|
|
defer conn.Close()
|
|
|
|
// TODO: make this atomic?
|
|
crpx.ws_conn = conn
|
|
|
|
// write the raw request line and headers as sent by the server.
|
|
// for the upgrade request, i assume no payload.
|
|
_, err = conn.Write(raw_req)
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Errorf("failed to write websocket request for rpx(%d) - %s", crpx.id, err.Error())
|
|
}
|
|
|
|
r = bufio.NewReader(conn)
|
|
resp, err = http.ReadResponse(r, req)
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Errorf("failed to write websocket response for rpx(%d) - %s", crpx.id, err.Error())
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Errorf("failed to send rpx(%d) WebSocket headers to server - %s", crpx.id, err.Error())
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
// websock upgrade failed. let the code jump to the done
|
|
// label to skip reading from the pipe. the server side
|
|
// has the code to ensure no content-length. and the upgrade
|
|
// fails, the pipe below will be pending forever as the server
|
|
// side doesn't send data and there's no feeding to the pipe.
|
|
return resp.StatusCode, fmt.Errorf("protocol switching failed for rpx(%d)", crpx.id)
|
|
}
|
|
|
|
// unlike with the normal request, the actual pipe is not read
|
|
// until the initial switching protocol response is received.
|
|
|
|
l_wg.Add(1)
|
|
go cts.server_pipe_to_ws_target(crpx, conn, &l_wg)
|
|
|
|
for {
|
|
n, err = conn.Read(buf[:])
|
|
if n > 0 {
|
|
var err2 error
|
|
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
|
if err2 != nil {
|
|
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
|
|
return resp.StatusCode, fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
|
}
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
cts.psc.Send(MakeRpxEofPacket(crpx.id))
|
|
cts.C.log.Write(cts.Sid, LOG_DEBUG, "WebSocket rpx(%d) closed by server", crpx.id)
|
|
break
|
|
}
|
|
|
|
crpx.ReqStop() // to break server_pipe_ws_target. don't care about multiple stops
|
|
return resp.StatusCode, fmt.Errorf("failed to read WebSocket rpx(%d) - %s", crpx.id, err.Error())
|
|
}
|
|
}
|
|
|
|
// wait until the pipe reading(from the server side) goroutine is over
|
|
l_wg.Wait()
|
|
return resp.StatusCode, nil
|
|
}
|
|
|
|
func (cts *ClientConn) proxy_http(crpx *ClientRpx, req *http.Request) (int, error) {
|
|
var tr *http.Transport
|
|
var resp *http.Response
|
|
var buf [4096]byte
|
|
var n int
|
|
var err error
|
|
|
|
tr = &http.Transport {
|
|
DisableKeepAlives: true, // this implementation can't support keepalive..
|
|
}
|
|
if cts.C.rpx_target_tls != nil {
|
|
tr.TLSClientConfig = cts.C.rpx_target_tls
|
|
}
|
|
|
|
resp, err = tr.RoundTrip(req)
|
|
if err != nil {
|
|
return http.StatusInternalServerError, fmt.Errorf("failed to send rpx(%d) request - %s", crpx.id, err.Error())
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
err = cts.psc.Send(MakeRpxStartPacket(crpx.id, get_http_resp_line_and_headers(resp)))
|
|
if err != nil {
|
|
return resp.StatusCode, fmt.Errorf("failed to send rpx(%d) status and headers to server - %s", crpx.id, err.Error())
|
|
}
|
|
|
|
for {
|
|
n, err = resp.Body.Read(buf[:])
|
|
if n > 0 {
|
|
var err2 error
|
|
err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n]))
|
|
if err2 != nil {
|
|
return resp.StatusCode, fmt.Errorf("failed to send rpx(%d) data to server - %s", crpx.id, err2.Error())
|
|
}
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
return resp.StatusCode, fmt.Errorf("failed to read response body for rpx(%d) - %s", crpx.id, err.Error())
|
|
}
|
|
}
|
|
|
|
return resp.StatusCode, nil
|
|
}
|
|
|
|
func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) {
|
|
var start_time time.Time
|
|
var time_taken time.Duration
|
|
var r *bufio.Reader
|
|
var line string
|
|
var flds []string
|
|
var req_meth string
|
|
var req_path string
|
|
//var req_proto string
|
|
var x_forwarded_host string
|
|
var raw_req bytes.Buffer
|
|
var status_code int
|
|
var req *http.Request
|
|
var err error
|
|
|
|
defer wg.Done()
|
|
|
|
cts.C.log.Write(cts.Sid, LOG_INFO, "Starting rpx(%d) loop", crpx.id)
|
|
|
|
start_time = time.Now()
|
|
|
|
const rpx_header_line_max = 65535 // TODO: make this configurable
|
|
|
|
r = bufio.NewReader(bytes.NewReader(data))
|
|
line, err = read_line_limited(r, rpx_header_line_max)
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
|
|
goto done
|
|
}
|
|
line = strings.TrimRight(line, "\r\n")
|
|
|
|
flds = strings.Fields(line)
|
|
if len(flds) < 3 {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Invalid request line for rpx(%d) - %s", crpx.id, line)
|
|
goto done
|
|
}
|
|
|
|
// TODO: handle trailers...
|
|
req_meth = flds[0]
|
|
req_path = flds[1]
|
|
//req_proto = flds[2]
|
|
|
|
raw_req.WriteString(line)
|
|
raw_req.WriteString("\r\n")
|
|
|
|
// create a request assuming it's a normal http request
|
|
req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr)
|
|
if err != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error())
|
|
goto done
|
|
}
|
|
|
|
for {
|
|
line, err = read_line_limited(r, rpx_header_line_max)
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to parse request for rpx(%d) - %s", crpx.id, err.Error())
|
|
goto done
|
|
}
|
|
line = strings.TrimRight(line, "\r\n")
|
|
if line == "" { break }
|
|
flds = strings.SplitN(line, ":", 2)
|
|
if len(flds) == 2 {
|
|
var k string
|
|
var v string
|
|
k = strings.TrimSpace(flds[0])
|
|
v = strings.TrimSpace(flds[1])
|
|
req.Header.Add(k, v)
|
|
|
|
if strings.EqualFold(k, "Host") {
|
|
// a normal http client would set HOst to be the target address.
|
|
// the raw header is coming from the server. so it's different
|
|
// from the host it's supposed to be. correct it to the right value.
|
|
fmt.Fprintf(&raw_req, "%s: %s\r\n", k, req.Host)
|
|
} else {
|
|
raw_req.WriteString(line)
|
|
raw_req.WriteString("\r\n")
|
|
|
|
if strings.EqualFold(k, "X-Forwarded-Host") {
|
|
x_forwarded_host = v
|
|
}
|
|
}
|
|
}
|
|
if errors.Is(err, io.EOF) { break }
|
|
}
|
|
raw_req.WriteString("\r\n")
|
|
if x_forwarded_host == "" {
|
|
x_forwarded_host = req.Host
|
|
}
|
|
|
|
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
|
// websocket
|
|
status_code, err = cts.proxy_ws(crpx, raw_req.Bytes(), req)
|
|
} else {
|
|
// normal http
|
|
status_code, err = cts.proxy_http(crpx, req)
|
|
}
|
|
|
|
time_taken = time.Since(start_time)
|
|
if err != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "rpx(%d) %s - %s %s %d %.9f - failed to proxy - %s", crpx.id, x_forwarded_host, req_meth, req_path, status_code, time_taken.Seconds(), err.Error())
|
|
goto done
|
|
} else {
|
|
cts.C.log.Write(cts.Sid, LOG_INFO, "rpx(%d) %s - %s %s %d %.9f", crpx.id, x_forwarded_host, req_meth, req_path, status_code, time_taken.Seconds())
|
|
}
|
|
|
|
done:
|
|
err = cts.psc.Send(MakeRpxStopPacket(crpx.id))
|
|
if err != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "rpx(%d) Failed to send %s to server - %s", crpx.id, PACKET_KIND_RPX_STOP.String(), err.Error())
|
|
}
|
|
cts.C.log.Write(cts.Sid, LOG_INFO, "Ending rpx(%d) loop", crpx.id)
|
|
|
|
crpx.ReqStop()
|
|
|
|
cts.rpx_mtx.Lock()
|
|
delete(cts.rpx_map, crpx.id)
|
|
cts.rpx_mtx.Unlock()
|
|
cts.C.stats.rpx_sessions.Add(-1)
|
|
cts.C.log.Write(cts.Sid, LOG_INFO, "Ended rpx(%d) loop", crpx.id)
|
|
}
|
|
|
|
func (cts *ClientConn) StartRpx(id uint64, data []byte, wg *sync.WaitGroup) error {
|
|
var crpx *ClientRpx
|
|
var ok bool
|
|
|
|
cts.rpx_mtx.Lock()
|
|
_, ok = cts.rpx_map[id]
|
|
if ok {
|
|
cts.rpx_mtx.Unlock()
|
|
return fmt.Errorf("multiple start on rpx id %d", id)
|
|
}
|
|
crpx = &ClientRpx{ id: id }
|
|
cts.rpx_map[id] = crpx
|
|
|
|
// i want the pipe to be created before the goroutine is started
|
|
// so that the WriteRpx() can write to the pipe. i protect pipe creation
|
|
// and context creation with a mutex
|
|
crpx.pr, crpx.pw = io.Pipe()
|
|
crpx.ctx, crpx.cancel = context.WithCancel(cts.C.Ctx)
|
|
|
|
cts.rpx_mtx.Unlock()
|
|
cts.C.stats.rpx_sessions.Add(1)
|
|
|
|
wg.Add(1)
|
|
go cts.RpxLoop(crpx, data, wg)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) StopRpx(id uint64) error {
|
|
var crpx *ClientRpx
|
|
|
|
crpx = cts.FindClientRpxById(id)
|
|
if crpx == nil {
|
|
return fmt.Errorf("unknown rpx id %d", id)
|
|
}
|
|
|
|
crpx.ReqStop()
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) WriteRpx(id uint64, data []byte) error {
|
|
var crpx *ClientRpx
|
|
var err error
|
|
|
|
crpx = cts.FindClientRpxById(id)
|
|
if crpx == nil {
|
|
return fmt.Errorf("unknown rpx id %d", id)
|
|
}
|
|
|
|
// TODO: may have to write it in a goroutine to avoid blocking?
|
|
_, err = crpx.pw.Write(data)
|
|
if err != nil {
|
|
cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error())
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) EofRpx(id uint64, data []byte) error {
|
|
var crpx *ClientRpx
|
|
|
|
crpx = cts.FindClientRpxById(id)
|
|
if crpx == nil {
|
|
return fmt.Errorf("unknown rpx id %d", id)
|
|
}
|
|
|
|
// close the writing end only. leave the reading end untouched
|
|
crpx.pw.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cts *ClientConn) HandleRpxEvent(packet_type PACKET_KIND, evt *RpxEvent) error {
|
|
switch packet_type {
|
|
case PACKET_KIND_RPX_START:
|
|
return cts.StartRpx(evt.Id, evt.Data, &cts.C.wg)
|
|
|
|
case PACKET_KIND_RPX_STOP:
|
|
return cts.StopRpx(evt.Id)
|
|
|
|
case PACKET_KIND_RPX_DATA:
|
|
return cts.WriteRpx(evt.Id, evt.Data)
|
|
|
|
case PACKET_KIND_RPX_EOF:
|
|
return cts.EofRpx(evt.Id, evt.Data)
|
|
}
|
|
|
|
// ignore other packet types
|
|
return nil
|
|
}
|