Files
codit/backend/internal/db/ssh_server_groups.go

174 lines
6.6 KiB
Go

package db
import "database/sql"
import "errors"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
func scanSSHServerGroup(row interface{ Scan(dest ...any) error }, item *models.SSHServerGroup) error {
return row.Scan(&item.ID, &item.Name, &item.Description, &item.Enabled, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
}
func (s *Store) listSSHServerGroupMemberIDs(groupID string) ([]string, error) {
var rows *sql.Rows
var items []string
var id string
var err error
rows, err = s.Query(`SELECT server_public_id FROM ssh_server_group_members WHERE group_public_id = ? ORDER BY server_public_id`, strings.TrimSpace(groupID))
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = rows.Scan(&id)
if err != nil { return nil, err }
items = append(items, id)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) ListSSHServerGroups() ([]models.SSHServerGroup, error) {
var rows *sql.Rows
var items []models.SSHServerGroup
var item models.SSHServerGroup
var err error
rows, err = s.Query(`SELECT public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_server_groups ORDER BY name`)
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = scanSSHServerGroup(rows, &item)
if err != nil { return nil, err }
item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) GetSSHServerGroup(id string) (models.SSHServerGroup, error) {
var row *sql.Row
var item models.SSHServerGroup
var err error
row = s.QueryRow(`SELECT public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_server_groups WHERE public_id = ?`, strings.TrimSpace(id))
err = scanSSHServerGroup(row, &item)
if err != nil { return item, err }
item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID)
if err != nil { return item, err }
return item, nil
}
func (s *Store) ListSSHServersForGroup(groupID string) ([]models.SSHServer, error) {
var rows *sql.Rows
var items []models.SSHServer
var item models.SSHServer
var tagsJSON string
var err error
rows, err = s.Query(`SELECT s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.created_by_kind, s.created_by_subject_id, s.created_by_subject_name, s.created_at, s.updated_at
FROM ssh_server_group_members gm
JOIN ssh_servers s ON s.public_id = gm.server_public_id
WHERE gm.group_public_id = ?
ORDER BY s.name`, strings.TrimSpace(groupID))
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil { return nil, err }
item.Tags, err = decodeStringList(tagsJSON)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) CreateSSHServerGroup(item models.SSHServerGroup) (models.SSHServerGroup, error) {
var tx *sql.Tx
var owned bool
var err error
var now int64
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
tx, owned, err = s.begin()
if err != nil { return item, err }
_, err = tx.Exec(`INSERT INTO ssh_server_groups (public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt)
if err != nil { rollbackIfOwned(tx, owned); return item, err }
err = commitIfOwned(tx, owned)
if err != nil { return item, err }
return s.GetSSHServerGroup(item.ID)
}
func (s *Store) UpdateSSHServerGroup(item models.SSHServerGroup) (models.SSHServerGroup, error) {
var tx *sql.Tx
var owned bool
var err error
var now int64
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
now = time.Now().UTC().Unix()
tx, owned, err = s.begin()
if err != nil { return item, err }
_, err = tx.Exec(`UPDATE ssh_server_groups SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, now, strings.TrimSpace(item.ID))
if err != nil { rollbackIfOwned(tx, owned); return item, err }
err = commitIfOwned(tx, owned)
if err != nil { return item, err }
return s.GetSSHServerGroup(item.ID)
}
func (s *Store) DeleteSSHServerGroup(id string) error {
var row *sql.Row
var count int
var err error
id = strings.TrimSpace(id)
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE server_group_public_id = ?`, id)
err = row.Scan(&count)
if err != nil { return err }
if count > 0 { return errors.New("ssh server group is used by access profiles") }
_, err = s.Exec(`DELETE FROM ssh_server_groups WHERE public_id = ?`, id)
return err
}
func (s *Store) AddSSHServerGroupMember(groupID string, serverID string) error {
var now int64
var err error
now = time.Now().UTC().Unix()
_, err = s.Exec(`INSERT OR IGNORE INTO ssh_server_group_members (group_public_id, server_public_id, created_at) VALUES (?, ?, ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID), now)
return err
}
func (s *Store) DeleteSSHServerGroupMember(groupID string, serverID string) error {
var err error
_, err = s.Exec(`DELETE FROM ssh_server_group_members WHERE group_public_id = ? AND server_public_id = ?`, strings.TrimSpace(groupID), strings.TrimSpace(serverID))
return err
}
func (s *Store) SSHServerBelongsToGroup(groupID string, serverID string) (bool, error) {
var row *sql.Row
var count int
var err error
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members WHERE group_public_id = ? AND server_public_id = ?`, strings.TrimSpace(groupID), strings.TrimSpace(serverID))
err = row.Scan(&count)
if err != nil { return false, err }
return count > 0, nil
}