package db import "database/sql" import "errors" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" func (s *Store) ListUserGroups() ([]models.UserGroup, error) { var rows *sql.Rows var err error var items []models.UserGroup var item models.UserGroup rows, err = s.Query(`SELECT public_id, name, description, disabled, created_at, updated_at FROM user_groups ORDER BY name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.Disabled, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetUserGroup(groupID string) (models.UserGroup, error) { var row *sql.Row var item models.UserGroup var err error row = s.QueryRow(`SELECT public_id, name, description, disabled, created_at, updated_at FROM user_groups WHERE public_id = ?`, strings.TrimSpace(groupID)) err = row.Scan(&item.ID, &item.Name, &item.Description, &item.Disabled, &item.CreatedAt, &item.UpdatedAt) return item, err } func (s *Store) CreateUserGroup(item models.UserGroup) (models.UserGroup, error) { var id string var now int64 var err error if strings.TrimSpace(item.ID) == "" { id, err = util.NewID() if err != nil { return item, err } item.ID = id } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now _, err = s.Exec(`INSERT INTO user_groups (public_id, name, description, disabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Disabled, item.CreatedAt, item.UpdatedAt, ) if err != nil { return item, err } return item, nil } func (s *Store) UpdateUserGroup(item models.UserGroup) (models.UserGroup, error) { var now int64 var err error now = time.Now().UTC().Unix() item.UpdatedAt = now _, err = s.Exec(`UPDATE user_groups SET name = ?, description = ?, disabled = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Disabled, item.UpdatedAt, strings.TrimSpace(item.ID), ) if err != nil { return item, err } item, err = s.GetUserGroup(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) DeleteUserGroup(groupID string) error { var tx *sql.Tx var owned bool var err error tx, owned, err = s.begin() if err != nil { return err } _, err = tx.Exec(`DELETE FROM project_role_bindings WHERE subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(groupID)) if err != nil { rollbackIfOwned(tx, owned) return err } _, err = tx.Exec(`DELETE FROM ssh_principal_grant_targets WHERE target_type = 'group' AND target_public_id = ?`, strings.TrimSpace(groupID)) if err != nil { rollbackIfOwned(tx, owned) return err } _, err = tx.Exec(`DELETE FROM subject_permissions WHERE subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(groupID)) if err != nil { rollbackIfOwned(tx, owned) return err } _, err = tx.Exec(`DELETE FROM user_groups WHERE public_id = ?`, strings.TrimSpace(groupID)) if err != nil { rollbackIfOwned(tx, owned) return err } err = commitIfOwned(tx, owned) if err != nil { return err } return err } func (s *Store) ListUserGroupMembers(groupID string) ([]models.UserGroupMember, error) { var rows *sql.Rows var err error var items []models.UserGroupMember var item models.UserGroupMember rows, err = s.Query(`SELECT g.public_id, u.public_id, u.username, u.display_name, m.created_at FROM user_group_members m JOIN user_groups g ON g.id = m.group_id JOIN users u ON u.id = m.user_id WHERE g.public_id = ? ORDER BY u.username`, strings.TrimSpace(groupID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.GroupID, &item.UserID, &item.Username, &item.UserRealName, &item.CreatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) AddUserGroupMember(groupID string, userID string) error { var now int64 var err error now = time.Now().UTC().Unix() _, err = s.Exec(`INSERT OR IGNORE INTO user_group_members (group_id, user_id, created_at) VALUES ((SELECT id FROM user_groups WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?)`, strings.TrimSpace(groupID), strings.TrimSpace(userID), now) return err } func (s *Store) RemoveUserGroupMember(groupID string, userID string) error { var err error _, err = s.Exec(`DELETE FROM user_group_members WHERE group_id = (SELECT id FROM user_groups WHERE public_id = ?) AND user_id = (SELECT id FROM users WHERE public_id = ?)`, strings.TrimSpace(groupID), strings.TrimSpace(userID), ) return err } func (s *Store) ListProjectGroupRoles(projectID string) ([]models.ProjectGroupRole, error) { var rows *sql.Rows var err error var items []models.ProjectGroupRole var item models.ProjectGroupRole rows, err = s.Query(`SELECT p.public_id, g.public_id, b.role, b.created_at FROM project_role_bindings b JOIN projects p ON p.id = b.project_id JOIN user_groups g ON g.public_id = b.subject_public_id WHERE p.public_id = ? AND b.subject_type = 'group' ORDER BY g.name`, strings.TrimSpace(projectID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ProjectID, &item.GroupID, &item.Role, &item.CreatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) UpsertProjectGroupRole(projectID string, groupID string, role string) (models.ProjectGroupRole, error) { var now int64 var item models.ProjectGroupRole var result sql.Result var rowsAffected int64 var err error now = time.Now().UTC().Unix() result, err = s.Exec(`INSERT INTO project_role_bindings (project_id, subject_type, subject_public_id, role, created_at, updated_at) SELECT (SELECT id FROM projects WHERE public_id = ?), 'group', g.public_id, ?, ?, ? FROM user_groups g WHERE g.public_id = ? ON CONFLICT(project_id, subject_type, subject_public_id) DO UPDATE SET role = excluded.role, updated_at = excluded.updated_at`, strings.TrimSpace(projectID), strings.TrimSpace(role), now, now, strings.TrimSpace(groupID)) if err != nil { return item, err } rowsAffected, err = result.RowsAffected() if err != nil { return item, err } if rowsAffected <= 0 { return item, errors.New("group not found") } item.ProjectID = strings.TrimSpace(projectID) item.GroupID = strings.TrimSpace(groupID) item.Role = strings.TrimSpace(role) item.CreatedAt = now return item, nil } func (s *Store) RemoveProjectGroupRole(projectID string, groupID string) error { var err error _, err = s.Exec(`DELETE FROM project_role_bindings WHERE project_id = (SELECT id FROM projects WHERE public_id = ?) AND subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(projectID), strings.TrimSpace(groupID), ) return err } func (s *Store) GetProjectRoleForUser(projectID string, userID string) (string, error) { var rows *sql.Rows var role string var best string var err error rows, err = s.Query(`SELECT b.role FROM project_role_bindings b WHERE b.project_id = (SELECT id FROM projects WHERE public_id = ?) AND ( (b.subject_type = 'user' AND b.subject_public_id = ?) OR (b.subject_type = 'group' AND b.subject_public_id IN ( SELECT g.public_id FROM user_group_members gm JOIN users ux ON ux.id = gm.user_id JOIN user_groups g ON g.id = gm.group_id WHERE ux.public_id = ? AND g.disabled = 0 )) )`, strings.TrimSpace(projectID), strings.TrimSpace(userID), strings.TrimSpace(userID), ) if err != nil { return "", err } defer rows.Close() best = "" for rows.Next() { err = rows.Scan(&role) if err != nil { return "", err } if projectRoleRank(role) > projectRoleRank(best) { best = role } } err = rows.Err() if err != nil { return "", err } if best == "" { return "", sql.ErrNoRows } return best, nil } func (s *Store) GetProjectRoleForPrincipal(projectID string, principalID string) (string, error) { var role string var row *sql.Row var err error row = s.QueryRow(`SELECT role FROM project_role_bindings WHERE project_id = (SELECT id FROM projects WHERE public_id = ?) AND subject_type = 'principal' AND subject_public_id = ?`, strings.TrimSpace(projectID), strings.TrimSpace(principalID), ) err = row.Scan(&role) return role, err } func projectRoleRank(value string) int { var v string v = strings.ToLower(strings.TrimSpace(value)) if v == "admin" { return 3 } if v == "writer" { return 2 } if v == "viewer" { return 1 } return 0 }