package db import "database/sql" import "errors" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" func scanSSHServerGroup(row interface{ Scan(dest ...any) error }, item *models.SSHServerGroup) error { return row.Scan(&item.ID, &item.Name, &item.Description, &item.Enabled, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) } func (s *Store) listSSHServerGroupMemberIDs(groupID string) ([]string, error) { var rows *sql.Rows var items []string var id string var err error rows, err = s.Query(`SELECT server_public_id FROM ssh_server_group_members WHERE group_public_id = ? ORDER BY server_public_id`, strings.TrimSpace(groupID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&id) if err != nil { return nil, err } items = append(items, id) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListSSHServerGroups() ([]models.SSHServerGroup, error) { var rows *sql.Rows var items []models.SSHServerGroup var item models.SSHServerGroup var err error rows, err = s.Query(`SELECT public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_server_groups ORDER BY name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = scanSSHServerGroup(rows, &item) if err != nil { return nil, err } item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHServerGroup(id string) (models.SSHServerGroup, error) { var row *sql.Row var item models.SSHServerGroup var err error row = s.QueryRow(`SELECT public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_server_groups WHERE public_id = ?`, strings.TrimSpace(id)) err = scanSSHServerGroup(row, &item) if err != nil { return item, err } item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) ListSSHServersForGroup(groupID string) ([]models.SSHServer, error) { var rows *sql.Rows var items []models.SSHServer var item models.SSHServer var tagsJSON string var err error rows, err = s.Query(`SELECT s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.created_by_kind, s.created_by_subject_id, s.created_by_subject_name, s.created_at, s.updated_at FROM ssh_server_group_members gm JOIN ssh_servers s ON s.public_id = gm.server_public_id WHERE gm.group_public_id = ? ORDER BY s.name`, strings.TrimSpace(groupID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } item.Tags, err = decodeStringList(tagsJSON) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) CreateSSHServerGroup(item models.SSHServerGroup) (models.SSHServerGroup, error) { var tx *sql.Tx var owned bool var err error var now int64 if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now tx, owned, err = s.begin() if err != nil { return item, err } _, err = tx.Exec(`INSERT INTO ssh_server_groups (public_id, name, description, enabled, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt) if err != nil { rollbackIfOwned(tx, owned); return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetSSHServerGroup(item.ID) } func (s *Store) UpdateSSHServerGroup(item models.SSHServerGroup) (models.SSHServerGroup, error) { var tx *sql.Tx var owned bool var err error var now int64 if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } now = time.Now().UTC().Unix() tx, owned, err = s.begin() if err != nil { return item, err } _, err = tx.Exec(`UPDATE ssh_server_groups SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, now, strings.TrimSpace(item.ID)) if err != nil { rollbackIfOwned(tx, owned); return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetSSHServerGroup(item.ID) } func (s *Store) DeleteSSHServerGroup(id string) error { var row *sql.Row var count int var err error id = strings.TrimSpace(id) row = s.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE server_group_public_id = ?`, id) err = row.Scan(&count) if err != nil { return err } if count > 0 { return errors.New("ssh server group is used by access profiles") } _, err = s.Exec(`DELETE FROM ssh_server_groups WHERE public_id = ?`, id) return err } func (s *Store) AddSSHServerGroupMember(groupID string, serverID string) error { var now int64 var err error now = time.Now().UTC().Unix() _, err = s.Exec(`INSERT OR IGNORE INTO ssh_server_group_members (group_public_id, server_public_id, created_at) VALUES (?, ?, ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID), now) return err } func (s *Store) DeleteSSHServerGroupMember(groupID string, serverID string) error { var err error _, err = s.Exec(`DELETE FROM ssh_server_group_members WHERE group_public_id = ? AND server_public_id = ?`, strings.TrimSpace(groupID), strings.TrimSpace(serverID)) return err } func (s *Store) SSHServerBelongsToGroup(groupID string, serverID string) (bool, error) { var row *sql.Row var count int var err error row = s.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members WHERE group_public_id = ? AND server_public_id = ?`, strings.TrimSpace(groupID), strings.TrimSpace(serverID)) err = row.Scan(&count) if err != nil { return false, err } return count > 0, nil }