package handlers import "context" import "crypto/rand" import "crypto/sha1" import "crypto/tls" import "database/sql" import "encoding/base64" import "encoding/hex" import "encoding/json" import "errors" import "fmt" import "io" import "net/http" import "net/url" import "strings" import "time" import "codit/internal/config" import "codit/internal/models" import "codit/internal/util" type oidcTokenResponse struct { AccessToken string `json:"access_token"` IDToken string `json:"id_token"` } type oidcUserClaims struct { Sub string `json:"sub"` PreferredUsername string `json:"preferred_username"` Name string `json:"name"` Email string `json:"email"` } type oidcErrorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` } func (api *API) OIDCEnabled(w http.ResponseWriter, _ *http.Request, _ map[string]string) { var settings models.AuthSettings var err error var configured bool settings, err = api.getMergedAuthSettings() if err != nil { WriteJSON(w, http.StatusOK, map[string]any{"enabled": false, "configured": false, "auth_mode": "db"}) return } configured = api.oidcConfiguredFromSettings(settings) WriteJSON(w, http.StatusOK, map[string]any{ "enabled": settings.OIDCEnabled, "configured": configured, "auth_mode": strings.ToLower(strings.TrimSpace(settings.AuthMode)), }) } func (api *API) OIDCLogin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var settings models.AuthSettings var cfg config.Config var state string var err error var authURL string settings, err = api.getMergedAuthSettings() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"}) return } cfg, err = api.effectiveAuthConfig() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"}) return } if !settings.OIDCEnabled { WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc login is disabled"}) return } if !api.oidcConfiguredFromSettings(settings) { WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc is not configured"}) return } state, err = newOIDCState() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create state"}) return } http.SetCookie(w, &http.Cookie{ Name: "codit_oidc_state", Value: state, HttpOnly: true, Path: "/api/auth/oidc", Expires: time.Now().UTC().Add(10 * time.Minute), SameSite: http.SameSiteLaxMode, }) authURL = api.buildOIDCAuthorizeURL(cfg, state) if api.Logger != nil { api.Logger.Write("auth", util.LOG_INFO, "oidc login redirect authorize_url=%s", cfg.OIDCAuthorizeURL) } http.Redirect(w, r, authURL, http.StatusFound) } func (api *API) OIDCCallback(w http.ResponseWriter, r *http.Request, _ map[string]string) { var settings models.AuthSettings var cfg config.Config var state string var code string var cookie *http.Cookie var err error var token oidcTokenResponse var claims oidcUserClaims var user models.User settings, err = api.getMergedAuthSettings() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"}) return } cfg, err = api.effectiveAuthConfig() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"}) return } if !settings.OIDCEnabled { WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc login is disabled"}) return } if !api.oidcConfiguredFromSettings(settings) { WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc is not configured"}) return } state = strings.TrimSpace(r.URL.Query().Get("state")) code = strings.TrimSpace(r.URL.Query().Get("code")) cookie, err = r.Cookie("codit_oidc_state") clearOIDCStateCookie(w) if err != nil || cookie.Value == "" || state == "" || state != cookie.Value { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid oidc state"}) return } if code == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "missing authorization code"}) return } token, err = api.oidcExchangeCode(r.Context(), cfg, code) if err != nil { if api.Logger != nil { api.Logger.Write("auth", util.LOG_WARN, "oidc token exchange failed err=%v", err) } WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": err.Error()}) return } claims, err = api.oidcResolveClaims(r.Context(), cfg, token) if err != nil { if api.Logger != nil { api.Logger.Write("auth", util.LOG_WARN, "oidc claims fetch failed err=%v", err) } WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "oidc claims fetch failed"}) return } user, err = api.oidcGetOrCreateUser(claims) if err != nil { if api.Logger != nil { api.Logger.Write("auth", util.LOG_WARN, "oidc user mapping failed err=%v", err) } WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "oidc user mapping failed"}) return } api.issueSession(w, user) if api.Logger != nil { api.Logger.Write("auth", util.LOG_INFO, "oidc login success username=%s", user.Username) } http.Redirect(w, r, "/", http.StatusFound) } func (api *API) oidcConfiguredFromSettings(settings models.AuthSettings) bool { if strings.TrimSpace(settings.OIDCClientID) == "" { return false } if strings.TrimSpace(settings.OIDCClientSecret) == "" { return false } if strings.TrimSpace(settings.OIDCAuthorizeURL) == "" { return false } if strings.TrimSpace(settings.OIDCTokenURL) == "" { return false } if strings.TrimSpace(settings.OIDCRedirectURL) == "" { return false } return true } func (api *API) buildOIDCAuthorizeURL(cfg config.Config, state string) string { var values url.Values var scopes string var endpoint string values = url.Values{} values.Set("response_type", "code") values.Set("client_id", cfg.OIDCClientID) values.Set("redirect_uri", cfg.OIDCRedirectURL) scopes = strings.TrimSpace(cfg.OIDCScopes) if scopes == "" { scopes = "openid profile email" } values.Set("scope", scopes) values.Set("state", state) endpoint = cfg.OIDCAuthorizeURL if strings.Contains(endpoint, "?") { return endpoint + "&" + values.Encode() } return endpoint + "?" + values.Encode() } func (api *API) oidcHTTPClient(cfg config.Config) *http.Client { var transport *http.Transport transport = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.OIDCTLSInsecureSkipVerify}, } return &http.Client{Transport: transport, Timeout: 15 * time.Second} } func (api *API) oidcExchangeCode(ctx context.Context, cfg config.Config, code string) (oidcTokenResponse, error) { var client *http.Client var form url.Values var req *http.Request var res *http.Response var body []byte var token oidcTokenResponse var err error client = api.oidcHTTPClient(cfg) form = url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", code) form.Set("redirect_uri", cfg.OIDCRedirectURL) form.Set("client_id", cfg.OIDCClientID) form.Set("client_secret", cfg.OIDCClientSecret) req, err = http.NewRequestWithContext(ctx, http.MethodPost, cfg.OIDCTokenURL, strings.NewReader(form.Encode())) if err != nil { return token, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") res, err = client.Do(req) if err != nil { return token, err } defer res.Body.Close() body, err = io.ReadAll(io.LimitReader(res.Body, 1024*1024)) if err != nil { return token, err } if res.StatusCode < 200 || res.StatusCode >= 300 { return token, fmt.Errorf("oidc token exchange failed: status=%d detail=%s", res.StatusCode, oidcErrorDetail(body)) } err = json.Unmarshal(body, &token) if err != nil { return token, err } if strings.TrimSpace(token.AccessToken) == "" && strings.TrimSpace(token.IDToken) == "" { return token, errors.New("missing oidc token") } return token, nil } func oidcErrorDetail(body []byte) string { var payload oidcErrorResponse var err error var text string err = json.Unmarshal(body, &payload) if err == nil { if strings.TrimSpace(payload.ErrorDescription) != "" { return payload.ErrorDescription } if strings.TrimSpace(payload.Error) != "" { return payload.Error } } text = strings.TrimSpace(string(body)) if len(text) > 240 { text = text[:240] } if text == "" { text = "empty response body" } return text } func (api *API) oidcResolveClaims(ctx context.Context, cfg config.Config, token oidcTokenResponse) (oidcUserClaims, error) { var claims oidcUserClaims var err error if strings.TrimSpace(cfg.OIDCUserInfoURL) != "" && strings.TrimSpace(token.AccessToken) != "" { claims, err = api.oidcUserInfo(ctx, cfg, token.AccessToken) if err == nil { return claims, nil } } if strings.TrimSpace(token.IDToken) == "" { return claims, errors.New("missing id token and userinfo unavailable") } claims, err = oidcClaimsFromIDToken(token.IDToken) if err != nil { return claims, err } return claims, nil } func (api *API) oidcUserInfo(ctx context.Context, cfg config.Config, accessToken string) (oidcUserClaims, error) { var client *http.Client var req *http.Request var res *http.Response var body []byte var claims oidcUserClaims var err error client = api.oidcHTTPClient(cfg) req, err = http.NewRequestWithContext(ctx, http.MethodGet, cfg.OIDCUserInfoURL, nil) if err != nil { return claims, err } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Accept", "application/json") res, err = client.Do(req) if err != nil { return claims, err } defer res.Body.Close() body, err = io.ReadAll(io.LimitReader(res.Body, 1024*1024)) if err != nil { return claims, err } if res.StatusCode < 200 || res.StatusCode >= 300 { return claims, fmt.Errorf("userinfo status %d", res.StatusCode) } err = json.Unmarshal(body, &claims) if err != nil { return claims, err } if strings.TrimSpace(claims.Sub) == "" { return claims, errors.New("userinfo missing sub") } return claims, nil } func (api *API) oidcGetOrCreateUser(claims oidcUserClaims) (models.User, error) { var username string var displayName string var email string var user models.User var hash string var err error var created models.User username = oidcUsernameFromSub(claims.Sub) user, hash, err = api.Store.GetUserByUsername(username) _ = hash if err == nil { if user.Disabled { return user, errors.New("user disabled") } return user, nil } if !errors.Is(err, sql.ErrNoRows) { return user, err } displayName = strings.TrimSpace(claims.Name) if displayName == "" { displayName = strings.TrimSpace(claims.PreferredUsername) } if displayName == "" { displayName = username } email = strings.TrimSpace(claims.Email) if email == "" { email = username + "@oidc.local" } user = models.User{ Username: username, DisplayName: displayName, Email: email, AuthSource: "oidc", } created, err = api.Store.CreateUser(user, "") if err != nil { return created, err } return created, nil } func oidcClaimsFromIDToken(idToken string) (oidcUserClaims, error) { var claims oidcUserClaims var parts []string var payload []byte var err error parts = strings.Split(idToken, ".") if len(parts) < 2 { return claims, errors.New("invalid id token") } payload, err = base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return claims, err } err = json.Unmarshal(payload, &claims) if err != nil { return claims, err } if strings.TrimSpace(claims.Sub) == "" { return claims, errors.New("id token missing sub") } return claims, nil } func oidcUsernameFromSub(sub string) string { var sum [20]byte sum = sha1.Sum([]byte(strings.TrimSpace(sub))) return "oidc-" + hex.EncodeToString(sum[:6]) } func newOIDCState() (string, error) { var buf []byte var err error buf = make([]byte, 24) _, err = rand.Read(buf) if err != nil { return "", err } return hex.EncodeToString(buf), nil } func clearOIDCStateCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{ Name: "codit_oidc_state", Value: "", Path: "/api/auth/oidc", Expires: time.Unix(0, 0), HttpOnly: true, SameSite: http.SameSiteLaxMode, }) }