package db import "database/sql" import "encoding/json" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" const TLSAuthPolicyIDDefault string = "tls-auth-default" const tlsAuthPolicyIDReadOpenWriteCert string = "tls-auth-read-open-write-cert" const tlsAuthPolicyIDReadOpenWriteCertOrAuth string = "tls-auth-read-open-write-cert-or-auth" const tlsAuthPolicyIDCertOnly string = "tls-auth-cert-only" const tlsAuthPolicyIDReadOnlyPublic string = "tls-auth-read-only-public" const tlsAuthPolicyIDDenyAccess string = "tls-auth-deny-all" func tlsBuiltinAuthPolicies() []models.TLSAuthPolicy { var items []models.TLSAuthPolicy items = append(items, models.TLSAuthPolicy{ ID: TLSAuthPolicyIDDefault, Name: "Default", Description: "Require normal authentication for read and write operations.", ReadMode: "auth", WriteMode: "auth", RequirePrincipalForWrite: true, PermissionMatch: "all", }) items = append(items, models.TLSAuthPolicy{ ID: tlsAuthPolicyIDReadOpenWriteCert, Name: "Read Open / Write Cert", Description: "Allow reads without auth and require a client certificate for writes.", ReadMode: "public", WriteMode: "cert", RequirePrincipalForWrite: true, PermissionMatch: "all", }) items = append(items, models.TLSAuthPolicy{ ID: tlsAuthPolicyIDReadOpenWriteCertOrAuth, Name: "Read Open / Write Cert Or Auth", Description: "Allow reads without auth and allow writes by client certificate or normal auth.", ReadMode: "public", WriteMode: "cert_or_auth", RequirePrincipalForWrite: true, PermissionMatch: "all", }) items = append(items, models.TLSAuthPolicy{ ID: tlsAuthPolicyIDCertOnly, Name: "Cert Only", Description: "Require a client certificate for reads and writes.", ReadMode: "cert", WriteMode: "cert", RequirePrincipalForWrite: true, PermissionMatch: "all", }) items = append(items, models.TLSAuthPolicy{ ID: tlsAuthPolicyIDReadOnlyPublic, Name: "Read Only Public", Description: "Allow reads without auth and deny writes.", ReadMode: "public", WriteMode: "deny", RequirePrincipalForWrite: true, PermissionMatch: "all", }) items = append(items, models.TLSAuthPolicy{ ID: tlsAuthPolicyIDDenyAccess, Name: "Deny Access", Description: "Deny reads and deny writes.", ReadMode: "deny", WriteMode: "deny", RequirePrincipalForWrite: true, PermissionMatch: "all", }) return items } func tlsEndpointServices() []string { var items []string items = append(items, "api") items = append(items, "git") items = append(items, "rpm") items = append(items, "v2") return items } func DefaultTLSEndpointPolicies() []models.TLSEndpointPolicy { var services []string var items []models.TLSEndpointPolicy var i int services = tlsEndpointServices() for i = 0; i < len(services); i++ { items = append(items, models.TLSEndpointPolicy{ Service: services[i], PolicyID: TLSAuthPolicyIDDefault, }) } return items } func normalizeTLSEndpointPolicies(items []models.TLSEndpointPolicy) []models.TLSEndpointPolicy { var out []models.TLSEndpointPolicy var current map[string]string var services []string var i int var service string var policyID string current = make(map[string]string) for i = 0; i < len(items); i++ { service = strings.ToLower(strings.TrimSpace(items[i].Service)) if service == "" { continue } policyID = strings.TrimSpace(items[i].PolicyID) if policyID == "" { policyID = TLSAuthPolicyIDDefault } current[service] = policyID } services = tlsEndpointServices() for i = 0; i < len(services); i++ { policyID = current[services[i]] if policyID == "" { policyID = TLSAuthPolicyIDDefault } out = append(out, models.TLSEndpointPolicy{ Service: services[i], PolicyID: policyID, }) } return out } func encodeTLSEndpointPolicies(items []models.TLSEndpointPolicy) (string, error) { var normalized []models.TLSEndpointPolicy var data []byte var err error normalized = normalizeTLSEndpointPolicies(items) data, err = json.Marshal(normalized) if err != nil { return "", err } return string(data), nil } func decodeTLSEndpointPolicies(value string) []models.TLSEndpointPolicy { var items []models.TLSEndpointPolicy var err error value = strings.TrimSpace(value) if value == "" { return nil } err = json.Unmarshal([]byte(value), &items) if err != nil { return nil } return normalizeTLSEndpointPolicies(items) } func (s *Store) ListTLSAuthPolicies() ([]models.TLSAuthPolicy, error) { var rows *sql.Rows var err error var items []models.TLSAuthPolicy var item models.TLSAuthPolicy var allowedPKICertIDs string var allowedFingerprints string var requiredPermissions string rows, err = s.Query(`SELECT public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at FROM tls_auth_policies ORDER BY name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs) item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints) item.RequiredPermissions = splitCSVValue(requiredPermissions) items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetTLSAuthPolicy(id string) (models.TLSAuthPolicy, error) { var row *sql.Row var item models.TLSAuthPolicy var err error var allowedPKICertIDs string var allowedFingerprints string var requiredPermissions string row = s.QueryRow(`SELECT public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at FROM tls_auth_policies WHERE public_id = ?`, id) err = row.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt) if err != nil { return item, err } item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs) item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints) item.RequiredPermissions = splitCSVValue(requiredPermissions) return item, nil } func (s *Store) CreateTLSAuthPolicy(item models.TLSAuthPolicy) (models.TLSAuthPolicy, error) { var id string var now int64 var err error if item.ID == "" { id, err = util.NewID() if err != nil { return item, err } item.ID = id } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now _, err = s.Exec(`INSERT INTO tls_auth_policies (public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedPKIClientCertIDs, ","), strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.CreatedAt, item.UpdatedAt) if err != nil { return item, err } return item, nil } func (s *Store) UpdateTLSAuthPolicy(item models.TLSAuthPolicy) error { var now int64 var err error now = time.Now().UTC().Unix() item.UpdatedAt = now _, err = s.Exec(`UPDATE tls_auth_policies SET name = ?, description = ?, read_mode = ?, write_mode = ?, require_principal_for_write = ?, allowed_pki_client_cert_ids = ?, allowed_cert_fingerprints = ?, required_permissions = ?, permission_match = ?, required_scope = ?, updated_at = ? WHERE public_id = ?`, item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedPKIClientCertIDs, ","), strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.UpdatedAt, item.ID) return err } func (s *Store) DeleteTLSAuthPolicy(id string) error { var err error _, err = s.Exec(`DELETE FROM tls_auth_policies WHERE public_id = ?`, id) return err } func (s *Store) CountTLSAuthPolicyUsages(id string) (int, error) { var count int var settings models.TLSSettings var listeners []models.TLSListener var i int var j int var err error settings, err = s.GetTLSSettings() if err != nil { return 0, err } for i = 0; i < len(settings.EndpointPolicies); i++ { if settings.EndpointPolicies[i].PolicyID == id { count++ } } listeners, err = s.ListTLSListeners() if err != nil { return 0, err } for i = 0; i < len(listeners); i++ { for j = 0; j < len(listeners[i].EndpointPolicies); j++ { if listeners[i].EndpointPolicies[j].PolicyID == id { count++ } } } return count, nil } func (s *Store) EnsureDefaultTLSAuthPolicies() error { var items []models.TLSAuthPolicy var i int var item models.TLSAuthPolicy var existing models.TLSAuthPolicy var err error items = tlsBuiltinAuthPolicies() for i = 0; i < len(items); i++ { item = items[i] existing, err = s.GetTLSAuthPolicy(item.ID) if err == nil { if existing.Name == item.Name { continue } continue } if err != sql.ErrNoRows { return err } _, err = s.CreateTLSAuthPolicy(item) if err != nil { return err } } return nil } func (s *Store) EnsureTLSAuthPolicyBindings() error { var settings models.TLSSettings var listeners []models.TLSListener var i int var err error settings, err = s.GetTLSSettings() if err != nil { return err } if len(settings.EndpointPolicies) == 0 { settings.EndpointPolicies = DefaultTLSEndpointPolicies() err = s.SetTLSSettings(settings) if err != nil { return err } } listeners, err = s.ListTLSListeners() if err != nil { return err } for i = 0; i < len(listeners); i++ { if len(listeners[i].EndpointPolicies) == 0 { listeners[i].EndpointPolicies = DefaultTLSEndpointPolicies() err = s.UpdateTLSListener(listeners[i]) if err != nil { return err } } } return nil } func (s *Store) ResolveTLSAuthPolicy(listenerKind string, listenerID string, service string) (models.TLSAuthPolicy, error) { var settings models.TLSSettings var listener models.TLSListener var policyID string var err error if strings.TrimSpace(listenerKind) == "extra" { listener, err = s.GetTLSListener(strings.TrimSpace(listenerID)) if err != nil { return models.TLSAuthPolicy{}, err } policyID = tlsEndpointPolicyIDForService(listener.EndpointPolicies, service) } else { settings, err = s.GetTLSSettings() if err != nil { return models.TLSAuthPolicy{}, err } policyID = tlsEndpointPolicyIDForService(settings.EndpointPolicies, service) } if policyID == "" { policyID = TLSAuthPolicyIDDefault } return s.GetTLSAuthPolicy(policyID) } func tlsEndpointPolicyIDForService(items []models.TLSEndpointPolicy, service string) string { var i int var normalizedService string normalizedService = strings.ToLower(strings.TrimSpace(service)) for i = 0; i < len(items); i++ { if strings.ToLower(strings.TrimSpace(items[i].Service)) != normalizedService { continue } return strings.TrimSpace(items[i].PolicyID) } return "" }