Files
codit/backend/internal/db/groups.go

331 lines
8.8 KiB
Go

package db
import "database/sql"
import "errors"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
func (s *Store) ListUserGroups() ([]models.UserGroup, error) {
var rows *sql.Rows
var err error
var items []models.UserGroup
var item models.UserGroup
rows, err = s.Query(`SELECT public_id, name, description, disabled, created_at, updated_at FROM user_groups ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.Disabled, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetUserGroup(groupID string) (models.UserGroup, error) {
var row *sql.Row
var item models.UserGroup
var err error
row = s.QueryRow(`SELECT public_id, name, description, disabled, created_at, updated_at
FROM user_groups WHERE public_id = ?`, strings.TrimSpace(groupID))
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.Disabled, &item.CreatedAt, &item.UpdatedAt)
return item, err
}
func (s *Store) CreateUserGroup(item models.UserGroup) (models.UserGroup, error) {
var id string
var now int64
var err error
if strings.TrimSpace(item.ID) == "" {
id, err = util.NewID()
if err != nil {
return item, err
}
item.ID = id
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
_, err = s.Exec(`INSERT INTO user_groups (public_id, name, description, disabled, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Description),
item.Disabled,
item.CreatedAt,
item.UpdatedAt,
)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) UpdateUserGroup(item models.UserGroup) (models.UserGroup, error) {
var now int64
var err error
now = time.Now().UTC().Unix()
item.UpdatedAt = now
_, err = s.Exec(`UPDATE user_groups
SET name = ?, description = ?, disabled = ?, updated_at = ?
WHERE public_id = ?`,
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Description),
item.Disabled,
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
if err != nil {
return item, err
}
item, err = s.GetUserGroup(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) DeleteUserGroup(groupID string) error {
var tx *sql.Tx
var owned bool
var err error
tx, owned, err = s.begin()
if err != nil {
return err
}
_, err = tx.Exec(`DELETE FROM project_role_bindings WHERE subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(groupID))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
_, err = tx.Exec(`DELETE FROM ssh_principal_grant_targets WHERE target_type = 'group' AND target_public_id = ?`, strings.TrimSpace(groupID))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
_, err = tx.Exec(`DELETE FROM subject_permissions WHERE subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(groupID))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
_, err = tx.Exec(`DELETE FROM user_groups WHERE public_id = ?`, strings.TrimSpace(groupID))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
err = commitIfOwned(tx, owned)
if err != nil {
return err
}
return err
}
func (s *Store) ListUserGroupMembers(groupID string) ([]models.UserGroupMember, error) {
var rows *sql.Rows
var err error
var items []models.UserGroupMember
var item models.UserGroupMember
rows, err = s.Query(`SELECT g.public_id, u.public_id, u.username, u.display_name, m.created_at
FROM user_group_members m
JOIN user_groups g ON g.id = m.group_id
JOIN users u ON u.id = m.user_id
WHERE g.public_id = ?
ORDER BY u.username`, strings.TrimSpace(groupID))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.GroupID, &item.UserID, &item.Username, &item.UserRealName, &item.CreatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) AddUserGroupMember(groupID string, userID string) error {
var now int64
var err error
now = time.Now().UTC().Unix()
_, err = s.Exec(`INSERT OR IGNORE INTO user_group_members (group_id, user_id, created_at)
VALUES ((SELECT id FROM user_groups WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?)`,
strings.TrimSpace(groupID), strings.TrimSpace(userID), now)
return err
}
func (s *Store) RemoveUserGroupMember(groupID string, userID string) error {
var err error
_, err = s.Exec(`DELETE FROM user_group_members
WHERE group_id = (SELECT id FROM user_groups WHERE public_id = ?)
AND user_id = (SELECT id FROM users WHERE public_id = ?)`,
strings.TrimSpace(groupID),
strings.TrimSpace(userID),
)
return err
}
func (s *Store) ListProjectGroupRoles(projectID string) ([]models.ProjectGroupRole, error) {
var rows *sql.Rows
var err error
var items []models.ProjectGroupRole
var item models.ProjectGroupRole
rows, err = s.Query(`SELECT p.public_id, g.public_id, b.role, b.created_at
FROM project_role_bindings b
JOIN projects p ON p.id = b.project_id
JOIN user_groups g ON g.public_id = b.subject_public_id
WHERE p.public_id = ? AND b.subject_type = 'group'
ORDER BY g.name`, strings.TrimSpace(projectID))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ProjectID, &item.GroupID, &item.Role, &item.CreatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) UpsertProjectGroupRole(projectID string, groupID string, role string) (models.ProjectGroupRole, error) {
var now int64
var item models.ProjectGroupRole
var result sql.Result
var rowsAffected int64
var err error
now = time.Now().UTC().Unix()
result, err = s.Exec(`INSERT INTO project_role_bindings (project_id, subject_type, subject_public_id, role, created_at, updated_at)
SELECT (SELECT id FROM projects WHERE public_id = ?), 'group', g.public_id, ?, ?, ?
FROM user_groups g
WHERE g.public_id = ?
ON CONFLICT(project_id, subject_type, subject_public_id) DO UPDATE SET role = excluded.role, updated_at = excluded.updated_at`,
strings.TrimSpace(projectID), strings.TrimSpace(role), now, now, strings.TrimSpace(groupID))
if err != nil {
return item, err
}
rowsAffected, err = result.RowsAffected()
if err != nil {
return item, err
}
if rowsAffected <= 0 {
return item, errors.New("group not found")
}
item.ProjectID = strings.TrimSpace(projectID)
item.GroupID = strings.TrimSpace(groupID)
item.Role = strings.TrimSpace(role)
item.CreatedAt = now
return item, nil
}
func (s *Store) RemoveProjectGroupRole(projectID string, groupID string) error {
var err error
_, err = s.Exec(`DELETE FROM project_role_bindings
WHERE project_id = (SELECT id FROM projects WHERE public_id = ?)
AND subject_type = 'group'
AND subject_public_id = ?`,
strings.TrimSpace(projectID),
strings.TrimSpace(groupID),
)
return err
}
func (s *Store) GetProjectRoleForUser(projectID string, userID string) (string, error) {
var rows *sql.Rows
var role string
var best string
var err error
rows, err = s.Query(`SELECT b.role
FROM project_role_bindings b
WHERE b.project_id = (SELECT id FROM projects WHERE public_id = ?)
AND (
(b.subject_type = 'user' AND b.subject_public_id = ?)
OR
(b.subject_type = 'group' AND b.subject_public_id IN (
SELECT g.public_id
FROM user_group_members gm
JOIN users ux ON ux.id = gm.user_id
JOIN user_groups g ON g.id = gm.group_id
WHERE ux.public_id = ? AND g.disabled = 0
))
)`,
strings.TrimSpace(projectID),
strings.TrimSpace(userID),
strings.TrimSpace(userID),
)
if err != nil {
return "", err
}
defer rows.Close()
best = ""
for rows.Next() {
err = rows.Scan(&role)
if err != nil {
return "", err
}
if projectRoleRank(role) > projectRoleRank(best) {
best = role
}
}
err = rows.Err()
if err != nil {
return "", err
}
if best == "" {
return "", sql.ErrNoRows
}
return best, nil
}
func (s *Store) GetProjectRoleForPrincipal(projectID string, principalID string) (string, error) {
var role string
var row *sql.Row
var err error
row = s.QueryRow(`SELECT role
FROM project_role_bindings
WHERE project_id = (SELECT id FROM projects WHERE public_id = ?)
AND subject_type = 'principal'
AND subject_public_id = ?`,
strings.TrimSpace(projectID),
strings.TrimSpace(principalID),
)
err = row.Scan(&role)
return role, err
}
func projectRoleRank(value string) int {
var v string
v = strings.ToLower(strings.TrimSpace(value))
if v == "admin" {
return 3
}
if v == "writer" {
return 2
}
if v == "viewer" {
return 1
}
return 0
}