331 lines
8.8 KiB
Go
331 lines
8.8 KiB
Go
package db
|
|
|
|
import "database/sql"
|
|
import "errors"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
func (s *Store) ListUserGroups() ([]models.UserGroup, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.UserGroup
|
|
var item models.UserGroup
|
|
rows, err = s.Query(`SELECT public_id, name, description, disabled, created_at, updated_at FROM user_groups ORDER BY name`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.Disabled, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetUserGroup(groupID string) (models.UserGroup, error) {
|
|
var row *sql.Row
|
|
var item models.UserGroup
|
|
var err error
|
|
row = s.QueryRow(`SELECT public_id, name, description, disabled, created_at, updated_at
|
|
FROM user_groups WHERE public_id = ?`, strings.TrimSpace(groupID))
|
|
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.Disabled, &item.CreatedAt, &item.UpdatedAt)
|
|
return item, err
|
|
}
|
|
|
|
func (s *Store) CreateUserGroup(item models.UserGroup) (models.UserGroup, error) {
|
|
var id string
|
|
var now int64
|
|
var err error
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
id, err = util.NewID()
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item.ID = id
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
_, err = s.Exec(`INSERT INTO user_groups (public_id, name, description, disabled, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
item.ID,
|
|
strings.TrimSpace(item.Name),
|
|
strings.TrimSpace(item.Description),
|
|
item.Disabled,
|
|
item.CreatedAt,
|
|
item.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) UpdateUserGroup(item models.UserGroup) (models.UserGroup, error) {
|
|
var now int64
|
|
var err error
|
|
now = time.Now().UTC().Unix()
|
|
item.UpdatedAt = now
|
|
_, err = s.Exec(`UPDATE user_groups
|
|
SET name = ?, description = ?, disabled = ?, updated_at = ?
|
|
WHERE public_id = ?`,
|
|
strings.TrimSpace(item.Name),
|
|
strings.TrimSpace(item.Description),
|
|
item.Disabled,
|
|
item.UpdatedAt,
|
|
strings.TrimSpace(item.ID),
|
|
)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item, err = s.GetUserGroup(item.ID)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) DeleteUserGroup(groupID string) error {
|
|
var tx *sql.Tx
|
|
var owned bool
|
|
var err error
|
|
tx, owned, err = s.begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM project_role_bindings WHERE subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(groupID))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM ssh_principal_grant_targets WHERE target_type = 'group' AND target_public_id = ?`, strings.TrimSpace(groupID))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM subject_permissions WHERE subject_type = 'group' AND subject_public_id = ?`, strings.TrimSpace(groupID))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM user_groups WHERE public_id = ?`, strings.TrimSpace(groupID))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *Store) ListUserGroupMembers(groupID string) ([]models.UserGroupMember, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.UserGroupMember
|
|
var item models.UserGroupMember
|
|
rows, err = s.Query(`SELECT g.public_id, u.public_id, u.username, u.display_name, m.created_at
|
|
FROM user_group_members m
|
|
JOIN user_groups g ON g.id = m.group_id
|
|
JOIN users u ON u.id = m.user_id
|
|
WHERE g.public_id = ?
|
|
ORDER BY u.username`, strings.TrimSpace(groupID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.GroupID, &item.UserID, &item.Username, &item.UserRealName, &item.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) AddUserGroupMember(groupID string, userID string) error {
|
|
var now int64
|
|
var err error
|
|
now = time.Now().UTC().Unix()
|
|
_, err = s.Exec(`INSERT OR IGNORE INTO user_group_members (group_id, user_id, created_at)
|
|
VALUES ((SELECT id FROM user_groups WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?)`,
|
|
strings.TrimSpace(groupID), strings.TrimSpace(userID), now)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RemoveUserGroupMember(groupID string, userID string) error {
|
|
var err error
|
|
_, err = s.Exec(`DELETE FROM user_group_members
|
|
WHERE group_id = (SELECT id FROM user_groups WHERE public_id = ?)
|
|
AND user_id = (SELECT id FROM users WHERE public_id = ?)`,
|
|
strings.TrimSpace(groupID),
|
|
strings.TrimSpace(userID),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) ListProjectGroupRoles(projectID string) ([]models.ProjectGroupRole, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.ProjectGroupRole
|
|
var item models.ProjectGroupRole
|
|
rows, err = s.Query(`SELECT p.public_id, g.public_id, b.role, b.created_at
|
|
FROM project_role_bindings b
|
|
JOIN projects p ON p.id = b.project_id
|
|
JOIN user_groups g ON g.public_id = b.subject_public_id
|
|
WHERE p.public_id = ? AND b.subject_type = 'group'
|
|
ORDER BY g.name`, strings.TrimSpace(projectID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ProjectID, &item.GroupID, &item.Role, &item.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) UpsertProjectGroupRole(projectID string, groupID string, role string) (models.ProjectGroupRole, error) {
|
|
var now int64
|
|
var item models.ProjectGroupRole
|
|
var result sql.Result
|
|
var rowsAffected int64
|
|
var err error
|
|
now = time.Now().UTC().Unix()
|
|
result, err = s.Exec(`INSERT INTO project_role_bindings (project_id, subject_type, subject_public_id, role, created_at, updated_at)
|
|
SELECT (SELECT id FROM projects WHERE public_id = ?), 'group', g.public_id, ?, ?, ?
|
|
FROM user_groups g
|
|
WHERE g.public_id = ?
|
|
ON CONFLICT(project_id, subject_type, subject_public_id) DO UPDATE SET role = excluded.role, updated_at = excluded.updated_at`,
|
|
strings.TrimSpace(projectID), strings.TrimSpace(role), now, now, strings.TrimSpace(groupID))
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
rowsAffected, err = result.RowsAffected()
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
if rowsAffected <= 0 {
|
|
return item, errors.New("group not found")
|
|
}
|
|
item.ProjectID = strings.TrimSpace(projectID)
|
|
item.GroupID = strings.TrimSpace(groupID)
|
|
item.Role = strings.TrimSpace(role)
|
|
item.CreatedAt = now
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) RemoveProjectGroupRole(projectID string, groupID string) error {
|
|
var err error
|
|
_, err = s.Exec(`DELETE FROM project_role_bindings
|
|
WHERE project_id = (SELECT id FROM projects WHERE public_id = ?)
|
|
AND subject_type = 'group'
|
|
AND subject_public_id = ?`,
|
|
strings.TrimSpace(projectID),
|
|
strings.TrimSpace(groupID),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) GetProjectRoleForUser(projectID string, userID string) (string, error) {
|
|
var rows *sql.Rows
|
|
var role string
|
|
var best string
|
|
var err error
|
|
rows, err = s.Query(`SELECT b.role
|
|
FROM project_role_bindings b
|
|
WHERE b.project_id = (SELECT id FROM projects WHERE public_id = ?)
|
|
AND (
|
|
(b.subject_type = 'user' AND b.subject_public_id = ?)
|
|
OR
|
|
(b.subject_type = 'group' AND b.subject_public_id IN (
|
|
SELECT g.public_id
|
|
FROM user_group_members gm
|
|
JOIN users ux ON ux.id = gm.user_id
|
|
JOIN user_groups g ON g.id = gm.group_id
|
|
WHERE ux.public_id = ? AND g.disabled = 0
|
|
))
|
|
)`,
|
|
strings.TrimSpace(projectID),
|
|
strings.TrimSpace(userID),
|
|
strings.TrimSpace(userID),
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer rows.Close()
|
|
best = ""
|
|
for rows.Next() {
|
|
err = rows.Scan(&role)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if projectRoleRank(role) > projectRoleRank(best) {
|
|
best = role
|
|
}
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if best == "" {
|
|
return "", sql.ErrNoRows
|
|
}
|
|
return best, nil
|
|
}
|
|
|
|
func (s *Store) GetProjectRoleForPrincipal(projectID string, principalID string) (string, error) {
|
|
var role string
|
|
var row *sql.Row
|
|
var err error
|
|
row = s.QueryRow(`SELECT role
|
|
FROM project_role_bindings
|
|
WHERE project_id = (SELECT id FROM projects WHERE public_id = ?)
|
|
AND subject_type = 'principal'
|
|
AND subject_public_id = ?`,
|
|
strings.TrimSpace(projectID),
|
|
strings.TrimSpace(principalID),
|
|
)
|
|
err = row.Scan(&role)
|
|
return role, err
|
|
}
|
|
|
|
func projectRoleRank(value string) int {
|
|
var v string
|
|
v = strings.ToLower(strings.TrimSpace(value))
|
|
if v == "admin" {
|
|
return 3
|
|
}
|
|
if v == "writer" {
|
|
return 2
|
|
}
|
|
if v == "viewer" {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|