Files
codit/backend/internal/db/tls-auth-policies.go

464 lines
14 KiB
Go

package db
import "context"
import "database/sql"
import "encoding/json"
import "fmt"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
const TLSAuthPolicyIDDefault string = "tls-auth-default"
const tlsAuthPolicyIDReadOpenWriteCert string = "tls-auth-read-open-write-cert"
const tlsAuthPolicyIDReadOpenWriteCertOrAuth string = "tls-auth-read-open-write-cert-or-auth"
const tlsAuthPolicyIDCertOnly string = "tls-auth-cert-only"
const tlsAuthPolicyIDReadOnlyPublic string = "tls-auth-read-only-public"
const tlsAuthPolicyIDDenyAccess string = "tls-auth-deny-all"
func tlsBuiltinAuthPolicies() []models.TLSAuthPolicy {
var items []models.TLSAuthPolicy
items = append(items, models.TLSAuthPolicy{
ID: TLSAuthPolicyIDDefault,
Name: "Default",
Description: "Require normal authentication for read and write operations.",
ReadMode: "auth",
WriteMode: "auth",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDReadOpenWriteCert,
Name: "Read Open / Write Cert",
Description: "Allow reads without auth and require a client certificate for writes.",
ReadMode: "public",
WriteMode: "cert",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDReadOpenWriteCertOrAuth,
Name: "Read Open / Write Cert Or Auth",
Description: "Allow reads without auth and allow writes by client certificate or normal auth.",
ReadMode: "public",
WriteMode: "cert_or_auth",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDCertOnly,
Name: "Cert Only",
Description: "Require a client certificate for reads and writes.",
ReadMode: "cert",
WriteMode: "cert",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDReadOnlyPublic,
Name: "Read Only Public",
Description: "Allow reads without auth and deny writes.",
ReadMode: "public",
WriteMode: "deny",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDDenyAccess,
Name: "Deny Access",
Description: "Deny reads and deny writes.",
ReadMode: "deny",
WriteMode: "deny",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
return items
}
func tlsEndpointServices() []string {
var items []string
items = append(items, "api")
items = append(items, "git")
items = append(items, "rpm")
items = append(items, "v2")
return items
}
func DefaultTLSEndpointPolicies() []models.TLSEndpointPolicy {
var services []string
var items []models.TLSEndpointPolicy
var i int
services = tlsEndpointServices()
for i = 0; i < len(services); i++ {
items = append(items, models.TLSEndpointPolicy{
Service: services[i],
PolicyID: TLSAuthPolicyIDDefault,
})
}
return items
}
func normalizeTLSEndpointPolicies(items []models.TLSEndpointPolicy) []models.TLSEndpointPolicy {
var out []models.TLSEndpointPolicy
var current map[string]string
var services []string
var i int
var service string
var policyID string
current = make(map[string]string)
for i = 0; i < len(items); i++ {
service = strings.ToLower(strings.TrimSpace(items[i].Service))
if service == "" {
continue
}
policyID = strings.TrimSpace(items[i].PolicyID)
if policyID == "" {
policyID = TLSAuthPolicyIDDefault
}
current[service] = policyID
}
services = tlsEndpointServices()
for i = 0; i < len(services); i++ {
policyID = current[services[i]]
if policyID == "" {
policyID = TLSAuthPolicyIDDefault
}
out = append(out, models.TLSEndpointPolicy{
Service: services[i],
PolicyID: policyID,
})
}
return out
}
func encodeTLSEndpointPolicies(items []models.TLSEndpointPolicy) (string, error) {
var normalized []models.TLSEndpointPolicy
var data []byte
var err error
normalized = normalizeTLSEndpointPolicies(items)
data, err = json.Marshal(normalized)
if err != nil {
return "", err
}
return string(data), nil
}
func decodeTLSEndpointPolicies(value string) []models.TLSEndpointPolicy {
var items []models.TLSEndpointPolicy
var err error
value = strings.TrimSpace(value)
if value == "" {
return nil
}
err = json.Unmarshal([]byte(value), &items)
if err != nil {
return nil
}
return normalizeTLSEndpointPolicies(items)
}
func (s *Store) ListTLSAuthPolicies() ([]models.TLSAuthPolicy, error) {
var rows *sql.Rows
var err error
var items []models.TLSAuthPolicy
var item models.TLSAuthPolicy
var allowedPKICertIDs string
var allowedFingerprints string
var requiredPermissions string
rows, err = s.Query(`SELECT p.public_id, p.name, p.description, p.read_mode, p.write_mode, p.require_principal_for_write, COALESCE((SELECT GROUP_CONCAT(c.public_id, ',') FROM tls_auth_policy_allowed_pki_certs apc JOIN pki_certs c ON c.id = apc.cert_id WHERE apc.policy_id = p.id), '') AS allowed_pki_client_cert_ids, p.allowed_cert_fingerprints, p.required_permissions, p.permission_match, p.required_scope, p.created_at, p.updated_at FROM tls_auth_policies p ORDER BY p.name`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
item.RequiredPermissions = splitCSVValue(requiredPermissions)
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetTLSAuthPolicy(id string) (models.TLSAuthPolicy, error) {
var row *sql.Row
var item models.TLSAuthPolicy
var err error
var allowedPKICertIDs string
var allowedFingerprints string
var requiredPermissions string
row = s.QueryRow(`SELECT p.public_id, p.name, p.description, p.read_mode, p.write_mode, p.require_principal_for_write, COALESCE((SELECT GROUP_CONCAT(c.public_id, ',') FROM tls_auth_policy_allowed_pki_certs apc JOIN pki_certs c ON c.id = apc.cert_id WHERE apc.policy_id = p.id), '') AS allowed_pki_client_cert_ids, p.allowed_cert_fingerprints, p.required_permissions, p.permission_match, p.required_scope, p.created_at, p.updated_at FROM tls_auth_policies p WHERE p.public_id = ?`, id)
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return item, err
}
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
item.RequiredPermissions = splitCSVValue(requiredPermissions)
return item, nil
}
func insertTLSAuthPolicyAllowedPKICerts(execer sqlExecutor, policyID int64, certPublicIDs []string) error {
var seen map[string]bool
var i int
var certPublicID string
var result sql.Result
var affected int64
var err error
seen = make(map[string]bool)
for i = 0; i < len(certPublicIDs); i++ {
certPublicID = strings.TrimSpace(certPublicIDs[i])
if certPublicID == "" {
continue
}
if seen[certPublicID] {
continue
}
seen[certPublicID] = true
result, err = execer.Exec(`INSERT INTO tls_auth_policy_allowed_pki_certs (policy_id, cert_id) SELECT ?, id FROM pki_certs WHERE public_id = ?`, policyID, certPublicID)
if err != nil {
return err
}
affected, err = result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return fmt.Errorf("pki cert not found: %s", certPublicID)
}
}
return nil
}
func (s *Store) CreateTLSAuthPolicy(ctx context.Context, item models.TLSAuthPolicy) (models.TLSAuthPolicy, error) {
var id string
var now int64
var err error
var result sql.Result
var internalID int64
var tx txExecutor
var owned bool
if item.ID == "" {
id, err = util.NewID()
if err != nil {
return item, err
}
item.ID = id
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
tx, owned, err = s.beginImmediateContext(ctx)
if err != nil { return item, err }
result, err = tx.Exec(`INSERT INTO tls_auth_policies (public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
item.ID, item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.CreatedAt, item.UpdatedAt)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
internalID, err = result.LastInsertId()
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = insertTLSAuthPolicyAllowedPKICerts(tx, internalID, item.AllowedPKIClientCertIDs)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
err = commitIfOwned(tx, owned)
if err != nil { return item, err }
return item, nil
}
func (s *Store) UpdateTLSAuthPolicy(ctx context.Context, item models.TLSAuthPolicy) error {
var now int64
var err error
var internalID int64
var tx txExecutor
var owned bool
tx, owned, err = s.beginImmediateContext(ctx)
if err != nil { return err }
err = tx.QueryRow(`SELECT id FROM tls_auth_policies WHERE public_id = ?`, item.ID).Scan(&internalID)
if err != nil { rollbackIfOwned(tx, owned); return err }
now = time.Now().UTC().Unix()
item.UpdatedAt = now
_, err = tx.Exec(`UPDATE tls_auth_policies SET name = ?, description = ?, read_mode = ?, write_mode = ?, require_principal_for_write = ?, allowed_cert_fingerprints = ?, required_permissions = ?, permission_match = ?, required_scope = ?, updated_at = ? WHERE public_id = ?`,
item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.UpdatedAt, item.ID)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
_, err = tx.Exec(`DELETE FROM tls_auth_policy_allowed_pki_certs WHERE policy_id = ?`, internalID)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
err = insertTLSAuthPolicyAllowedPKICerts(tx, internalID, item.AllowedPKIClientCertIDs)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
return commitIfOwned(tx, owned)
}
func (s *Store) DeleteTLSAuthPolicy(id string) error {
var err error
_, err = s.Exec(`DELETE FROM tls_auth_policies WHERE public_id = ?`, id)
return err
}
func (s *Store) CountTLSAuthPolicyUsages(id string) (int, error) {
var count int
var settings models.TLSSettings
var listeners []models.TLSListener
var i int
var j int
var err error
settings, err = s.GetTLSSettings()
if err != nil {
return 0, err
}
for i = 0; i < len(settings.EndpointPolicies); i++ {
if settings.EndpointPolicies[i].PolicyID == id {
count++
}
}
listeners, err = s.ListTLSListeners()
if err != nil {
return 0, err
}
for i = 0; i < len(listeners); i++ {
for j = 0; j < len(listeners[i].EndpointPolicies); j++ {
if listeners[i].EndpointPolicies[j].PolicyID == id {
count++
}
}
}
return count, nil
}
func (s *Store) EnsureDefaultTLSAuthPolicies(ctx context.Context) error {
var items []models.TLSAuthPolicy
var i int
var item models.TLSAuthPolicy
var existing models.TLSAuthPolicy
var err error
items = tlsBuiltinAuthPolicies()
for i = 0; i < len(items); i++ {
item = items[i]
existing, err = s.GetTLSAuthPolicy(item.ID)
if err == nil {
if existing.Name == item.Name {
continue
}
continue
}
if err != sql.ErrNoRows {
return err
}
_, err = s.CreateTLSAuthPolicy(ctx, item)
if err != nil {
return err
}
}
return nil
}
func (s *Store) EnsureTLSAuthPolicyBindings(ctx context.Context) error {
var settings models.TLSSettings
var listeners []models.TLSListener
var i int
var err error
settings, err = s.GetTLSSettings()
if err != nil {
return err
}
if len(settings.EndpointPolicies) == 0 {
settings.EndpointPolicies = DefaultTLSEndpointPolicies()
err = s.SetTLSSettings(ctx, settings)
if err != nil {
return err
}
}
listeners, err = s.ListTLSListeners()
if err != nil {
return err
}
for i = 0; i < len(listeners); i++ {
if len(listeners[i].EndpointPolicies) == 0 {
listeners[i].EndpointPolicies = DefaultTLSEndpointPolicies()
err = s.UpdateTLSListener(listeners[i])
if err != nil {
return err
}
}
}
return nil
}
func (s *Store) ResolveTLSAuthPolicy(listenerKind string, listenerID string, service string) (models.TLSAuthPolicy, error) {
var settings models.TLSSettings
var listener models.TLSListener
var policyID string
var err error
if strings.TrimSpace(listenerKind) == "extra" {
listener, err = s.GetTLSListener(strings.TrimSpace(listenerID))
if err != nil {
return models.TLSAuthPolicy{}, err
}
policyID = tlsEndpointPolicyIDForService(listener.EndpointPolicies, service)
} else {
settings, err = s.GetTLSSettings()
if err != nil {
return models.TLSAuthPolicy{}, err
}
policyID = tlsEndpointPolicyIDForService(settings.EndpointPolicies, service)
}
if policyID == "" {
policyID = TLSAuthPolicyIDDefault
}
return s.GetTLSAuthPolicy(policyID)
}
func tlsEndpointPolicyIDForService(items []models.TLSEndpointPolicy, service string) string {
var i int
var normalizedService string
normalizedService = strings.ToLower(strings.TrimSpace(service))
for i = 0; i < len(items); i++ {
if strings.ToLower(strings.TrimSpace(items[i].Service)) != normalizedService {
continue
}
return strings.TrimSpace(items[i].PolicyID)
}
return ""
}