package db import "database/sql" import "encoding/json" import "errors" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" func (s *Store) ListSSHUserCAs() ([]models.SSHUserCA, error) { var rows *sql.Rows var err error var items []models.SSHUserCA var item models.SSHUserCA rows, err = s.Query(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at FROM ssh_user_cas ORDER BY name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHUserCA(id string) (models.SSHUserCA, error) { var row *sql.Row var item models.SSHUserCA var err error row = s.QueryRow(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id)) err = row.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt) if err != nil { return item, err } return item, nil } func (s *Store) CreateSSHUserCA(item models.SSHUserCA) (models.SSHUserCA, error) { var err error var now int64 if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } if strings.TrimSpace(item.Algorithm) == "" { item.Algorithm = "ssh-ed25519" } if item.SerialCounter == 0 { item.SerialCounter = 1 } if item.MaxUserValidSeconds <= 0 { item.MaxUserValidSeconds = 1800 } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now _, err = s.Exec(`INSERT INTO ssh_user_cas (public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), item.Algorithm, strings.TrimSpace(item.PublicKey), strings.TrimSpace(item.PrivateKeyPEM), strings.TrimSpace(item.Fingerprint), item.SerialCounter, item.Enabled, item.AllowUserSign, item.MaxUserValidSeconds, item.CreatedAt, item.UpdatedAt) if err != nil { return item, err } return item, nil } func (s *Store) UpdateSSHUserCA(item models.SSHUserCA) (models.SSHUserCA, error) { var err error if item.MaxUserValidSeconds <= 0 { item.MaxUserValidSeconds = 1800 } item.UpdatedAt = time.Now().UTC().Unix() _, err = s.Exec(`UPDATE ssh_user_cas SET name = ?, enabled = ?, allow_user_sign = ?, max_user_valid_seconds = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), item.Enabled, item.AllowUserSign, item.MaxUserValidSeconds, item.UpdatedAt, strings.TrimSpace(item.ID)) if err != nil { return item, err } item, err = s.GetSSHUserCA(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) DeleteSSHUserCA(id string) error { var err error _, err = s.Exec(`DELETE FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id)) return err } func (s *Store) ListSSHUserCAsForUser() ([]models.SSHUserCA, error) { var rows *sql.Rows var err error var items []models.SSHUserCA var item models.SSHUserCA rows, err = s.Query(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at FROM ssh_user_cas WHERE enabled = 1 AND allow_user_sign = 1 ORDER BY name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHUserCAForUser(id string) (models.SSHUserCA, error) { var row *sql.Row var item models.SSHUserCA var err error row = s.QueryRow(`SELECT public_id, name, algorithm, public_key, private_key_pem, fingerprint, serial_counter, enabled, allow_user_sign, max_user_valid_seconds, created_at, updated_at FROM ssh_user_cas WHERE public_id = ? AND enabled = 1 AND allow_user_sign = 1`, strings.TrimSpace(id)) err = row.Scan(&item.ID, &item.Name, &item.Algorithm, &item.PublicKey, &item.PrivateKeyPEM, &item.Fingerprint, &item.SerialCounter, &item.Enabled, &item.AllowUserSign, &item.MaxUserValidSeconds, &item.CreatedAt, &item.UpdatedAt) if err != nil { return item, err } return item, nil } func (s *Store) NextSSHUserCASerial(id string) (uint64, error) { var tx *sql.Tx var owned bool var row *sql.Row var serial int64 var err error tx, owned, err = s.begin() if err != nil { return 0, err } row = tx.QueryRow(`SELECT serial_counter FROM ssh_user_cas WHERE public_id = ?`, strings.TrimSpace(id)) err = row.Scan(&serial) if err != nil { rollbackIfOwned(tx, owned) return 0, err } if serial <= 0 { serial = 1 } _, err = tx.Exec(`UPDATE ssh_user_cas SET serial_counter = ?, updated_at = ? WHERE public_id = ?`, serial+1, time.Now().UTC().Unix(), strings.TrimSpace(id)) if err != nil { rollbackIfOwned(tx, owned) return 0, err } err = commitIfOwned(tx, owned) if err != nil { return 0, err } return uint64(serial), nil } func (s *Store) CreateSSHUserCAIssuance(item models.SSHUserCAIssuance) (models.SSHUserCAIssuance, error) { var err error var now int64 var principalsJSON string if strings.TrimSpace(item.CAID) == "" { return item, errors.New("ca id is required") } if item.Serial == 0 { return item, errors.New("serial is required") } if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } if strings.TrimSpace(item.IssuerKind) == "" { item.IssuerKind = "unknown" } principalsJSON, err = encodeSSHUserCAIssuancePrincipals(item.Principals) if err != nil { return item, err } now = time.Now().UTC().Unix() if item.CreatedAt <= 0 { item.CreatedAt = now } _, err = s.Exec(`INSERT INTO ssh_user_ca_issuances ( public_id, ca_public_id, issuer_user_public_id, issuer_username, issuer_kind, source_public_key, source_public_key_fingerprint, certificate, key_id, principals_json, valid_after, valid_before, serial, remote_addr, user_agent, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, strings.TrimSpace(item.ID), strings.TrimSpace(item.CAID), nullableTrimmedString(item.IssuerUserID), strings.TrimSpace(item.IssuerUsername), strings.TrimSpace(item.IssuerKind), strings.TrimSpace(item.SourcePublicKey), strings.TrimSpace(item.SourcePublicKeyFingerprint), strings.TrimSpace(item.Certificate), strings.TrimSpace(item.KeyID), principalsJSON, item.ValidAfter, item.ValidBefore, int64(item.Serial), strings.TrimSpace(item.RemoteAddr), strings.TrimSpace(item.UserAgent), item.CreatedAt, ) if err != nil { return item, err } return item, nil } func (s *Store) ListSSHUserCAIssuancesByCA(caID string, limit int) ([]models.SSHUserCAIssuance, error) { var rows *sql.Rows var err error var items []models.SSHUserCAIssuance var item models.SSHUserCAIssuance var principalsJSON string limit = normalizeSSHUserCAIssuanceLimit(limit) rows, err = s.Query(`SELECT public_id, ca_public_id, COALESCE(issuer_user_public_id, ''), issuer_username, issuer_kind, source_public_key, source_public_key_fingerprint, certificate, key_id, principals_json, valid_after, valid_before, serial, remote_addr, user_agent, created_at FROM ssh_user_ca_issuances WHERE ca_public_id = ? ORDER BY created_at DESC LIMIT ?`, strings.TrimSpace(caID), limit, ) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.CAID, &item.IssuerUserID, &item.IssuerUsername, &item.IssuerKind, &item.SourcePublicKey, &item.SourcePublicKeyFingerprint, &item.Certificate, &item.KeyID, &principalsJSON, &item.ValidAfter, &item.ValidBefore, &item.Serial, &item.RemoteAddr, &item.UserAgent, &item.CreatedAt, ) if err != nil { return nil, err } item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListSSHUserCAIssuances(limit int, caID string) ([]models.SSHUserCAIssuance, error) { var rows *sql.Rows var err error var items []models.SSHUserCAIssuance var item models.SSHUserCAIssuance var principalsJSON string limit = normalizeSSHUserCAIssuanceLimit(limit) caID = strings.TrimSpace(caID) if caID == "" { rows, err = s.Query(`SELECT public_id, ca_public_id, COALESCE(issuer_user_public_id, ''), issuer_username, issuer_kind, source_public_key, source_public_key_fingerprint, certificate, key_id, principals_json, valid_after, valid_before, serial, remote_addr, user_agent, created_at FROM ssh_user_ca_issuances ORDER BY created_at DESC LIMIT ?`, limit, ) } else { rows, err = s.Query(`SELECT public_id, ca_public_id, COALESCE(issuer_user_public_id, ''), issuer_username, issuer_kind, source_public_key, source_public_key_fingerprint, certificate, key_id, principals_json, valid_after, valid_before, serial, remote_addr, user_agent, created_at FROM ssh_user_ca_issuances WHERE ca_public_id = ? ORDER BY created_at DESC LIMIT ?`, caID, limit, ) } if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.CAID, &item.IssuerUserID, &item.IssuerUsername, &item.IssuerKind, &item.SourcePublicKey, &item.SourcePublicKeyFingerprint, &item.Certificate, &item.KeyID, &principalsJSON, &item.ValidAfter, &item.ValidBefore, &item.Serial, &item.RemoteAddr, &item.UserAgent, &item.CreatedAt, ) if err != nil { return nil, err } item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListSSHUserCAIssuancesForSelf(userID string, limit int) ([]models.SSHUserCAIssuance, error) { var rows *sql.Rows var err error var items []models.SSHUserCAIssuance var item models.SSHUserCAIssuance var principalsJSON string limit = normalizeSSHUserCAIssuanceLimit(limit) rows, err = s.Query(`SELECT public_id, ca_public_id, COALESCE(issuer_user_public_id, ''), issuer_username, issuer_kind, source_public_key, source_public_key_fingerprint, certificate, key_id, principals_json, valid_after, valid_before, serial, remote_addr, user_agent, created_at FROM ssh_user_ca_issuances WHERE issuer_user_public_id = ? AND issuer_kind = 'self' ORDER BY created_at DESC LIMIT ?`, strings.TrimSpace(userID), limit, ) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.CAID, &item.IssuerUserID, &item.IssuerUsername, &item.IssuerKind, &item.SourcePublicKey, &item.SourcePublicKeyFingerprint, &item.Certificate, &item.KeyID, &principalsJSON, &item.ValidAfter, &item.ValidBefore, &item.Serial, &item.RemoteAddr, &item.UserAgent, &item.CreatedAt, ) if err != nil { return nil, err } item.Principals, err = decodeSSHUserCAIssuancePrincipals(principalsJSON) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func nullableTrimmedString(value string) sql.NullString { var out sql.NullString value = strings.TrimSpace(value) if value == "" { out.Valid = false out.String = "" return out } out.Valid = true out.String = value return out } func encodeSSHUserCAIssuancePrincipals(principals []string) (string, error) { var normalized []string var i int var value string var raw []byte var err error for i = 0; i < len(principals); i++ { value = strings.TrimSpace(principals[i]) if value == "" { continue } normalized = append(normalized, value) } raw, err = json.Marshal(normalized) if err != nil { return "", err } return string(raw), nil } func decodeSSHUserCAIssuancePrincipals(raw string) ([]string, error) { var values []string var i int var value string var out []string var err error raw = strings.TrimSpace(raw) if raw == "" { return nil, nil } err = json.Unmarshal([]byte(raw), &values) if err != nil { return nil, err } for i = 0; i < len(values); i++ { value = strings.TrimSpace(values[i]) if value == "" { continue } out = append(out, value) } return out, nil } func normalizeSSHUserCAIssuanceLimit(limit int) int { var out int out = limit if out <= 0 { out = 50 } if out > 500 { out = 500 } return out }