Files
codit/backend/internal/db/tls_auth_policies.go

387 lines
12 KiB
Go

package db
import "database/sql"
import "encoding/json"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
const TLSAuthPolicyIDDefault string = "tls-auth-default"
const tlsAuthPolicyIDReadOpenWriteCert string = "tls-auth-read-open-write-cert"
const tlsAuthPolicyIDReadOpenWriteCertOrAuth string = "tls-auth-read-open-write-cert-or-auth"
const tlsAuthPolicyIDCertOnly string = "tls-auth-cert-only"
const tlsAuthPolicyIDReadOnlyPublic string = "tls-auth-read-only-public"
const tlsAuthPolicyIDDenyAccess string = "tls-auth-deny-all"
func tlsBuiltinAuthPolicies() []models.TLSAuthPolicy {
var items []models.TLSAuthPolicy
items = append(items, models.TLSAuthPolicy{
ID: TLSAuthPolicyIDDefault,
Name: "Default",
Description: "Require normal authentication for read and write operations.",
ReadMode: "auth",
WriteMode: "auth",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDReadOpenWriteCert,
Name: "Read Open / Write Cert",
Description: "Allow reads without auth and require a client certificate for writes.",
ReadMode: "public",
WriteMode: "cert",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDReadOpenWriteCertOrAuth,
Name: "Read Open / Write Cert Or Auth",
Description: "Allow reads without auth and allow writes by client certificate or normal auth.",
ReadMode: "public",
WriteMode: "cert_or_auth",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDCertOnly,
Name: "Cert Only",
Description: "Require a client certificate for reads and writes.",
ReadMode: "cert",
WriteMode: "cert",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDReadOnlyPublic,
Name: "Read Only Public",
Description: "Allow reads without auth and deny writes.",
ReadMode: "public",
WriteMode: "deny",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
items = append(items, models.TLSAuthPolicy{
ID: tlsAuthPolicyIDDenyAccess,
Name: "Deny Access",
Description: "Deny reads and deny writes.",
ReadMode: "deny",
WriteMode: "deny",
RequirePrincipalForWrite: true,
PermissionMatch: "all",
})
return items
}
func tlsEndpointServices() []string {
var items []string
items = append(items, "api")
items = append(items, "git")
items = append(items, "rpm")
items = append(items, "v2")
return items
}
func DefaultTLSEndpointPolicies() []models.TLSEndpointPolicy {
var services []string
var items []models.TLSEndpointPolicy
var i int
services = tlsEndpointServices()
for i = 0; i < len(services); i++ {
items = append(items, models.TLSEndpointPolicy{
Service: services[i],
PolicyID: TLSAuthPolicyIDDefault,
})
}
return items
}
func normalizeTLSEndpointPolicies(items []models.TLSEndpointPolicy) []models.TLSEndpointPolicy {
var out []models.TLSEndpointPolicy
var current map[string]string
var services []string
var i int
var service string
var policyID string
current = make(map[string]string)
for i = 0; i < len(items); i++ {
service = strings.ToLower(strings.TrimSpace(items[i].Service))
if service == "" {
continue
}
policyID = strings.TrimSpace(items[i].PolicyID)
if policyID == "" {
policyID = TLSAuthPolicyIDDefault
}
current[service] = policyID
}
services = tlsEndpointServices()
for i = 0; i < len(services); i++ {
policyID = current[services[i]]
if policyID == "" {
policyID = TLSAuthPolicyIDDefault
}
out = append(out, models.TLSEndpointPolicy{
Service: services[i],
PolicyID: policyID,
})
}
return out
}
func encodeTLSEndpointPolicies(items []models.TLSEndpointPolicy) (string, error) {
var normalized []models.TLSEndpointPolicy
var data []byte
var err error
normalized = normalizeTLSEndpointPolicies(items)
data, err = json.Marshal(normalized)
if err != nil {
return "", err
}
return string(data), nil
}
func decodeTLSEndpointPolicies(value string) []models.TLSEndpointPolicy {
var items []models.TLSEndpointPolicy
var err error
value = strings.TrimSpace(value)
if value == "" {
return nil
}
err = json.Unmarshal([]byte(value), &items)
if err != nil {
return nil
}
return normalizeTLSEndpointPolicies(items)
}
func (s *Store) ListTLSAuthPolicies() ([]models.TLSAuthPolicy, error) {
var rows *sql.Rows
var err error
var items []models.TLSAuthPolicy
var item models.TLSAuthPolicy
var allowedPKICertIDs string
var allowedFingerprints string
var requiredPermissions string
rows, err = s.Query(`SELECT public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at FROM tls_auth_policies ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
item.RequiredPermissions = splitCSVValue(requiredPermissions)
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetTLSAuthPolicy(id string) (models.TLSAuthPolicy, error) {
var row *sql.Row
var item models.TLSAuthPolicy
var err error
var allowedPKICertIDs string
var allowedFingerprints string
var requiredPermissions string
row = s.QueryRow(`SELECT public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at FROM tls_auth_policies WHERE public_id = ?`, id)
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return item, err
}
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
item.RequiredPermissions = splitCSVValue(requiredPermissions)
return item, nil
}
func (s *Store) CreateTLSAuthPolicy(item models.TLSAuthPolicy) (models.TLSAuthPolicy, error) {
var id string
var now int64
var err error
if item.ID == "" {
id, err = util.NewID()
if err != nil {
return item, err
}
item.ID = id
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
_, err = s.Exec(`INSERT INTO tls_auth_policies (public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
item.ID, item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedPKIClientCertIDs, ","), strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.CreatedAt, item.UpdatedAt)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) UpdateTLSAuthPolicy(item models.TLSAuthPolicy) error {
var now int64
var err error
now = time.Now().UTC().Unix()
item.UpdatedAt = now
_, err = s.Exec(`UPDATE tls_auth_policies SET name = ?, description = ?, read_mode = ?, write_mode = ?, require_principal_for_write = ?, allowed_pki_client_cert_ids = ?, allowed_cert_fingerprints = ?, required_permissions = ?, permission_match = ?, required_scope = ?, updated_at = ? WHERE public_id = ?`,
item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedPKIClientCertIDs, ","), strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.UpdatedAt, item.ID)
return err
}
func (s *Store) DeleteTLSAuthPolicy(id string) error {
var err error
_, err = s.Exec(`DELETE FROM tls_auth_policies WHERE public_id = ?`, id)
return err
}
func (s *Store) CountTLSAuthPolicyUsages(id string) (int, error) {
var count int
var settings models.TLSSettings
var listeners []models.TLSListener
var i int
var j int
var err error
settings, err = s.GetTLSSettings()
if err != nil {
return 0, err
}
for i = 0; i < len(settings.EndpointPolicies); i++ {
if settings.EndpointPolicies[i].PolicyID == id {
count++
}
}
listeners, err = s.ListTLSListeners()
if err != nil {
return 0, err
}
for i = 0; i < len(listeners); i++ {
for j = 0; j < len(listeners[i].EndpointPolicies); j++ {
if listeners[i].EndpointPolicies[j].PolicyID == id {
count++
}
}
}
return count, nil
}
func (s *Store) EnsureDefaultTLSAuthPolicies() error {
var items []models.TLSAuthPolicy
var i int
var item models.TLSAuthPolicy
var existing models.TLSAuthPolicy
var err error
items = tlsBuiltinAuthPolicies()
for i = 0; i < len(items); i++ {
item = items[i]
existing, err = s.GetTLSAuthPolicy(item.ID)
if err == nil {
if existing.Name == item.Name {
continue
}
continue
}
if err != sql.ErrNoRows {
return err
}
_, err = s.CreateTLSAuthPolicy(item)
if err != nil {
return err
}
}
return nil
}
func (s *Store) EnsureTLSAuthPolicyBindings() error {
var settings models.TLSSettings
var listeners []models.TLSListener
var i int
var err error
settings, err = s.GetTLSSettings()
if err != nil {
return err
}
if len(settings.EndpointPolicies) == 0 {
settings.EndpointPolicies = DefaultTLSEndpointPolicies()
err = s.SetTLSSettings(settings)
if err != nil {
return err
}
}
listeners, err = s.ListTLSListeners()
if err != nil {
return err
}
for i = 0; i < len(listeners); i++ {
if len(listeners[i].EndpointPolicies) == 0 {
listeners[i].EndpointPolicies = DefaultTLSEndpointPolicies()
err = s.UpdateTLSListener(listeners[i])
if err != nil {
return err
}
}
}
return nil
}
func (s *Store) ResolveTLSAuthPolicy(listenerKind string, listenerID string, service string) (models.TLSAuthPolicy, error) {
var settings models.TLSSettings
var listener models.TLSListener
var policyID string
var err error
if strings.TrimSpace(listenerKind) == "extra" {
listener, err = s.GetTLSListener(strings.TrimSpace(listenerID))
if err != nil {
return models.TLSAuthPolicy{}, err
}
policyID = tlsEndpointPolicyIDForService(listener.EndpointPolicies, service)
} else {
settings, err = s.GetTLSSettings()
if err != nil {
return models.TLSAuthPolicy{}, err
}
policyID = tlsEndpointPolicyIDForService(settings.EndpointPolicies, service)
}
if policyID == "" {
policyID = TLSAuthPolicyIDDefault
}
return s.GetTLSAuthPolicy(policyID)
}
func tlsEndpointPolicyIDForService(items []models.TLSEndpointPolicy, service string) string {
var i int
var normalizedService string
normalizedService = strings.ToLower(strings.TrimSpace(service))
for i = 0; i < len(items); i++ {
if strings.ToLower(strings.TrimSpace(items[i].Service)) != normalizedService {
continue
}
return strings.TrimSpace(items[i].PolicyID)
}
return ""
}