package db import "context" import "database/sql" import "errors" import "fmt" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" const BuiltinDBProviderPublicID = "db" func scanAuthProvider(row interface { Scan(dest ...any) error }, item *models.AuthProvider) error { var defaultUserGroupID sql.NullString var err error err = row.Scan( &item.ID, &item.Name, &item.Type, &item.Enabled, &item.LDAPUrl, &item.LDAPBindDN, &item.LDAPBindPassword, &item.LDAPUserBaseDN, &item.LDAPUserFilter, &item.LDAPTLSInsecureSkipVerify, &item.OIDCClientID, &item.OIDCClientSecret, &item.OIDCAuthorizeURL, &item.OIDCTokenURL, &item.OIDCUserInfoURL, &item.OIDCRedirectURL, &item.OIDCScopes, &item.OIDCTLSInsecureSkipVerify, &item.OIDCGroupsClaim, &item.OIDCAdmissionExpr, &item.OIDCEndSessionURL, &item.OIDCPostLogoutRedirect, &item.GroupSyncMode, &defaultUserGroupID, &item.UserCount, &item.CreatedAt, &item.UpdatedAt, ) if err != nil { return err } item.DefaultUserGroupID = defaultUserGroupID.String return nil } const authProviderSelectCols = ` p.public_id, p.name, p.type, p.enabled, p.ldap_url, p.ldap_bind_dn, p.ldap_bind_password, p.ldap_user_base_dn, p.ldap_user_filter, p.ldap_tls_insecure_skip_verify, p.oidc_client_id, p.oidc_client_secret, p.oidc_authorize_url, p.oidc_token_url, p.oidc_userinfo_url, p.oidc_redirect_url, p.oidc_scopes, p.oidc_tls_insecure_skip_verify, p.oidc_groups_claim, p.oidc_admission_expr, p.oidc_end_session_url, p.oidc_post_logout_redirect, p.group_sync_mode, COALESCE(ug.public_id, ''), (SELECT COUNT(*) FROM users WHERE auth_provider_id = p.id), p.created_at, p.updated_at` const authProviderFromClause = ` FROM auth_providers p LEFT JOIN user_groups ug ON ug.id = p.default_user_group_id` func normalizeAuthProviderGroupSyncMode(value string) (string, error) { var mode string mode = strings.TrimSpace(value) if mode == "" { return models.GroupSyncModeOff, nil } if mode == models.GroupSyncModeOff || mode == models.GroupSyncModeFirstLogin || mode == models.GroupSyncModeSync { return mode, nil } return "", errors.New("invalid group_sync_mode") } func scanAuthProviderGroupMappings(rows *sql.Rows) ([]models.AuthProviderGroupMapping, error) { var items []models.AuthProviderGroupMapping var item models.AuthProviderGroupMapping var err error for rows.Next() { err = rows.Scan(&item.ClaimValue, &item.GroupID, &item.GroupName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (s *Store) ListAuthProviderGroupMappings(providerID string) ([]models.AuthProviderGroupMapping, error) { var rows *sql.Rows var err error rows, err = s.Query(` SELECT m.claim_value, g.public_id, g.name, m.created_at, m.updated_at FROM auth_provider_group_mappings m JOIN auth_providers p ON p.id = m.auth_provider_id JOIN user_groups g ON g.id = m.group_id WHERE p.public_id = ? ORDER BY m.claim_value, g.name`, strings.TrimSpace(providerID)) if err != nil { return nil, err } defer rows.Close() return scanAuthProviderGroupMappings(rows) } func (s *Store) fillAuthProviderGroupMappings(items []models.AuthProvider) ([]models.AuthProvider, error) { var err error var i int for i = 0; i < len(items); i++ { items[i].GroupMappings, err = s.ListAuthProviderGroupMappings(items[i].ID) if err != nil { return nil, err } } return items, nil } func insertAuthProviderGroupMappings(exec sqlExecutor, providerID string, mappings []models.AuthProviderGroupMapping, now int64) error { var result sql.Result var err error var mapping models.AuthProviderGroupMapping var claimValue string var groupID string var affected int64 var i int for i = 0; i < len(mappings); i++ { mapping = mappings[i] claimValue = strings.TrimSpace(mapping.ClaimValue) groupID = strings.TrimSpace(mapping.GroupID) if claimValue == "" { return errors.New("group mapping claim_value is required") } if groupID == "" { return errors.New("group mapping group_id is required") } result, err = exec.Exec(` INSERT INTO auth_provider_group_mappings (auth_provider_id, group_id, claim_value, created_at, updated_at) SELECT p.id, g.id, ?, ?, ? FROM auth_providers p JOIN user_groups g ON g.public_id = ? WHERE p.public_id = ? AND g.scope = ?`, claimValue, now, now, groupID, strings.TrimSpace(providerID), models.UserGroupScopeExplicit) if err != nil { return err } affected, err = result.RowsAffected() if err != nil { return err } if affected == 0 { return fmt.Errorf("explicit user group not found for group mapping: %s", groupID) } } return nil } func (s *Store) ListAuthProviders() ([]models.AuthProvider, error) { var rows *sql.Rows var items []models.AuthProvider var item models.AuthProvider var err error rows, err = s.Query(`SELECT` + authProviderSelectCols + authProviderFromClause + ` ORDER BY p.name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = scanAuthProvider(rows, &item) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListAuthProvidersWithGroupMappings() ([]models.AuthProvider, error) { var items []models.AuthProvider var err error items, err = s.ListAuthProviders() if err != nil { return nil, err } return s.fillAuthProviderGroupMappings(items) } func (s *Store) ListEnabledAuthProviders() ([]models.AuthProvider, error) { var rows *sql.Rows var items []models.AuthProvider var item models.AuthProvider var err error rows, err = s.Query(`SELECT` + authProviderSelectCols + authProviderFromClause + ` WHERE p.enabled = 1 ORDER BY p.created_at`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = scanAuthProvider(rows, &item) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetAuthProvider(id string) (models.AuthProvider, error) { var row *sql.Row var item models.AuthProvider var err error row = s.QueryRow(`SELECT`+authProviderSelectCols+authProviderFromClause+` WHERE p.public_id = ?`, strings.TrimSpace(id)) err = scanAuthProvider(row, &item) if err != nil { return item, err } return item, nil } func (s *Store) GetAuthProviderWithGroupMappings(id string) (models.AuthProvider, error) { var item models.AuthProvider var err error item, err = s.GetAuthProvider(id) if err != nil { return item, err } item.GroupMappings, err = s.ListAuthProviderGroupMappings(item.ID) return item, err } func (s *Store) CountAuthProviderUsers(id string) (int, error) { return countAuthProviderUsers(s, id) } func countAuthProviderUsers(execer sqlExecutor, id string) (int, error) { var count int var err error err = execer.QueryRow(`SELECT COUNT(*) FROM users WHERE auth_provider_id = (SELECT id FROM auth_providers WHERE public_id = ?)`, strings.TrimSpace(id)).Scan(&count) return count, err } func (s *Store) CreateAuthProvider(ctx context.Context, item models.AuthProvider) (models.AuthProvider, error) { var tx txExecutor var err error var owned bool var now int64 now = time.Now().Unix() item.CreatedAt = now item.UpdatedAt = now if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } if item.Type == models.AuthProviderTypeLDAP && strings.TrimSpace(item.LDAPUserFilter) == "" { item.LDAPUserFilter = "(uid={username})" } if item.Type == models.AuthProviderTypeOIDC && strings.TrimSpace(item.OIDCScopes) == "" { item.OIDCScopes = "openid profile email" } item.GroupSyncMode, err = normalizeAuthProviderGroupSyncMode(item.GroupSyncMode) if err != nil { return item, err } tx, owned, err = s.beginContext(ctx) if err != nil { return item, err } _, err = tx.Exec(`INSERT INTO auth_providers ( public_id, name, type, enabled, ldap_url, ldap_bind_dn, ldap_bind_password, ldap_user_base_dn, ldap_user_filter, ldap_tls_insecure_skip_verify, oidc_client_id, oidc_client_secret, oidc_authorize_url, oidc_token_url, oidc_userinfo_url, oidc_redirect_url, oidc_scopes, oidc_tls_insecure_skip_verify, oidc_groups_claim, oidc_admission_expr, oidc_end_session_url, oidc_post_logout_redirect, group_sync_mode, default_user_group_id, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM user_groups WHERE public_id = ?) END, ?, ?)`, item.ID, item.Name, item.Type, item.Enabled, item.LDAPUrl, item.LDAPBindDN, item.LDAPBindPassword, item.LDAPUserBaseDN, item.LDAPUserFilter, item.LDAPTLSInsecureSkipVerify, item.OIDCClientID, item.OIDCClientSecret, item.OIDCAuthorizeURL, item.OIDCTokenURL, item.OIDCUserInfoURL, item.OIDCRedirectURL, item.OIDCScopes, item.OIDCTLSInsecureSkipVerify, item.OIDCGroupsClaim, item.OIDCAdmissionExpr, item.OIDCEndSessionURL, item.OIDCPostLogoutRedirect, item.GroupSyncMode, item.DefaultUserGroupID, item.DefaultUserGroupID, item.CreatedAt, item.UpdatedAt, ) if err != nil { rollbackIfOwned(tx, owned) return item, err } err = insertAuthProviderGroupMappings(tx, item.ID, item.GroupMappings, now) if err != nil { rollbackIfOwned(tx, owned) return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetAuthProviderWithGroupMappings(item.ID) } func (s *Store) UpdateAuthProvider(ctx context.Context, item models.AuthProvider) (models.AuthProvider, error) { var tx txExecutor var err error var owned bool var now int64 now = time.Now().Unix() item.UpdatedAt = now if item.Type == models.AuthProviderTypeLDAP && strings.TrimSpace(item.LDAPUserFilter) == "" { item.LDAPUserFilter = "(uid={username})" } if item.Type == models.AuthProviderTypeOIDC && strings.TrimSpace(item.OIDCScopes) == "" { item.OIDCScopes = "openid profile email" } item.GroupSyncMode, err = normalizeAuthProviderGroupSyncMode(item.GroupSyncMode) if err != nil { return item, err } tx, owned, err = s.beginContext(ctx) if err != nil { return item, err } _, err = tx.Exec(`UPDATE auth_providers SET name = ?, enabled = ?, ldap_url = ?, ldap_bind_dn = ?, ldap_bind_password = ?, ldap_user_base_dn = ?, ldap_user_filter = ?, ldap_tls_insecure_skip_verify = ?, oidc_client_id = ?, oidc_client_secret = ?, oidc_authorize_url = ?, oidc_token_url = ?, oidc_userinfo_url = ?, oidc_redirect_url = ?, oidc_scopes = ?, oidc_tls_insecure_skip_verify = ?, oidc_groups_claim = ?, oidc_admission_expr = ?, oidc_end_session_url = ?, oidc_post_logout_redirect = ?, group_sync_mode = ?, default_user_group_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM user_groups WHERE public_id = ?) END, updated_at = ? WHERE public_id = ?`, item.Name, item.Enabled, item.LDAPUrl, item.LDAPBindDN, item.LDAPBindPassword, item.LDAPUserBaseDN, item.LDAPUserFilter, item.LDAPTLSInsecureSkipVerify, item.OIDCClientID, item.OIDCClientSecret, item.OIDCAuthorizeURL, item.OIDCTokenURL, item.OIDCUserInfoURL, item.OIDCRedirectURL, item.OIDCScopes, item.OIDCTLSInsecureSkipVerify, item.OIDCGroupsClaim, item.OIDCAdmissionExpr, item.OIDCEndSessionURL, item.OIDCPostLogoutRedirect, item.GroupSyncMode, item.DefaultUserGroupID, item.DefaultUserGroupID, item.UpdatedAt, strings.TrimSpace(item.ID), ) if err != nil { rollbackIfOwned(tx, owned) return item, err } _, err = tx.Exec(`DELETE FROM auth_provider_group_mappings WHERE auth_provider_id = (SELECT id FROM auth_providers WHERE public_id = ?)`, strings.TrimSpace(item.ID)) if err != nil { rollbackIfOwned(tx, owned) return item, err } err = insertAuthProviderGroupMappings(tx, item.ID, item.GroupMappings, now) if err != nil { rollbackIfOwned(tx, owned) return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetAuthProviderWithGroupMappings(item.ID) } func (s *Store) DeleteAuthProvider(ctx context.Context, id string, force bool) error { var count int var err error var tx txExecutor var owned bool if strings.TrimSpace(id) == BuiltinDBProviderPublicID { return errors.New("builtin provider cannot be deleted") } tx, owned, err = s.beginImmediateContext(ctx) if err != nil { return err } if !force { count, err = countAuthProviderUsers(tx, id) if err != nil { rollbackIfOwned(tx, owned) return err } if count > 0 { rollbackIfOwned(tx, owned) return errors.New("provider has associated users") } } _, err = tx.Exec(`DELETE FROM auth_providers WHERE public_id = ?`, strings.TrimSpace(id)) if err != nil { rollbackIfOwned(tx, owned) return err } return commitIfOwned(tx, owned) }