Files
codit/backend/internal/db/ssh_user_ca.go

551 lines
14 KiB
Go

package db
import "database/sql"
import "encoding/json"
import "errors"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
func (s *Store) ListSSHUserCAs() ([]models.SSHUserCA, error) {
var rows *sql.Rows
var err error
var items []models.SSHUserCA
var item models.SSHUserCA
rows, err = s.Query(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
FROM ssh_user_cas
ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetSSHUserCA(id string) (models.SSHUserCA, error) {
var row *sql.Row
var item models.SSHUserCA
var err error
row = s.QueryRow(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
FROM ssh_user_cas
WHERE public_id = ?`, strings.TrimSpace(id))
err = row.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) CreateSSHUserCA(item models.SSHUserCA) (models.SSHUserCA, error) {
var err error
var now int64
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil {
return item, err
}
}
if strings.TrimSpace(item.Algorithm) == "" {
item.Algorithm = "ssh-ed25519"
}
if item.SerialCounter == 0 {
item.SerialCounter = 1
}
if item.MaxUserValidSeconds <= 0 {
item.MaxUserValidSeconds = 1800
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
_, err = s.Exec(`INSERT INTO ssh_user_cas (public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
item.ID, strings.TrimSpace(item.Name), item.Algorithm, strings.TrimSpace(item.PublicKey), strings.TrimSpace(item.PrivateKeyPEM), strings.TrimSpace(item.Fingerprint), item.SerialCounter, item.Enabled, item.AllowUserSign, item.MaxUserValidSeconds, item.CreatedAt, item.UpdatedAt)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) UpdateSSHUserCA(item models.SSHUserCA) (models.SSHUserCA, error) {
var err error
if item.MaxUserValidSeconds <= 0 {
item.MaxUserValidSeconds = 1800
}
item.UpdatedAt = time.Now().UTC().Unix()
_, err = s.Exec(`UPDATE ssh_user_cas
SET name = ?, enabled = ?, allow_user_sign = ?, max_user_valid_seconds = ?, updated_at = ?
WHERE public_id = ?`, strings.TrimSpace(item.Name), item.Enabled, item.AllowUserSign, item.MaxUserValidSeconds, item.UpdatedAt, strings.TrimSpace(item.ID))
if err != nil {
return item, err
}
item, err = s.GetSSHUserCA(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) DeleteSSHUserCA(id string) error {
var err error
_, err = s.Exec(`DELETE FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id))
return err
}
func (s *Store) ListSSHUserCAsForUser() ([]models.SSHUserCA, error) {
var rows *sql.Rows
var err error
var items []models.SSHUserCA
var item models.SSHUserCA
rows, err = s.Query(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
FROM ssh_user_cas
WHERE enabled = 1 AND allow_user_sign = 1
ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetSSHUserCAForUser(id string) (models.SSHUserCA, error) {
var row *sql.Row
var item models.SSHUserCA
var err error
row = s.QueryRow(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
FROM ssh_user_cas
WHERE public_id = ? AND enabled = 1 AND allow_user_sign = 1`, strings.TrimSpace(id))
err = row.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) NextSSHUserCASerial(id string) (uint64, error) {
var tx *sql.Tx
var owned bool
var row *sql.Row
var serial int64
var err error
tx, owned, err = s.begin()
if err != nil {
return 0, err
}
row = tx.QueryRow(`SELECT serial_counter FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id))
err = row.Scan(&serial)
if err != nil {
rollbackIfOwned(tx, owned)
return 0, err
}
if serial <= 0 {
serial = 1
}
_, err = tx.Exec(`UPDATE ssh_user_cas SET serial_counter = ?, updated_at = ? WHERE public_id = ?`, serial+1, time.Now().UTC().Unix(), strings.TrimSpace(id))
if err != nil {
rollbackIfOwned(tx, owned)
return 0, err
}
err = commitIfOwned(tx, owned)
if err != nil {
return 0, err
}
return uint64(serial), nil
}
func (s *Store) CreateSSHUserCAIssuance(item models.SSHUserCAIssuance) (models.SSHUserCAIssuance, error) {
var err error
var now int64
var principalsJSON string
if strings.TrimSpace(item.CAID) == "" {
return item, errors.New("ca id is required")
}
if item.Serial == 0 {
return item, errors.New("serial is required")
}
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil {
return item, err
}
}
if strings.TrimSpace(item.IssuerKind) == "" {
item.IssuerKind = "unknown"
}
principalsJSON, err = encodeSSHUserCAIssuancePrincipals(item.Principals)
if err != nil {
return item, err
}
now = time.Now().UTC().Unix()
if item.CreatedAt <= 0 {
item.CreatedAt = now
}
_, err = s.Exec(`INSERT INTO ssh_user_ca_issuances (
public_id,
ca_public_id,
issuer_user_public_id,
issuer_username,
issuer_kind,
source_public_key,
source_public_key_fingerprint,
certificate,
key_id,
principals_json,
valid_after,
valid_before,
serial,
remote_addr,
user_agent,
created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
strings.TrimSpace(item.ID),
strings.TrimSpace(item.CAID),
nullableTrimmedString(item.IssuerUserID),
strings.TrimSpace(item.IssuerUsername),
strings.TrimSpace(item.IssuerKind),
strings.TrimSpace(item.SourcePublicKey),
strings.TrimSpace(item.SourcePublicKeyFingerprint),
strings.TrimSpace(item.Certificate),
strings.TrimSpace(item.KeyID),
principalsJSON,
item.ValidAfter,
item.ValidBefore,
int64(item.Serial),
strings.TrimSpace(item.RemoteAddr),
strings.TrimSpace(item.UserAgent),
item.CreatedAt,
)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) ListSSHUserCAIssuancesByCA(caID string, limit int) ([]models.SSHUserCAIssuance, error) {
var rows *sql.Rows
var err error
var items []models.SSHUserCAIssuance
var item models.SSHUserCAIssuance
var principalsJSON string
limit = normalizeSSHUserCAIssuanceLimit(limit)
rows, err = s.Query(`SELECT
public_id,
ca_public_id,
COALESCE(issuer_user_public_id, ''),
issuer_username,
issuer_kind,
source_public_key,
source_public_key_fingerprint,
certificate,
key_id,
principals_json,
valid_after,
valid_before,
serial,
remote_addr,
user_agent,
created_at
FROM ssh_user_ca_issuances
WHERE ca_public_id = ?
ORDER BY created_at DESC
LIMIT ?`,
strings.TrimSpace(caID),
limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.CAID,
&item.IssuerUserID,
&item.IssuerUsername,
&item.IssuerKind,
&item.SourcePublicKey,
&item.SourcePublicKeyFingerprint,
&item.Certificate,
&item.KeyID,
&principalsJSON,
&item.ValidAfter,
&item.ValidBefore,
&item.Serial,
&item.RemoteAddr,
&item.UserAgent,
&item.CreatedAt,
)
if err != nil {
return nil, err
}
item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) ListSSHUserCAIssuances(limit int, caID string) ([]models.SSHUserCAIssuance, error) {
var rows *sql.Rows
var err error
var items []models.SSHUserCAIssuance
var item models.SSHUserCAIssuance
var principalsJSON string
limit = normalizeSSHUserCAIssuanceLimit(limit)
caID = strings.TrimSpace(caID)
if caID == "" {
rows, err = s.Query(`SELECT
public_id,
ca_public_id,
COALESCE(issuer_user_public_id, ''),
issuer_username,
issuer_kind,
source_public_key,
source_public_key_fingerprint,
certificate,
key_id,
principals_json,
valid_after,
valid_before,
serial,
remote_addr,
user_agent,
created_at
FROM ssh_user_ca_issuances
ORDER BY created_at DESC
LIMIT ?`,
limit,
)
} else {
rows, err = s.Query(`SELECT
public_id,
ca_public_id,
COALESCE(issuer_user_public_id, ''),
issuer_username,
issuer_kind,
source_public_key,
source_public_key_fingerprint,
certificate,
key_id,
principals_json,
valid_after,
valid_before,
serial,
remote_addr,
user_agent,
created_at
FROM ssh_user_ca_issuances
WHERE ca_public_id = ?
ORDER BY created_at DESC
LIMIT ?`,
caID,
limit,
)
}
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.CAID,
&item.IssuerUserID,
&item.IssuerUsername,
&item.IssuerKind,
&item.SourcePublicKey,
&item.SourcePublicKeyFingerprint,
&item.Certificate,
&item.KeyID,
&principalsJSON,
&item.ValidAfter,
&item.ValidBefore,
&item.Serial,
&item.RemoteAddr,
&item.UserAgent,
&item.CreatedAt,
)
if err != nil {
return nil, err
}
item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) ListSSHUserCAIssuancesForSelf(userID string, limit int) ([]models.SSHUserCAIssuance, error) {
var rows *sql.Rows
var err error
var items []models.SSHUserCAIssuance
var item models.SSHUserCAIssuance
var principalsJSON string
limit = normalizeSSHUserCAIssuanceLimit(limit)
rows, err = s.Query(`SELECT
public_id,
ca_public_id,
COALESCE(issuer_user_public_id, ''),
issuer_username,
issuer_kind,
source_public_key,
source_public_key_fingerprint,
certificate,
key_id,
principals_json,
valid_after,
valid_before,
serial,
remote_addr,
user_agent,
created_at
FROM ssh_user_ca_issuances
WHERE issuer_user_public_id = ? AND issuer_kind = 'self'
ORDER BY created_at DESC
LIMIT ?`,
strings.TrimSpace(userID),
limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.CAID,
&item.IssuerUserID,
&item.IssuerUsername,
&item.IssuerKind,
&item.SourcePublicKey,
&item.SourcePublicKeyFingerprint,
&item.Certificate,
&item.KeyID,
&principalsJSON,
&item.ValidAfter,
&item.ValidBefore,
&item.Serial,
&item.RemoteAddr,
&item.UserAgent,
&item.CreatedAt,
)
if err != nil {
return nil, err
}
item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func nullableTrimmedString(value string) sql.NullString {
var out sql.NullString
value = strings.TrimSpace(value)
if value == "" {
out.Valid = false
out.String = ""
return out
}
out.Valid = true
out.String = value
return out
}
func encodeSSHUserCAIssuancePrincipals(principals []string) (string, error) {
var normalized []string
var i int
var value string
var raw []byte
var err error
for i = 0; i < len(principals); i++ {
value = strings.TrimSpace(principals[i])
if value == "" {
continue
}
normalized = append(normalized, value)
}
raw, err = json.Marshal(normalized)
if err != nil {
return "", err
}
return string(raw), nil
}
func decodeSSHUserCAIssuancePrincipals(raw string) ([]string, error) {
var values []string
var i int
var value string
var out []string
var err error
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
err = json.Unmarshal([]byte(raw), &values)
if err != nil {
return nil, err
}
for i = 0; i < len(values); i++ {
value = strings.TrimSpace(values[i])
if value == "" {
continue
}
out = append(out, value)
}
return out, nil
}
func normalizeSSHUserCAIssuanceLimit(limit int) int {
var out int
out = limit
if out <= 0 {
out = 50
}
if out > 500 {
out = 500
}
return out
}