Files
codit/backend/internal/db/ssh_principal_grant.go

577 lines
14 KiB
Go

package db
import "database/sql"
import "encoding/json"
import "errors"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
func (s *Store) ListSSHPrincipalGrants(targetType string, targetID string) ([]models.SSHPrincipalGrant, error) {
var rows *sql.Rows
var err error
var items []models.SSHPrincipalGrant
var item models.SSHPrincipalGrant
var principalsJSON string
targetType = strings.TrimSpace(targetType)
targetID = strings.TrimSpace(targetID)
if targetType != "" && targetID != "" {
rows, err = s.Query(`SELECT DISTINCT
g.public_id,
g.principal,
g.principals_json,
g.valid_after,
g.valid_before,
g.max_cert_valid_seconds,
g.max_uses,
g.used_count,
g.disabled,
g.reason,
COALESCE(g.created_by_user_public_id, ''),
g.created_at,
g.updated_at,
g.last_used_at
FROM ssh_principal_grants g
JOIN ssh_principal_grant_targets t ON t.grant_public_id = g.public_id
WHERE t.target_type = ? AND t.target_public_id = ?
ORDER BY g.principal, g.created_at`, targetType, targetID)
} else {
rows, err = s.Query(`SELECT
g.public_id,
g.principal,
g.principals_json,
g.valid_after,
g.valid_before,
g.max_cert_valid_seconds,
g.max_uses,
g.used_count,
g.disabled,
g.reason,
COALESCE(g.created_by_user_public_id, ''),
g.created_at,
g.updated_at,
g.last_used_at
FROM ssh_principal_grants g
ORDER BY g.principal, g.created_at`)
}
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.Name,
&principalsJSON,
&item.ValidAfter,
&item.ValidBefore,
&item.MaxCertValidSeconds,
&item.MaxUses,
&item.UsedCount,
&item.Disabled,
&item.Reason,
&item.CreatedByUserID,
&item.CreatedAt,
&item.UpdatedAt,
&item.LastUsedAt,
)
if err != nil {
return nil, err
}
item.Principals, err = decodeSSHPrincipalGrantPrincipals(principalsJSON)
if err != nil {
return nil, err
}
item.Targets, err = s.listSSHPrincipalGrantTargets(item.ID)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) ListActiveSSHPrincipalGrantsForUser(userID string, now int64) ([]models.SSHPrincipalGrant, error) {
var rows *sql.Rows
var err error
var items []models.SSHPrincipalGrant
var item models.SSHPrincipalGrant
var principalsJSON string
rows, err = s.Query(`SELECT DISTINCT
g.public_id,
g.principal,
g.principals_json,
g.valid_after,
g.valid_before,
g.max_cert_valid_seconds,
g.max_uses,
g.used_count,
g.disabled,
g.reason,
COALESCE(g.created_by_user_public_id, ''),
g.created_at,
g.updated_at,
g.last_used_at
FROM ssh_principal_grants g
JOIN ssh_principal_grant_targets t ON t.grant_public_id = g.public_id
LEFT JOIN user_groups ug ON t.target_type = 'group' AND ug.public_id = t.target_public_id
LEFT JOIN user_group_members gm ON t.target_type = 'group' AND gm.group_id = ug.id
LEFT JOIN users gu ON gm.user_id = gu.id
WHERE g.disabled = 0
AND (g.valid_after = 0 OR g.valid_after <= ?)
AND (g.valid_before = 0 OR g.valid_before >= ?)
AND (g.max_uses = 0 OR g.used_count < g.max_uses)
AND (
(t.target_type = 'user' AND t.target_public_id = ?)
OR
(t.target_type = 'group' AND ug.disabled = 0 AND gu.public_id = ?)
)
ORDER BY g.principal, g.created_at`,
now,
now,
strings.TrimSpace(userID),
strings.TrimSpace(userID),
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.Name,
&principalsJSON,
&item.ValidAfter,
&item.ValidBefore,
&item.MaxCertValidSeconds,
&item.MaxUses,
&item.UsedCount,
&item.Disabled,
&item.Reason,
&item.CreatedByUserID,
&item.CreatedAt,
&item.UpdatedAt,
&item.LastUsedAt,
)
if err != nil {
return nil, err
}
item.Principals, err = decodeSSHPrincipalGrantPrincipals(principalsJSON)
if err != nil {
return nil, err
}
item.Targets, err = s.listSSHPrincipalGrantTargets(item.ID)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetSSHPrincipalGrant(id string) (models.SSHPrincipalGrant, error) {
var row *sql.Row
var item models.SSHPrincipalGrant
var principalsJSON string
var err error
row = s.QueryRow(`SELECT
g.public_id,
g.principal,
g.principals_json,
g.valid_after,
g.valid_before,
g.max_cert_valid_seconds,
g.max_uses,
g.used_count,
g.disabled,
g.reason,
COALESCE(g.created_by_user_public_id, ''),
g.created_at,
g.updated_at,
g.last_used_at
FROM ssh_principal_grants g
WHERE g.public_id = ?`, strings.TrimSpace(id))
err = row.Scan(
&item.ID,
&item.Name,
&principalsJSON,
&item.ValidAfter,
&item.ValidBefore,
&item.MaxCertValidSeconds,
&item.MaxUses,
&item.UsedCount,
&item.Disabled,
&item.Reason,
&item.CreatedByUserID,
&item.CreatedAt,
&item.UpdatedAt,
&item.LastUsedAt,
)
if err != nil {
return item, err
}
item.Principals, err = decodeSSHPrincipalGrantPrincipals(principalsJSON)
if err != nil {
return item, err
}
item.Targets, err = s.listSSHPrincipalGrantTargets(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) CreateSSHPrincipalGrant(item models.SSHPrincipalGrant) (models.SSHPrincipalGrant, error) {
var tx *sql.Tx
var owned bool
var err error
var now int64
var i int
var principalsJSON string
var target models.SSHPrincipalGrantTarget
if strings.TrimSpace(item.Name) == "" {
return item, errors.New("name is required")
}
if len(item.Principals) == 0 {
return item, errors.New("at least one principal is required")
}
if len(item.Targets) == 0 {
return item, errors.New("at least one target is required")
}
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil {
return item, err
}
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
item.LastUsedAt = 0
if item.MaxCertValidSeconds < 0 {
item.MaxCertValidSeconds = 0
}
if item.MaxUses < 0 {
item.MaxUses = 0
}
principalsJSON, err = encodeSSHPrincipalGrantPrincipals(item.Principals)
if err != nil {
return item, err
}
tx, owned, err = s.begin()
if err != nil {
return item, err
}
_, err = tx.Exec(`INSERT INTO ssh_principal_grants (
public_id,
principal,
principals_json,
valid_after,
valid_before,
max_cert_valid_seconds,
max_uses,
used_count,
disabled,
reason,
created_by_user_public_id,
created_at,
updated_at,
last_used_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
strings.TrimSpace(item.ID),
strings.TrimSpace(item.Name),
principalsJSON,
item.ValidAfter,
item.ValidBefore,
item.MaxCertValidSeconds,
item.MaxUses,
item.UsedCount,
item.Disabled,
strings.TrimSpace(item.Reason),
nullableTrimmedString(item.CreatedByUserID),
item.CreatedAt,
item.UpdatedAt,
item.LastUsedAt,
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
for i = 0; i < len(item.Targets); i++ {
target = item.Targets[i]
err = insertSSHPrincipalGrantTargetTx(tx, item.ID, target.TargetType, target.TargetID, now)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
}
err = commitIfOwned(tx, owned)
if err != nil {
return item, err
}
item, err = s.GetSSHPrincipalGrant(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) UpdateSSHPrincipalGrant(item models.SSHPrincipalGrant) (models.SSHPrincipalGrant, error) {
var tx *sql.Tx
var owned bool
var err error
var now int64
var i int
var principalsJSON string
var target models.SSHPrincipalGrantTarget
if strings.TrimSpace(item.ID) == "" {
return item, errors.New("id is required")
}
if strings.TrimSpace(item.Name) == "" {
return item, errors.New("name is required")
}
if len(item.Principals) == 0 {
return item, errors.New("at least one principal is required")
}
if len(item.Targets) == 0 {
return item, errors.New("at least one target is required")
}
if item.MaxCertValidSeconds < 0 {
item.MaxCertValidSeconds = 0
}
if item.MaxUses < 0 {
item.MaxUses = 0
}
principalsJSON, err = encodeSSHPrincipalGrantPrincipals(item.Principals)
if err != nil {
return item, err
}
now = time.Now().UTC().Unix()
item.UpdatedAt = now
tx, owned, err = s.begin()
if err != nil {
return item, err
}
_, err = tx.Exec(`UPDATE ssh_principal_grants
SET principal = ?,
principals_json = ?,
valid_after = ?,
valid_before = ?,
max_cert_valid_seconds = ?,
max_uses = ?,
disabled = ?,
reason = ?,
updated_at = ?
WHERE public_id = ?`,
strings.TrimSpace(item.Name),
principalsJSON,
item.ValidAfter,
item.ValidBefore,
item.MaxCertValidSeconds,
item.MaxUses,
item.Disabled,
strings.TrimSpace(item.Reason),
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
_, err = tx.Exec(`DELETE FROM ssh_principal_grant_targets WHERE grant_public_id = ?`, strings.TrimSpace(item.ID))
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
for i = 0; i < len(item.Targets); i++ {
target = item.Targets[i]
err = insertSSHPrincipalGrantTargetTx(tx, item.ID, target.TargetType, target.TargetID, now)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
}
err = commitIfOwned(tx, owned)
if err != nil {
return item, err
}
item, err = s.GetSSHPrincipalGrant(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) DeleteSSHPrincipalGrant(id string) error {
var err error
_, err = s.Exec(`DELETE FROM ssh_principal_grants WHERE public_id = ?`, strings.TrimSpace(id))
return err
}
func (s *Store) MarkSSHPrincipalGrantUsed(id string) error {
var now int64
var result sql.Result
var rowsAffected int64
var err error
now = time.Now().UTC().Unix()
result, err = s.Exec(`UPDATE ssh_principal_grants
SET used_count = used_count + 1,
last_used_at = ?,
updated_at = ?
WHERE public_id = ?
AND disabled = 0
AND (max_uses = 0 OR used_count < max_uses)
AND (valid_after = 0 OR valid_after <= ?)
AND (valid_before = 0 OR valid_before >= ?)`,
now,
now,
strings.TrimSpace(id),
now,
now,
)
if err != nil {
return err
}
rowsAffected, err = result.RowsAffected()
if err != nil {
return err
}
if rowsAffected <= 0 {
return errors.New("grant is not active")
}
return nil
}
func encodeSSHPrincipalGrantPrincipals(items []string) (string, error) {
var data []byte
var err error
data, err = json.Marshal(items)
if err != nil {
return "", err
}
return string(data), nil
}
func decodeSSHPrincipalGrantPrincipals(raw string) ([]string, error) {
var items []string
var err error
raw = strings.TrimSpace(raw)
if raw == "" {
return []string{}, nil
}
err = json.Unmarshal([]byte(raw), &items)
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) listSSHPrincipalGrantTargets(grantID string) ([]models.SSHPrincipalGrantTarget, error) {
var rows *sql.Rows
var err error
var items []models.SSHPrincipalGrantTarget
var item models.SSHPrincipalGrantTarget
rows, err = s.Query(`SELECT
t.grant_public_id,
t.target_type,
t.target_public_id,
CASE
WHEN t.target_type = 'user' THEN COALESCE(u.username, '')
WHEN t.target_type = 'group' THEN COALESCE(g.name, '')
ELSE ''
END AS target_name,
CASE
WHEN t.target_type = 'user' THEN CASE WHEN u.disabled = 0 THEN 1 ELSE 0 END
WHEN t.target_type = 'group' THEN CASE WHEN g.disabled = 0 THEN 1 ELSE 0 END
ELSE 1
END AS target_active,
t.created_at
FROM ssh_principal_grant_targets t
LEFT JOIN users u ON t.target_type = 'user' AND u.public_id = t.target_public_id
LEFT JOIN user_groups g ON t.target_type = 'group' AND g.public_id = t.target_public_id
WHERE t.grant_public_id = ?
ORDER BY t.target_type, target_name, t.target_public_id`,
strings.TrimSpace(grantID),
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.GrantID,
&item.TargetType,
&item.TargetID,
&item.TargetName,
&item.TargetActive,
&item.CreatedAt,
)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func insertSSHPrincipalGrantTargetTx(tx *sql.Tx, grantID string, targetType string, targetID string, createdAt int64) error {
var err error
var result sql.Result
var rowsAffected int64
var normalizedType string
normalizedType = strings.ToLower(strings.TrimSpace(targetType))
if normalizedType != "user" && normalizedType != "group" {
return errors.New("target_type must be user or group")
}
if strings.TrimSpace(targetID) == "" {
return errors.New("target id is required")
}
if normalizedType == "user" {
result, err = tx.Exec(`INSERT INTO ssh_principal_grant_targets (grant_public_id, target_type, target_public_id, created_at)
SELECT ?, 'user', u.public_id, ?
FROM users u
WHERE u.public_id = ?`,
strings.TrimSpace(grantID),
createdAt,
strings.TrimSpace(targetID),
)
} else {
result, err = tx.Exec(`INSERT INTO ssh_principal_grant_targets (grant_public_id, target_type, target_public_id, created_at)
SELECT ?, 'group', g.public_id, ?
FROM user_groups g
WHERE g.public_id = ?`,
strings.TrimSpace(grantID),
createdAt,
strings.TrimSpace(targetID),
)
}
if err != nil {
return err
}
rowsAffected, err = result.RowsAffected()
if err != nil {
return err
}
if rowsAffected <= 0 {
if normalizedType == "user" {
return errors.New("target user not found")
}
return errors.New("target group not found")
}
return nil
}