233 lines
8.9 KiB
Go
233 lines
8.9 KiB
Go
package db
|
|
|
|
import "context"
|
|
import "database/sql"
|
|
import "errors"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
func scanSSHCredential(row interface{ Scan(dest ...any) error }, item *models.SSHCredential) error {
|
|
return row.Scan(&item.ID, &item.Name, &item.Description, &item.Type, &item.SecretID, &item.PublicKey, &item.Fingerprint, &item.Enabled, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
|
|
}
|
|
|
|
func normalizeSSHCredentialOwner(item models.SSHCredential) models.SSHCredential {
|
|
item.OwnerScope = strings.TrimSpace(item.OwnerScope)
|
|
item.OwnerUserID = strings.TrimSpace(item.OwnerUserID)
|
|
if item.OwnerScope == "" {
|
|
item.OwnerScope = SSHOwnerScopeAdmin
|
|
}
|
|
if item.OwnerScope != SSHOwnerScopeUser {
|
|
item.OwnerScope = SSHOwnerScopeAdmin
|
|
item.OwnerUserID = ""
|
|
}
|
|
return item
|
|
}
|
|
|
|
func (s *Store) ListSSHCredentials() ([]models.SSHCredential, error) {
|
|
var rows *sql.Rows
|
|
var items []models.SSHCredential
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT c.public_id, c.name, c.description, c.type, COALESCE(sec.public_id, ''), c.public_key, c.fingerprint, c.enabled, c.owner_scope, COALESCE(u.public_id, ''), c.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), c.created_by_subject_name, c.created_at, c.updated_at
|
|
FROM ssh_credentials c
|
|
LEFT JOIN ssh_secrets sec ON sec.id = c.secret_id
|
|
LEFT JOIN users u ON u.id = c.owner_user_id
|
|
LEFT JOIN users cbu ON cbu.id = c.created_by_user_id
|
|
LEFT JOIN service_principals cbp ON cbp.id = c.created_by_principal_id
|
|
WHERE c.owner_scope = ?
|
|
ORDER BY c.name`, SSHOwnerScopeAdmin)
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanSSHCredential(rows, &item)
|
|
if err != nil { return nil, err }
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHCredentialsForUser(userID string) ([]models.SSHCredential, error) {
|
|
var rows *sql.Rows
|
|
var items []models.SSHCredential
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT c.public_id, c.name, c.description, c.type, COALESCE(sec.public_id, ''), c.public_key, c.fingerprint, c.enabled, c.owner_scope, COALESCE(u.public_id, ''), c.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), c.created_by_subject_name, c.created_at, c.updated_at
|
|
FROM ssh_credentials c
|
|
LEFT JOIN ssh_secrets sec ON sec.id = c.secret_id
|
|
LEFT JOIN users u ON u.id = c.owner_user_id
|
|
LEFT JOIN users cbu ON cbu.id = c.created_by_user_id
|
|
LEFT JOIN service_principals cbp ON cbp.id = c.created_by_principal_id
|
|
WHERE c.owner_scope = ? AND u.public_id = ?
|
|
ORDER BY c.name`, SSHOwnerScopeUser, strings.TrimSpace(userID))
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanSSHCredential(rows, &item)
|
|
if err != nil { return nil, err }
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetSSHCredential(id string) (models.SSHCredential, error) {
|
|
var row *sql.Row
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT c.public_id, c.name, c.description, c.type, COALESCE(sec.public_id, ''), c.public_key, c.fingerprint, c.enabled, c.owner_scope, COALESCE(u.public_id, ''), c.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), c.created_by_subject_name, c.created_at, c.updated_at
|
|
FROM ssh_credentials c
|
|
LEFT JOIN ssh_secrets sec ON sec.id = c.secret_id
|
|
LEFT JOIN users u ON u.id = c.owner_user_id
|
|
LEFT JOIN users cbu ON cbu.id = c.created_by_user_id
|
|
LEFT JOIN service_principals cbp ON cbp.id = c.created_by_principal_id
|
|
WHERE c.public_id = ?`, strings.TrimSpace(id))
|
|
err = scanSSHCredential(row, &item)
|
|
return item, err
|
|
}
|
|
|
|
func (s *Store) GetSSHCredentialForUser(userID string, id string) (models.SSHCredential, error) {
|
|
var row *sql.Row
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT c.public_id, c.name, c.description, c.type, COALESCE(sec.public_id, ''), c.public_key, c.fingerprint, c.enabled, c.owner_scope, COALESCE(u.public_id, ''), c.created_by_kind, COALESCE(cbu.public_id, cbp.public_id, ''), c.created_by_subject_name, c.created_at, c.updated_at
|
|
FROM ssh_credentials c
|
|
LEFT JOIN ssh_secrets sec ON sec.id = c.secret_id
|
|
LEFT JOIN users u ON u.id = c.owner_user_id
|
|
LEFT JOIN users cbu ON cbu.id = c.created_by_user_id
|
|
LEFT JOIN service_principals cbp ON cbp.id = c.created_by_principal_id
|
|
WHERE c.public_id = ? AND c.owner_scope = ? AND u.public_id = ?`, strings.TrimSpace(id), SSHOwnerScopeUser, strings.TrimSpace(userID))
|
|
err = scanSSHCredential(row, &item)
|
|
return item, err
|
|
}
|
|
|
|
func (s *Store) CreateSSHCredential(ctx context.Context, item models.SSHCredential, secret models.SSHSecret) (models.SSHCredential, error) {
|
|
var tx txExecutor
|
|
var owned bool
|
|
var err error
|
|
var now int64
|
|
|
|
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
|
|
if strings.TrimSpace(item.Type) == "" { item.Type = "private_key" }
|
|
if strings.TrimSpace(item.Type) != "private_key" { return item, errors.New("unsupported credential type") }
|
|
item = normalizeSSHCredentialOwner(item)
|
|
if item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == "" { return item, errors.New("owner_user_id is required") }
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
item.ID, err = util.NewID()
|
|
if err != nil { return item, err }
|
|
}
|
|
tx, owned, err = s.beginContext(ctx)
|
|
if err != nil { return item, err }
|
|
secret.Kind = "private_key"
|
|
secret.CreatedByKind = item.CreatedByKind
|
|
secret.CreatedBySubjectID = item.CreatedBySubjectID
|
|
secret.CreatedBySubjectName = item.CreatedBySubjectName
|
|
secret, err = createSSHSecretTx(tx, secret)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
item.SecretID = secret.ID
|
|
_, err = tx.Exec(`INSERT INTO ssh_credentials (public_id, name, description, type, secret_id, public_key, fingerprint, enabled, owner_scope, owner_user_id, created_by_kind, created_by_user_id, created_by_principal_id, created_by_subject_name, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, (SELECT id FROM ssh_secrets WHERE public_id = ?), ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ?), ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`,
|
|
item.ID,
|
|
strings.TrimSpace(item.Name),
|
|
strings.TrimSpace(item.Description),
|
|
strings.TrimSpace(item.Type),
|
|
strings.TrimSpace(item.SecretID),
|
|
strings.TrimSpace(item.PublicKey),
|
|
strings.TrimSpace(item.Fingerprint),
|
|
item.Enabled,
|
|
strings.TrimSpace(item.OwnerScope),
|
|
strings.TrimSpace(item.OwnerUserID),
|
|
strings.TrimSpace(item.CreatedByKind),
|
|
strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind),
|
|
strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind),
|
|
strings.TrimSpace(item.CreatedBySubjectName),
|
|
item.CreatedAt,
|
|
item.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil { return item, err }
|
|
return s.GetSSHCredential(item.ID)
|
|
}
|
|
|
|
func (s *Store) UpdateSSHCredential(item models.SSHCredential) (models.SSHCredential, error) {
|
|
var err error
|
|
var now int64
|
|
|
|
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
|
|
now = time.Now().UTC().Unix()
|
|
_, err = s.Exec(`UPDATE ssh_credentials SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`,
|
|
strings.TrimSpace(item.Name),
|
|
strings.TrimSpace(item.Description),
|
|
item.Enabled,
|
|
now,
|
|
strings.TrimSpace(item.ID),
|
|
)
|
|
if err != nil { return item, err }
|
|
return s.GetSSHCredential(item.ID)
|
|
}
|
|
|
|
func (s *Store) DeleteSSHCredential(ctx context.Context, id string) error {
|
|
var count int
|
|
var tx txExecutor
|
|
var owned bool
|
|
var secretID string
|
|
var err error
|
|
|
|
id = strings.TrimSpace(id)
|
|
tx, owned, err = s.beginImmediateContext(ctx)
|
|
if err != nil { return err }
|
|
|
|
err = tx.QueryRow(`SELECT COUNT(*)
|
|
FROM ssh_access_profiles p
|
|
JOIN ssh_credentials c ON c.id = p.ssh_credential_id
|
|
WHERE c.public_id = ?`, id).Scan(&count)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
rollbackIfOwned(tx, owned)
|
|
return errors.New("ssh credential is used by access profiles")
|
|
}
|
|
err = tx.QueryRow(`SELECT COALESCE(sec.public_id, '')
|
|
FROM ssh_credentials c
|
|
LEFT JOIN ssh_secrets sec ON sec.id = c.secret_id
|
|
WHERE c.public_id = ?`, id).Scan(&secretID)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM ssh_credentials WHERE public_id = ?`, id)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
if strings.TrimSpace(secretID) != "" {
|
|
_, err = tx.Exec(`DELETE FROM ssh_secrets WHERE public_id = ?`, strings.TrimSpace(secretID))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
}
|
|
return commitIfOwned(tx, owned)
|
|
}
|