7f3a8b24ff
defined some more constants and places them in models/constannts.go
424 lines
13 KiB
Go
424 lines
13 KiB
Go
package db
|
|
|
|
import "context"
|
|
import "database/sql"
|
|
import "errors"
|
|
import "fmt"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
|
|
const BuiltinDBProviderPublicID = "db"
|
|
|
|
|
|
func scanAuthProvider(row interface {
|
|
Scan(dest ...any) error
|
|
}, item *models.AuthProvider) error {
|
|
var defaultUserGroupID sql.NullString
|
|
var err error
|
|
err = row.Scan(
|
|
&item.ID,
|
|
&item.Name,
|
|
&item.Type,
|
|
&item.Enabled,
|
|
&item.LDAPUrl,
|
|
&item.LDAPBindDN,
|
|
&item.LDAPBindPassword,
|
|
&item.LDAPUserBaseDN,
|
|
&item.LDAPUserFilter,
|
|
&item.LDAPTLSInsecureSkipVerify,
|
|
&item.OIDCClientID,
|
|
&item.OIDCClientSecret,
|
|
&item.OIDCAuthorizeURL,
|
|
&item.OIDCTokenURL,
|
|
&item.OIDCUserInfoURL,
|
|
&item.OIDCRedirectURL,
|
|
&item.OIDCScopes,
|
|
&item.OIDCTLSInsecureSkipVerify,
|
|
&item.OIDCGroupsClaim,
|
|
&item.OIDCAdmissionExpr,
|
|
&item.OIDCEndSessionURL,
|
|
&item.OIDCPostLogoutRedirect,
|
|
&item.GroupSyncMode,
|
|
&defaultUserGroupID,
|
|
&item.UserCount,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
item.DefaultUserGroupID = defaultUserGroupID.String
|
|
return nil
|
|
}
|
|
|
|
const authProviderSelectCols = `
|
|
p.public_id, p.name, p.type, p.enabled,
|
|
p.ldap_url, p.ldap_bind_dn, p.ldap_bind_password, p.ldap_user_base_dn, p.ldap_user_filter,
|
|
p.ldap_tls_insecure_skip_verify,
|
|
p.oidc_client_id, p.oidc_client_secret, p.oidc_authorize_url, p.oidc_token_url,
|
|
p.oidc_userinfo_url, p.oidc_redirect_url, p.oidc_scopes, p.oidc_tls_insecure_skip_verify,
|
|
p.oidc_groups_claim, p.oidc_admission_expr, p.oidc_end_session_url, p.oidc_post_logout_redirect, p.group_sync_mode,
|
|
COALESCE(ug.public_id, ''),
|
|
(SELECT COUNT(*) FROM users WHERE auth_provider_id = p.id),
|
|
p.created_at, p.updated_at`
|
|
|
|
const authProviderFromClause = `
|
|
FROM auth_providers p
|
|
LEFT JOIN user_groups ug ON ug.id = p.default_user_group_id`
|
|
|
|
func normalizeAuthProviderGroupSyncMode(value string) (string, error) {
|
|
var mode string
|
|
|
|
mode = strings.TrimSpace(value)
|
|
if mode == "" {
|
|
return models.GroupSyncModeOff, nil
|
|
}
|
|
if mode == models.GroupSyncModeOff || mode == models.GroupSyncModeFirstLogin || mode == models.GroupSyncModeSync {
|
|
return mode, nil
|
|
}
|
|
return "", errors.New("invalid group_sync_mode")
|
|
}
|
|
|
|
func scanAuthProviderGroupMappings(rows *sql.Rows) ([]models.AuthProviderGroupMapping, error) {
|
|
var items []models.AuthProviderGroupMapping
|
|
var item models.AuthProviderGroupMapping
|
|
var err error
|
|
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ClaimValue, &item.GroupID, &item.GroupName, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListAuthProviderGroupMappings(providerID string) ([]models.AuthProviderGroupMapping, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
|
|
rows, err = s.Query(`
|
|
SELECT m.claim_value, g.public_id, g.name, m.created_at, m.updated_at
|
|
FROM auth_provider_group_mappings m
|
|
JOIN auth_providers p ON p.id = m.auth_provider_id
|
|
JOIN user_groups g ON g.id = m.group_id
|
|
WHERE p.public_id = ?
|
|
ORDER BY m.claim_value, g.name`, strings.TrimSpace(providerID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanAuthProviderGroupMappings(rows)
|
|
}
|
|
|
|
func (s *Store) fillAuthProviderGroupMappings(items []models.AuthProvider) ([]models.AuthProvider, error) {
|
|
var err error
|
|
var i int
|
|
|
|
for i = 0; i < len(items); i++ {
|
|
items[i].GroupMappings, err = s.ListAuthProviderGroupMappings(items[i].ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func insertAuthProviderGroupMappings(exec sqlExecutor, providerID string, mappings []models.AuthProviderGroupMapping, now int64) error {
|
|
var result sql.Result
|
|
var err error
|
|
var mapping models.AuthProviderGroupMapping
|
|
var claimValue string
|
|
var groupID string
|
|
var affected int64
|
|
var i int
|
|
|
|
for i = 0; i < len(mappings); i++ {
|
|
mapping = mappings[i]
|
|
claimValue = strings.TrimSpace(mapping.ClaimValue)
|
|
groupID = strings.TrimSpace(mapping.GroupID)
|
|
if claimValue == "" {
|
|
return errors.New("group mapping claim_value is required")
|
|
}
|
|
if groupID == "" {
|
|
return errors.New("group mapping group_id is required")
|
|
}
|
|
result, err = exec.Exec(`
|
|
INSERT INTO auth_provider_group_mappings (auth_provider_id, group_id, claim_value, created_at, updated_at)
|
|
SELECT p.id, g.id, ?, ?, ?
|
|
FROM auth_providers p
|
|
JOIN user_groups g ON g.public_id = ?
|
|
WHERE p.public_id = ? AND g.scope = ?`,
|
|
claimValue, now, now, groupID, strings.TrimSpace(providerID), models.UserGroupScopeExplicit)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err = result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("explicit user group not found for group mapping: %s", groupID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ListAuthProviders() ([]models.AuthProvider, error) {
|
|
var rows *sql.Rows
|
|
var items []models.AuthProvider
|
|
var item models.AuthProvider
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT` + authProviderSelectCols + authProviderFromClause + ` ORDER BY p.name`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanAuthProvider(rows, &item)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListAuthProvidersWithGroupMappings() ([]models.AuthProvider, error) {
|
|
var items []models.AuthProvider
|
|
var err error
|
|
|
|
items, err = s.ListAuthProviders()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.fillAuthProviderGroupMappings(items)
|
|
}
|
|
|
|
func (s *Store) ListEnabledAuthProviders() ([]models.AuthProvider, error) {
|
|
var rows *sql.Rows
|
|
var items []models.AuthProvider
|
|
var item models.AuthProvider
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT` + authProviderSelectCols + authProviderFromClause + ` WHERE p.enabled = 1 ORDER BY p.created_at`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanAuthProvider(rows, &item)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetAuthProvider(id string) (models.AuthProvider, error) {
|
|
var row *sql.Row
|
|
var item models.AuthProvider
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT`+authProviderSelectCols+authProviderFromClause+` WHERE p.public_id = ?`, strings.TrimSpace(id))
|
|
err = scanAuthProvider(row, &item)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) GetAuthProviderWithGroupMappings(id string) (models.AuthProvider, error) {
|
|
var item models.AuthProvider
|
|
var err error
|
|
|
|
item, err = s.GetAuthProvider(id)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item.GroupMappings, err = s.ListAuthProviderGroupMappings(item.ID)
|
|
return item, err
|
|
}
|
|
|
|
func (s *Store) CountAuthProviderUsers(id string) (int, error) {
|
|
return countAuthProviderUsers(s, id)
|
|
}
|
|
|
|
func countAuthProviderUsers(execer sqlExecutor, id string) (int, error) {
|
|
var count int
|
|
var err error
|
|
|
|
err = execer.QueryRow(`SELECT COUNT(*) FROM users WHERE auth_provider_id = (SELECT id FROM auth_providers WHERE public_id = ?)`, strings.TrimSpace(id)).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
func (s *Store) CreateAuthProvider(ctx context.Context, item models.AuthProvider) (models.AuthProvider, error) {
|
|
var tx txExecutor
|
|
var err error
|
|
var owned bool
|
|
var now int64
|
|
|
|
now = time.Now().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
item.ID, err = util.NewID()
|
|
if err != nil { return item, err }
|
|
}
|
|
if item.Type == models.AuthProviderTypeLDAP && strings.TrimSpace(item.LDAPUserFilter) == "" {
|
|
item.LDAPUserFilter = "(uid={username})"
|
|
}
|
|
if item.Type == models.AuthProviderTypeOIDC && strings.TrimSpace(item.OIDCScopes) == "" {
|
|
item.OIDCScopes = "openid profile email"
|
|
}
|
|
item.GroupSyncMode, err = normalizeAuthProviderGroupSyncMode(item.GroupSyncMode)
|
|
if err != nil { return item, err }
|
|
|
|
tx, owned, err = s.beginContext(ctx)
|
|
if err != nil { return item, err }
|
|
|
|
_, err = tx.Exec(`INSERT INTO auth_providers (
|
|
public_id, name, type, enabled,
|
|
ldap_url, ldap_bind_dn, ldap_bind_password, ldap_user_base_dn, ldap_user_filter,
|
|
ldap_tls_insecure_skip_verify,
|
|
oidc_client_id, oidc_client_secret, oidc_authorize_url, oidc_token_url,
|
|
oidc_userinfo_url, oidc_redirect_url, oidc_scopes, oidc_tls_insecure_skip_verify,
|
|
oidc_groups_claim, oidc_admission_expr, oidc_end_session_url, oidc_post_logout_redirect, group_sync_mode,
|
|
default_user_group_id,
|
|
created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM user_groups WHERE public_id = ?) END, ?, ?)`,
|
|
item.ID, item.Name, item.Type, item.Enabled,
|
|
item.LDAPUrl, item.LDAPBindDN, item.LDAPBindPassword, item.LDAPUserBaseDN, item.LDAPUserFilter,
|
|
item.LDAPTLSInsecureSkipVerify,
|
|
item.OIDCClientID, item.OIDCClientSecret, item.OIDCAuthorizeURL, item.OIDCTokenURL,
|
|
item.OIDCUserInfoURL, item.OIDCRedirectURL, item.OIDCScopes, item.OIDCTLSInsecureSkipVerify,
|
|
item.OIDCGroupsClaim, item.OIDCAdmissionExpr, item.OIDCEndSessionURL, item.OIDCPostLogoutRedirect, item.GroupSyncMode,
|
|
item.DefaultUserGroupID, item.DefaultUserGroupID,
|
|
item.CreatedAt, item.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = insertAuthProviderGroupMappings(tx, item.ID, item.GroupMappings, now)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return s.GetAuthProviderWithGroupMappings(item.ID)
|
|
}
|
|
|
|
func (s *Store) UpdateAuthProvider(ctx context.Context, item models.AuthProvider) (models.AuthProvider, error) {
|
|
var tx txExecutor
|
|
var err error
|
|
var owned bool
|
|
var now int64
|
|
|
|
now = time.Now().Unix()
|
|
item.UpdatedAt = now
|
|
|
|
if item.Type == models.AuthProviderTypeLDAP && strings.TrimSpace(item.LDAPUserFilter) == "" {
|
|
item.LDAPUserFilter = "(uid={username})"
|
|
}
|
|
if item.Type == models.AuthProviderTypeOIDC && strings.TrimSpace(item.OIDCScopes) == "" {
|
|
item.OIDCScopes = "openid profile email"
|
|
}
|
|
item.GroupSyncMode, err = normalizeAuthProviderGroupSyncMode(item.GroupSyncMode)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
|
|
tx, owned, err = s.beginContext(ctx)
|
|
if err != nil { return item, err }
|
|
|
|
_, err = tx.Exec(`UPDATE auth_providers SET
|
|
name = ?, enabled = ?,
|
|
ldap_url = ?, ldap_bind_dn = ?, ldap_bind_password = ?, ldap_user_base_dn = ?, ldap_user_filter = ?,
|
|
ldap_tls_insecure_skip_verify = ?,
|
|
oidc_client_id = ?, oidc_client_secret = ?, oidc_authorize_url = ?, oidc_token_url = ?,
|
|
oidc_userinfo_url = ?, oidc_redirect_url = ?, oidc_scopes = ?, oidc_tls_insecure_skip_verify = ?,
|
|
oidc_groups_claim = ?, oidc_admission_expr = ?, oidc_end_session_url = ?, oidc_post_logout_redirect = ?, group_sync_mode = ?,
|
|
default_user_group_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM user_groups WHERE public_id = ?) END,
|
|
updated_at = ?
|
|
WHERE public_id = ?`,
|
|
item.Name, item.Enabled,
|
|
item.LDAPUrl, item.LDAPBindDN, item.LDAPBindPassword, item.LDAPUserBaseDN, item.LDAPUserFilter,
|
|
item.LDAPTLSInsecureSkipVerify,
|
|
item.OIDCClientID, item.OIDCClientSecret, item.OIDCAuthorizeURL, item.OIDCTokenURL,
|
|
item.OIDCUserInfoURL, item.OIDCRedirectURL, item.OIDCScopes, item.OIDCTLSInsecureSkipVerify,
|
|
item.OIDCGroupsClaim, item.OIDCAdmissionExpr, item.OIDCEndSessionURL, item.OIDCPostLogoutRedirect, item.GroupSyncMode,
|
|
item.DefaultUserGroupID, item.DefaultUserGroupID,
|
|
item.UpdatedAt,
|
|
strings.TrimSpace(item.ID),
|
|
)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM auth_provider_group_mappings WHERE auth_provider_id = (SELECT id FROM auth_providers WHERE public_id = ?)`, strings.TrimSpace(item.ID))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = insertAuthProviderGroupMappings(tx, item.ID, item.GroupMappings, now)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return s.GetAuthProviderWithGroupMappings(item.ID)
|
|
}
|
|
|
|
func (s *Store) DeleteAuthProvider(ctx context.Context, id string, force bool) error {
|
|
var count int
|
|
var err error
|
|
var tx txExecutor
|
|
var owned bool
|
|
|
|
if strings.TrimSpace(id) == BuiltinDBProviderPublicID {
|
|
return errors.New("builtin provider cannot be deleted")
|
|
}
|
|
tx, owned, err = s.beginImmediateContext(ctx)
|
|
if err != nil { return err }
|
|
if !force {
|
|
count, err = countAuthProviderUsers(tx, id)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
rollbackIfOwned(tx, owned)
|
|
return errors.New("provider has associated users")
|
|
}
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM auth_providers WHERE public_id = ?`, strings.TrimSpace(id))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
return commitIfOwned(tx, owned)
|
|
}
|