From 0696f4f5600f30c25eb06e83ff61ea7a7c7b13de Mon Sep 17 00:00:00 2001 From: hyung-hwan Date: Wed, 20 Aug 2025 02:04:06 +0900 Subject: [PATCH] updated code to add x-forwarded-host and x-forwarded-proto for rpx --- client.go | 18 ++++++++++++++---- hodu.go | 17 ++++++++++++++++- server-rpx.go | 1 - server.go | 20 +++++++++++--------- 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 0b1dec5..e576ae1 100644 --- a/client.go +++ b/client.go @@ -1705,7 +1705,6 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) //req_proto = flds[2] // create a request assuming it's a normal http request - req, err = http.NewRequestWithContext(crpx.ctx, req_meth, cts.C.rpx_target_url + req_path, crpx.pr) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "failed to create request for rpx(%d) - %s", crpx.id, err.Error()) @@ -1717,7 +1716,12 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) if line == "" { break } flds = strings.SplitN(line, ":", 2) if len(flds) == 2 { - req.Header.Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1])) + var k string + var v string + k = strings.TrimSpace(flds[0]) + v = strings.TrimSpace(flds[1]) + req.Header.Add(k, v) +//fmt.Printf ("ADDING HEADER %s: %v\n", k, v) } } err = sc.Err() @@ -1845,9 +1849,12 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) var resp *http.Response tr = &http.Transport { - TLSClientConfig: cts.C.rpx_target_tls, + DisableKeepAlives: true, // this implementation can't support keepalive.. } - + if cts.C.rpx_target_tls != nil { + tr.TLSClientConfig = cts.C.rpx_target_tls + } +//fmt.Printf("%+v\n", req) resp, err = tr.RoundTrip(req) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to send rpx(%d) request - %s", crpx.id, err.Error()) @@ -1864,6 +1871,7 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) for { n, err = resp.Body.Read(buf[:]) +//fmt.Printf ("READ RESPONSE [%s], %d, %v\n", string(buf[:n]), n, err) if n > 0 { var err2 error err2 = cts.psc.Send(MakeRpxDataPacket(crpx.id, buf[:n])) @@ -1880,6 +1888,7 @@ func (cts *ClientConn) RpxLoop(crpx *ClientRpx, data []byte, wg *sync.WaitGroup) break } } +//fmt.Printf ("READ RESPONSE LOOP IS OVER\n") } done: @@ -1946,6 +1955,7 @@ func (cts *ClientConn) WriteRpx(id uint64, data []byte) error { } // TODO: may have to write it in a goroutine to avoid blocking? +//fmt.Printf("UPLOADED DATA [%s]\n", string(data)) _, err = crpx.pw.Write(data) if err != nil { cts.C.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx(%d) data - %s", id, err.Error()) diff --git a/hodu.go b/hodu.go index 40d3096..af68dba 100644 --- a/hodu.go +++ b/hodu.go @@ -473,6 +473,8 @@ func get_http_req_line_and_headers(r *http.Request, force_host bool) []byte { var value string var values []string var host_found bool + var x_forwarded_host_found bool + var x_forwarded_proto_found bool fmt.Fprintf(&buf, "%s %s %s\r\n", r.Method, r.RequestURI, r.Proto) @@ -485,7 +487,12 @@ func get_http_req_line_and_headers(r *http.Request, force_host bool) []byte { continue } else if strings.EqualFold(name, "Host") { host_found = true + } else if strings.EqualFold(name, "X-Forwarded-Host") { + x_forwarded_host_found = true + } else if strings.EqualFold(name, "X-Forwarded-Proto") { + x_forwarded_proto_found = true } + for _, value = range values { fmt.Fprintf(&buf, "%s: %s\r\n", name, value) } @@ -494,7 +501,15 @@ func get_http_req_line_and_headers(r *http.Request, force_host bool) []byte { if force_host && !host_found && r.Host != "" { fmt.Fprintf(&buf, "Host: %s\r\n", r.Host) } -// TODO: host and x-forwarded-for, x-forwarded-proto, etc??? + if !x_forwarded_host_found && r.Host != "" { + fmt.Fprintf(&buf, "X-Forwarded-Host: %s\r\n", r.Host) + } + if !x_forwarded_proto_found && r.Host != "" { + var proto string + if r.TLS != nil { proto = "https" } else { proto = "http" } + fmt.Fprintf(&buf, "X-Forwarded-Proto: %s\r\n", proto) + } +// TODO: host and x-forwarded-for, etc??? buf.WriteString("\r\n") // End of headers return buf.Bytes() diff --git a/server-rpx.go b/server-rpx.go index 1ec569e..20bc3cc 100644 --- a/server-rpx.go +++ b/server-rpx.go @@ -283,7 +283,6 @@ func (rpx *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, for { var n int - n, err = srpx.br.Read(buf[:]) if n > 0 { var err2 error diff --git a/server.go b/server.go index b41fabc..63397b0 100644 --- a/server.go +++ b/server.go @@ -293,9 +293,10 @@ func (rpty *ServerRpty) ReqStop() { rpty.ws.Close() } -func (rpx *ServerRpx) ReqStop() { +func (rpx *ServerRpx) ReqStop(close_web bool) { rpx.done_chan <- true rpx.pw.Close() + if close_web { rpx.br.Close() } } // ------------------------------------ @@ -894,7 +895,9 @@ func (cts *ServerConn) StartRpxWebById(srpx* ServerRpx, id uint64, data []byte) } func (cts *ServerConn) StopRpxWebById(srpx* ServerRpx, id uint64) error { - srpx.ReqStop() + cts.S.log.Write(cts.Sid, LOG_DEBUG, "Requesting to stop rpx(%d)", srpx.id) + srpx.ReqStop(true) + cts.S.log.Write(cts.Sid, LOG_DEBUG, "Requested to stop rpx(%d)", srpx.id) return nil } @@ -903,13 +906,13 @@ func (cts *ServerConn) WroteRpxWebById(srpx* ServerRpx, id uint64, data []byte) _, err = srpx.pw.Write(data) if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to write rpx data(%d) to rpx pipe - %s", id, err.Error()) - srpx.ReqStop() + srpx.ReqStop(true) } return err } func (cts *ServerConn) EofRpxWebById(srpx* ServerRpx, id uint64) error { - srpx.ReqStop() + srpx.ReqStop(false) return nil } @@ -1163,7 +1166,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpty(%d) from %s - %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr, err.Error()) } else { - cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpty(%d) from %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_DEBUG, "Handled %s event for rpty(%d) from %s", pkt.Kind.String(), x.RptyEvt.Id, cts.RemoteAddr) } } else { cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) @@ -1184,7 +1187,7 @@ func (cts *ServerConn) receive_from_stream(wg *sync.WaitGroup) { if err != nil { cts.S.log.Write(cts.Sid, LOG_ERROR, "Failed to handle %s event for rpx(%d) from %s - %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr, err.Error()) } else { - cts.S.log.Write(cts.Sid, LOG_ERROR, "Handled %s event for rpx(%d) from %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr) + cts.S.log.Write(cts.Sid, LOG_DEBUG, "Handled %s event for rpx(%d) from %s", pkt.Kind.String(), x.RpxEvt.Id, cts.RemoteAddr) } } else { cts.S.log.Write(cts.Sid, LOG_ERROR, "Invalid %s packet from %s", pkt.Kind.String(), cts.RemoteAddr) @@ -1208,7 +1211,7 @@ done: if len(cts.rpx_map) > 0 { var rpx *ServerRpx for _, rpx = range cts.rpx_map { - rpx.ReqStop() + rpx.ReqStop(false) } } cts.rpx_mtx.Unlock() @@ -1293,10 +1296,9 @@ func (cts *ServerConn) ReqStop() { cts.rpty_mtx.Unlock() cts.rpx_mtx.Lock() - for _, srpx = range cts.rpx_map { srpx.ReqStop() } + for _, srpx = range cts.rpx_map { srpx.ReqStop(true) } cts.rpx_mtx.Unlock() - // there is no good way to break a specific connection client to // the grpc server. while the global grpc server is closed in // ReqStop() for Server, the individuation connection is closed