188 lines
6.8 KiB
Go
188 lines
6.8 KiB
Go
package db
|
|
|
|
import "database/sql"
|
|
import "errors"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
func scanSSHCredential(row interface{ Scan(dest ...any) error }, item *models.SSHCredential) error {
|
|
return row.Scan(&item.ID, &item.Name, &item.Description, &item.Type, &item.SecretID, &item.PublicKey, &item.Fingerprint, &item.Enabled, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
|
|
}
|
|
|
|
func normalizeSSHCredentialOwner(item models.SSHCredential) models.SSHCredential {
|
|
item.OwnerScope = strings.TrimSpace(item.OwnerScope)
|
|
item.OwnerUserID = strings.TrimSpace(item.OwnerUserID)
|
|
if item.OwnerScope == "" {
|
|
item.OwnerScope = "admin"
|
|
}
|
|
if item.OwnerScope != "user" {
|
|
item.OwnerScope = "admin"
|
|
item.OwnerUserID = ""
|
|
}
|
|
return item
|
|
}
|
|
|
|
func (s *Store) ListSSHCredentials() ([]models.SSHCredential, error) {
|
|
var rows *sql.Rows
|
|
var items []models.SSHCredential
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
|
|
FROM ssh_credentials
|
|
WHERE owner_scope = 'admin'
|
|
ORDER BY name`)
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanSSHCredential(rows, &item)
|
|
if err != nil { return nil, err }
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHCredentialsForUser(userID string) ([]models.SSHCredential, error) {
|
|
var rows *sql.Rows
|
|
var items []models.SSHCredential
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
rows, err = s.Query(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
|
|
FROM ssh_credentials
|
|
WHERE owner_scope = 'user' AND owner_user_public_id = ?
|
|
ORDER BY name`, strings.TrimSpace(userID))
|
|
if err != nil { return nil, err }
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = scanSSHCredential(rows, &item)
|
|
if err != nil { return nil, err }
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil { return nil, err }
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetSSHCredential(id string) (models.SSHCredential, error) {
|
|
var row *sql.Row
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
|
|
FROM ssh_credentials
|
|
WHERE public_id = ?`, strings.TrimSpace(id))
|
|
err = scanSSHCredential(row, &item)
|
|
return item, err
|
|
}
|
|
|
|
func (s *Store) GetSSHCredentialForUser(userID string, id string) (models.SSHCredential, error) {
|
|
var row *sql.Row
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
row = s.QueryRow(`SELECT public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at
|
|
FROM ssh_credentials
|
|
WHERE public_id = ? AND owner_scope = 'user' AND owner_user_public_id = ?`, strings.TrimSpace(id), strings.TrimSpace(userID))
|
|
err = scanSSHCredential(row, &item)
|
|
return item, err
|
|
}
|
|
|
|
func (s *Store) CreateSSHCredential(item models.SSHCredential, secret models.SSHSecret) (models.SSHCredential, error) {
|
|
var tx *sql.Tx
|
|
var owned bool
|
|
var err error
|
|
var now int64
|
|
|
|
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
|
|
if strings.TrimSpace(item.Type) == "" { item.Type = "private_key" }
|
|
if strings.TrimSpace(item.Type) != "private_key" { return item, errors.New("unsupported credential type") }
|
|
item = normalizeSSHCredentialOwner(item)
|
|
if item.OwnerScope == "user" && item.OwnerUserID == "" { return item, errors.New("owner_user_id is required") }
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
item.ID, err = util.NewID()
|
|
if err != nil { return item, err }
|
|
}
|
|
tx, owned, err = s.begin()
|
|
if err != nil { return item, err }
|
|
secret.Kind = "private_key"
|
|
secret.CreatedByKind = item.CreatedByKind
|
|
secret.CreatedBySubjectID = item.CreatedBySubjectID
|
|
secret.CreatedBySubjectName = item.CreatedBySubjectName
|
|
secret, err = createSSHSecretTx(tx, secret)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
item.SecretID = secret.ID
|
|
_, err = tx.Exec(`INSERT INTO ssh_credentials (public_id, name, description, type, secret_public_id, public_key, fingerprint, enabled, owner_scope, owner_user_public_id, created_by_kind, created_by_subject_id, created_by_subject_name, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
item.ID,
|
|
strings.TrimSpace(item.Name),
|
|
strings.TrimSpace(item.Description),
|
|
strings.TrimSpace(item.Type),
|
|
strings.TrimSpace(item.SecretID),
|
|
strings.TrimSpace(item.PublicKey),
|
|
strings.TrimSpace(item.Fingerprint),
|
|
item.Enabled,
|
|
strings.TrimSpace(item.OwnerScope),
|
|
strings.TrimSpace(item.OwnerUserID),
|
|
strings.TrimSpace(item.CreatedByKind),
|
|
strings.TrimSpace(item.CreatedBySubjectID),
|
|
strings.TrimSpace(item.CreatedBySubjectName),
|
|
item.CreatedAt,
|
|
item.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil { return item, err }
|
|
return s.GetSSHCredential(item.ID)
|
|
}
|
|
|
|
func (s *Store) UpdateSSHCredential(item models.SSHCredential) (models.SSHCredential, error) {
|
|
var err error
|
|
var now int64
|
|
|
|
if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") }
|
|
now = time.Now().UTC().Unix()
|
|
_, err = s.Exec(`UPDATE ssh_credentials SET name = ?, description = ?, enabled = ?, updated_at = ? WHERE public_id = ?`,
|
|
strings.TrimSpace(item.Name),
|
|
strings.TrimSpace(item.Description),
|
|
item.Enabled,
|
|
now,
|
|
strings.TrimSpace(item.ID),
|
|
)
|
|
if err != nil { return item, err }
|
|
return s.GetSSHCredential(item.ID)
|
|
}
|
|
|
|
func (s *Store) DeleteSSHCredential(id string) error {
|
|
var row *sql.Row
|
|
var count int
|
|
var item models.SSHCredential
|
|
var err error
|
|
|
|
id = strings.TrimSpace(id)
|
|
row = s.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE ssh_credential_public_id = ?`, id)
|
|
err = row.Scan(&count)
|
|
if err != nil { return err }
|
|
if count > 0 { return errors.New("ssh credential is used by access profiles") }
|
|
item, err = s.GetSSHCredential(id)
|
|
if err != nil { return err }
|
|
_, err = s.Exec(`DELETE FROM ssh_credentials WHERE public_id = ?`, id)
|
|
if err != nil { return err }
|
|
_, err = s.Exec(`DELETE FROM ssh_secrets WHERE public_id = ?`, strings.TrimSpace(item.SecretID))
|
|
return err
|
|
}
|