174 lines
6.6 KiB
Go
174 lines
6.6 KiB
Go
package db
|
|
|
|
import "database/sql"
|
|
import "errors"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
func scanSSHServerGroup(row interface{ Scan(dest ...any) error }, item *models.SSHServerGroup) error {
|
|
return row.Scan(&item.ID, &item.Name, &item.Description, &item.Enabled, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
|
|
}
|
|
|
|
func (s *Store) listSSHServerGroupMemberIDs(groupID string) ([]string, error) {
|
|
var rows *sql.Rows
|
|
var items []string
|
|
var id string
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT server_public_id FROM ssh_server_group_members WHERE group_public_id = ? ORDER BY server_public_id`, strings.TrimSpace(groupID))
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&id)
|
|
if err != nil { return nil, err }
|
|
items = append(items, id)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHServerGroups() ([]models.SSHServerGroup, error) {
|
|
var rows *sql.Rows
|
|
var items []models.SSHServerGroup
|
|
var item models.SSHServerGroup
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_server_groups ORDER BY name`)
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanSSHServerGroup(rows, &item)
|
|
if err != nil { return nil, err }
|
|
item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID)
|
|
if err != nil { return nil, err }
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetSSHServerGroup(id string) (models.SSHServerGroup, error) {
|
|
var row *sql.Row
|
|
var item models.SSHServerGroup
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_server_groups WHERE public_id = ?`, strings.TrimSpace(id))
|
|
err = scanSSHServerGroup(row, &item)
|
|
if err != nil { return item, err }
|
|
item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID)
|
|
if err != nil { return item, err }
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHServersForGroup(groupID string) ([]models.SSHServer, error) {
|
|
var rows *sql.Rows
|
|
var items []models.SSHServer
|
|
var item models.SSHServer
|
|
var tagsJSON string
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.created_by_kind, s.created_by_subject_id, s.created_by_subject_name, s.created_at, s.updated_at
|
|
FROM ssh_server_group_members gm
|
|
JOIN ssh_servers s ON s.public_id = gm.server_public_id
|
|
WHERE gm.group_public_id = ?
|
|
ORDER BY s.name`, strings.TrimSpace(groupID))
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil { return nil, err }
|
|
item.Tags, err = decodeStringList(tagsJSON)
|
|
if err != nil { return nil, err }
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) CreateSSHServerGroup(item models.SSHServerGroup) (models.SSHServerGroup, error) {
|
|
var tx *sql.Tx
|
|
var owned bool
|
|
var err error
|
|
var now int64
|
|
|
|
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
item.ID, err = util.NewID()
|
|
if err != nil { return item, err }
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
tx, owned, err = s.begin()
|
|
if err != nil { return item, err }
|
|
_, err = tx.Exec(`INSERT INTO ssh_server_groups (public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt)
|
|
if err != nil { rollbackIfOwned(tx, owned); return item, err }
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil { return item, err }
|
|
return s.GetSSHServerGroup(item.ID)
|
|
}
|
|
|
|
func (s *Store) UpdateSSHServerGroup(item models.SSHServerGroup) (models.SSHServerGroup, error) {
|
|
var tx *sql.Tx
|
|
var owned bool
|
|
var err error
|
|
var now int64
|
|
|
|
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
|
|
now = time.Now().UTC().Unix()
|
|
tx, owned, err = s.begin()
|
|
if err != nil { return item, err }
|
|
_, err = tx.Exec(`UPDATE ssh_server_groups SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, now, strings.TrimSpace(item.ID))
|
|
if err != nil { rollbackIfOwned(tx, owned); return item, err }
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil { return item, err }
|
|
return s.GetSSHServerGroup(item.ID)
|
|
}
|
|
|
|
func (s *Store) DeleteSSHServerGroup(id string) error {
|
|
var row *sql.Row
|
|
var count int
|
|
var err error
|
|
|
|
id = strings.TrimSpace(id)
|
|
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE server_group_public_id = ?`, id)
|
|
err = row.Scan(&count)
|
|
if err != nil { return err }
|
|
if count > 0 { return errors.New("ssh server group is used by access profiles") }
|
|
_, err = s.Exec(`DELETE FROM ssh_server_groups WHERE public_id = ?`, id)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) AddSSHServerGroupMember(groupID string, serverID string) error {
|
|
var now int64
|
|
var err error
|
|
|
|
now = time.Now().UTC().Unix()
|
|
_, err = s.Exec(`INSERT OR IGNORE INTO ssh_server_group_members (group_public_id, server_public_id, created_at) VALUES (?, ?, ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID), now)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) DeleteSSHServerGroupMember(groupID string, serverID string) error {
|
|
var err error
|
|
|
|
_, err = s.Exec(`DELETE FROM ssh_server_group_members WHERE group_public_id = ? AND server_public_id = ?`, strings.TrimSpace(groupID), strings.TrimSpace(serverID))
|
|
return err
|
|
}
|
|
|
|
func (s *Store) SSHServerBelongsToGroup(groupID string, serverID string) (bool, error) {
|
|
var row *sql.Row
|
|
var count int
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members WHERE group_public_id = ? AND server_public_id = ?`, strings.TrimSpace(groupID), strings.TrimSpace(serverID))
|
|
err = row.Scan(&count)
|
|
if err != nil { return false, err }
|
|
return count > 0, nil
|
|
}
|