Files
2026-02-15 02:03:31 +09:00

435 lines
12 KiB
Go

package handlers
import "context"
import "crypto/rand"
import "crypto/sha1"
import "crypto/tls"
import "database/sql"
import "encoding/base64"
import "encoding/hex"
import "encoding/json"
import "errors"
import "fmt"
import "io"
import "net/http"
import "net/url"
import "strings"
import "time"
import "codit/internal/config"
import "codit/internal/models"
import "codit/internal/util"
type oidcTokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
}
type oidcUserClaims struct {
Sub string `json:"sub"`
PreferredUsername string `json:"preferred_username"`
Name string `json:"name"`
Email string `json:"email"`
}
type oidcErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
func (api *API) OIDCEnabled(w http.ResponseWriter, _ *http.Request, _ map[string]string) {
var settings models.AuthSettings
var err error
var configured bool
settings, err = api.getMergedAuthSettings()
if err != nil {
WriteJSON(w, http.StatusOK, map[string]any{"enabled": false, "configured": false, "auth_mode": "db"})
return
}
configured = api.oidcConfiguredFromSettings(settings)
WriteJSON(w, http.StatusOK, map[string]any{
"enabled": settings.OIDCEnabled,
"configured": configured,
"auth_mode": strings.ToLower(strings.TrimSpace(settings.AuthMode)),
})
}
func (api *API) OIDCLogin(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var settings models.AuthSettings
var cfg config.Config
var state string
var err error
var authURL string
settings, err = api.getMergedAuthSettings()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
return
}
cfg, err = api.effectiveAuthConfig()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
return
}
if !settings.OIDCEnabled {
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc login is disabled"})
return
}
if !api.oidcConfiguredFromSettings(settings) {
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc is not configured"})
return
}
state, err = newOIDCState()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create state"})
return
}
http.SetCookie(w, &http.Cookie{
Name: "codit_oidc_state",
Value: state,
HttpOnly: true,
Path: "/api/auth/oidc",
Expires: time.Now().UTC().Add(10 * time.Minute),
SameSite: http.SameSiteLaxMode,
})
authURL = api.buildOIDCAuthorizeURL(cfg, state)
if api.Logger != nil {
api.Logger.Write("auth", util.LOG_INFO, "oidc login redirect authorize_url=%s", cfg.OIDCAuthorizeURL)
}
http.Redirect(w, r, authURL, http.StatusFound)
}
func (api *API) OIDCCallback(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var settings models.AuthSettings
var cfg config.Config
var state string
var code string
var cookie *http.Cookie
var err error
var token oidcTokenResponse
var claims oidcUserClaims
var user models.User
settings, err = api.getMergedAuthSettings()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
return
}
cfg, err = api.effectiveAuthConfig()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to load auth settings"})
return
}
if !settings.OIDCEnabled {
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc login is disabled"})
return
}
if !api.oidcConfiguredFromSettings(settings) {
WriteJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "oidc is not configured"})
return
}
state = strings.TrimSpace(r.URL.Query().Get("state"))
code = strings.TrimSpace(r.URL.Query().Get("code"))
cookie, err = r.Cookie("codit_oidc_state")
clearOIDCStateCookie(w)
if err != nil || cookie.Value == "" || state == "" || state != cookie.Value {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid oidc state"})
return
}
if code == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "missing authorization code"})
return
}
token, err = api.oidcExchangeCode(r.Context(), cfg, code)
if err != nil {
if api.Logger != nil {
api.Logger.Write("auth", util.LOG_WARN, "oidc token exchange failed err=%v", err)
}
WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": err.Error()})
return
}
claims, err = api.oidcResolveClaims(r.Context(), cfg, token)
if err != nil {
if api.Logger != nil {
api.Logger.Write("auth", util.LOG_WARN, "oidc claims fetch failed err=%v", err)
}
WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "oidc claims fetch failed"})
return
}
user, err = api.oidcGetOrCreateUser(claims)
if err != nil {
if api.Logger != nil {
api.Logger.Write("auth", util.LOG_WARN, "oidc user mapping failed err=%v", err)
}
WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "oidc user mapping failed"})
return
}
api.issueSession(w, user)
if api.Logger != nil {
api.Logger.Write("auth", util.LOG_INFO, "oidc login success username=%s", user.Username)
}
http.Redirect(w, r, "/", http.StatusFound)
}
func (api *API) oidcConfiguredFromSettings(settings models.AuthSettings) bool {
if strings.TrimSpace(settings.OIDCClientID) == "" {
return false
}
if strings.TrimSpace(settings.OIDCClientSecret) == "" {
return false
}
if strings.TrimSpace(settings.OIDCAuthorizeURL) == "" {
return false
}
if strings.TrimSpace(settings.OIDCTokenURL) == "" {
return false
}
if strings.TrimSpace(settings.OIDCRedirectURL) == "" {
return false
}
return true
}
func (api *API) buildOIDCAuthorizeURL(cfg config.Config, state string) string {
var values url.Values
var scopes string
var endpoint string
values = url.Values{}
values.Set("response_type", "code")
values.Set("client_id", cfg.OIDCClientID)
values.Set("redirect_uri", cfg.OIDCRedirectURL)
scopes = strings.TrimSpace(cfg.OIDCScopes)
if scopes == "" {
scopes = "openid profile email"
}
values.Set("scope", scopes)
values.Set("state", state)
endpoint = cfg.OIDCAuthorizeURL
if strings.Contains(endpoint, "?") {
return endpoint + "&" + values.Encode()
}
return endpoint + "?" + values.Encode()
}
func (api *API) oidcHTTPClient(cfg config.Config) *http.Client {
var transport *http.Transport
transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.OIDCTLSInsecureSkipVerify},
}
return &http.Client{Transport: transport, Timeout: 15 * time.Second}
}
func (api *API) oidcExchangeCode(ctx context.Context, cfg config.Config, code string) (oidcTokenResponse, error) {
var client *http.Client
var form url.Values
var req *http.Request
var res *http.Response
var body []byte
var token oidcTokenResponse
var err error
client = api.oidcHTTPClient(cfg)
form = url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("redirect_uri", cfg.OIDCRedirectURL)
form.Set("client_id", cfg.OIDCClientID)
form.Set("client_secret", cfg.OIDCClientSecret)
req, err = http.NewRequestWithContext(ctx, http.MethodPost, cfg.OIDCTokenURL, strings.NewReader(form.Encode()))
if err != nil {
return token, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
res, err = client.Do(req)
if err != nil {
return token, err
}
defer res.Body.Close()
body, err = io.ReadAll(io.LimitReader(res.Body, 1024*1024))
if err != nil {
return token, err
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
return token, fmt.Errorf("oidc token exchange failed: status=%d detail=%s", res.StatusCode, oidcErrorDetail(body))
}
err = json.Unmarshal(body, &token)
if err != nil {
return token, err
}
if strings.TrimSpace(token.AccessToken) == "" && strings.TrimSpace(token.IDToken) == "" {
return token, errors.New("missing oidc token")
}
return token, nil
}
func oidcErrorDetail(body []byte) string {
var payload oidcErrorResponse
var err error
var text string
err = json.Unmarshal(body, &payload)
if err == nil {
if strings.TrimSpace(payload.ErrorDescription) != "" {
return payload.ErrorDescription
}
if strings.TrimSpace(payload.Error) != "" {
return payload.Error
}
}
text = strings.TrimSpace(string(body))
if len(text) > 240 {
text = text[:240]
}
if text == "" {
text = "empty response body"
}
return text
}
func (api *API) oidcResolveClaims(ctx context.Context, cfg config.Config, token oidcTokenResponse) (oidcUserClaims, error) {
var claims oidcUserClaims
var err error
if strings.TrimSpace(cfg.OIDCUserInfoURL) != "" && strings.TrimSpace(token.AccessToken) != "" {
claims, err = api.oidcUserInfo(ctx, cfg, token.AccessToken)
if err == nil {
return claims, nil
}
}
if strings.TrimSpace(token.IDToken) == "" {
return claims, errors.New("missing id token and userinfo unavailable")
}
claims, err = oidcClaimsFromIDToken(token.IDToken)
if err != nil {
return claims, err
}
return claims, nil
}
func (api *API) oidcUserInfo(ctx context.Context, cfg config.Config, accessToken string) (oidcUserClaims, error) {
var client *http.Client
var req *http.Request
var res *http.Response
var body []byte
var claims oidcUserClaims
var err error
client = api.oidcHTTPClient(cfg)
req, err = http.NewRequestWithContext(ctx, http.MethodGet, cfg.OIDCUserInfoURL, nil)
if err != nil {
return claims, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
res, err = client.Do(req)
if err != nil {
return claims, err
}
defer res.Body.Close()
body, err = io.ReadAll(io.LimitReader(res.Body, 1024*1024))
if err != nil {
return claims, err
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
return claims, fmt.Errorf("userinfo status %d", res.StatusCode)
}
err = json.Unmarshal(body, &claims)
if err != nil {
return claims, err
}
if strings.TrimSpace(claims.Sub) == "" {
return claims, errors.New("userinfo missing sub")
}
return claims, nil
}
func (api *API) oidcGetOrCreateUser(claims oidcUserClaims) (models.User, error) {
var username string
var displayName string
var email string
var user models.User
var hash string
var err error
var created models.User
username = oidcUsernameFromSub(claims.Sub)
user, hash, err = api.Store.GetUserByUsername(username)
_ = hash
if err == nil {
if user.Disabled {
return user, errors.New("user disabled")
}
return user, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return user, err
}
displayName = strings.TrimSpace(claims.Name)
if displayName == "" {
displayName = strings.TrimSpace(claims.PreferredUsername)
}
if displayName == "" {
displayName = username
}
email = strings.TrimSpace(claims.Email)
if email == "" {
email = username + "@oidc.local"
}
user = models.User{
Username: username,
DisplayName: displayName,
Email: email,
AuthSource: "oidc",
}
created, err = api.Store.CreateUser(user, "")
if err != nil {
return created, err
}
return created, nil
}
func oidcClaimsFromIDToken(idToken string) (oidcUserClaims, error) {
var claims oidcUserClaims
var parts []string
var payload []byte
var err error
parts = strings.Split(idToken, ".")
if len(parts) < 2 {
return claims, errors.New("invalid id token")
}
payload, err = base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return claims, err
}
err = json.Unmarshal(payload, &claims)
if err != nil {
return claims, err
}
if strings.TrimSpace(claims.Sub) == "" {
return claims, errors.New("id token missing sub")
}
return claims, nil
}
func oidcUsernameFromSub(sub string) string {
var sum [20]byte
sum = sha1.Sum([]byte(strings.TrimSpace(sub)))
return "oidc-" + hex.EncodeToString(sum[:6])
}
func newOIDCState() (string, error) {
var buf []byte
var err error
buf = make([]byte, 24)
_, err = rand.Read(buf)
if err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func clearOIDCStateCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: "codit_oidc_state",
Value: "",
Path: "/api/auth/oidc",
Expires: time.Unix(0, 0),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}