added code for rpx handling
This commit is contained in:
318
server-rpx.go
318
server-rpx.go
@ -1,6 +1,15 @@
|
||||
package hodu
|
||||
|
||||
import "bufio"
|
||||
import "bytes"
|
||||
import "errors"
|
||||
import "fmt"
|
||||
import "io"
|
||||
import "net"
|
||||
import "net/http"
|
||||
import "strconv"
|
||||
import "strings"
|
||||
import "sync"
|
||||
|
||||
type server_rpx struct {
|
||||
S *Server
|
||||
@ -8,28 +17,313 @@ type server_rpx struct {
|
||||
}
|
||||
|
||||
// ------------------------------------
|
||||
func (pxy *server_rpx) Identity() string {
|
||||
return pxy.Id
|
||||
func (rpx *server_rpx) Identity() string {
|
||||
return rpx.Id
|
||||
}
|
||||
|
||||
func (pxy *server_rpx) Cors(req *http.Request) bool {
|
||||
func (rpx *server_rpx) Cors(req *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (pxy *server_rpx) Authenticate(req *http.Request) (int, string) {
|
||||
func (rpx *server_rpx) Authenticate(req *http.Request) (int, string) {
|
||||
return http.StatusOK, ""
|
||||
}
|
||||
|
||||
func (pxy *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
|
||||
var status_code int
|
||||
// var err error
|
||||
func (rpx *server_rpx) get_client_token(req *http.Request) string {
|
||||
var val string
|
||||
|
||||
status_code = http.StatusOK
|
||||
// TODO: enhance this client token extraction logic with some expression language?
|
||||
val = req.Header.Get(rpx.S.Cfg.RpxClientTokenAttrName)
|
||||
if val == "" { val = req.Host }
|
||||
|
||||
//done:
|
||||
return status_code, nil
|
||||
if rpx.S.Cfg.RpxClientTokenRegex != nil {
|
||||
val = get_regex_submatch(rpx.S.Cfg.RpxClientTokenRegex, val, rpx.S.Cfg.RpxClientTokenSubmatchIndex)
|
||||
}
|
||||
|
||||
//oops:
|
||||
// return status_code, err
|
||||
return val
|
||||
}
|
||||
|
||||
func (rpx* server_rpx) handle_header_data(rpx_id uint64, data []byte, w http.ResponseWriter) (int, error) {
|
||||
var sc *bufio.Scanner
|
||||
var line string
|
||||
var flds []string
|
||||
var status_code int
|
||||
var err error
|
||||
|
||||
sc = bufio.NewScanner(bytes.NewReader(data))
|
||||
sc.Scan()
|
||||
line = sc.Text()
|
||||
|
||||
flds = strings.Fields(line)
|
||||
if (len(flds) < 2) { // i care about the status code..
|
||||
return http.StatusBadGateway, fmt.Errorf("invalid response status for rpx(%d) - %s", rpx_id, line)
|
||||
}
|
||||
status_code, err = strconv.Atoi(flds[1])
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, fmt.Errorf("invalid response code for rpx(%d) - %s", rpx_id, err.Error())
|
||||
}
|
||||
|
||||
for sc.Scan() {
|
||||
line = sc.Text()
|
||||
if line == "" { break }
|
||||
flds = strings.SplitN(line, ":", 2)
|
||||
if len(flds) == 2 {
|
||||
w.Header().Add(strings.TrimSpace(flds[0]), strings.TrimSpace(flds[1]))
|
||||
}
|
||||
}
|
||||
err = sc.Err()
|
||||
if err != nil {
|
||||
return http.StatusBadGateway, fmt.Errorf("failed to parse response for rpx(%d) - %s", rpx_id, err.Error())
|
||||
}
|
||||
|
||||
w.WriteHeader(status_code)
|
||||
return status_code, nil
|
||||
}
|
||||
|
||||
func (rpx *server_rpx) handle_response(srpx *ServerRpx, req *http.Request, w http.ResponseWriter, ws_upgrade bool, wg *sync.WaitGroup) {
|
||||
var start_resp []byte
|
||||
var status_code int
|
||||
var buf [4096]byte
|
||||
var n int
|
||||
var wr io.Writer
|
||||
var wrote_br_chan bool
|
||||
var err error
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
select {
|
||||
case start_resp = <- srpx.start_chan:
|
||||
// received the header. ready to proceed to the body
|
||||
// do nothing. just continue
|
||||
status_code, err = rpx.handle_header_data(srpx.id, start_resp, w)
|
||||
if err != nil { goto done }
|
||||
|
||||
case <- srpx.done_chan:
|
||||
err = fmt.Errorf("rpx(%d) terminated before receiving header", srpx.id)
|
||||
status_code = http.StatusBadGateway
|
||||
goto done
|
||||
case <- req.Context().Done():
|
||||
err = fmt.Errorf("rpx(%d) terminated before receiving header - %s", srpx.id, req.Context().Err().Error())
|
||||
status_code = http.StatusBadGateway
|
||||
goto done
|
||||
|
||||
// no default. block
|
||||
}
|
||||
|
||||
if ws_upgrade && status_code == http.StatusSwitchingProtocols {
|
||||
var hijk http.Hijacker
|
||||
var conn net.Conn
|
||||
var ok bool
|
||||
|
||||
hijk, ok = w.(http.Hijacker)
|
||||
if !ok {
|
||||
err = fmt.Errorf("failed to upgrade rpx(%d) - not a hijacker", srpx.id)
|
||||
status_code = http.StatusInternalServerError
|
||||
goto done
|
||||
}
|
||||
|
||||
conn, _, err = hijk.Hijack()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to upgrade rpx(%d) - %s", srpx.id, err.Error())
|
||||
status_code = http.StatusInternalServerError
|
||||
goto done
|
||||
}
|
||||
|
||||
// websocket upgrade is successful
|
||||
srpx.br = conn
|
||||
srpx.br_chan <- true // inform another goroutine that the protocol switching is completed.
|
||||
wrote_br_chan = true
|
||||
|
||||
wr = conn
|
||||
} else {
|
||||
if ws_upgrade {
|
||||
srpx.br_chan <- false
|
||||
wrote_br_chan = true
|
||||
} // indicate upgrade failure
|
||||
wr = w
|
||||
}
|
||||
|
||||
for {
|
||||
n, err = srpx.pr.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
_, err2 = wr.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
status_code = http.StatusInternalServerError
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
} else {
|
||||
status_code = http.StatusInternalServerError
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
// just send another in case the code got jump into this part for an error
|
||||
// may not be consumed but the channel is large enough for redundant data
|
||||
srpx.resp_status_code = status_code
|
||||
srpx.resp_error = err
|
||||
|
||||
if ws_upgrade && !wrote_br_chan {
|
||||
srpx.br_chan <- false
|
||||
}
|
||||
}
|
||||
|
||||
func (rpx *server_rpx) alloc_server_rpx(cts *ServerConn, req *http.Request) (*ServerRpx, error) {
|
||||
var srpx *ServerRpx
|
||||
var start_id uint64
|
||||
var assigned_id uint64
|
||||
var ok bool
|
||||
|
||||
cts.rpx_mtx.Lock()
|
||||
start_id = cts.rpx_next_id
|
||||
for {
|
||||
_, ok = cts.rpx_map[cts.rpx_next_id]
|
||||
if !ok {
|
||||
assigned_id = cts.rpx_next_id
|
||||
cts.rpx_next_id++
|
||||
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
|
||||
break
|
||||
}
|
||||
cts.rpx_next_id++
|
||||
if cts.rpx_next_id == 0 { cts.rpx_next_id++ }
|
||||
if cts.rpx_next_id == start_id {
|
||||
// unlikely to happen but it cycled through the whole range.
|
||||
cts.rpx_mtx.Unlock()
|
||||
return nil, fmt.Errorf("failed to assign id")
|
||||
}
|
||||
}
|
||||
|
||||
srpx = &ServerRpx{
|
||||
id: assigned_id,
|
||||
start_chan: make(chan []byte, 5),
|
||||
done_chan: make(chan bool, 5),
|
||||
br_chan: make(chan bool, 5),
|
||||
}
|
||||
srpx.br = req.Body
|
||||
srpx.pr, srpx.pw = io.Pipe()
|
||||
cts.rpx_map[assigned_id] = srpx
|
||||
|
||||
cts.rpx_mtx.Unlock()
|
||||
return srpx, nil
|
||||
}
|
||||
|
||||
func (rpx *server_rpx) ServeHTTP(w http.ResponseWriter, req *http.Request) (int, error) {
|
||||
var s *Server
|
||||
var client_token string
|
||||
var start_sent bool
|
||||
var cts *ServerConn
|
||||
var status_code int
|
||||
var srpx *ServerRpx
|
||||
var ws_upgrade bool
|
||||
var buf [4096]byte
|
||||
var wg sync.WaitGroup
|
||||
var err error
|
||||
|
||||
s = rpx.S
|
||||
client_token = rpx.get_client_token(req)
|
||||
cts = s.FindServerConnByClientToken(client_token)
|
||||
if cts == nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusNotFound)
|
||||
err = fmt.Errorf("unknown client token - %s", client_token)
|
||||
goto oops
|
||||
}
|
||||
|
||||
srpx, err = rpx.alloc_server_rpx(cts, req)
|
||||
if err != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusServiceUnavailable)
|
||||
err = fmt.Errorf("unable to allocate rpx - %s", err.Error())
|
||||
goto oops
|
||||
}
|
||||
|
||||
// arrange to clear the rpx_map entry when this function exits
|
||||
defer func() {
|
||||
cts.rpx_mtx.Lock()
|
||||
delete(cts.rpx_map, srpx.id)
|
||||
cts.rpx_mtx.Unlock()
|
||||
}()
|
||||
|
||||
ws_upgrade = strings.EqualFold(req.Header.Get("Upgrade"), "websocket") && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade");
|
||||
if ws_upgrade && req.ContentLength > 0 {
|
||||
// while other webservers are ok with upgrade request with body payload,
|
||||
// this program rejects such a request for impelementation limitation as
|
||||
// it's not dealing with a raw byte but is using the standard web server handler.
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadRequest)
|
||||
err = fmt.Errorf("failed to assign id")
|
||||
goto oops
|
||||
}
|
||||
|
||||
err = cts.pss.Send(MakeRpxStartPacket(srpx.id, get_http_req_line_and_headers(req, true)))
|
||||
if err != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
|
||||
goto oops
|
||||
}
|
||||
start_sent = true
|
||||
|
||||
wg.Add(1)
|
||||
go rpx.handle_response(srpx, req, w, ws_upgrade, &wg)
|
||||
|
||||
if ws_upgrade {
|
||||
// wait until the protocol switching is done in rpx.handle_response()
|
||||
var upgraded bool
|
||||
upgraded = <- srpx.br_chan
|
||||
if upgraded {
|
||||
// arrange to close the hijacked connection inside rpx.handle_response()
|
||||
defer srpx.br.Close()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
var n int
|
||||
|
||||
n, err = srpx.br.Read(buf[:])
|
||||
if n > 0 {
|
||||
var err2 error
|
||||
err2 = cts.pss.Send(MakeRpxDataPacket(srpx.id, buf[:n]))
|
||||
if err2 != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
|
||||
goto oops
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = cts.pss.Send(MakeRpxEofPacket(srpx.id))
|
||||
if err != nil {
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusBadGateway)
|
||||
goto oops
|
||||
}
|
||||
break
|
||||
}
|
||||
status_code = WriteEmptyRespHeader(w, http.StatusInternalServerError)
|
||||
goto oops
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if srpx.resp_error != nil {
|
||||
status_code = WriteEmptyRespHeader(w, srpx.resp_status_code)
|
||||
err = srpx.resp_error
|
||||
goto oops
|
||||
}
|
||||
|
||||
select {
|
||||
case <- srpx.done_chan:
|
||||
// anything to do?
|
||||
case <- req.Context().Done():
|
||||
// anything to do?
|
||||
// no default. block
|
||||
}
|
||||
|
||||
cts.pss.Send(MakeRpxStopPacket(srpx.id))
|
||||
return srpx.resp_status_code, nil
|
||||
|
||||
oops:
|
||||
if srpx != nil && start_sent { cts.pss.Send(MakeRpxStopPacket(srpx.id)) }
|
||||
return status_code, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user