131 lines
3.0 KiB
Go
131 lines
3.0 KiB
Go
package middleware
|
|
|
|
import "context"
|
|
import "net/http"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/db"
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
type ctxKey string
|
|
|
|
const userKey ctxKey = "user"
|
|
const principalKey ctxKey = "principal"
|
|
|
|
func WithUser(store *db.Store, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var cookie *http.Cookie
|
|
var err error
|
|
var user models.User
|
|
var expires time.Time
|
|
var ctx context.Context
|
|
var token string
|
|
var hash string
|
|
cookie, err = r.Cookie("codit_session")
|
|
if err != nil || cookie.Value == "" {
|
|
token = apiKeyFromRequest(r)
|
|
if token == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
hash = util.HashToken(token)
|
|
user, err = store.GetUserByAPIKeyHash(hash)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
ctx = context.WithValue(r.Context(), userKey, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
user, expires, err = store.GetSessionUser(cookie.Value)
|
|
if err != nil || time.Now().UTC().After(expires) {
|
|
token = apiKeyFromRequest(r)
|
|
if token == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
hash = util.HashToken(token)
|
|
user, err = store.GetUserByAPIKeyHash(hash)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
ctx = context.WithValue(r.Context(), userKey, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
ctx = context.WithValue(r.Context(), userKey, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func UserFromContext(ctx context.Context) (models.User, bool) {
|
|
var user models.User
|
|
var ok bool
|
|
user, ok = ctx.Value(userKey).(models.User)
|
|
return user, ok
|
|
}
|
|
|
|
func WithPrincipal(r *http.Request, principal models.ServicePrincipal) *http.Request {
|
|
var ctx context.Context
|
|
ctx = context.WithValue(r.Context(), principalKey, principal)
|
|
return r.WithContext(ctx)
|
|
}
|
|
|
|
func PrincipalFromContext(ctx context.Context) (models.ServicePrincipal, bool) {
|
|
var principal models.ServicePrincipal
|
|
var ok bool
|
|
principal, ok = ctx.Value(principalKey).(models.ServicePrincipal)
|
|
return principal, ok
|
|
}
|
|
|
|
func RequireAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var ok bool
|
|
_, ok = UserFromContext(r.Context())
|
|
if !ok {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func RequireAdmin(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var user models.User
|
|
var ok bool
|
|
user, ok = UserFromContext(r.Context())
|
|
if !ok || !user.IsAdmin {
|
|
w.WriteHeader(http.StatusForbidden)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func apiKeyFromRequest(r *http.Request) string {
|
|
var token string
|
|
var auth string
|
|
var parts []string
|
|
token = r.Header.Get("X-API-Key")
|
|
if token != "" {
|
|
return token
|
|
}
|
|
auth = r.Header.Get("Authorization")
|
|
if auth == "" {
|
|
return ""
|
|
}
|
|
parts = strings.SplitN(auth, " ", 2)
|
|
if len(parts) != 2 {
|
|
return ""
|
|
}
|
|
if strings.ToLower(parts[0]) != "bearer" {
|
|
return ""
|
|
}
|
|
return parts[1]
|
|
}
|