429 lines
9.5 KiB
Go
429 lines
9.5 KiB
Go
package middleware
|
|
|
|
import "bufio"
|
|
import "bytes"
|
|
import "context"
|
|
import "fmt"
|
|
import "io"
|
|
import "net"
|
|
import "net/http"
|
|
import "strings"
|
|
|
|
import "codit/internal/db"
|
|
import "codit/internal/models"
|
|
|
|
const storeKey ctxKey = "store"
|
|
const userHolderKey ctxKey = "user_holder"
|
|
const principalHolderKey ctxKey = "principal_holder"
|
|
const afterCommitKey ctxKey = "after_commit_holder"
|
|
const authReasonHolderKey ctxKey = "auth_reason_holder"
|
|
|
|
type userHolder struct {
|
|
User models.User
|
|
OK bool
|
|
}
|
|
|
|
type principalHolder struct {
|
|
Principal models.ServicePrincipal
|
|
OK bool
|
|
}
|
|
|
|
type afterCommitHolder struct {
|
|
Fns []func()
|
|
}
|
|
|
|
type authReasonHolder struct {
|
|
Reason string
|
|
OK bool
|
|
}
|
|
|
|
type bufferedResponseWriter struct {
|
|
header http.Header
|
|
body bytes.Buffer
|
|
status int
|
|
wroteHeader bool
|
|
}
|
|
|
|
type passThroughStatusWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
wroteHeader bool
|
|
}
|
|
|
|
func WithRequestStore(store *db.Store, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var ctx context.Context
|
|
var holder *userHolder
|
|
var principal *principalHolder
|
|
var afterHolder *afterCommitHolder
|
|
var authHolder *authReasonHolder
|
|
|
|
ctx = r.Context()
|
|
if _, ok := StoreFromContext(ctx); !ok {
|
|
ctx = context.WithValue(ctx, storeKey, store)
|
|
}
|
|
if _, ok := ctx.Value(userHolderKey).(*userHolder); !ok {
|
|
holder = &userHolder{}
|
|
ctx = context.WithValue(ctx, userHolderKey, holder)
|
|
}
|
|
if _, ok := ctx.Value(principalHolderKey).(*principalHolder); !ok {
|
|
principal = &principalHolder{}
|
|
ctx = context.WithValue(ctx, principalHolderKey, principal)
|
|
}
|
|
if _, ok := ctx.Value(afterCommitKey).(*afterCommitHolder); !ok {
|
|
afterHolder = &afterCommitHolder{}
|
|
ctx = context.WithValue(ctx, afterCommitKey, afterHolder)
|
|
}
|
|
if _, ok := ctx.Value(authReasonHolderKey).(*authReasonHolder); !ok {
|
|
authHolder = &authReasonHolder{}
|
|
ctx = context.WithValue(ctx, authReasonHolderKey, authHolder)
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func WithStore(r *http.Request, store *db.Store) *http.Request {
|
|
var ctx context.Context
|
|
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
ctx = context.WithValue(r.Context(), storeKey, store)
|
|
return r.WithContext(ctx)
|
|
}
|
|
|
|
func StoreFromContext(ctx context.Context) (*db.Store, bool) {
|
|
var store *db.Store
|
|
var ok bool
|
|
|
|
store, ok = ctx.Value(storeKey).(*db.Store)
|
|
return store, ok
|
|
}
|
|
|
|
func WithStoreTransaction(store *db.Store, next http.Handler) http.Handler {
|
|
return withStoreTransaction(store, next, true)
|
|
}
|
|
|
|
func WithStoreTransactionUnbuffered(store *db.Store, next http.Handler) http.Handler {
|
|
return withStoreTransaction(store, next, false)
|
|
}
|
|
|
|
func withStoreTransaction(store *db.Store, next http.Handler, bufferWrites bool) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var txStore *db.Store
|
|
var recorder *bufferedResponseWriter
|
|
var statusWriter *passThroughStatusWriter
|
|
var err error
|
|
|
|
if isWebSocketUpgrade(r) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
txStore, err = store.BeginStore(r.Context())
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
r = WithStore(r, txStore)
|
|
if isReadOnlyMethod(r.Method) {
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
_ = txStore.Rollback()
|
|
panic(rec)
|
|
}
|
|
}()
|
|
next.ServeHTTP(w, r)
|
|
err = txStore.Commit()
|
|
if err != nil {
|
|
_ = txStore.Rollback()
|
|
return
|
|
}
|
|
runAfterCommit(r.Context())
|
|
return
|
|
}
|
|
if !bufferWrites {
|
|
statusWriter = newPassThroughStatusWriter(w)
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
_ = txStore.Rollback()
|
|
panic(rec)
|
|
}
|
|
}()
|
|
next.ServeHTTP(statusWriter, r)
|
|
if statusWriter.status >= 400 {
|
|
_ = txStore.Rollback()
|
|
return
|
|
}
|
|
err = txStore.Commit()
|
|
if err != nil {
|
|
_ = txStore.Rollback()
|
|
if !statusWriter.wroteHeader {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
runAfterCommit(r.Context())
|
|
return
|
|
}
|
|
recorder = newBufferedResponseWriter()
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
_ = txStore.Rollback()
|
|
panic(rec)
|
|
}
|
|
}()
|
|
next.ServeHTTP(recorder, r)
|
|
if recorder.status >= 400 {
|
|
_ = txStore.Rollback()
|
|
recorder.FlushTo(w)
|
|
return
|
|
}
|
|
err = txStore.Commit()
|
|
if err != nil {
|
|
_ = txStore.Rollback()
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
runAfterCommit(r.Context())
|
|
recorder.FlushTo(w)
|
|
})
|
|
}
|
|
|
|
func isWebSocketUpgrade(r *http.Request) bool {
|
|
if r == nil {
|
|
return false
|
|
}
|
|
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") &&
|
|
strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket")
|
|
}
|
|
|
|
func RegisterAfterCommit(ctx context.Context, fn func()) {
|
|
var holder *afterCommitHolder
|
|
var ok bool
|
|
|
|
if ctx == nil || fn == nil {
|
|
return
|
|
}
|
|
holder, ok = ctx.Value(afterCommitKey).(*afterCommitHolder)
|
|
if !ok || holder == nil {
|
|
return
|
|
}
|
|
holder.Fns = append(holder.Fns, fn)
|
|
}
|
|
|
|
func runAfterCommit(ctx context.Context) {
|
|
var holder *afterCommitHolder
|
|
var ok bool
|
|
var fns []func()
|
|
var i int
|
|
|
|
if ctx == nil {
|
|
return
|
|
}
|
|
holder, ok = ctx.Value(afterCommitKey).(*afterCommitHolder)
|
|
if !ok || holder == nil || len(holder.Fns) == 0 {
|
|
return
|
|
}
|
|
fns = append(fns, holder.Fns...)
|
|
holder.Fns = nil
|
|
for i = 0; i < len(fns); i++ {
|
|
if fns[i] != nil {
|
|
fns[i]()
|
|
}
|
|
}
|
|
}
|
|
|
|
func rememberUser(ctx context.Context, user models.User) {
|
|
var holder *userHolder
|
|
var ok bool
|
|
|
|
holder, ok = ctx.Value(userHolderKey).(*userHolder)
|
|
if ok && holder != nil {
|
|
holder.User = user
|
|
holder.OK = true
|
|
}
|
|
}
|
|
|
|
func userFromHolder(ctx context.Context) (models.User, bool) {
|
|
var holder *userHolder
|
|
var ok bool
|
|
|
|
holder, ok = ctx.Value(userHolderKey).(*userHolder)
|
|
if ok && holder != nil && holder.OK {
|
|
return holder.User, true
|
|
}
|
|
return models.User{}, false
|
|
}
|
|
|
|
func rememberPrincipal(ctx context.Context, principal models.ServicePrincipal) {
|
|
var holder *principalHolder
|
|
var ok bool
|
|
|
|
holder, ok = ctx.Value(principalHolderKey).(*principalHolder)
|
|
if ok && holder != nil {
|
|
holder.Principal = principal
|
|
holder.OK = true
|
|
}
|
|
}
|
|
|
|
func principalFromHolder(ctx context.Context) (models.ServicePrincipal, bool) {
|
|
var holder *principalHolder
|
|
var ok bool
|
|
|
|
holder, ok = ctx.Value(principalHolderKey).(*principalHolder)
|
|
if ok && holder != nil && holder.OK {
|
|
return holder.Principal, true
|
|
}
|
|
return models.ServicePrincipal{}, false
|
|
}
|
|
|
|
func RememberAuthReason(ctx context.Context, reason string) {
|
|
var holder *authReasonHolder
|
|
var ok bool
|
|
|
|
if ctx == nil {
|
|
return
|
|
}
|
|
holder, ok = ctx.Value(authReasonHolderKey).(*authReasonHolder)
|
|
if ok && holder != nil {
|
|
holder.Reason = reason
|
|
holder.OK = strings.TrimSpace(reason) != ""
|
|
}
|
|
}
|
|
|
|
func AuthReasonFromContext(ctx context.Context) (string, bool) {
|
|
var holder *authReasonHolder
|
|
var ok bool
|
|
|
|
if ctx == nil {
|
|
return "", false
|
|
}
|
|
holder, ok = ctx.Value(authReasonHolderKey).(*authReasonHolder)
|
|
if ok && holder != nil && holder.OK {
|
|
return holder.Reason, true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func isReadOnlyMethod(method string) bool {
|
|
return method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions
|
|
}
|
|
|
|
func newBufferedResponseWriter() *bufferedResponseWriter {
|
|
var writer *bufferedResponseWriter
|
|
|
|
writer = &bufferedResponseWriter{
|
|
header: make(http.Header),
|
|
status: http.StatusOK,
|
|
}
|
|
return writer
|
|
}
|
|
|
|
func newPassThroughStatusWriter(w http.ResponseWriter) *passThroughStatusWriter {
|
|
var writer *passThroughStatusWriter
|
|
|
|
writer = &passThroughStatusWriter{
|
|
ResponseWriter: w,
|
|
status: http.StatusOK,
|
|
}
|
|
return writer
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) Header() http.Header {
|
|
return w.header
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) WriteHeader(status int) {
|
|
if w.wroteHeader {
|
|
return
|
|
}
|
|
w.status = status
|
|
w.wroteHeader = true
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) Write(data []byte) (int, error) {
|
|
if !w.wroteHeader {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
return w.body.Write(data)
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) Flush() {
|
|
}
|
|
|
|
func (w *passThroughStatusWriter) WriteHeader(status int) {
|
|
if w.wroteHeader {
|
|
return
|
|
}
|
|
w.status = status
|
|
w.wroteHeader = true
|
|
w.ResponseWriter.WriteHeader(status)
|
|
}
|
|
|
|
func (w *passThroughStatusWriter) Write(data []byte) (int, error) {
|
|
if !w.wroteHeader {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
return w.ResponseWriter.Write(data)
|
|
}
|
|
|
|
func (w *passThroughStatusWriter) Flush() {
|
|
var flusher http.Flusher
|
|
var ok bool
|
|
|
|
flusher, ok = w.ResponseWriter.(http.Flusher)
|
|
if ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
func (w *passThroughStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
var hijacker http.Hijacker
|
|
var ok bool
|
|
|
|
hijacker, ok = w.ResponseWriter.(http.Hijacker)
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("response writer does not support hijacking")
|
|
}
|
|
return hijacker.Hijack()
|
|
}
|
|
|
|
func (w *passThroughStatusWriter) Push(target string, opts *http.PushOptions) error {
|
|
var pusher http.Pusher
|
|
var ok bool
|
|
|
|
pusher, ok = w.ResponseWriter.(http.Pusher)
|
|
if !ok {
|
|
return http.ErrNotSupported
|
|
}
|
|
return pusher.Push(target, opts)
|
|
}
|
|
|
|
func (w *passThroughStatusWriter) ReadFrom(r io.Reader) (int64, error) {
|
|
var readerFrom io.ReaderFrom
|
|
var ok bool
|
|
|
|
if !w.wroteHeader {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
readerFrom, ok = w.ResponseWriter.(io.ReaderFrom)
|
|
if ok {
|
|
return readerFrom.ReadFrom(r)
|
|
}
|
|
return io.Copy(w.ResponseWriter, r)
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) FlushTo(dst http.ResponseWriter) {
|
|
var key string
|
|
var values []string
|
|
var i int
|
|
|
|
for key, values = range w.header {
|
|
for i = 0; i < len(values); i++ {
|
|
dst.Header().Add(key, values[i])
|
|
}
|
|
}
|
|
dst.WriteHeader(w.status)
|
|
_, _ = dst.Write(w.body.Bytes())
|
|
}
|