435 lines
12 KiB
Go
435 lines
12 KiB
Go
package handlers
|
|
|
|
import "context"
|
|
import "crypto/rand"
|
|
import "crypto/sha1"
|
|
import "crypto/tls"
|
|
import "database/sql"
|
|
import "encoding/base64"
|
|
import "encoding/hex"
|
|
import "encoding/json"
|
|
import "errors"
|
|
import "fmt"
|
|
import "io"
|
|
import "net/http"
|
|
import "net/url"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/config"
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
type oidcTokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
|
|
type oidcUserClaims struct {
|
|
Sub string `json:"sub"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
Name string `json:"name"`
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
type oidcErrorResponse struct {
|
|
Error string `json:"error"`
|
|
ErrorDescription string `json:"error_description"`
|
|
}
|
|
|
|
func (api *API) OIDCEnabled(w http.ResponseWriter, _ *http.Request, _ map[string]string) {
|
|
var settings models.AuthSettings
|
|
var err error
|
|
var configured bool
|
|
settings, err = api.getMergedAuthSettings()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusOK, map[string]any{"enabled": false, "configured": false, "auth_mode": "db"})
|
|
return
|
|
}
|
|
configured = api.oidcConfiguredFromSettings(settings)
|
|
WriteJSON(w, http.StatusOK, map[string]any{
|
|
"enabled": settings.OIDCEnabled,
|
|
"configured": configured,
|
|
"auth_mode": strings.ToLower(strings.TrimSpace(settings.AuthMode)),
|
|
})
|
|
}
|
|
|
|
func (api *API) OIDCLogin(w http.ResponseWriter, r *http.Request, _ map[string]string) {
|
|
var settings models.AuthSettings
|
|
var cfg config.Config
|
|
var state string
|
|
var err error
|
|
var authURL string
|
|
settings, err = api.getMergedAuthSettings()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
|
|
return
|
|
}
|
|
cfg, err = api.effectiveAuthConfig()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
|
|
return
|
|
}
|
|
if !settings.OIDCEnabled {
|
|
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc login is disabled"})
|
|
return
|
|
}
|
|
if !api.oidcConfiguredFromSettings(settings) {
|
|
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc is not configured"})
|
|
return
|
|
}
|
|
state, err = newOIDCState()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create state"})
|
|
return
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "codit_oidc_state",
|
|
Value: state,
|
|
HttpOnly: true,
|
|
Path: "/api/auth/oidc",
|
|
Expires: time.Now().UTC().Add(10 * time.Minute),
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
authURL = api.buildOIDCAuthorizeURL(cfg, state)
|
|
if api.Logger != nil {
|
|
api.Logger.Write("auth", util.LOG_INFO, "oidc login redirect authorize_url=%s", cfg.OIDCAuthorizeURL)
|
|
}
|
|
http.Redirect(w, r, authURL, http.StatusFound)
|
|
}
|
|
|
|
func (api *API) OIDCCallback(w http.ResponseWriter, r *http.Request, _ map[string]string) {
|
|
var settings models.AuthSettings
|
|
var cfg config.Config
|
|
var state string
|
|
var code string
|
|
var cookie *http.Cookie
|
|
var err error
|
|
var token oidcTokenResponse
|
|
var claims oidcUserClaims
|
|
var user models.User
|
|
settings, err = api.getMergedAuthSettings()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
|
|
return
|
|
}
|
|
cfg, err = api.effectiveAuthConfig()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
|
|
return
|
|
}
|
|
if !settings.OIDCEnabled {
|
|
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc login is disabled"})
|
|
return
|
|
}
|
|
if !api.oidcConfiguredFromSettings(settings) {
|
|
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc is not configured"})
|
|
return
|
|
}
|
|
state = strings.TrimSpace(r.URL.Query().Get("state"))
|
|
code = strings.TrimSpace(r.URL.Query().Get("code"))
|
|
cookie, err = r.Cookie("codit_oidc_state")
|
|
clearOIDCStateCookie(w)
|
|
if err != nil || cookie.Value == "" || state == "" || state != cookie.Value {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid oidc state"})
|
|
return
|
|
}
|
|
if code == "" {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "missing authorization code"})
|
|
return
|
|
}
|
|
token, err = api.oidcExchangeCode(r.Context(), cfg, code)
|
|
if err != nil {
|
|
if api.Logger != nil {
|
|
api.Logger.Write("auth", util.LOG_WARN, "oidc token exchange failed err=%v", err)
|
|
}
|
|
WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
claims, err = api.oidcResolveClaims(r.Context(), cfg, token)
|
|
if err != nil {
|
|
if api.Logger != nil {
|
|
api.Logger.Write("auth", util.LOG_WARN, "oidc claims fetch failed err=%v", err)
|
|
}
|
|
WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "oidc claims fetch failed"})
|
|
return
|
|
}
|
|
user, err = api.oidcGetOrCreateUser(claims)
|
|
if err != nil {
|
|
if api.Logger != nil {
|
|
api.Logger.Write("auth", util.LOG_WARN, "oidc user mapping failed err=%v", err)
|
|
}
|
|
WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "oidc user mapping failed"})
|
|
return
|
|
}
|
|
api.issueSession(w, user)
|
|
if api.Logger != nil {
|
|
api.Logger.Write("auth", util.LOG_INFO, "oidc login success username=%s", user.Username)
|
|
}
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
func (api *API) oidcConfiguredFromSettings(settings models.AuthSettings) bool {
|
|
if strings.TrimSpace(settings.OIDCClientID) == "" {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(settings.OIDCClientSecret) == "" {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(settings.OIDCAuthorizeURL) == "" {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(settings.OIDCTokenURL) == "" {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(settings.OIDCRedirectURL) == "" {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (api *API) buildOIDCAuthorizeURL(cfg config.Config, state string) string {
|
|
var values url.Values
|
|
var scopes string
|
|
var endpoint string
|
|
values = url.Values{}
|
|
values.Set("response_type", "code")
|
|
values.Set("client_id", cfg.OIDCClientID)
|
|
values.Set("redirect_uri", cfg.OIDCRedirectURL)
|
|
scopes = strings.TrimSpace(cfg.OIDCScopes)
|
|
if scopes == "" {
|
|
scopes = "openid profile email"
|
|
}
|
|
values.Set("scope", scopes)
|
|
values.Set("state", state)
|
|
endpoint = cfg.OIDCAuthorizeURL
|
|
if strings.Contains(endpoint, "?") {
|
|
return endpoint + "&" + values.Encode()
|
|
}
|
|
return endpoint + "?" + values.Encode()
|
|
}
|
|
|
|
func (api *API) oidcHTTPClient(cfg config.Config) *http.Client {
|
|
var transport *http.Transport
|
|
transport = &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.OIDCTLSInsecureSkipVerify},
|
|
}
|
|
return &http.Client{Transport: transport, Timeout: 15 * time.Second}
|
|
}
|
|
|
|
func (api *API) oidcExchangeCode(ctx context.Context, cfg config.Config, code string) (oidcTokenResponse, error) {
|
|
var client *http.Client
|
|
var form url.Values
|
|
var req *http.Request
|
|
var res *http.Response
|
|
var body []byte
|
|
var token oidcTokenResponse
|
|
var err error
|
|
client = api.oidcHTTPClient(cfg)
|
|
form = url.Values{}
|
|
form.Set("grant_type", "authorization_code")
|
|
form.Set("code", code)
|
|
form.Set("redirect_uri", cfg.OIDCRedirectURL)
|
|
form.Set("client_id", cfg.OIDCClientID)
|
|
form.Set("client_secret", cfg.OIDCClientSecret)
|
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, cfg.OIDCTokenURL, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return token, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
res, err = client.Do(req)
|
|
if err != nil {
|
|
return token, err
|
|
}
|
|
defer res.Body.Close()
|
|
body, err = io.ReadAll(io.LimitReader(res.Body, 1024*1024))
|
|
if err != nil {
|
|
return token, err
|
|
}
|
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
|
return token, fmt.Errorf("oidc token exchange failed: status=%d detail=%s", res.StatusCode, oidcErrorDetail(body))
|
|
}
|
|
err = json.Unmarshal(body, &token)
|
|
if err != nil {
|
|
return token, err
|
|
}
|
|
if strings.TrimSpace(token.AccessToken) == "" && strings.TrimSpace(token.IDToken) == "" {
|
|
return token, errors.New("missing oidc token")
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
func oidcErrorDetail(body []byte) string {
|
|
var payload oidcErrorResponse
|
|
var err error
|
|
var text string
|
|
err = json.Unmarshal(body, &payload)
|
|
if err == nil {
|
|
if strings.TrimSpace(payload.ErrorDescription) != "" {
|
|
return payload.ErrorDescription
|
|
}
|
|
if strings.TrimSpace(payload.Error) != "" {
|
|
return payload.Error
|
|
}
|
|
}
|
|
text = strings.TrimSpace(string(body))
|
|
if len(text) > 240 {
|
|
text = text[:240]
|
|
}
|
|
if text == "" {
|
|
text = "empty response body"
|
|
}
|
|
return text
|
|
}
|
|
|
|
func (api *API) oidcResolveClaims(ctx context.Context, cfg config.Config, token oidcTokenResponse) (oidcUserClaims, error) {
|
|
var claims oidcUserClaims
|
|
var err error
|
|
if strings.TrimSpace(cfg.OIDCUserInfoURL) != "" && strings.TrimSpace(token.AccessToken) != "" {
|
|
claims, err = api.oidcUserInfo(ctx, cfg, token.AccessToken)
|
|
if err == nil {
|
|
return claims, nil
|
|
}
|
|
}
|
|
if strings.TrimSpace(token.IDToken) == "" {
|
|
return claims, errors.New("missing id token and userinfo unavailable")
|
|
}
|
|
claims, err = oidcClaimsFromIDToken(token.IDToken)
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func (api *API) oidcUserInfo(ctx context.Context, cfg config.Config, accessToken string) (oidcUserClaims, error) {
|
|
var client *http.Client
|
|
var req *http.Request
|
|
var res *http.Response
|
|
var body []byte
|
|
var claims oidcUserClaims
|
|
var err error
|
|
client = api.oidcHTTPClient(cfg)
|
|
req, err = http.NewRequestWithContext(ctx, http.MethodGet, cfg.OIDCUserInfoURL, nil)
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
res, err = client.Do(req)
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
defer res.Body.Close()
|
|
body, err = io.ReadAll(io.LimitReader(res.Body, 1024*1024))
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
|
return claims, fmt.Errorf("userinfo status %d", res.StatusCode)
|
|
}
|
|
err = json.Unmarshal(body, &claims)
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
if strings.TrimSpace(claims.Sub) == "" {
|
|
return claims, errors.New("userinfo missing sub")
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func (api *API) oidcGetOrCreateUser(claims oidcUserClaims) (models.User, error) {
|
|
var username string
|
|
var displayName string
|
|
var email string
|
|
var user models.User
|
|
var hash string
|
|
var err error
|
|
var created models.User
|
|
username = oidcUsernameFromSub(claims.Sub)
|
|
user, hash, err = api.Store.GetUserByUsername(username)
|
|
_ = hash
|
|
if err == nil {
|
|
if user.Disabled {
|
|
return user, errors.New("user disabled")
|
|
}
|
|
return user, nil
|
|
}
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
return user, err
|
|
}
|
|
displayName = strings.TrimSpace(claims.Name)
|
|
if displayName == "" {
|
|
displayName = strings.TrimSpace(claims.PreferredUsername)
|
|
}
|
|
if displayName == "" {
|
|
displayName = username
|
|
}
|
|
email = strings.TrimSpace(claims.Email)
|
|
if email == "" {
|
|
email = username + "@oidc.local"
|
|
}
|
|
user = models.User{
|
|
Username: username,
|
|
DisplayName: displayName,
|
|
Email: email,
|
|
AuthSource: "oidc",
|
|
}
|
|
created, err = api.Store.CreateUser(user, "")
|
|
if err != nil {
|
|
return created, err
|
|
}
|
|
return created, nil
|
|
}
|
|
|
|
func oidcClaimsFromIDToken(idToken string) (oidcUserClaims, error) {
|
|
var claims oidcUserClaims
|
|
var parts []string
|
|
var payload []byte
|
|
var err error
|
|
parts = strings.Split(idToken, ".")
|
|
if len(parts) < 2 {
|
|
return claims, errors.New("invalid id token")
|
|
}
|
|
payload, err = base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
err = json.Unmarshal(payload, &claims)
|
|
if err != nil {
|
|
return claims, err
|
|
}
|
|
if strings.TrimSpace(claims.Sub) == "" {
|
|
return claims, errors.New("id token missing sub")
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func oidcUsernameFromSub(sub string) string {
|
|
var sum [20]byte
|
|
sum = sha1.Sum([]byte(strings.TrimSpace(sub)))
|
|
return "oidc-" + hex.EncodeToString(sum[:6])
|
|
}
|
|
|
|
func newOIDCState() (string, error) {
|
|
var buf []byte
|
|
var err error
|
|
buf = make([]byte, 24)
|
|
_, err = rand.Read(buf)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buf), nil
|
|
}
|
|
|
|
func clearOIDCStateCookie(w http.ResponseWriter) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "codit_oidc_state",
|
|
Value: "",
|
|
Path: "/api/auth/oidc",
|
|
Expires: time.Unix(0, 0),
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
}
|