Files
hodu/server-rpx.go

332 lines
8.1 KiB
Go

package hodu
import "bufio"
import "bytes"
import "errors"
import "fmt"
import "io"
import "net"
import "net/http"
import "strconv"
import "strings"
import "sync"
type server_rpx struct {
S *Server
Id string
}
// ------------------------------------
func (rpx *server_rpx) Identity() string {
return rpx.Id
}
func (rpx *server_rpx) Cors(req *http.Request) bool {
return false
}
func (rpx *server_rpx) Authenticate(req *http.Request) (int, string) {
return http.StatusOK, ""
}
func (rpx *server_rpx) get_client_token(req *http.Request) string {
var val string
// TODO: enhance this client token extraction logic with some expression language?
val = req.Header.Get(rpx.S.Cfg.RpxClientTokenAttrName)
if val == "" { val = req.Host }
if rpx.S.Cfg.RpxClientTokenRegex != nil {
val = get_regex_submatch(rpx.S.Cfg.RpxClientTokenRegex, val, rpx.S.Cfg.RpxClientTokenSubmatchIndex)
}
return val
}
func (rpx* server_rpx) handle_header_data(rpx_id uint64, data []byte, w http.ResponseWriter) (int, error) {
var sc *bufio.Scanner
var line string
var flds []string
var status_code int
var err error
sc = bufio.NewScanner(bytes.NewReader(data))
sc.Scan()
line = sc.Text()
flds = strings.Fields(line)
if (len(flds) < 2) { // i care about the status code..
return http.StatusBadGateway, fmt.Errorf("invalid response status for rpx(%d) - %s", rpx_id, line)
}
status_code, err = strconv.Atoi(flds[1])
if err != nil {
return http.StatusBadGateway, fmt.Errorf("invalid response code for rpx(%d) - %s", rpx_id, err.Error())
}
for sc.Scan() {
line = sc.Text()
if line == "" { break }
flds = strings.SplitN(line, ":", 2)
if len(flds) == 2 {
w.Header().Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
}
}
err = sc.Err()
if err != nil {
return http.StatusBadGateway, fmt.Errorf("failed to parse response for rpx(%d) - %s", rpx_id, err.Error())
}
w.WriteHeader(status_code)
return status_code, nil
}
func (rpx *server_rpx) handle_response(srpx *ServerRpx, req *http.Request, w http.ResponseWriter, ws_upgrade bool, wg *sync.WaitGroup) {
var start_resp []byte
var status_code int
var buf [4096]byte
var n int
var wr io.Writer
var wrote_br_chan bool
var err error
defer wg.Done()
select {
case start_resp = <- srpx.start_chan:
// received the header. ready to proceed to the body
// do nothing. just continue
status_code, err = rpx.handle_header_data(srpx.id, start_resp, w)
if err != nil { goto done }
case <- srpx.done_chan:
err = fmt.Errorf("rpx(%d) terminated before receiving header", srpx.id)
status_code = http.StatusBadGateway
goto done
case <- req.Context().Done():
err = fmt.Errorf("rpx(%d) terminated before receiving header - %s", srpx.id, req.Context().Err().Error())
status_code = http.StatusBadGateway
goto done
// no default. block
}
if ws_upgrade && status_code == http.StatusSwitchingProtocols {
var hijk http.Hijacker
var conn net.Conn
var ok bool
hijk, ok = w.(http.Hijacker)
if !ok {
err = fmt.Errorf("failed to upgrade rpx(%d) - not a hijacker", srpx.id)
status_code = http.StatusInternalServerError
goto done
}
conn, _, err = hijk.Hijack()
if err != nil {
err = fmt.Errorf("failed to upgrade rpx(%d) - %s", srpx.id, err.Error())
status_code = http.StatusInternalServerError
goto done
}
// websocket upgrade is successful
srpx.br = conn
srpx.br_chan <- true // inform another goroutine that the protocol switching is completed.
wrote_br_chan = true
wr = conn
} else {
if ws_upgrade {
srpx.br_chan <- false
wrote_br_chan = true
} // indicate upgrade failure
wr = w
}
for {
n, err = srpx.pr.Read(buf[:])
if n > 0 {
var err2 error
_, err2 = wr.Write(buf[:n])
if err2 != nil {
err = err2
status_code = http.StatusInternalServerError
break
}
}
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
} else {
status_code = http.StatusInternalServerError
}
break
}
}
done:
// just send another in case the code got jump into this part for an error
// may not be consumed but the channel is large enough for redundant data
srpx.resp_status_code = status_code
srpx.resp_error = err
if ws_upgrade && !wrote_br_chan {
srpx.br_chan <- false
}
}
func (rpx *server_rpx) alloc_server_rpx(cts *ServerConn, req *http.Request) (*ServerRpx, error) {
var srpx *ServerRpx
var start_id uint64
var assigned_id uint64
var ok bool
cts.rpx_mtx.Lock()
start_id = cts.rpx_next_id
for {
_, ok = cts.rpx_map[cts.rpx_next_id]
if !ok {
assigned_id = cts.rpx_next_id
cts.rpx_next_id++
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
break
}
cts.rpx_next_id++
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
if cts.rpx_next_id == start_id {
// unlikely to happen but it cycled through the whole range.
cts.rpx_mtx.Unlock()
return nil, fmt.Errorf("failed to assign id")
}
}
srpx = &ServerRpx{
id: assigned_id,
start_chan: make(chan []byte, 5),
done_chan: make(chan bool, 5),
br_chan: make(chan bool, 5),
}
srpx.br = req.Body
srpx.pr, srpx.pw = io.Pipe()
cts.rpx_map[assigned_id] = srpx
cts.rpx_mtx.Unlock()
cts.S.stats.rpx_sessions.Add(1)
return srpx, nil
}
func (rpx *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
var s *Server
var client_token string
var start_sent bool
var cts *ServerConn
var status_code int
var srpx *ServerRpx
var ws_upgrade bool
var buf [4096]byte
var wg sync.WaitGroup
var err error
s = rpx.S
client_token = rpx.get_client_token(req)
cts = s.FindServerConnByClientToken(client_token)
if cts == nil {
status_code = WriteEmptyRespHeader(w, http.StatusNotFound)
err = fmt.Errorf("unknown client token - %s", client_token)
goto oops
}
srpx, err = rpx.alloc_server_rpx(cts, req)
if err != nil {
status_code = WriteEmptyRespHeader(w, http.StatusServiceUnavailable)
err = fmt.Errorf("unable to allocate rpx - %s", err.Error())
goto oops
}
// arrange to clear the rpx_map entry when this function exits
defer func() {
cts.rpx_mtx.Lock()
delete(cts.rpx_map, srpx.id)
cts.rpx_mtx.Unlock()
cts.S.stats.rpx_sessions.Add(-1)
}()
ws_upgrade = strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade");
if ws_upgrade && req.ContentLength > 0 {
// while other webservers are ok with upgrade request with body payload,
// this program rejects such a request for impelementation limitation as
// it's not dealing with a raw byte but is using the standard web server handler.
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
err = fmt.Errorf("failed to assign id")
goto oops
}
err = cts.pss.Send(MakeRpxStartPacket(srpx.id, get_http_req_line_and_headers(req, true)))
if err != nil {
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
goto oops
}
start_sent = true
wg.Add(1)
go rpx.handle_response(srpx, req, w, ws_upgrade, &wg)
if ws_upgrade {
// wait until the protocol switching is done in rpx.handle_response()
var upgraded bool
upgraded = <- srpx.br_chan
if upgraded {
// arrange to close the hijacked connection inside rpx.handle_response()
defer srpx.br.Close()
}
}
for {
var n int
n, err = srpx.br.Read(buf[:])
if n > 0 {
var err2 error
err2 = cts.pss.Send(MakeRpxDataPacket(srpx.id, buf[:n]))
if err2 != nil {
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
goto oops
}
}
if err != nil {
if errors.Is(err, io.EOF) {
err = cts.pss.Send(MakeRpxEofPacket(srpx.id))
if err != nil {
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
goto oops
}
break
}
status_code = WriteEmptyRespHeader(w, http.StatusInternalServerError)
goto oops
}
}
wg.Wait()
if srpx.resp_error != nil {
status_code = WriteEmptyRespHeader(w, srpx.resp_status_code)
err = srpx.resp_error
goto oops
}
select {
case <- srpx.done_chan:
// anything to do?
case <- req.Context().Done():
// anything to do?
// no default. block
}
cts.pss.Send(MakeRpxStopPacket(srpx.id))
return srpx.resp_status_code, nil
oops:
if srpx != nil && start_sent { cts.pss.Send(MakeRpxStopPacket(srpx.id)) }
return status_code, err
}