Files
codit/backend/internal/db/ssh_credentials.go

188 lines
6.8 KiB
Go

package db
import "database/sql"
import "errors"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
func scanSSHCredential(row interface{ Scan(dest ...any) error }, item *models.SSHCredential) error {
return row.Scan(&item.ID, &item.Name, &item.Description, &item.Type, &item.SecretID, &item.PublicKey, &item.Fingerprint, &item.Enabled, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
}
func normalizeSSHCredentialOwner(item models.SSHCredential) models.SSHCredential {
item.OwnerScope = strings.TrimSpace(item.OwnerScope)
item.OwnerUserID = strings.TrimSpace(item.OwnerUserID)
if item.OwnerScope == "" {
item.OwnerScope = "admin"
}
if item.OwnerScope != "user" {
item.OwnerScope = "admin"
item.OwnerUserID = ""
}
return item
}
func (s *Store) ListSSHCredentials() ([]models.SSHCredential, error) {
var rows *sql.Rows
var items []models.SSHCredential
var item models.SSHCredential
var err error
rows, err = s.Query(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
FROM ssh_credentials
WHERE owner_scope = 'admin'
ORDER BY name`)
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = scanSSHCredential(rows, &item)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) ListSSHCredentialsForUser(userID string) ([]models.SSHCredential, error) {
var rows *sql.Rows
var items []models.SSHCredential
var item models.SSHCredential
var err error
rows, err = s.Query(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
FROM ssh_credentials
WHERE owner_scope = 'user' AND owner_user_public_id = ?
ORDER BY name`, strings.TrimSpace(userID))
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = scanSSHCredential(rows, &item)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) GetSSHCredential(id string) (models.SSHCredential, error) {
var row *sql.Row
var item models.SSHCredential
var err error
row = s.QueryRow(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
FROM ssh_credentials
WHERE public_id = ?`, strings.TrimSpace(id))
err = scanSSHCredential(row, &item)
return item, err
}
func (s *Store) GetSSHCredentialForUser(userID string, id string) (models.SSHCredential, error) {
var row *sql.Row
var item models.SSHCredential
var err error
row = s.QueryRow(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
FROM ssh_credentials
WHERE public_id = ? AND owner_scope = 'user' AND owner_user_public_id = ?`, strings.TrimSpace(id), strings.TrimSpace(userID))
err = scanSSHCredential(row, &item)
return item, err
}
func (s *Store) CreateSSHCredential(item models.SSHCredential, secret models.SSHSecret) (models.SSHCredential, error) {
var tx *sql.Tx
var owned bool
var err error
var now int64
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
if strings.TrimSpace(item.Type) == "" { item.Type = "private_key" }
if strings.TrimSpace(item.Type) != "private_key" { return item, errors.New("unsupported credential type") }
item = normalizeSSHCredentialOwner(item)
if item.OwnerScope == "user" && item.OwnerUserID == "" { return item, errors.New("owner_user_id is required") }
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
tx, owned, err = s.begin()
if err != nil { return item, err }
secret.Kind = "private_key"
secret.CreatedByKind = item.CreatedByKind
secret.CreatedBySubjectID = item.CreatedBySubjectID
secret.CreatedBySubjectName = item.CreatedBySubjectName
secret, err = createSSHSecretTx(tx, secret)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
item.SecretID = secret.ID
_, err = tx.Exec(`INSERT INTO ssh_credentials (public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Description),
strings.TrimSpace(item.Type),
strings.TrimSpace(item.SecretID),
strings.TrimSpace(item.PublicKey),
strings.TrimSpace(item.Fingerprint),
item.Enabled,
strings.TrimSpace(item.OwnerScope),
strings.TrimSpace(item.OwnerUserID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedBySubjectName),
item.CreatedAt,
item.UpdatedAt,
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = commitIfOwned(tx, owned)
if err != nil { return item, err }
return s.GetSSHCredential(item.ID)
}
func (s *Store) UpdateSSHCredential(item models.SSHCredential) (models.SSHCredential, error) {
var err error
var now int64
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
now = time.Now().UTC().Unix()
_, err = s.Exec(`UPDATE ssh_credentials SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`,
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Description),
item.Enabled,
now,
strings.TrimSpace(item.ID),
)
if err != nil { return item, err }
return s.GetSSHCredential(item.ID)
}
func (s *Store) DeleteSSHCredential(id string) error {
var row *sql.Row
var count int
var item models.SSHCredential
var err error
id = strings.TrimSpace(id)
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE ssh_credential_public_id = ?`, id)
err = row.Scan(&count)
if err != nil { return err }
if count > 0 { return errors.New("ssh credential is used by access profiles") }
item, err = s.GetSSHCredential(id)
if err != nil { return err }
_, err = s.Exec(`DELETE FROM ssh_credentials WHERE public_id = ?`, id)
if err != nil { return err }
_, err = s.Exec(`DELETE FROM ssh_secrets WHERE public_id = ?`, strings.TrimSpace(item.SecretID))
return err
}