Files
codit/backend/internal/middleware/request_state.go

429 lines
9.5 KiB
Go

package middleware
import "bufio"
import "bytes"
import "context"
import "fmt"
import "io"
import "net"
import "net/http"
import "strings"
import "codit/internal/db"
import "codit/internal/models"
const storeKey ctxKey = "store"
const userHolderKey ctxKey = "user_holder"
const principalHolderKey ctxKey = "principal_holder"
const afterCommitKey ctxKey = "after_commit_holder"
const authReasonHolderKey ctxKey = "auth_reason_holder"
type userHolder struct {
User models.User
OK bool
}
type principalHolder struct {
Principal models.ServicePrincipal
OK bool
}
type afterCommitHolder struct {
Fns []func()
}
type authReasonHolder struct {
Reason string
OK bool
}
type bufferedResponseWriter struct {
header http.Header
body bytes.Buffer
status int
wroteHeader bool
}
type passThroughStatusWriter struct {
http.ResponseWriter
status int
wroteHeader bool
}
func WithRequestStore(store *db.Store, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx context.Context
var holder *userHolder
var principal *principalHolder
var afterHolder *afterCommitHolder
var authHolder *authReasonHolder
ctx = r.Context()
if _, ok := StoreFromContext(ctx); !ok {
ctx = context.WithValue(ctx, storeKey, store)
}
if _, ok := ctx.Value(userHolderKey).(*userHolder); !ok {
holder = &userHolder{}
ctx = context.WithValue(ctx, userHolderKey, holder)
}
if _, ok := ctx.Value(principalHolderKey).(*principalHolder); !ok {
principal = &principalHolder{}
ctx = context.WithValue(ctx, principalHolderKey, principal)
}
if _, ok := ctx.Value(afterCommitKey).(*afterCommitHolder); !ok {
afterHolder = &afterCommitHolder{}
ctx = context.WithValue(ctx, afterCommitKey, afterHolder)
}
if _, ok := ctx.Value(authReasonHolderKey).(*authReasonHolder); !ok {
authHolder = &authReasonHolder{}
ctx = context.WithValue(ctx, authReasonHolderKey, authHolder)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func WithStore(r *http.Request, store *db.Store) *http.Request {
var ctx context.Context
if r == nil {
return nil
}
ctx = context.WithValue(r.Context(), storeKey, store)
return r.WithContext(ctx)
}
func StoreFromContext(ctx context.Context) (*db.Store, bool) {
var store *db.Store
var ok bool
store, ok = ctx.Value(storeKey).(*db.Store)
return store, ok
}
func WithStoreTransaction(store *db.Store, next http.Handler) http.Handler {
return withStoreTransaction(store, next, true)
}
func WithStoreTransactionUnbuffered(store *db.Store, next http.Handler) http.Handler {
return withStoreTransaction(store, next, false)
}
func withStoreTransaction(store *db.Store, next http.Handler, bufferWrites bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var txStore *db.Store
var recorder *bufferedResponseWriter
var statusWriter *passThroughStatusWriter
var err error
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
txStore, err = store.BeginStore(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
r = WithStore(r, txStore)
if isReadOnlyMethod(r.Method) {
defer func() {
if rec := recover(); rec != nil {
_ = txStore.Rollback()
panic(rec)
}
}()
next.ServeHTTP(w, r)
err = txStore.Commit()
if err != nil {
_ = txStore.Rollback()
return
}
runAfterCommit(r.Context())
return
}
if !bufferWrites {
statusWriter = newPassThroughStatusWriter(w)
defer func() {
if rec := recover(); rec != nil {
_ = txStore.Rollback()
panic(rec)
}
}()
next.ServeHTTP(statusWriter, r)
if statusWriter.status >= 400 {
_ = txStore.Rollback()
return
}
err = txStore.Commit()
if err != nil {
_ = txStore.Rollback()
if !statusWriter.wroteHeader {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}
runAfterCommit(r.Context())
return
}
recorder = newBufferedResponseWriter()
defer func() {
if rec := recover(); rec != nil {
_ = txStore.Rollback()
panic(rec)
}
}()
next.ServeHTTP(recorder, r)
if recorder.status >= 400 {
_ = txStore.Rollback()
recorder.FlushTo(w)
return
}
err = txStore.Commit()
if err != nil {
_ = txStore.Rollback()
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
runAfterCommit(r.Context())
recorder.FlushTo(w)
})
}
func isWebSocketUpgrade(r *http.Request) bool {
if r == nil {
return false
}
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") &&
strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket")
}
func RegisterAfterCommit(ctx context.Context, fn func()) {
var holder *afterCommitHolder
var ok bool
if ctx == nil || fn == nil {
return
}
holder, ok = ctx.Value(afterCommitKey).(*afterCommitHolder)
if !ok || holder == nil {
return
}
holder.Fns = append(holder.Fns, fn)
}
func runAfterCommit(ctx context.Context) {
var holder *afterCommitHolder
var ok bool
var fns []func()
var i int
if ctx == nil {
return
}
holder, ok = ctx.Value(afterCommitKey).(*afterCommitHolder)
if !ok || holder == nil || len(holder.Fns) == 0 {
return
}
fns = append(fns, holder.Fns...)
holder.Fns = nil
for i = 0; i < len(fns); i++ {
if fns[i] != nil {
fns[i]()
}
}
}
func rememberUser(ctx context.Context, user models.User) {
var holder *userHolder
var ok bool
holder, ok = ctx.Value(userHolderKey).(*userHolder)
if ok && holder != nil {
holder.User = user
holder.OK = true
}
}
func userFromHolder(ctx context.Context) (models.User, bool) {
var holder *userHolder
var ok bool
holder, ok = ctx.Value(userHolderKey).(*userHolder)
if ok && holder != nil && holder.OK {
return holder.User, true
}
return models.User{}, false
}
func rememberPrincipal(ctx context.Context, principal models.ServicePrincipal) {
var holder *principalHolder
var ok bool
holder, ok = ctx.Value(principalHolderKey).(*principalHolder)
if ok && holder != nil {
holder.Principal = principal
holder.OK = true
}
}
func principalFromHolder(ctx context.Context) (models.ServicePrincipal, bool) {
var holder *principalHolder
var ok bool
holder, ok = ctx.Value(principalHolderKey).(*principalHolder)
if ok && holder != nil && holder.OK {
return holder.Principal, true
}
return models.ServicePrincipal{}, false
}
func RememberAuthReason(ctx context.Context, reason string) {
var holder *authReasonHolder
var ok bool
if ctx == nil {
return
}
holder, ok = ctx.Value(authReasonHolderKey).(*authReasonHolder)
if ok && holder != nil {
holder.Reason = reason
holder.OK = strings.TrimSpace(reason) != ""
}
}
func AuthReasonFromContext(ctx context.Context) (string, bool) {
var holder *authReasonHolder
var ok bool
if ctx == nil {
return "", false
}
holder, ok = ctx.Value(authReasonHolderKey).(*authReasonHolder)
if ok && holder != nil && holder.OK {
return holder.Reason, true
}
return "", false
}
func isReadOnlyMethod(method string) bool {
return method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions
}
func newBufferedResponseWriter() *bufferedResponseWriter {
var writer *bufferedResponseWriter
writer = &bufferedResponseWriter{
header: make(http.Header),
status: http.StatusOK,
}
return writer
}
func newPassThroughStatusWriter(w http.ResponseWriter) *passThroughStatusWriter {
var writer *passThroughStatusWriter
writer = &passThroughStatusWriter{
ResponseWriter: w,
status: http.StatusOK,
}
return writer
}
func (w *bufferedResponseWriter) Header() http.Header {
return w.header
}
func (w *bufferedResponseWriter) WriteHeader(status int) {
if w.wroteHeader {
return
}
w.status = status
w.wroteHeader = true
}
func (w *bufferedResponseWriter) Write(data []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.body.Write(data)
}
func (w *bufferedResponseWriter) Flush() {
}
func (w *passThroughStatusWriter) WriteHeader(status int) {
if w.wroteHeader {
return
}
w.status = status
w.wroteHeader = true
w.ResponseWriter.WriteHeader(status)
}
func (w *passThroughStatusWriter) Write(data []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(data)
}
func (w *passThroughStatusWriter) Flush() {
var flusher http.Flusher
var ok bool
flusher, ok = w.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
func (w *passThroughStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
var hijacker http.Hijacker
var ok bool
hijacker, ok = w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("response writer does not support hijacking")
}
return hijacker.Hijack()
}
func (w *passThroughStatusWriter) Push(target string, opts *http.PushOptions) error {
var pusher http.Pusher
var ok bool
pusher, ok = w.ResponseWriter.(http.Pusher)
if !ok {
return http.ErrNotSupported
}
return pusher.Push(target, opts)
}
func (w *passThroughStatusWriter) ReadFrom(r io.Reader) (int64, error) {
var readerFrom io.ReaderFrom
var ok bool
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
readerFrom, ok = w.ResponseWriter.(io.ReaderFrom)
if ok {
return readerFrom.ReadFrom(r)
}
return io.Copy(w.ResponseWriter, r)
}
func (w *bufferedResponseWriter) FlushTo(dst http.ResponseWriter) {
var key string
var values []string
var i int
for key, values = range w.header {
for i = 0; i < len(values); i++ {
dst.Header().Add(key, values[i])
}
}
dst.WriteHeader(w.status)
_, _ = dst.Write(w.body.Bytes())
}