387 lines
12 KiB
Go
387 lines
12 KiB
Go
package db
|
|
|
|
import "database/sql"
|
|
import "encoding/json"
|
|
import "strings"
|
|
import "time"
|
|
|
|
import "codit/internal/models"
|
|
import "codit/internal/util"
|
|
|
|
const TLSAuthPolicyIDDefault string = "tls-auth-default"
|
|
const tlsAuthPolicyIDReadOpenWriteCert string = "tls-auth-read-open-write-cert"
|
|
const tlsAuthPolicyIDReadOpenWriteCertOrAuth string = "tls-auth-read-open-write-cert-or-auth"
|
|
const tlsAuthPolicyIDCertOnly string = "tls-auth-cert-only"
|
|
const tlsAuthPolicyIDReadOnlyPublic string = "tls-auth-read-only-public"
|
|
const tlsAuthPolicyIDDenyAccess string = "tls-auth-deny-all"
|
|
|
|
func tlsBuiltinAuthPolicies() []models.TLSAuthPolicy {
|
|
var items []models.TLSAuthPolicy
|
|
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: TLSAuthPolicyIDDefault,
|
|
Name: "Default",
|
|
Description: "Require normal authentication for read and write operations.",
|
|
ReadMode: "auth",
|
|
WriteMode: "auth",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDReadOpenWriteCert,
|
|
Name: "Read Open / Write Cert",
|
|
Description: "Allow reads without auth and require a client certificate for writes.",
|
|
ReadMode: "public",
|
|
WriteMode: "cert",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDReadOpenWriteCertOrAuth,
|
|
Name: "Read Open / Write Cert Or Auth",
|
|
Description: "Allow reads without auth and allow writes by client certificate or normal auth.",
|
|
ReadMode: "public",
|
|
WriteMode: "cert_or_auth",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDCertOnly,
|
|
Name: "Cert Only",
|
|
Description: "Require a client certificate for reads and writes.",
|
|
ReadMode: "cert",
|
|
WriteMode: "cert",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDReadOnlyPublic,
|
|
Name: "Read Only Public",
|
|
Description: "Allow reads without auth and deny writes.",
|
|
ReadMode: "public",
|
|
WriteMode: "deny",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
items = append(items, models.TLSAuthPolicy{
|
|
ID: tlsAuthPolicyIDDenyAccess,
|
|
Name: "Deny Access",
|
|
Description: "Deny reads and deny writes.",
|
|
ReadMode: "deny",
|
|
WriteMode: "deny",
|
|
RequirePrincipalForWrite: true,
|
|
PermissionMatch: "all",
|
|
})
|
|
return items
|
|
}
|
|
|
|
func tlsEndpointServices() []string {
|
|
var items []string
|
|
|
|
items = append(items, "api")
|
|
items = append(items, "git")
|
|
items = append(items, "rpm")
|
|
items = append(items, "v2")
|
|
return items
|
|
}
|
|
|
|
func DefaultTLSEndpointPolicies() []models.TLSEndpointPolicy {
|
|
var services []string
|
|
var items []models.TLSEndpointPolicy
|
|
var i int
|
|
|
|
services = tlsEndpointServices()
|
|
for i = 0; i < len(services); i++ {
|
|
items = append(items, models.TLSEndpointPolicy{
|
|
Service: services[i],
|
|
PolicyID: TLSAuthPolicyIDDefault,
|
|
})
|
|
}
|
|
return items
|
|
}
|
|
|
|
func normalizeTLSEndpointPolicies(items []models.TLSEndpointPolicy) []models.TLSEndpointPolicy {
|
|
var out []models.TLSEndpointPolicy
|
|
var current map[string]string
|
|
var services []string
|
|
var i int
|
|
var service string
|
|
var policyID string
|
|
|
|
current = make(map[string]string)
|
|
for i = 0; i < len(items); i++ {
|
|
service = strings.ToLower(strings.TrimSpace(items[i].Service))
|
|
if service == "" {
|
|
continue
|
|
}
|
|
policyID = strings.TrimSpace(items[i].PolicyID)
|
|
if policyID == "" {
|
|
policyID = TLSAuthPolicyIDDefault
|
|
}
|
|
current[service] = policyID
|
|
}
|
|
services = tlsEndpointServices()
|
|
for i = 0; i < len(services); i++ {
|
|
policyID = current[services[i]]
|
|
if policyID == "" {
|
|
policyID = TLSAuthPolicyIDDefault
|
|
}
|
|
out = append(out, models.TLSEndpointPolicy{
|
|
Service: services[i],
|
|
PolicyID: policyID,
|
|
})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func encodeTLSEndpointPolicies(items []models.TLSEndpointPolicy) (string, error) {
|
|
var normalized []models.TLSEndpointPolicy
|
|
var data []byte
|
|
var err error
|
|
|
|
normalized = normalizeTLSEndpointPolicies(items)
|
|
data, err = json.Marshal(normalized)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(data), nil
|
|
}
|
|
|
|
func decodeTLSEndpointPolicies(value string) []models.TLSEndpointPolicy {
|
|
var items []models.TLSEndpointPolicy
|
|
var err error
|
|
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return nil
|
|
}
|
|
err = json.Unmarshal([]byte(value), &items)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return normalizeTLSEndpointPolicies(items)
|
|
}
|
|
|
|
func (s *Store) ListTLSAuthPolicies() ([]models.TLSAuthPolicy, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
var items []models.TLSAuthPolicy
|
|
var item models.TLSAuthPolicy
|
|
var allowedPKICertIDs string
|
|
var allowedFingerprints string
|
|
var requiredPermissions string
|
|
|
|
rows, err = s.Query(`SELECT public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at FROM tls_auth_policies ORDER BY name`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
|
|
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
|
|
item.RequiredPermissions = splitCSVValue(requiredPermissions)
|
|
items = append(items, item)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) GetTLSAuthPolicy(id string) (models.TLSAuthPolicy, error) {
|
|
var row *sql.Row
|
|
var item models.TLSAuthPolicy
|
|
var err error
|
|
var allowedPKICertIDs string
|
|
var allowedFingerprints string
|
|
var requiredPermissions string
|
|
|
|
row = s.QueryRow(`SELECT public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at FROM tls_auth_policies WHERE public_id = ?`, id)
|
|
err = row.Scan(&item.ID, &item.Name, &item.Description, &item.ReadMode, &item.WriteMode, &item.RequirePrincipalForWrite, &allowedPKICertIDs, &allowedFingerprints, &requiredPermissions, &item.PermissionMatch, &item.RequiredScope, &item.CreatedAt, &item.UpdatedAt)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item.AllowedPKIClientCertIDs = splitCSVValue(allowedPKICertIDs)
|
|
item.AllowedCertFingerprints = splitCSVValue(allowedFingerprints)
|
|
item.RequiredPermissions = splitCSVValue(requiredPermissions)
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) CreateTLSAuthPolicy(item models.TLSAuthPolicy) (models.TLSAuthPolicy, error) {
|
|
var id string
|
|
var now int64
|
|
var err error
|
|
|
|
if item.ID == "" {
|
|
id, err = util.NewID()
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
item.ID = id
|
|
}
|
|
now = time.Now().UTC().Unix()
|
|
item.CreatedAt = now
|
|
item.UpdatedAt = now
|
|
_, err = s.Exec(`INSERT INTO tls_auth_policies (public_id, name, description, read_mode, write_mode, require_principal_for_write, allowed_pki_client_cert_ids, allowed_cert_fingerprints, required_permissions, permission_match, required_scope, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
item.ID, item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedPKIClientCertIDs, ","), strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.CreatedAt, item.UpdatedAt)
|
|
if err != nil {
|
|
return item, err
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) UpdateTLSAuthPolicy(item models.TLSAuthPolicy) error {
|
|
var now int64
|
|
var err error
|
|
|
|
now = time.Now().UTC().Unix()
|
|
item.UpdatedAt = now
|
|
_, err = s.Exec(`UPDATE tls_auth_policies SET name = ?, description = ?, read_mode = ?, write_mode = ?, require_principal_for_write = ?, allowed_pki_client_cert_ids = ?, allowed_cert_fingerprints = ?, required_permissions = ?, permission_match = ?, required_scope = ?, updated_at = ? WHERE public_id = ?`,
|
|
item.Name, item.Description, item.ReadMode, item.WriteMode, item.RequirePrincipalForWrite, strings.Join(item.AllowedPKIClientCertIDs, ","), strings.Join(item.AllowedCertFingerprints, ","), strings.Join(item.RequiredPermissions, ","), item.PermissionMatch, item.RequiredScope, item.UpdatedAt, item.ID)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) DeleteTLSAuthPolicy(id string) error {
|
|
var err error
|
|
|
|
_, err = s.Exec(`DELETE FROM tls_auth_policies WHERE public_id = ?`, id)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CountTLSAuthPolicyUsages(id string) (int, error) {
|
|
var count int
|
|
var settings models.TLSSettings
|
|
var listeners []models.TLSListener
|
|
var i int
|
|
var j int
|
|
var err error
|
|
|
|
settings, err = s.GetTLSSettings()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i = 0; i < len(settings.EndpointPolicies); i++ {
|
|
if settings.EndpointPolicies[i].PolicyID == id {
|
|
count++
|
|
}
|
|
}
|
|
listeners, err = s.ListTLSListeners()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i = 0; i < len(listeners); i++ {
|
|
for j = 0; j < len(listeners[i].EndpointPolicies); j++ {
|
|
if listeners[i].EndpointPolicies[j].PolicyID == id {
|
|
count++
|
|
}
|
|
}
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (s *Store) EnsureDefaultTLSAuthPolicies() error {
|
|
var items []models.TLSAuthPolicy
|
|
var i int
|
|
var item models.TLSAuthPolicy
|
|
var existing models.TLSAuthPolicy
|
|
var err error
|
|
|
|
items = tlsBuiltinAuthPolicies()
|
|
for i = 0; i < len(items); i++ {
|
|
item = items[i]
|
|
existing, err = s.GetTLSAuthPolicy(item.ID)
|
|
if err == nil {
|
|
if existing.Name == item.Name {
|
|
continue
|
|
}
|
|
continue
|
|
}
|
|
if err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
_, err = s.CreateTLSAuthPolicy(item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) EnsureTLSAuthPolicyBindings() error {
|
|
var settings models.TLSSettings
|
|
var listeners []models.TLSListener
|
|
var i int
|
|
var err error
|
|
|
|
settings, err = s.GetTLSSettings()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(settings.EndpointPolicies) == 0 {
|
|
settings.EndpointPolicies = DefaultTLSEndpointPolicies()
|
|
err = s.SetTLSSettings(settings)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
listeners, err = s.ListTLSListeners()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i = 0; i < len(listeners); i++ {
|
|
if len(listeners[i].EndpointPolicies) == 0 {
|
|
listeners[i].EndpointPolicies = DefaultTLSEndpointPolicies()
|
|
err = s.UpdateTLSListener(listeners[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ResolveTLSAuthPolicy(listenerKind string, listenerID string, service string) (models.TLSAuthPolicy, error) {
|
|
var settings models.TLSSettings
|
|
var listener models.TLSListener
|
|
var policyID string
|
|
var err error
|
|
|
|
if strings.TrimSpace(listenerKind) == "extra" {
|
|
listener, err = s.GetTLSListener(strings.TrimSpace(listenerID))
|
|
if err != nil {
|
|
return models.TLSAuthPolicy{}, err
|
|
}
|
|
policyID = tlsEndpointPolicyIDForService(listener.EndpointPolicies, service)
|
|
} else {
|
|
settings, err = s.GetTLSSettings()
|
|
if err != nil {
|
|
return models.TLSAuthPolicy{}, err
|
|
}
|
|
policyID = tlsEndpointPolicyIDForService(settings.EndpointPolicies, service)
|
|
}
|
|
if policyID == "" {
|
|
policyID = TLSAuthPolicyIDDefault
|
|
}
|
|
return s.GetTLSAuthPolicy(policyID)
|
|
}
|
|
|
|
func tlsEndpointPolicyIDForService(items []models.TLSEndpointPolicy, service string) string {
|
|
var i int
|
|
var normalizedService string
|
|
|
|
normalizedService = strings.ToLower(strings.TrimSpace(service))
|
|
for i = 0; i < len(items); i++ {
|
|
if strings.ToLower(strings.TrimSpace(items[i].Service)) != normalizedService {
|
|
continue
|
|
}
|
|
return strings.TrimSpace(items[i].PolicyID)
|
|
}
|
|
return ""
|
|
}
|