551 lines
14 KiB
Go
551 lines
14 KiB
Go
package db
|
|
|
|
import "database/sql"
|
|
import "encoding/json"
|
|
import "errors"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
func (s *Store) ListSSHUserCAs() ([]models.SSHUserCA, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.SSHUserCA
|
|
var item models.SSHUserCA
|
|
rows, err = s.Query(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
|
|
FROM ssh_user_cas
|
|
ORDER BY name`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetSSHUserCA(id string) (models.SSHUserCA, error) {
|
|
var row *sql.Row
|
|
var item models.SSHUserCA
|
|
var err error
|
|
row = s.QueryRow(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
|
|
FROM ssh_user_cas
|
|
WHERE public_id = ?`, strings.TrimSpace(id))
|
|
err = row.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) CreateSSHUserCA(item models.SSHUserCA) (models.SSHUserCA, error) {
|
|
var err error
|
|
var now int64
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
item.ID, err = util.NewID()
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
}
|
|
if strings.TrimSpace(item.Algorithm) == "" {
|
|
item.Algorithm = "ssh-ed25519"
|
|
}
|
|
if item.SerialCounter == 0 {
|
|
item.SerialCounter = 1
|
|
}
|
|
if item.MaxUserValidSeconds <= 0 {
|
|
item.MaxUserValidSeconds = 1800
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
_, err = s.Exec(`INSERT INTO ssh_user_cas (public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
item.ID, strings.TrimSpace(item.Name), item.Algorithm, strings.TrimSpace(item.PublicKey), strings.TrimSpace(item.PrivateKeyPEM), strings.TrimSpace(item.Fingerprint), item.SerialCounter, item.Enabled, item.AllowUserSign, item.MaxUserValidSeconds, item.CreatedAt, item.UpdatedAt)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) UpdateSSHUserCA(item models.SSHUserCA) (models.SSHUserCA, error) {
|
|
var err error
|
|
if item.MaxUserValidSeconds <= 0 {
|
|
item.MaxUserValidSeconds = 1800
|
|
}
|
|
item.UpdatedAt = time.Now().UTC().Unix()
|
|
_, err = s.Exec(`UPDATE ssh_user_cas
|
|
SET name = ?, enabled = ?, allow_user_sign = ?, max_user_valid_seconds = ?, updated_at = ?
|
|
WHERE public_id = ?`, strings.TrimSpace(item.Name), item.Enabled, item.AllowUserSign, item.MaxUserValidSeconds, item.UpdatedAt, strings.TrimSpace(item.ID))
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item, err = s.GetSSHUserCA(item.ID)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) DeleteSSHUserCA(id string) error {
|
|
var err error
|
|
_, err = s.Exec(`DELETE FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id))
|
|
return err
|
|
}
|
|
|
|
func (s *Store) ListSSHUserCAsForUser() ([]models.SSHUserCA, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.SSHUserCA
|
|
var item models.SSHUserCA
|
|
rows, err = s.Query(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
|
|
FROM ssh_user_cas
|
|
WHERE enabled = 1 AND allow_user_sign = 1
|
|
ORDER BY name`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetSSHUserCAForUser(id string) (models.SSHUserCA, error) {
|
|
var row *sql.Row
|
|
var item models.SSHUserCA
|
|
var err error
|
|
row = s.QueryRow(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at
|
|
FROM ssh_user_cas
|
|
WHERE public_id = ? AND enabled = 1 AND allow_user_sign = 1`, strings.TrimSpace(id))
|
|
err = row.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) NextSSHUserCASerial(id string) (uint64, error) {
|
|
var tx *sql.Tx
|
|
var owned bool
|
|
var row *sql.Row
|
|
var serial int64
|
|
var err error
|
|
tx, owned, err = s.begin()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
row = tx.QueryRow(`SELECT serial_counter FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id))
|
|
err = row.Scan(&serial)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return 0, err
|
|
}
|
|
if serial <= 0 {
|
|
serial = 1
|
|
}
|
|
_, err = tx.Exec(`UPDATE ssh_user_cas SET serial_counter = ?, updated_at = ? WHERE public_id = ?`, serial+1, time.Now().UTC().Unix(), strings.TrimSpace(id))
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return 0, err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return uint64(serial), nil
|
|
}
|
|
|
|
func (s *Store) CreateSSHUserCAIssuance(item models.SSHUserCAIssuance) (models.SSHUserCAIssuance, error) {
|
|
var err error
|
|
var now int64
|
|
var principalsJSON string
|
|
if strings.TrimSpace(item.CAID) == "" {
|
|
return item, errors.New("ca id is required")
|
|
}
|
|
if item.Serial == 0 {
|
|
return item, errors.New("serial is required")
|
|
}
|
|
if strings.TrimSpace(item.ID) == "" {
|
|
item.ID, err = util.NewID()
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
}
|
|
if strings.TrimSpace(item.IssuerKind) == "" {
|
|
item.IssuerKind = "unknown"
|
|
}
|
|
principalsJSON, err = encodeSSHUserCAIssuancePrincipals(item.Principals)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
if item.CreatedAt <= 0 {
|
|
item.CreatedAt = now
|
|
}
|
|
_, err = s.Exec(`INSERT INTO ssh_user_ca_issuances (
|
|
public_id,
|
|
ca_public_id,
|
|
issuer_user_public_id,
|
|
issuer_username,
|
|
issuer_kind,
|
|
source_public_key,
|
|
source_public_key_fingerprint,
|
|
certificate,
|
|
key_id,
|
|
principals_json,
|
|
valid_after,
|
|
valid_before,
|
|
serial,
|
|
remote_addr,
|
|
user_agent,
|
|
created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
strings.TrimSpace(item.ID),
|
|
strings.TrimSpace(item.CAID),
|
|
nullableTrimmedString(item.IssuerUserID),
|
|
strings.TrimSpace(item.IssuerUsername),
|
|
strings.TrimSpace(item.IssuerKind),
|
|
strings.TrimSpace(item.SourcePublicKey),
|
|
strings.TrimSpace(item.SourcePublicKeyFingerprint),
|
|
strings.TrimSpace(item.Certificate),
|
|
strings.TrimSpace(item.KeyID),
|
|
principalsJSON,
|
|
item.ValidAfter,
|
|
item.ValidBefore,
|
|
int64(item.Serial),
|
|
strings.TrimSpace(item.RemoteAddr),
|
|
strings.TrimSpace(item.UserAgent),
|
|
item.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHUserCAIssuancesByCA(caID string, limit int) ([]models.SSHUserCAIssuance, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.SSHUserCAIssuance
|
|
var item models.SSHUserCAIssuance
|
|
var principalsJSON string
|
|
limit = normalizeSSHUserCAIssuanceLimit(limit)
|
|
rows, err = s.Query(`SELECT
|
|
public_id,
|
|
ca_public_id,
|
|
COALESCE(issuer_user_public_id, ''),
|
|
issuer_username,
|
|
issuer_kind,
|
|
source_public_key,
|
|
source_public_key_fingerprint,
|
|
certificate,
|
|
key_id,
|
|
principals_json,
|
|
valid_after,
|
|
valid_before,
|
|
serial,
|
|
remote_addr,
|
|
user_agent,
|
|
created_at
|
|
FROM ssh_user_ca_issuances
|
|
WHERE ca_public_id = ?
|
|
ORDER BY created_at DESC
|
|
LIMIT ?`,
|
|
strings.TrimSpace(caID),
|
|
limit,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(
|
|
&item.ID,
|
|
&item.CAID,
|
|
&item.IssuerUserID,
|
|
&item.IssuerUsername,
|
|
&item.IssuerKind,
|
|
&item.SourcePublicKey,
|
|
&item.SourcePublicKeyFingerprint,
|
|
&item.Certificate,
|
|
&item.KeyID,
|
|
&principalsJSON,
|
|
&item.ValidAfter,
|
|
&item.ValidBefore,
|
|
&item.Serial,
|
|
&item.RemoteAddr,
|
|
&item.UserAgent,
|
|
&item.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHUserCAIssuances(limit int, caID string) ([]models.SSHUserCAIssuance, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.SSHUserCAIssuance
|
|
var item models.SSHUserCAIssuance
|
|
var principalsJSON string
|
|
limit = normalizeSSHUserCAIssuanceLimit(limit)
|
|
caID = strings.TrimSpace(caID)
|
|
if caID == "" {
|
|
rows, err = s.Query(`SELECT
|
|
public_id,
|
|
ca_public_id,
|
|
COALESCE(issuer_user_public_id, ''),
|
|
issuer_username,
|
|
issuer_kind,
|
|
source_public_key,
|
|
source_public_key_fingerprint,
|
|
certificate,
|
|
key_id,
|
|
principals_json,
|
|
valid_after,
|
|
valid_before,
|
|
serial,
|
|
remote_addr,
|
|
user_agent,
|
|
created_at
|
|
FROM ssh_user_ca_issuances
|
|
ORDER BY created_at DESC
|
|
LIMIT ?`,
|
|
limit,
|
|
)
|
|
} else {
|
|
rows, err = s.Query(`SELECT
|
|
public_id,
|
|
ca_public_id,
|
|
COALESCE(issuer_user_public_id, ''),
|
|
issuer_username,
|
|
issuer_kind,
|
|
source_public_key,
|
|
source_public_key_fingerprint,
|
|
certificate,
|
|
key_id,
|
|
principals_json,
|
|
valid_after,
|
|
valid_before,
|
|
serial,
|
|
remote_addr,
|
|
user_agent,
|
|
created_at
|
|
FROM ssh_user_ca_issuances
|
|
WHERE ca_public_id = ?
|
|
ORDER BY created_at DESC
|
|
LIMIT ?`,
|
|
caID,
|
|
limit,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(
|
|
&item.ID,
|
|
&item.CAID,
|
|
&item.IssuerUserID,
|
|
&item.IssuerUsername,
|
|
&item.IssuerKind,
|
|
&item.SourcePublicKey,
|
|
&item.SourcePublicKeyFingerprint,
|
|
&item.Certificate,
|
|
&item.KeyID,
|
|
&principalsJSON,
|
|
&item.ValidAfter,
|
|
&item.ValidBefore,
|
|
&item.Serial,
|
|
&item.RemoteAddr,
|
|
&item.UserAgent,
|
|
&item.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListSSHUserCAIssuancesForSelf(userID string, limit int) ([]models.SSHUserCAIssuance, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.SSHUserCAIssuance
|
|
var item models.SSHUserCAIssuance
|
|
var principalsJSON string
|
|
limit = normalizeSSHUserCAIssuanceLimit(limit)
|
|
rows, err = s.Query(`SELECT
|
|
public_id,
|
|
ca_public_id,
|
|
COALESCE(issuer_user_public_id, ''),
|
|
issuer_username,
|
|
issuer_kind,
|
|
source_public_key,
|
|
source_public_key_fingerprint,
|
|
certificate,
|
|
key_id,
|
|
principals_json,
|
|
valid_after,
|
|
valid_before,
|
|
serial,
|
|
remote_addr,
|
|
user_agent,
|
|
created_at
|
|
FROM ssh_user_ca_issuances
|
|
WHERE issuer_user_public_id = ? AND issuer_kind = 'self'
|
|
ORDER BY created_at DESC
|
|
LIMIT ?`,
|
|
strings.TrimSpace(userID),
|
|
limit,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(
|
|
&item.ID,
|
|
&item.CAID,
|
|
&item.IssuerUserID,
|
|
&item.IssuerUsername,
|
|
&item.IssuerKind,
|
|
&item.SourcePublicKey,
|
|
&item.SourcePublicKeyFingerprint,
|
|
&item.Certificate,
|
|
&item.KeyID,
|
|
&principalsJSON,
|
|
&item.ValidAfter,
|
|
&item.ValidBefore,
|
|
&item.Serial,
|
|
&item.RemoteAddr,
|
|
&item.UserAgent,
|
|
&item.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func nullableTrimmedString(value string) sql.NullString {
|
|
var out sql.NullString
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
out.Valid = false
|
|
out.String = ""
|
|
return out
|
|
}
|
|
out.Valid = true
|
|
out.String = value
|
|
return out
|
|
}
|
|
|
|
func encodeSSHUserCAIssuancePrincipals(principals []string) (string, error) {
|
|
var normalized []string
|
|
var i int
|
|
var value string
|
|
var raw []byte
|
|
var err error
|
|
for i = 0; i < len(principals); i++ {
|
|
value = strings.TrimSpace(principals[i])
|
|
if value == "" {
|
|
continue
|
|
}
|
|
normalized = append(normalized, value)
|
|
}
|
|
raw, err = json.Marshal(normalized)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(raw), nil
|
|
}
|
|
|
|
func decodeSSHUserCAIssuancePrincipals(raw string) ([]string, error) {
|
|
var values []string
|
|
var i int
|
|
var value string
|
|
var out []string
|
|
var err error
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return nil, nil
|
|
}
|
|
err = json.Unmarshal([]byte(raw), &values)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i = 0; i < len(values); i++ {
|
|
value = strings.TrimSpace(values[i])
|
|
if value == "" {
|
|
continue
|
|
}
|
|
out = append(out, value)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func normalizeSSHUserCAIssuanceLimit(limit int) int {
|
|
var out int
|
|
out = limit
|
|
if out <= 0 {
|
|
out = 50
|
|
}
|
|
if out > 500 {
|
|
out = 500
|
|
}
|
|
return out
|
|
}
|