Files
codit/backend/internal/db/ssh-server-groups.go

229 lines
8.7 KiB
Go

package db
import "context"
import "database/sql"
import "errors"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
func scanSSHServerGroup(row interface{ Scan(dest ...any) error }, item *models.SSHServerGroup) error {
var tagsJSON string
var err error
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.Enabled, &tagsJSON, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil { return err }
item.Tags, err = decodeStringList(tagsJSON)
return err
}
func (s *Store) listSSHServerGroupMemberIDs(groupID string) ([]string, error) {
var rows *sql.Rows
var items []string
var id string
var err error
rows, err = s.Query(`SELECT s.public_id
FROM ssh_server_group_members gm
JOIN ssh_servers s ON s.id = gm.server_id
JOIN ssh_server_groups g ON g.id = gm.group_id
WHERE g.public_id = ?
ORDER BY s.public_id`, strings.TrimSpace(groupID))
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = rows.Scan(&id)
if err != nil { return nil, err }
items = append(items, id)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) ListSSHServerGroups() ([]models.SSHServerGroup, error) {
var rows *sql.Rows
var items []models.SSHServerGroup
var item models.SSHServerGroup
var err error
rows, err = s.Query(`SELECT g.public_id, g.name, g.description, g.enabled, g.tags_json, g.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), g.created_by_subject_name, g.created_at, g.updated_at
FROM ssh_server_groups g
LEFT JOIN users cbu ON cbu.id = g.created_by_user_id
LEFT JOIN service_principals cbp ON cbp.id = g.created_by_principal_id
ORDER BY g.name`)
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = scanSSHServerGroup(rows, &item)
if err != nil { return nil, err }
item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) GetSSHServerGroup(id string) (models.SSHServerGroup, error) {
var row *sql.Row
var item models.SSHServerGroup
var err error
row = s.QueryRow(`SELECT g.public_id, g.name, g.description, g.enabled, g.tags_json, g.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), g.created_by_subject_name, g.created_at, g.updated_at
FROM ssh_server_groups g
LEFT JOIN users cbu ON cbu.id = g.created_by_user_id
LEFT JOIN service_principals cbp ON cbp.id = g.created_by_principal_id
WHERE g.public_id = ?`, strings.TrimSpace(id))
err = scanSSHServerGroup(row, &item)
if err != nil { return item, err }
item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID)
if err != nil { return item, err }
return item, nil
}
func (s *Store) ListSSHServersForGroup(groupID string) ([]models.SSHServer, error) {
var rows *sql.Rows
var items []models.SSHServer
var item models.SSHServer
var tagsJSON string
var err error
rows, err = s.Query(`SELECT s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.owner_scope, COALESCE(u.public_id, ''), s.created_by_kind, COALESCE(sbu.public_id, sbp.public_id, ''), s.created_by_subject_name, s.created_at, s.updated_at
FROM ssh_server_group_members gm
JOIN ssh_servers s ON s.id = gm.server_id
JOIN ssh_server_groups g ON g.id = gm.group_id
LEFT JOIN users u ON u.id = s.owner_user_id
LEFT JOIN users sbu ON sbu.id = s.created_by_user_id
LEFT JOIN service_principals sbp ON sbp.id = s.created_by_principal_id
WHERE g.public_id = ?
ORDER BY s.name`, strings.TrimSpace(groupID))
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil { return nil, err }
item.Tags, err = decodeStringList(tagsJSON)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) CreateSSHServerGroup(ctx context.Context, item models.SSHServerGroup) (models.SSHServerGroup, error) {
var tx txExecutor
var owned bool
var err error
var now int64
var tagsJSON string
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
item.Tags = normalizeStringList(item.Tags)
tagsJSON, err = encodeStringList(item.Tags)
if err != nil { return item, err }
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
tx, owned, err = s.beginContext(ctx)
if err != nil { return item, err }
_, err = tx.Exec(`INSERT INTO ssh_server_groups (public_id, name, description, enabled, tags_json, created_by_kind, created_by_user_id, created_by_principal_id, created_by_subject_name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, tagsJSON, strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt)
if err != nil { rollbackIfOwned(tx, owned); return item, err }
err = commitIfOwned(tx, owned)
if err != nil { return item, err }
return s.GetSSHServerGroup(item.ID)
}
func (s *Store) UpdateSSHServerGroup(ctx context.Context, item models.SSHServerGroup) (models.SSHServerGroup, error) {
var tx txExecutor
var owned bool
var err error
var now int64
var tagsJSON string
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
item.Tags = normalizeStringList(item.Tags)
tagsJSON, err = encodeStringList(item.Tags)
if err != nil { return item, err }
now = time.Now().UTC().Unix()
tx, owned, err = s.beginContext(ctx)
if err != nil { return item, err }
_, err = tx.Exec(`UPDATE ssh_server_groups SET name = ?, description = ?, enabled = ?, tags_json = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, tagsJSON, now, strings.TrimSpace(item.ID))
if err != nil { rollbackIfOwned(tx, owned); return item, err }
err = commitIfOwned(tx, owned)
if err != nil { return item, err }
return s.GetSSHServerGroup(item.ID)
}
func (s *Store) DeleteSSHServerGroup(ctx context.Context, id string) error {
var tx txExecutor
var owned bool
var count int
var err error
id = strings.TrimSpace(id)
tx, owned, err = s.beginImmediateContext(ctx)
if err != nil { return err }
err = tx.QueryRow(`SELECT COUNT(*)
FROM ssh_access_profiles p
JOIN ssh_server_groups g ON g.id = p.server_group_id
WHERE g.public_id = ?`, id).Scan(&count)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
if count > 0 {
rollbackIfOwned(tx, owned)
return errors.New("ssh server group is used by access profiles")
}
_, err = tx.Exec(`DELETE FROM ssh_server_groups WHERE public_id = ?`, id)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
return commitIfOwned(tx, owned)
}
func (s *Store) AddSSHServerGroupMember(groupID string, serverID string) error {
var now int64
var err error
now = time.Now().UTC().Unix()
_, err = s.Exec(`INSERT OR IGNORE INTO ssh_server_group_members (group_id, server_id, created_at)
SELECT g.id, srv.id, ?
FROM ssh_server_groups g
JOIN ssh_servers srv ON srv.public_id = ?
WHERE g.public_id = ?`, now, strings.TrimSpace(serverID), strings.TrimSpace(groupID))
return err
}
func (s *Store) DeleteSSHServerGroupMember(groupID string, serverID string) error {
var err error
_, err = s.Exec(`DELETE FROM ssh_server_group_members
WHERE group_id = (SELECT id FROM ssh_server_groups WHERE public_id = ?)
AND server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID))
return err
}
func (s *Store) SSHServerBelongsToGroup(groupID string, serverID string) (bool, error) {
var row *sql.Row
var count int
var err error
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members
WHERE group_id = (SELECT id FROM ssh_server_groups WHERE public_id = ?)
AND server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID))
err = row.Scan(&count)
if err != nil { return false, err }
return count > 0, nil
}