package db import "context" import "database/sql" import "errors" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" func scanSSHServerGroup(row interface{ Scan(dest ...any) error }, item *models.SSHServerGroup) error { var tagsJSON string var err error err = row.Scan(&item.ID, &item.Name, &item.Description, &item.Enabled, &tagsJSON, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return err } item.Tags, err = decodeStringList(tagsJSON) return err } func (s *Store) listSSHServerGroupMemberIDs(groupID string) ([]string, error) { var rows *sql.Rows var items []string var id string var err error rows, err = s.Query(`SELECT s.public_id FROM ssh_server_group_members gm JOIN ssh_servers s ON s.id = gm.server_id JOIN ssh_server_groups g ON g.id = gm.group_id WHERE g.public_id = ? ORDER BY s.public_id`, strings.TrimSpace(groupID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&id) if err != nil { return nil, err } items = append(items, id) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListSSHServerGroups() ([]models.SSHServerGroup, error) { var rows *sql.Rows var items []models.SSHServerGroup var item models.SSHServerGroup var err error rows, err = s.Query(`SELECT g.public_id, g.name, g.description, g.enabled, g.tags_json, g.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), g.created_by_subject_name, g.created_at, g.updated_at FROM ssh_server_groups g LEFT JOIN users cbu ON cbu.id = g.created_by_user_id LEFT JOIN service_principals cbp ON cbp.id = g.created_by_principal_id ORDER BY g.name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = scanSSHServerGroup(rows, &item) if err != nil { return nil, err } item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHServerGroup(id string) (models.SSHServerGroup, error) { var row *sql.Row var item models.SSHServerGroup var err error row = s.QueryRow(`SELECT g.public_id, g.name, g.description, g.enabled, g.tags_json, g.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), g.created_by_subject_name, g.created_at, g.updated_at FROM ssh_server_groups g LEFT JOIN users cbu ON cbu.id = g.created_by_user_id LEFT JOIN service_principals cbp ON cbp.id = g.created_by_principal_id WHERE g.public_id = ?`, strings.TrimSpace(id)) err = scanSSHServerGroup(row, &item) if err != nil { return item, err } item.ServerIDs, err = s.listSSHServerGroupMemberIDs(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) ListSSHServersForGroup(groupID string) ([]models.SSHServer, error) { var rows *sql.Rows var items []models.SSHServer var item models.SSHServer var tagsJSON string var err error rows, err = s.Query(`SELECT s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.owner_scope, COALESCE(u.public_id, ''), s.created_by_kind, COALESCE(sbu.public_id, sbp.public_id, ''), s.created_by_subject_name, s.created_at, s.updated_at FROM ssh_server_group_members gm JOIN ssh_servers s ON s.id = gm.server_id JOIN ssh_server_groups g ON g.id = gm.group_id LEFT JOIN users u ON u.id = s.owner_user_id LEFT JOIN users sbu ON sbu.id = s.created_by_user_id LEFT JOIN service_principals sbp ON sbp.id = s.created_by_principal_id WHERE g.public_id = ? ORDER BY s.name`, strings.TrimSpace(groupID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } item.Tags, err = decodeStringList(tagsJSON) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) CreateSSHServerGroup(ctx context.Context, item models.SSHServerGroup) (models.SSHServerGroup, error) { var tx txExecutor var owned bool var err error var now int64 var tagsJSON string if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } item.Tags = normalizeStringList(item.Tags) tagsJSON, err = encodeStringList(item.Tags) if err != nil { return item, err } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now tx, owned, err = s.beginContext(ctx) if err != nil { return item, err } _, err = tx.Exec(`INSERT INTO ssh_server_groups (public_id, name, description, enabled, tags_json, created_by_kind, created_by_user_id, created_by_principal_id, created_by_subject_name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, tagsJSON, strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt) if err != nil { rollbackIfOwned(tx, owned); return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetSSHServerGroup(item.ID) } func (s *Store) UpdateSSHServerGroup(ctx context.Context, item models.SSHServerGroup) (models.SSHServerGroup, error) { var tx txExecutor var owned bool var err error var now int64 var tagsJSON string if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } item.Tags = normalizeStringList(item.Tags) tagsJSON, err = encodeStringList(item.Tags) if err != nil { return item, err } now = time.Now().UTC().Unix() tx, owned, err = s.beginContext(ctx) if err != nil { return item, err } _, err = tx.Exec(`UPDATE ssh_server_groups SET name = ?, description = ?, enabled = ?, tags_json = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, tagsJSON, now, strings.TrimSpace(item.ID)) if err != nil { rollbackIfOwned(tx, owned); return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetSSHServerGroup(item.ID) } func (s *Store) DeleteSSHServerGroup(ctx context.Context, id string) error { var tx txExecutor var owned bool var count int var err error id = strings.TrimSpace(id) tx, owned, err = s.beginImmediateContext(ctx) if err != nil { return err } err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles p JOIN ssh_server_groups g ON g.id = p.server_group_id WHERE g.public_id = ?`, id).Scan(&count) if err != nil { rollbackIfOwned(tx, owned) return err } if count > 0 { rollbackIfOwned(tx, owned) return errors.New("ssh server group is used by access profiles") } _, err = tx.Exec(`DELETE FROM ssh_server_groups WHERE public_id = ?`, id) if err != nil { rollbackIfOwned(tx, owned) return err } return commitIfOwned(tx, owned) } func (s *Store) AddSSHServerGroupMember(groupID string, serverID string) error { var now int64 var err error now = time.Now().UTC().Unix() _, err = s.Exec(`INSERT OR IGNORE INTO ssh_server_group_members (group_id, server_id, created_at) SELECT g.id, srv.id, ? FROM ssh_server_groups g JOIN ssh_servers srv ON srv.public_id = ? WHERE g.public_id = ?`, now, strings.TrimSpace(serverID), strings.TrimSpace(groupID)) return err } func (s *Store) DeleteSSHServerGroupMember(groupID string, serverID string) error { var err error _, err = s.Exec(`DELETE FROM ssh_server_group_members WHERE group_id = (SELECT id FROM ssh_server_groups WHERE public_id = ?) AND server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID)) return err } func (s *Store) SSHServerBelongsToGroup(groupID string, serverID string) (bool, error) { var row *sql.Row var count int var err error row = s.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members WHERE group_id = (SELECT id FROM ssh_server_groups WHERE public_id = ?) AND server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, strings.TrimSpace(groupID), strings.TrimSpace(serverID)) err = row.Scan(&count) if err != nil { return false, err } return count > 0, nil }