package middleware import "bufio" import "bytes" import "context" import "fmt" import "io" import "net" import "net/http" import "strings" import "codit/internal/db" import "codit/internal/models" const storeKey ctxKey = "store" const userHolderKey ctxKey = "user_holder" const principalHolderKey ctxKey = "principal_holder" const afterCommitKey ctxKey = "after_commit_holder" const authReasonHolderKey ctxKey = "auth_reason_holder" type userHolder struct { User models.User OK bool } type principalHolder struct { Principal models.ServicePrincipal OK bool } type afterCommitHolder struct { Fns []func() } type authReasonHolder struct { Reason string OK bool } type bufferedResponseWriter struct { header http.Header body bytes.Buffer status int wroteHeader bool } type passThroughStatusWriter struct { http.ResponseWriter status int wroteHeader bool } func WithRequestStore(store *db.Store, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx context.Context var holder *userHolder var principal *principalHolder var afterHolder *afterCommitHolder var authHolder *authReasonHolder ctx = r.Context() if _, ok := StoreFromContext(ctx); !ok { ctx = context.WithValue(ctx, storeKey, store) } if _, ok := ctx.Value(userHolderKey).(*userHolder); !ok { holder = &userHolder{} ctx = context.WithValue(ctx, userHolderKey, holder) } if _, ok := ctx.Value(principalHolderKey).(*principalHolder); !ok { principal = &principalHolder{} ctx = context.WithValue(ctx, principalHolderKey, principal) } if _, ok := ctx.Value(afterCommitKey).(*afterCommitHolder); !ok { afterHolder = &afterCommitHolder{} ctx = context.WithValue(ctx, afterCommitKey, afterHolder) } if _, ok := ctx.Value(authReasonHolderKey).(*authReasonHolder); !ok { authHolder = &authReasonHolder{} ctx = context.WithValue(ctx, authReasonHolderKey, authHolder) } next.ServeHTTP(w, r.WithContext(ctx)) }) } func WithStore(r *http.Request, store *db.Store) *http.Request { var ctx context.Context if r == nil { return nil } ctx = context.WithValue(r.Context(), storeKey, store) return r.WithContext(ctx) } func StoreFromContext(ctx context.Context) (*db.Store, bool) { var store *db.Store var ok bool store, ok = ctx.Value(storeKey).(*db.Store) return store, ok } func WithStoreTransaction(store *db.Store, next http.Handler) http.Handler { return withStoreTransaction(store, next, true) } func WithStoreTransactionUnbuffered(store *db.Store, next http.Handler) http.Handler { return withStoreTransaction(store, next, false) } func withStoreTransaction(store *db.Store, next http.Handler, bufferWrites bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var txStore *db.Store var recorder *bufferedResponseWriter var statusWriter *passThroughStatusWriter var err error if isWebSocketUpgrade(r) { next.ServeHTTP(w, r) return } txStore, err = store.BeginStore(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } r = WithStore(r, txStore) if isReadOnlyMethod(r.Method) { defer func() { if rec := recover(); rec != nil { _ = txStore.Rollback() panic(rec) } }() next.ServeHTTP(w, r) err = txStore.Commit() if err != nil { _ = txStore.Rollback() return } runAfterCommit(r.Context()) return } if !bufferWrites { statusWriter = newPassThroughStatusWriter(w) defer func() { if rec := recover(); rec != nil { _ = txStore.Rollback() panic(rec) } }() next.ServeHTTP(statusWriter, r) if statusWriter.status >= 400 { _ = txStore.Rollback() return } err = txStore.Commit() if err != nil { _ = txStore.Rollback() if !statusWriter.wroteHeader { http.Error(w, err.Error(), http.StatusInternalServerError) } return } runAfterCommit(r.Context()) return } recorder = newBufferedResponseWriter() defer func() { if rec := recover(); rec != nil { _ = txStore.Rollback() panic(rec) } }() next.ServeHTTP(recorder, r) if recorder.status >= 400 { _ = txStore.Rollback() recorder.FlushTo(w) return } err = txStore.Commit() if err != nil { _ = txStore.Rollback() http.Error(w, err.Error(), http.StatusInternalServerError) return } runAfterCommit(r.Context()) recorder.FlushTo(w) }) } func isWebSocketUpgrade(r *http.Request) bool { if r == nil { return false } return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") && strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") } func RegisterAfterCommit(ctx context.Context, fn func()) { var holder *afterCommitHolder var ok bool if ctx == nil || fn == nil { return } holder, ok = ctx.Value(afterCommitKey).(*afterCommitHolder) if !ok || holder == nil { return } holder.Fns = append(holder.Fns, fn) } func runAfterCommit(ctx context.Context) { var holder *afterCommitHolder var ok bool var fns []func() var i int if ctx == nil { return } holder, ok = ctx.Value(afterCommitKey).(*afterCommitHolder) if !ok || holder == nil || len(holder.Fns) == 0 { return } fns = append(fns, holder.Fns...) holder.Fns = nil for i = 0; i < len(fns); i++ { if fns[i] != nil { fns[i]() } } } func rememberUser(ctx context.Context, user models.User) { var holder *userHolder var ok bool holder, ok = ctx.Value(userHolderKey).(*userHolder) if ok && holder != nil { holder.User = user holder.OK = true } } func userFromHolder(ctx context.Context) (models.User, bool) { var holder *userHolder var ok bool holder, ok = ctx.Value(userHolderKey).(*userHolder) if ok && holder != nil && holder.OK { return holder.User, true } return models.User{}, false } func rememberPrincipal(ctx context.Context, principal models.ServicePrincipal) { var holder *principalHolder var ok bool holder, ok = ctx.Value(principalHolderKey).(*principalHolder) if ok && holder != nil { holder.Principal = principal holder.OK = true } } func principalFromHolder(ctx context.Context) (models.ServicePrincipal, bool) { var holder *principalHolder var ok bool holder, ok = ctx.Value(principalHolderKey).(*principalHolder) if ok && holder != nil && holder.OK { return holder.Principal, true } return models.ServicePrincipal{}, false } func RememberAuthReason(ctx context.Context, reason string) { var holder *authReasonHolder var ok bool if ctx == nil { return } holder, ok = ctx.Value(authReasonHolderKey).(*authReasonHolder) if ok && holder != nil { holder.Reason = reason holder.OK = strings.TrimSpace(reason) != "" } } func AuthReasonFromContext(ctx context.Context) (string, bool) { var holder *authReasonHolder var ok bool if ctx == nil { return "", false } holder, ok = ctx.Value(authReasonHolderKey).(*authReasonHolder) if ok && holder != nil && holder.OK { return holder.Reason, true } return "", false } func isReadOnlyMethod(method string) bool { return method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions } func newBufferedResponseWriter() *bufferedResponseWriter { var writer *bufferedResponseWriter writer = &bufferedResponseWriter{ header: make(http.Header), status: http.StatusOK, } return writer } func newPassThroughStatusWriter(w http.ResponseWriter) *passThroughStatusWriter { var writer *passThroughStatusWriter writer = &passThroughStatusWriter{ ResponseWriter: w, status: http.StatusOK, } return writer } func (w *bufferedResponseWriter) Header() http.Header { return w.header } func (w *bufferedResponseWriter) WriteHeader(status int) { if w.wroteHeader { return } w.status = status w.wroteHeader = true } func (w *bufferedResponseWriter) Write(data []byte) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } return w.body.Write(data) } func (w *bufferedResponseWriter) Flush() { } func (w *passThroughStatusWriter) WriteHeader(status int) { if w.wroteHeader { return } w.status = status w.wroteHeader = true w.ResponseWriter.WriteHeader(status) } func (w *passThroughStatusWriter) Write(data []byte) (int, error) { if !w.wroteHeader { w.WriteHeader(http.StatusOK) } return w.ResponseWriter.Write(data) } func (w *passThroughStatusWriter) Flush() { var flusher http.Flusher var ok bool flusher, ok = w.ResponseWriter.(http.Flusher) if ok { flusher.Flush() } } func (w *passThroughStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { var hijacker http.Hijacker var ok bool hijacker, ok = w.ResponseWriter.(http.Hijacker) if !ok { return nil, nil, fmt.Errorf("response writer does not support hijacking") } return hijacker.Hijack() } func (w *passThroughStatusWriter) Push(target string, opts *http.PushOptions) error { var pusher http.Pusher var ok bool pusher, ok = w.ResponseWriter.(http.Pusher) if !ok { return http.ErrNotSupported } return pusher.Push(target, opts) } func (w *passThroughStatusWriter) ReadFrom(r io.Reader) (int64, error) { var readerFrom io.ReaderFrom var ok bool if !w.wroteHeader { w.WriteHeader(http.StatusOK) } readerFrom, ok = w.ResponseWriter.(io.ReaderFrom) if ok { return readerFrom.ReadFrom(r) } return io.Copy(w.ResponseWriter, r) } func (w *bufferedResponseWriter) FlushTo(dst http.ResponseWriter) { var key string var values []string var i int for key, values = range w.header { for i = 0; i < len(values); i++ { dst.Header().Add(key, values[i]) } } dst.WriteHeader(w.status) _, _ = dst.Write(w.body.Bytes()) }