Files
hyung-hwan 7f3a8b24ff no request-wide transaction for api.RepoRPMUpload and api.RepoRPMRebuildSubdirMetadata
defined some more constants and places them in models/constannts.go
2026-06-20 21:46:58 +09:00

424 lines
13 KiB
Go

package db
import "context"
import "database/sql"
import "errors"
import "fmt"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
const BuiltinDBProviderPublicID = "db"
func scanAuthProvider(row interface {
Scan(dest ...any) error
}, item *models.AuthProvider) error {
var defaultUserGroupID sql.NullString
var err error
err = row.Scan(
&item.ID,
&item.Name,
&item.Type,
&item.Enabled,
&item.LDAPUrl,
&item.LDAPBindDN,
&item.LDAPBindPassword,
&item.LDAPUserBaseDN,
&item.LDAPUserFilter,
&item.LDAPTLSInsecureSkipVerify,
&item.OIDCClientID,
&item.OIDCClientSecret,
&item.OIDCAuthorizeURL,
&item.OIDCTokenURL,
&item.OIDCUserInfoURL,
&item.OIDCRedirectURL,
&item.OIDCScopes,
&item.OIDCTLSInsecureSkipVerify,
&item.OIDCGroupsClaim,
&item.OIDCAdmissionExpr,
&item.OIDCEndSessionURL,
&item.OIDCPostLogoutRedirect,
&item.GroupSyncMode,
&defaultUserGroupID,
&item.UserCount,
&item.CreatedAt,
&item.UpdatedAt,
)
if err != nil {
return err
}
item.DefaultUserGroupID = defaultUserGroupID.String
return nil
}
const authProviderSelectCols = `
p.public_id, p.name, p.type, p.enabled,
p.ldap_url, p.ldap_bind_dn, p.ldap_bind_password, p.ldap_user_base_dn, p.ldap_user_filter,
p.ldap_tls_insecure_skip_verify,
p.oidc_client_id, p.oidc_client_secret, p.oidc_authorize_url, p.oidc_token_url,
p.oidc_userinfo_url, p.oidc_redirect_url, p.oidc_scopes, p.oidc_tls_insecure_skip_verify,
p.oidc_groups_claim, p.oidc_admission_expr, p.oidc_end_session_url, p.oidc_post_logout_redirect, p.group_sync_mode,
COALESCE(ug.public_id, ''),
(SELECT COUNT(*) FROM users WHERE auth_provider_id = p.id),
p.created_at, p.updated_at`
const authProviderFromClause = `
FROM auth_providers p
LEFT JOIN user_groups ug ON ug.id = p.default_user_group_id`
func normalizeAuthProviderGroupSyncMode(value string) (string, error) {
var mode string
mode = strings.TrimSpace(value)
if mode == "" {
return models.GroupSyncModeOff, nil
}
if mode == models.GroupSyncModeOff || mode == models.GroupSyncModeFirstLogin || mode == models.GroupSyncModeSync {
return mode, nil
}
return "", errors.New("invalid group_sync_mode")
}
func scanAuthProviderGroupMappings(rows *sql.Rows) ([]models.AuthProviderGroupMapping, error) {
var items []models.AuthProviderGroupMapping
var item models.AuthProviderGroupMapping
var err error
for rows.Next() {
err = rows.Scan(&item.ClaimValue, &item.GroupID, &item.GroupName, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) ListAuthProviderGroupMappings(providerID string) ([]models.AuthProviderGroupMapping, error) {
var rows *sql.Rows
var err error
rows, err = s.Query(`
SELECT m.claim_value, g.public_id, g.name, m.created_at, m.updated_at
FROM auth_provider_group_mappings m
JOIN auth_providers p ON p.id = m.auth_provider_id
JOIN user_groups g ON g.id = m.group_id
WHERE p.public_id = ?
ORDER BY m.claim_value, g.name`, strings.TrimSpace(providerID))
if err != nil {
return nil, err
}
defer rows.Close()
return scanAuthProviderGroupMappings(rows)
}
func (s *Store) fillAuthProviderGroupMappings(items []models.AuthProvider) ([]models.AuthProvider, error) {
var err error
var i int
for i = 0; i < len(items); i++ {
items[i].GroupMappings, err = s.ListAuthProviderGroupMappings(items[i].ID)
if err != nil {
return nil, err
}
}
return items, nil
}
func insertAuthProviderGroupMappings(exec sqlExecutor, providerID string, mappings []models.AuthProviderGroupMapping, now int64) error {
var result sql.Result
var err error
var mapping models.AuthProviderGroupMapping
var claimValue string
var groupID string
var affected int64
var i int
for i = 0; i < len(mappings); i++ {
mapping = mappings[i]
claimValue = strings.TrimSpace(mapping.ClaimValue)
groupID = strings.TrimSpace(mapping.GroupID)
if claimValue == "" {
return errors.New("group mapping claim_value is required")
}
if groupID == "" {
return errors.New("group mapping group_id is required")
}
result, err = exec.Exec(`
INSERT INTO auth_provider_group_mappings (auth_provider_id, group_id, claim_value, created_at, updated_at)
SELECT p.id, g.id, ?, ?, ?
FROM auth_providers p
JOIN user_groups g ON g.public_id = ?
WHERE p.public_id = ? AND g.scope = ?`,
claimValue, now, now, groupID, strings.TrimSpace(providerID), models.UserGroupScopeExplicit)
if err != nil {
return err
}
affected, err = result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return fmt.Errorf("explicit user group not found for group mapping: %s", groupID)
}
}
return nil
}
func (s *Store) ListAuthProviders() ([]models.AuthProvider, error) {
var rows *sql.Rows
var items []models.AuthProvider
var item models.AuthProvider
var err error
rows, err = s.Query(`SELECT` + authProviderSelectCols + authProviderFromClause + ` ORDER BY p.name`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = scanAuthProvider(rows, &item)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) ListAuthProvidersWithGroupMappings() ([]models.AuthProvider, error) {
var items []models.AuthProvider
var err error
items, err = s.ListAuthProviders()
if err != nil {
return nil, err
}
return s.fillAuthProviderGroupMappings(items)
}
func (s *Store) ListEnabledAuthProviders() ([]models.AuthProvider, error) {
var rows *sql.Rows
var items []models.AuthProvider
var item models.AuthProvider
var err error
rows, err = s.Query(`SELECT` + authProviderSelectCols + authProviderFromClause + ` WHERE p.enabled = 1 ORDER BY p.created_at`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = scanAuthProvider(rows, &item)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetAuthProvider(id string) (models.AuthProvider, error) {
var row *sql.Row
var item models.AuthProvider
var err error
row = s.QueryRow(`SELECT`+authProviderSelectCols+authProviderFromClause+` WHERE p.public_id = ?`, strings.TrimSpace(id))
err = scanAuthProvider(row, &item)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) GetAuthProviderWithGroupMappings(id string) (models.AuthProvider, error) {
var item models.AuthProvider
var err error
item, err = s.GetAuthProvider(id)
if err != nil {
return item, err
}
item.GroupMappings, err = s.ListAuthProviderGroupMappings(item.ID)
return item, err
}
func (s *Store) CountAuthProviderUsers(id string) (int, error) {
return countAuthProviderUsers(s, id)
}
func countAuthProviderUsers(execer sqlExecutor, id string) (int, error) {
var count int
var err error
err = execer.QueryRow(`SELECT COUNT(*) FROM users WHERE auth_provider_id = (SELECT id FROM auth_providers WHERE public_id = ?)`, strings.TrimSpace(id)).Scan(&count)
return count, err
}
func (s *Store) CreateAuthProvider(ctx context.Context, item models.AuthProvider) (models.AuthProvider, error) {
var tx txExecutor
var err error
var owned bool
var now int64
now = time.Now().Unix()
item.CreatedAt = now
item.UpdatedAt = now
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
if item.Type == models.AuthProviderTypeLDAP && strings.TrimSpace(item.LDAPUserFilter) == "" {
item.LDAPUserFilter = "(uid={username})"
}
if item.Type == models.AuthProviderTypeOIDC && strings.TrimSpace(item.OIDCScopes) == "" {
item.OIDCScopes = "openid profile email"
}
item.GroupSyncMode, err = normalizeAuthProviderGroupSyncMode(item.GroupSyncMode)
if err != nil { return item, err }
tx, owned, err = s.beginContext(ctx)
if err != nil { return item, err }
_, err = tx.Exec(`INSERT INTO auth_providers (
public_id, name, type, enabled,
ldap_url, ldap_bind_dn, ldap_bind_password, ldap_user_base_dn, ldap_user_filter,
ldap_tls_insecure_skip_verify,
oidc_client_id, oidc_client_secret, oidc_authorize_url, oidc_token_url,
oidc_userinfo_url, oidc_redirect_url, oidc_scopes, oidc_tls_insecure_skip_verify,
oidc_groups_claim, oidc_admission_expr, oidc_end_session_url, oidc_post_logout_redirect, group_sync_mode,
default_user_group_id,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM user_groups WHERE public_id = ?) END, ?, ?)`,
item.ID, item.Name, item.Type, item.Enabled,
item.LDAPUrl, item.LDAPBindDN, item.LDAPBindPassword, item.LDAPUserBaseDN, item.LDAPUserFilter,
item.LDAPTLSInsecureSkipVerify,
item.OIDCClientID, item.OIDCClientSecret, item.OIDCAuthorizeURL, item.OIDCTokenURL,
item.OIDCUserInfoURL, item.OIDCRedirectURL, item.OIDCScopes, item.OIDCTLSInsecureSkipVerify,
item.OIDCGroupsClaim, item.OIDCAdmissionExpr, item.OIDCEndSessionURL, item.OIDCPostLogoutRedirect, item.GroupSyncMode,
item.DefaultUserGroupID, item.DefaultUserGroupID,
item.CreatedAt, item.UpdatedAt,
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = insertAuthProviderGroupMappings(tx, item.ID, item.GroupMappings, now)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = commitIfOwned(tx, owned)
if err != nil {
return item, err
}
return s.GetAuthProviderWithGroupMappings(item.ID)
}
func (s *Store) UpdateAuthProvider(ctx context.Context, item models.AuthProvider) (models.AuthProvider, error) {
var tx txExecutor
var err error
var owned bool
var now int64
now = time.Now().Unix()
item.UpdatedAt = now
if item.Type == models.AuthProviderTypeLDAP && strings.TrimSpace(item.LDAPUserFilter) == "" {
item.LDAPUserFilter = "(uid={username})"
}
if item.Type == models.AuthProviderTypeOIDC && strings.TrimSpace(item.OIDCScopes) == "" {
item.OIDCScopes = "openid profile email"
}
item.GroupSyncMode, err = normalizeAuthProviderGroupSyncMode(item.GroupSyncMode)
if err != nil {
return item, err
}
tx, owned, err = s.beginContext(ctx)
if err != nil { return item, err }
_, err = tx.Exec(`UPDATE auth_providers SET
name = ?, enabled = ?,
ldap_url = ?, ldap_bind_dn = ?, ldap_bind_password = ?, ldap_user_base_dn = ?, ldap_user_filter = ?,
ldap_tls_insecure_skip_verify = ?,
oidc_client_id = ?, oidc_client_secret = ?, oidc_authorize_url = ?, oidc_token_url = ?,
oidc_userinfo_url = ?, oidc_redirect_url = ?, oidc_scopes = ?, oidc_tls_insecure_skip_verify = ?,
oidc_groups_claim = ?, oidc_admission_expr = ?, oidc_end_session_url = ?, oidc_post_logout_redirect = ?, group_sync_mode = ?,
default_user_group_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM user_groups WHERE public_id = ?) END,
updated_at = ?
WHERE public_id = ?`,
item.Name, item.Enabled,
item.LDAPUrl, item.LDAPBindDN, item.LDAPBindPassword, item.LDAPUserBaseDN, item.LDAPUserFilter,
item.LDAPTLSInsecureSkipVerify,
item.OIDCClientID, item.OIDCClientSecret, item.OIDCAuthorizeURL, item.OIDCTokenURL,
item.OIDCUserInfoURL, item.OIDCRedirectURL, item.OIDCScopes, item.OIDCTLSInsecureSkipVerify,
item.OIDCGroupsClaim, item.OIDCAdmissionExpr, item.OIDCEndSessionURL, item.OIDCPostLogoutRedirect, item.GroupSyncMode,
item.DefaultUserGroupID, item.DefaultUserGroupID,
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
_, err = tx.Exec(`DELETE FROM auth_provider_group_mappings WHERE auth_provider_id = (SELECT id FROM auth_providers WHERE public_id = ?)`, strings.TrimSpace(item.ID))
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = insertAuthProviderGroupMappings(tx, item.ID, item.GroupMappings, now)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = commitIfOwned(tx, owned)
if err != nil {
return item, err
}
return s.GetAuthProviderWithGroupMappings(item.ID)
}
func (s *Store) DeleteAuthProvider(ctx context.Context, id string, force bool) error {
var count int
var err error
var tx txExecutor
var owned bool
if strings.TrimSpace(id) == BuiltinDBProviderPublicID {
return errors.New("builtin provider cannot be deleted")
}
tx, owned, err = s.beginImmediateContext(ctx)
if err != nil { return err }
if !force {
count, err = countAuthProviderUsers(tx, id)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
if count > 0 {
rollbackIfOwned(tx, owned)
return errors.New("provider has associated users")
}
}
_, err = tx.Exec(`DELETE FROM auth_providers WHERE public_id = ?`, strings.TrimSpace(id))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
return commitIfOwned(tx, owned)
}