package db import "database/sql" import "errors" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" func scanSSHCredential(row interface{ Scan(dest ...any) error }, item *models.SSHCredential) error { return row.Scan(&item.ID, &item.Name, &item.Description, &item.Type, &item.SecretID, &item.PublicKey, &item.Fingerprint, &item.Enabled, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) } func normalizeSSHCredentialOwner(item models.SSHCredential) models.SSHCredential { item.OwnerScope = strings.TrimSpace(item.OwnerScope) item.OwnerUserID = strings.TrimSpace(item.OwnerUserID) if item.OwnerScope == "" { item.OwnerScope = "admin" } if item.OwnerScope != "user" { item.OwnerScope = "admin" item.OwnerUserID = "" } return item } func (s *Store) ListSSHCredentials() ([]models.SSHCredential, error) { var rows *sql.Rows var items []models.SSHCredential var item models.SSHCredential var err error rows, err = s.Query(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_credentials WHERE owner_scope = 'admin' ORDER BY name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = scanSSHCredential(rows, &item) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListSSHCredentialsForUser(userID string) ([]models.SSHCredential, error) { var rows *sql.Rows var items []models.SSHCredential var item models.SSHCredential var err error rows, err = s.Query(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_credentials WHERE owner_scope = 'user' AND owner_user_public_id = ? ORDER BY name`, strings.TrimSpace(userID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = scanSSHCredential(rows, &item) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHCredential(id string) (models.SSHCredential, error) { var row *sql.Row var item models.SSHCredential var err error row = s.QueryRow(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_credentials WHERE public_id = ?`, strings.TrimSpace(id)) err = scanSSHCredential(row, &item) return item, err } func (s *Store) GetSSHCredentialForUser(userID string, id string) (models.SSHCredential, error) { var row *sql.Row var item models.SSHCredential var err error row = s.QueryRow(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at FROM ssh_credentials WHERE public_id = ? AND owner_scope = 'user' AND owner_user_public_id = ?`, strings.TrimSpace(id), strings.TrimSpace(userID)) err = scanSSHCredential(row, &item) return item, err } func (s *Store) CreateSSHCredential(item models.SSHCredential, secret models.SSHSecret) (models.SSHCredential, error) { var tx *sql.Tx var owned bool var err error var now int64 if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if strings.TrimSpace(item.Type) == "" { item.Type = "private_key" } if strings.TrimSpace(item.Type) != "private_key" { return item, errors.New("unsupported credential type") } item = normalizeSSHCredentialOwner(item) if item.OwnerScope == "user" && item.OwnerUserID == "" { return item, errors.New("owner_user_id is required") } if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } tx, owned, err = s.begin() if err != nil { return item, err } secret.Kind = "private_key" secret.CreatedByKind = item.CreatedByKind secret.CreatedBySubjectID = item.CreatedBySubjectID secret.CreatedBySubjectName = item.CreatedBySubjectName secret, err = createSSHSecretTx(tx, secret) if err != nil { rollbackIfOwned(tx, owned) return item, err } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now item.SecretID = secret.ID _, err = tx.Exec(`INSERT INTO ssh_credentials (public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), strings.TrimSpace(item.Type), strings.TrimSpace(item.SecretID), strings.TrimSpace(item.PublicKey), strings.TrimSpace(item.Fingerprint), item.Enabled, strings.TrimSpace(item.OwnerScope), strings.TrimSpace(item.OwnerUserID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt, ) if err != nil { rollbackIfOwned(tx, owned) return item, err } err = commitIfOwned(tx, owned) if err != nil { return item, err } return s.GetSSHCredential(item.ID) } func (s *Store) UpdateSSHCredential(item models.SSHCredential) (models.SSHCredential, error) { var err error var now int64 if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } now = time.Now().UTC().Unix() _, err = s.Exec(`UPDATE ssh_credentials SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), item.Enabled, now, strings.TrimSpace(item.ID), ) if err != nil { return item, err } return s.GetSSHCredential(item.ID) } func (s *Store) DeleteSSHCredential(id string) error { var row *sql.Row var count int var item models.SSHCredential var err error id = strings.TrimSpace(id) row = s.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE ssh_credential_public_id = ?`, id) err = row.Scan(&count) if err != nil { return err } if count > 0 { return errors.New("ssh credential is used by access profiles") } item, err = s.GetSSHCredential(id) if err != nil { return err } _, err = s.Exec(`DELETE FROM ssh_credentials WHERE public_id = ?`, id) if err != nil { return err } _, err = s.Exec(`DELETE FROM ssh_secrets WHERE public_id = ?`, strings.TrimSpace(item.SecretID)) return err }