464 lines
14 KiB
Go
464 lines
14 KiB
Go
package db
|
|
|
|
import "context"
|
|
import "database/sql"
|
|
import "encoding/json"
|
|
import "fmt"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
const TLSAuthPolicyIDDefault string = "tls-auth-default"
|
|
const tlsAuthPolicyIDReadOpenWriteCert string = "tls-auth-read-open-write-cert"
|
|
const tlsAuthPolicyIDReadOpenWriteCertOrAuth string = "tls-auth-read-open-write-cert-or-auth"
|
|
const tlsAuthPolicyIDCertOnly string = "tls-auth-cert-only"
|
|
const tlsAuthPolicyIDReadOnlyPublic string = "tls-auth-read-only-public"
|
|
const tlsAuthPolicyIDDenyAccess string = "tls-auth-deny-all"
|
|
|
|
func tlsBuiltinAuthPolicies() []models.TLSAuthPolicy {
|
|
var items []models.TLSAuthPolicy
|
|
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: TLSAuthPolicyIDDefault,
|
|
Name: "Default",
|
|
Description: "Require normal authentication for read and write operations.",
|
|
ReadMode: "auth",
|
|
WriteMode: "auth",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDReadOpenWriteCert,
|
|
Name: "Read Open / Write Cert",
|
|
Description: "Allow reads without auth and require a client certificate for writes.",
|
|
ReadMode: "public",
|
|
WriteMode: "cert",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDReadOpenWriteCertOrAuth,
|
|
Name: "Read Open / Write Cert Or Auth",
|
|
Description: "Allow reads without auth and allow writes by client certificate or normal auth.",
|
|
ReadMode: "public",
|
|
WriteMode: "cert_or_auth",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDCertOnly,
|
|
Name: "Cert Only",
|
|
Description: "Require a client certificate for reads and writes.",
|
|
ReadMode: "cert",
|
|
WriteMode: "cert",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDReadOnlyPublic,
|
|
Name: "Read Only Public",
|
|
Description: "Allow reads without auth and deny writes.",
|
|
ReadMode: "public",
|
|
WriteMode: "deny",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDDenyAccess,
|
|
Name: "Deny Access",
|
|
Description: "Deny reads and deny writes.",
|
|
ReadMode: "deny",
|
|
WriteMode: "deny",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
return items
|
|
}
|
|
|
|
func tlsEndpointServices() []string {
|
|
var items []string
|
|
|
|
items = append(items, "api")
|
|
items = append(items, "git")
|
|
items = append(items, "rpm")
|
|
items = append(items, "v2")
|
|
return items
|
|
}
|
|
|
|
func DefaultTLSEndpointPolicies() []models.TLSEndpointPolicy {
|
|
var services []string
|
|
var items []models.TLSEndpointPolicy
|
|
var i int
|
|
|
|
services = tlsEndpointServices()
|
|
for i = 0; i < len(services); i++ {
|
|
items = append(items, models.TLSEndpointPolicy{
|
|
Service: services[i],
|
|
PolicyID: TLSAuthPolicyIDDefault,
|
|
})
|
|
}
|
|
return items
|
|
}
|
|
|
|
func normalizeTLSEndpointPolicies(items []models.TLSEndpointPolicy) []models.TLSEndpointPolicy {
|
|
var out []models.TLSEndpointPolicy
|
|
var current map[string]string
|
|
var services []string
|
|
var i int
|
|
var service string
|
|
var policyID string
|
|
|
|
current = make(map[string]string)
|
|
for i = 0; i < len(items); i++ {
|
|
service = strings.ToLower(strings.TrimSpace(items[i].Service))
|
|
if service == "" {
|
|
continue
|
|
}
|
|
policyID = strings.TrimSpace(items[i].PolicyID)
|
|
if policyID == "" {
|
|
policyID = TLSAuthPolicyIDDefault
|
|
}
|
|
current[service] = policyID
|
|
}
|
|
services = tlsEndpointServices()
|
|
for i = 0; i < len(services); i++ {
|
|
policyID = current[services[i]]
|
|
if policyID == "" {
|
|
policyID = TLSAuthPolicyIDDefault
|
|
}
|
|
out = append(out, models.TLSEndpointPolicy{
|
|
Service: services[i],
|
|
PolicyID: policyID,
|
|
})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func encodeTLSEndpointPolicies(items []models.TLSEndpointPolicy) (string, error) {
|
|
var normalized []models.TLSEndpointPolicy
|
|
var data []byte
|
|
var err error
|
|
|
|
normalized = normalizeTLSEndpointPolicies(items)
|
|
data, err = json.Marshal(normalized)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(data), nil
|
|
}
|
|
|
|
func decodeTLSEndpointPolicies(value string) []models.TLSEndpointPolicy {
|
|
var items []models.TLSEndpointPolicy
|
|
var err error
|
|
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return nil
|
|
}
|
|
err = json.Unmarshal([]byte(value), &items)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return normalizeTLSEndpointPolicies(items)
|
|
}
|
|
|
|
func (s *Store) ListTLSAuthPolicies() ([]models.TLSAuthPolicy, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.TLSAuthPolicy
|
|
var item models.TLSAuthPolicy
|
|
var allowedPKICertIDs string
|
|
var allowedFingerprints string
|
|
var requiredPermissions string
|
|
|
|
rows, err = s.Query(`SELECT p.public_id, p.name, p.description, p.read_mode, p.write_mode, p.require_principal_for_write, COALESCE((SELECT GROUP_CONCAT(c.public_id, ',') FROM tls_auth_policy_allowed_pki_certs apc JOIN pki_certs c ON c.id = apc.cert_id WHERE apc.policy_id = p.id), '') AS allowed_pki_client_cert_ids, p.allowed_cert_fingerprints, p.required_permissions, p.permission_match, p.required_scope, p.created_at, p.updated_at FROM tls_auth_policies p ORDER BY p.name`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
|
|
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
|
|
item.RequiredPermissions = splitCSVValue(requiredPermissions)
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetTLSAuthPolicy(id string) (models.TLSAuthPolicy, error) {
|
|
var row *sql.Row
|
|
var item models.TLSAuthPolicy
|
|
var err error
|
|
var allowedPKICertIDs string
|
|
var allowedFingerprints string
|
|
var requiredPermissions string
|
|
|
|
row = s.QueryRow(`SELECT p.public_id, p.name, p.description, p.read_mode, p.write_mode, p.require_principal_for_write, COALESCE((SELECT GROUP_CONCAT(c.public_id, ',') FROM tls_auth_policy_allowed_pki_certs apc JOIN pki_certs c ON c.id = apc.cert_id WHERE apc.policy_id = p.id), '') AS allowed_pki_client_cert_ids, p.allowed_cert_fingerprints, p.required_permissions, p.permission_match, p.required_scope, p.created_at, p.updated_at FROM tls_auth_policies p WHERE p.public_id = ?`, id)
|
|
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
|
|
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
|
|
item.RequiredPermissions = splitCSVValue(requiredPermissions)
|
|
return item, nil
|
|
}
|
|
|
|
func insertTLSAuthPolicyAllowedPKICerts(execer sqlExecutor, policyID int64, certPublicIDs []string) error {
|
|
var seen map[string]bool
|
|
var i int
|
|
var certPublicID string
|
|
var result sql.Result
|
|
var affected int64
|
|
var err error
|
|
|
|
seen = make(map[string]bool)
|
|
for i = 0; i < len(certPublicIDs); i++ {
|
|
certPublicID = strings.TrimSpace(certPublicIDs[i])
|
|
if certPublicID == "" {
|
|
continue
|
|
}
|
|
if seen[certPublicID] {
|
|
continue
|
|
}
|
|
seen[certPublicID] = true
|
|
result, err = execer.Exec(`INSERT INTO tls_auth_policy_allowed_pki_certs (policy_id, cert_id) SELECT ?, id FROM pki_certs WHERE public_id = ?`, policyID, certPublicID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err = result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("pki cert not found: %s", certPublicID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) CreateTLSAuthPolicy(ctx context.Context, item models.TLSAuthPolicy) (models.TLSAuthPolicy, error) {
|
|
var id string
|
|
var now int64
|
|
var err error
|
|
var result sql.Result
|
|
var internalID int64
|
|
var tx txExecutor
|
|
var owned bool
|
|
|
|
if item.ID == "" {
|
|
id, err = util.NewID()
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item.ID = id
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
tx, owned, err = s.beginImmediateContext(ctx)
|
|
if err != nil { return item, err }
|
|
|
|
result, err = tx.Exec(`INSERT INTO tls_auth_policies (public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
item.ID, item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.CreatedAt, item.UpdatedAt)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
internalID, err = result.LastInsertId()
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = insertTLSAuthPolicyAllowedPKICerts(tx, internalID, item.AllowedPKIClientCertIDs)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return item, err
|
|
}
|
|
err = commitIfOwned(tx, owned)
|
|
if err != nil { return item, err }
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) UpdateTLSAuthPolicy(ctx context.Context, item models.TLSAuthPolicy) error {
|
|
var now int64
|
|
var err error
|
|
var internalID int64
|
|
var tx txExecutor
|
|
var owned bool
|
|
|
|
tx, owned, err = s.beginImmediateContext(ctx)
|
|
if err != nil { return err }
|
|
|
|
err = tx.QueryRow(`SELECT id FROM tls_auth_policies WHERE public_id = ?`, item.ID).Scan(&internalID)
|
|
if err != nil { rollbackIfOwned(tx, owned); return err }
|
|
now = time.Now().UTC().Unix()
|
|
item.UpdatedAt = now
|
|
_, err = tx.Exec(`UPDATE tls_auth_policies SET name = ?, description = ?, read_mode = ?, write_mode = ?, require_principal_for_write = ?, allowed_cert_fingerprints = ?, required_permissions = ?, permission_match = ?, required_scope = ?, updated_at = ? WHERE public_id = ?`,
|
|
item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.UpdatedAt, item.ID)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
_, err = tx.Exec(`DELETE FROM tls_auth_policy_allowed_pki_certs WHERE policy_id = ?`, internalID)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
err = insertTLSAuthPolicyAllowedPKICerts(tx, internalID, item.AllowedPKIClientCertIDs)
|
|
if err != nil {
|
|
rollbackIfOwned(tx, owned)
|
|
return err
|
|
}
|
|
return commitIfOwned(tx, owned)
|
|
}
|
|
|
|
func (s *Store) DeleteTLSAuthPolicy(id string) error {
|
|
var err error
|
|
|
|
_, err = s.Exec(`DELETE FROM tls_auth_policies WHERE public_id = ?`, id)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CountTLSAuthPolicyUsages(id string) (int, error) {
|
|
var count int
|
|
var settings models.TLSSettings
|
|
var listeners []models.TLSListener
|
|
var i int
|
|
var j int
|
|
var err error
|
|
|
|
settings, err = s.GetTLSSettings()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i = 0; i < len(settings.EndpointPolicies); i++ {
|
|
if settings.EndpointPolicies[i].PolicyID == id {
|
|
count++
|
|
}
|
|
}
|
|
listeners, err = s.ListTLSListeners()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i = 0; i < len(listeners); i++ {
|
|
for j = 0; j < len(listeners[i].EndpointPolicies); j++ {
|
|
if listeners[i].EndpointPolicies[j].PolicyID == id {
|
|
count++
|
|
}
|
|
}
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (s *Store) EnsureDefaultTLSAuthPolicies(ctx context.Context) error {
|
|
var items []models.TLSAuthPolicy
|
|
var i int
|
|
var item models.TLSAuthPolicy
|
|
var existing models.TLSAuthPolicy
|
|
var err error
|
|
|
|
items = tlsBuiltinAuthPolicies()
|
|
for i = 0; i < len(items); i++ {
|
|
item = items[i]
|
|
existing, err = s.GetTLSAuthPolicy(item.ID)
|
|
if err == nil {
|
|
if existing.Name == item.Name {
|
|
continue
|
|
}
|
|
continue
|
|
}
|
|
if err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
_, err = s.CreateTLSAuthPolicy(ctx, item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) EnsureTLSAuthPolicyBindings(ctx context.Context) error {
|
|
var settings models.TLSSettings
|
|
var listeners []models.TLSListener
|
|
var i int
|
|
var err error
|
|
|
|
settings, err = s.GetTLSSettings()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(settings.EndpointPolicies) == 0 {
|
|
settings.EndpointPolicies = DefaultTLSEndpointPolicies()
|
|
err = s.SetTLSSettings(ctx, settings)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
listeners, err = s.ListTLSListeners()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i = 0; i < len(listeners); i++ {
|
|
if len(listeners[i].EndpointPolicies) == 0 {
|
|
listeners[i].EndpointPolicies = DefaultTLSEndpointPolicies()
|
|
err = s.UpdateTLSListener(listeners[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ResolveTLSAuthPolicy(listenerKind string, listenerID string, service string) (models.TLSAuthPolicy, error) {
|
|
var settings models.TLSSettings
|
|
var listener models.TLSListener
|
|
var policyID string
|
|
var err error
|
|
|
|
if strings.TrimSpace(listenerKind) == "extra" {
|
|
listener, err = s.GetTLSListener(strings.TrimSpace(listenerID))
|
|
if err != nil {
|
|
return models.TLSAuthPolicy{}, err
|
|
}
|
|
policyID = tlsEndpointPolicyIDForService(listener.EndpointPolicies, service)
|
|
} else {
|
|
settings, err = s.GetTLSSettings()
|
|
if err != nil {
|
|
return models.TLSAuthPolicy{}, err
|
|
}
|
|
policyID = tlsEndpointPolicyIDForService(settings.EndpointPolicies, service)
|
|
}
|
|
if policyID == "" {
|
|
policyID = TLSAuthPolicyIDDefault
|
|
}
|
|
return s.GetTLSAuthPolicy(policyID)
|
|
}
|
|
|
|
func tlsEndpointPolicyIDForService(items []models.TLSEndpointPolicy, service string) string {
|
|
var i int
|
|
var normalizedService string
|
|
|
|
normalizedService = strings.ToLower(strings.TrimSpace(service))
|
|
for i = 0; i < len(items); i++ {
|
|
if strings.ToLower(strings.TrimSpace(items[i].Service)) != normalizedService {
|
|
continue
|
|
}
|
|
return strings.TrimSpace(items[i].PolicyID)
|
|
}
|
|
return ""
|
|
}
|