package db import "database/sql" import "encoding/json" import "errors" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" func (s *Store) ListSSHPrincipalGrants(targetType string, targetID string) ([]models.SSHPrincipalGrant, error) { var rows *sql.Rows var err error var items []models.SSHPrincipalGrant var item models.SSHPrincipalGrant var principalsJSON string targetType = strings.TrimSpace(targetType) targetID = strings.TrimSpace(targetID) if targetType != "" && targetID != "" { rows, err = s.Query(`SELECT DISTINCT g.public_id, g.principal, g.principals_json, g.valid_after, g.valid_before, g.max_cert_valid_seconds, g.max_uses, g.used_count, g.disabled, g.reason, COALESCE(g.created_by_user_public_id, ''), g.created_at, g.updated_at, g.last_used_at FROM ssh_principal_grants g JOIN ssh_principal_grant_targets t ON t.grant_public_id = g.public_id WHERE t.target_type = ? AND t.target_public_id = ? ORDER BY g.principal, g.created_at`, targetType, targetID) } else { rows, err = s.Query(`SELECT g.public_id, g.principal, g.principals_json, g.valid_after, g.valid_before, g.max_cert_valid_seconds, g.max_uses, g.used_count, g.disabled, g.reason, COALESCE(g.created_by_user_public_id, ''), g.created_at, g.updated_at, g.last_used_at FROM ssh_principal_grants g ORDER BY g.principal, g.created_at`) } if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.Name, &principalsJSON, &item.ValidAfter, &item.ValidBefore, &item.MaxCertValidSeconds, &item.MaxUses, &item.UsedCount, &item.Disabled, &item.Reason, &item.CreatedByUserID, &item.CreatedAt, &item.UpdatedAt, &item.LastUsedAt, ) if err != nil { return nil, err } item.Principals, err = decodeSSHPrincipalGrantPrincipals(principalsJSON) if err != nil { return nil, err } item.Targets, err = s.listSSHPrincipalGrantTargets(item.ID) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) ListActiveSSHPrincipalGrantsForUser(userID string, now int64) ([]models.SSHPrincipalGrant, error) { var rows *sql.Rows var err error var items []models.SSHPrincipalGrant var item models.SSHPrincipalGrant var principalsJSON string rows, err = s.Query(`SELECT DISTINCT g.public_id, g.principal, g.principals_json, g.valid_after, g.valid_before, g.max_cert_valid_seconds, g.max_uses, g.used_count, g.disabled, g.reason, COALESCE(g.created_by_user_public_id, ''), g.created_at, g.updated_at, g.last_used_at FROM ssh_principal_grants g JOIN ssh_principal_grant_targets t ON t.grant_public_id = g.public_id LEFT JOIN user_groups ug ON t.target_type = 'group' AND ug.public_id = t.target_public_id LEFT JOIN user_group_members gm ON t.target_type = 'group' AND gm.group_id = ug.id LEFT JOIN users gu ON gm.user_id = gu.id WHERE g.disabled = 0 AND (g.valid_after = 0 OR g.valid_after <= ?) AND (g.valid_before = 0 OR g.valid_before >= ?) AND (g.max_uses = 0 OR g.used_count < g.max_uses) AND ( (t.target_type = 'user' AND t.target_public_id = ?) OR (t.target_type = 'group' AND ug.disabled = 0 AND gu.public_id = ?) ) ORDER BY g.principal, g.created_at`, now, now, strings.TrimSpace(userID), strings.TrimSpace(userID), ) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.Name, &principalsJSON, &item.ValidAfter, &item.ValidBefore, &item.MaxCertValidSeconds, &item.MaxUses, &item.UsedCount, &item.Disabled, &item.Reason, &item.CreatedByUserID, &item.CreatedAt, &item.UpdatedAt, &item.LastUsedAt, ) if err != nil { return nil, err } item.Principals, err = decodeSSHPrincipalGrantPrincipals(principalsJSON) if err != nil { return nil, err } item.Targets, err = s.listSSHPrincipalGrantTargets(item.ID) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHPrincipalGrant(id string) (models.SSHPrincipalGrant, error) { var row *sql.Row var item models.SSHPrincipalGrant var principalsJSON string var err error row = s.QueryRow(`SELECT g.public_id, g.principal, g.principals_json, g.valid_after, g.valid_before, g.max_cert_valid_seconds, g.max_uses, g.used_count, g.disabled, g.reason, COALESCE(g.created_by_user_public_id, ''), g.created_at, g.updated_at, g.last_used_at FROM ssh_principal_grants g WHERE g.public_id = ?`, strings.TrimSpace(id)) err = row.Scan( &item.ID, &item.Name, &principalsJSON, &item.ValidAfter, &item.ValidBefore, &item.MaxCertValidSeconds, &item.MaxUses, &item.UsedCount, &item.Disabled, &item.Reason, &item.CreatedByUserID, &item.CreatedAt, &item.UpdatedAt, &item.LastUsedAt, ) if err != nil { return item, err } item.Principals, err = decodeSSHPrincipalGrantPrincipals(principalsJSON) if err != nil { return item, err } item.Targets, err = s.listSSHPrincipalGrantTargets(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) CreateSSHPrincipalGrant(item models.SSHPrincipalGrant) (models.SSHPrincipalGrant, error) { var tx *sql.Tx var owned bool var err error var now int64 var i int var principalsJSON string var target models.SSHPrincipalGrantTarget if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if len(item.Principals) == 0 { return item, errors.New("at least one principal is required") } if len(item.Targets) == 0 { return item, errors.New("at least one target is required") } if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now item.LastUsedAt = 0 if item.MaxCertValidSeconds < 0 { item.MaxCertValidSeconds = 0 } if item.MaxUses < 0 { item.MaxUses = 0 } principalsJSON, err = encodeSSHPrincipalGrantPrincipals(item.Principals) if err != nil { return item, err } tx, owned, err = s.begin() if err != nil { return item, err } _, err = tx.Exec(`INSERT INTO ssh_principal_grants ( public_id, principal, principals_json, valid_after, valid_before, max_cert_valid_seconds, max_uses, used_count, disabled, reason, created_by_user_public_id, created_at, updated_at, last_used_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, strings.TrimSpace(item.ID), strings.TrimSpace(item.Name), principalsJSON, item.ValidAfter, item.ValidBefore, item.MaxCertValidSeconds, item.MaxUses, item.UsedCount, item.Disabled, strings.TrimSpace(item.Reason), nullableTrimmedString(item.CreatedByUserID), item.CreatedAt, item.UpdatedAt, item.LastUsedAt, ) if err != nil { rollbackIfOwned(tx, owned) return item, err } for i = 0; i < len(item.Targets); i++ { target = item.Targets[i] err = insertSSHPrincipalGrantTargetTx(tx, item.ID, target.TargetType, target.TargetID, now) if err != nil { rollbackIfOwned(tx, owned) return item, err } } err = commitIfOwned(tx, owned) if err != nil { return item, err } item, err = s.GetSSHPrincipalGrant(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) UpdateSSHPrincipalGrant(item models.SSHPrincipalGrant) (models.SSHPrincipalGrant, error) { var tx *sql.Tx var owned bool var err error var now int64 var i int var principalsJSON string var target models.SSHPrincipalGrantTarget if strings.TrimSpace(item.ID) == "" { return item, errors.New("id is required") } if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if len(item.Principals) == 0 { return item, errors.New("at least one principal is required") } if len(item.Targets) == 0 { return item, errors.New("at least one target is required") } if item.MaxCertValidSeconds < 0 { item.MaxCertValidSeconds = 0 } if item.MaxUses < 0 { item.MaxUses = 0 } principalsJSON, err = encodeSSHPrincipalGrantPrincipals(item.Principals) if err != nil { return item, err } now = time.Now().UTC().Unix() item.UpdatedAt = now tx, owned, err = s.begin() if err != nil { return item, err } _, err = tx.Exec(`UPDATE ssh_principal_grants SET principal = ?, principals_json = ?, valid_after = ?, valid_before = ?, max_cert_valid_seconds = ?, max_uses = ?, disabled = ?, reason = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), principalsJSON, item.ValidAfter, item.ValidBefore, item.MaxCertValidSeconds, item.MaxUses, item.Disabled, strings.TrimSpace(item.Reason), item.UpdatedAt, strings.TrimSpace(item.ID), ) if err != nil { rollbackIfOwned(tx, owned) return item, err } _, err = tx.Exec(`DELETE FROM ssh_principal_grant_targets WHERE grant_public_id = ?`, strings.TrimSpace(item.ID)) if err != nil { rollbackIfOwned(tx, owned) return item, err } for i = 0; i < len(item.Targets); i++ { target = item.Targets[i] err = insertSSHPrincipalGrantTargetTx(tx, item.ID, target.TargetType, target.TargetID, now) if err != nil { rollbackIfOwned(tx, owned) return item, err } } err = commitIfOwned(tx, owned) if err != nil { return item, err } item, err = s.GetSSHPrincipalGrant(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) DeleteSSHPrincipalGrant(id string) error { var err error _, err = s.Exec(`DELETE FROM ssh_principal_grants WHERE public_id = ?`, strings.TrimSpace(id)) return err } func (s *Store) MarkSSHPrincipalGrantUsed(id string) error { var now int64 var result sql.Result var rowsAffected int64 var err error now = time.Now().UTC().Unix() result, err = s.Exec(`UPDATE ssh_principal_grants SET used_count = used_count + 1, last_used_at = ?, updated_at = ? WHERE public_id = ? AND disabled = 0 AND (max_uses = 0 OR used_count < max_uses) AND (valid_after = 0 OR valid_after <= ?) AND (valid_before = 0 OR valid_before >= ?)`, now, now, strings.TrimSpace(id), now, now, ) if err != nil { return err } rowsAffected, err = result.RowsAffected() if err != nil { return err } if rowsAffected <= 0 { return errors.New("grant is not active") } return nil } func encodeSSHPrincipalGrantPrincipals(items []string) (string, error) { var data []byte var err error data, err = json.Marshal(items) if err != nil { return "", err } return string(data), nil } func decodeSSHPrincipalGrantPrincipals(raw string) ([]string, error) { var items []string var err error raw = strings.TrimSpace(raw) if raw == "" { return []string{}, nil } err = json.Unmarshal([]byte(raw), &items) if err != nil { return nil, err } return items, nil } func (s *Store) listSSHPrincipalGrantTargets(grantID string) ([]models.SSHPrincipalGrantTarget, error) { var rows *sql.Rows var err error var items []models.SSHPrincipalGrantTarget var item models.SSHPrincipalGrantTarget rows, err = s.Query(`SELECT t.grant_public_id, t.target_type, t.target_public_id, CASE WHEN t.target_type = 'user' THEN COALESCE(u.username, '') WHEN t.target_type = 'group' THEN COALESCE(g.name, '') ELSE '' END AS target_name, CASE WHEN t.target_type = 'user' THEN CASE WHEN u.disabled = 0 THEN 1 ELSE 0 END WHEN t.target_type = 'group' THEN CASE WHEN g.disabled = 0 THEN 1 ELSE 0 END ELSE 1 END AS target_active, t.created_at FROM ssh_principal_grant_targets t LEFT JOIN users u ON t.target_type = 'user' AND u.public_id = t.target_public_id LEFT JOIN user_groups g ON t.target_type = 'group' AND g.public_id = t.target_public_id WHERE t.grant_public_id = ? ORDER BY t.target_type, target_name, t.target_public_id`, strings.TrimSpace(grantID), ) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.GrantID, &item.TargetType, &item.TargetID, &item.TargetName, &item.TargetActive, &item.CreatedAt, ) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func insertSSHPrincipalGrantTargetTx(tx *sql.Tx, grantID string, targetType string, targetID string, createdAt int64) error { var err error var result sql.Result var rowsAffected int64 var normalizedType string normalizedType = strings.ToLower(strings.TrimSpace(targetType)) if normalizedType != "user" && normalizedType != "group" { return errors.New("target_type must be user or group") } if strings.TrimSpace(targetID) == "" { return errors.New("target id is required") } if normalizedType == "user" { result, err = tx.Exec(`INSERT INTO ssh_principal_grant_targets (grant_public_id, target_type, target_public_id, created_at) SELECT ?, 'user', u.public_id, ? FROM users u WHERE u.public_id = ?`, strings.TrimSpace(grantID), createdAt, strings.TrimSpace(targetID), ) } else { result, err = tx.Exec(`INSERT INTO ssh_principal_grant_targets (grant_public_id, target_type, target_public_id, created_at) SELECT ?, 'group', g.public_id, ? FROM user_groups g WHERE g.public_id = ?`, strings.TrimSpace(grantID), createdAt, strings.TrimSpace(targetID), ) } if err != nil { return err } rowsAffected, err = result.RowsAffected() if err != nil { return err } if rowsAffected <= 0 { if normalizedType == "user" { return errors.New("target user not found") } return errors.New("target group not found") } return nil }