Files

1645 lines
50 KiB
Go

package main
import "flag"
//import "io"
import "context"
import "crypto/sha256"
import "crypto/tls"
import "crypto/x509"
import "encoding/hex"
import "errors"
import "fmt"
import "log"
import "net"
import "net/http"
import "os"
import "path/filepath"
import "sync"
import "strings"
import "time"
import "codit/internal/auth"
import "codit/internal/config"
import "codit/internal/db"
import "codit/internal/docker"
import "codit/internal/git"
import "codit/internal/handlers"
import httpx "codit/internal/http"
import "codit/internal/middleware"
import "codit/internal/models"
import "codit/internal/rpm"
import "codit/internal/storage"
import "codit/internal/util"
import _ "modernc.org/sqlite"
type gitPathRewriteHandler struct {
next http.Handler
store *db.Store
}
type server_http_log_writer struct {
l *util.Logger
id string
depth int
}
func (hlw *server_http_log_writer) Write(p []byte) (n int, err error) {
// the standard http.Server always requires *log.Logger
// use this iowriter to create a logger to pass it to the http server.
// since this is another log write wrapper, give adjustment value
hlw.l.WriteWithCallDepth(hlw.id, util.LOG_INFO, hlw.depth, string(p))
return len(p), nil
}
func (h *gitPathRewriteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var path string
var parts []string
var slug string
var repoSegment string
var rest string
var repoName string
var project models.Project
var repo models.Repo
var projectStorageID int64
var repoStorageID int64
var err error
var newPath string
path = strings.TrimPrefix(r.URL.Path, "/")
if path == "" {
h.next.ServeHTTP(w, r)
return
}
parts = strings.SplitN(path, "/", 3)
if len(parts) < 2 {
h.next.ServeHTTP(w, r)
return
}
slug = parts[0]
repoSegment = parts[1]
if len(parts) == 3 {
rest = "/" + parts[2]
} else {
rest = ""
}
repoName = strings.TrimSuffix(repoSegment, ".git")
project, err = h.store.GetProjectBySlug(slug)
if err != nil {
http.NotFound(w, r)
return
}
repo, err = h.store.GetRepoByProjectNameType(project.ID, repoName, "git")
if err != nil {
http.NotFound(w, r)
return
}
projectStorageID, repoStorageID, err = h.store.GetRepoStorageIDs(repo.ID)
if err != nil {
http.NotFound(w, r)
return
}
newPath = "/" + storageIDSegment(projectStorageID) + "/" + storageIDSegment(repoStorageID) + ".git" + rest
r.URL.Path = newPath
h.next.ServeHTTP(w, r)
}
type gitIDPathRewriteHandler struct {
next http.Handler
store *db.Store
}
func (h *gitIDPathRewriteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var path string
var parts []string
var rest string
var repo models.Repo
var projectStorageID int64
var repoStorageID int64
var err error
var newPath string
var repoID string
path = strings.TrimPrefix(r.URL.Path, "/")
if path == "" {
h.next.ServeHTTP(w, r)
return
}
parts = strings.SplitN(path, "/", 3)
repoID = strings.TrimSuffix(parts[0], ".git")
repo, err = h.store.GetRepo(repoID)
if err != nil || repo.Type != "git" {
http.NotFound(w, r)
return
}
projectStorageID, repoStorageID, err = h.store.GetRepoStorageIDs(repo.ID)
if err != nil {
http.NotFound(w, r)
return
}
if len(parts) == 3 {
rest = "/" + parts[1] + "/" + parts[2]
} else if len(parts) == 2 {
rest = "/" + parts[1]
} else {
rest = ""
}
newPath = "/" + storageIDSegment(projectStorageID) + "/" + storageIDSegment(repoStorageID) + ".git" + rest
r.URL.Path = newPath
h.next.ServeHTTP(w, r)
}
type rpmPathRewriteHandler struct {
next http.Handler
store *db.Store
}
func (h *rpmPathRewriteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var path string
var parts []string
var slug string
var repoName string
var rest string
var project models.Project
var repo models.Repo
var projectStorageID int64
var repoStorageID int64
var err error
var newPath string
path = strings.TrimPrefix(r.URL.Path, "/")
if path == "" {
h.next.ServeHTTP(w, r)
return
}
parts = strings.SplitN(path, "/", 3)
if len(parts) < 2 {
h.next.ServeHTTP(w, r)
return
}
slug = parts[0]
repoName = parts[1]
if len(parts) == 3 {
rest = "/" + parts[2]
} else {
rest = ""
}
project, err = h.store.GetProjectBySlug(slug)
if err != nil {
http.NotFound(w, r)
return
}
repo, err = h.store.GetRepoByProjectNameType(project.ID, repoName, "rpm")
if err != nil {
http.NotFound(w, r)
return
}
projectStorageID, repoStorageID, err = h.store.GetRepoStorageIDs(repo.ID)
if err != nil {
http.NotFound(w, r)
return
}
newPath = "/" + storageIDSegment(projectStorageID) + "/" + storageIDSegment(repoStorageID) + rest
r.URL.Path = newPath
h.next.ServeHTTP(w, r)
}
type rpmIDPathRewriteHandler struct {
next http.Handler
store *db.Store
}
func (h *rpmIDPathRewriteHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var path string
var parts []string
var rest string
var repo models.Repo
var projectStorageID int64
var repoStorageID int64
var err error
var newPath string
var repoID string
path = strings.TrimPrefix(r.URL.Path, "/")
if path == "" {
h.next.ServeHTTP(w, r)
return
}
parts = strings.SplitN(path, "/", 3)
repoID = parts[0]
repo, err = h.store.GetRepo(repoID)
if err != nil || repo.Type != "rpm" {
http.NotFound(w, r)
return
}
projectStorageID, repoStorageID, err = h.store.GetRepoStorageIDs(repo.ID)
if err != nil {
http.NotFound(w, r)
return
}
if len(parts) == 3 {
rest = "/" + parts[1] + "/" + parts[2]
} else if len(parts) == 2 {
rest = "/" + parts[1]
} else {
rest = ""
}
newPath = "/" + storageIDSegment(projectStorageID) + "/" + storageIDSegment(repoStorageID) + rest
r.URL.Path = newPath
h.next.ServeHTTP(w, r)
}
func storageIDSegment(id int64) string {
return fmt.Sprintf("%016x", id)
}
func main() {
var configPath string
var cfg config.Config
var store *db.Store
var err error
var gitRewrite http.Handler
var gitIDRewrite http.Handler
var rpmRewrite http.Handler
var rpmIDRewrite http.Handler
var logger *util.Logger
var mainEndpoints []listenerEndpoint
var extraListenerManager *additionalListenerManager
flag.StringVar(&configPath, "config", "", "path to config json")
flag.Parse()
cfg, err = config.Load(configPath)
if err != nil {
log.Fatalf("config error: %v", err)
}
// log.SetOutput(io.Discard)
//if cfg.APP.LogFile == "" {
logger = util.NewLogger("codit", os.Stderr, util.LogStrsToMask(nil))
//} else {
// logger, err = util.NewLoggerToFile("codit", cfg.APP.LogFile, cfg.APP.LogMaxSize, cfg.APP.LogRotate, util.LogStrsToMask(cfg.APP.LogMask))
// if err != nil {
// log.Fatalf("failed to initialize logger - %s", err.Error())
// }
//}
err = os.MkdirAll(cfg.DataDir, 0o755)
if err != nil {
//log.Fatalf("data dir error: %v", err)
logger.Write("", util.LOG_ERROR, "data dir error: %v", err)
}
store, err = db.Open(cfg.DBDriver, cfg.DBDSN)
if err != nil {
log.Fatalf("db open error: %v", err)
}
defer store.Close()
err = store.ApplyMigrations(filepath.Join("migrations"))
if err != nil {
log.Fatalf("migrations error: %v", err)
}
err = mergeTLSSettingsFromDB(&cfg, store)
if err != nil {
log.Fatalf("tls settings load error: %v", err)
}
err = bootstrapAdmin(store)
if err != nil {
log.Fatalf("bootstrap admin error: %v", err)
}
var repoManager git.RepoManager
repoManager = git.RepoManager{BaseDir: filepath.Join(cfg.DataDir, "git")}
var rpmBase string
var dockerBase string
rpmBase = filepath.Join(cfg.DataDir, "rpm")
dockerBase = filepath.Join(cfg.DataDir, "docker")
var rpmMeta *rpm.MetaManager
var rpmMirror *rpm.MirrorManager
rpmMeta = rpm.NewMetaManager()
rpmMirror = rpm.NewMirrorManager(store, logger, rpmMeta)
var uploadStore storage.FileStore
uploadStore = storage.FileStore{BaseDir: filepath.Join(cfg.DataDir, "uploads")}
err = os.MkdirAll(repoManager.BaseDir, 0o755)
if err != nil {
log.Fatalf("git dir error: %v", err)
}
err = os.MkdirAll(rpmBase, 0o755)
if err != nil {
log.Fatalf("rpm dir error: %v", err)
}
err = os.MkdirAll(dockerBase, 0o755)
if err != nil {
log.Fatalf("docker dir error: %v", err)
}
var api *handlers.API
api = &handlers.API{
Cfg: cfg,
Store: store,
Repos: repoManager,
RpmBase: rpmBase,
RpmMeta: rpmMeta,
RpmMirror: rpmMirror,
DockerBase: dockerBase,
Uploads: uploadStore,
Logger: logger,
}
rpmMirror.Start()
var graphqlHandler http.Handler
graphqlHandler, err = handlers.NewGraphQL(store)
if err != nil {
log.Fatalf("graphql init error: %v", err)
}
var gitServer http.Handler
var authFunc func(username string, password string) (bool, error)
authFunc = func(username string, password string) (bool, error) {
var user models.User
var hash string
var err error
var key string
user, hash, err = store.GetUserByUsername(username)
if err != nil || hash == "" {
if password != "" {
key = password
user, err = store.GetUserByAPIKeyHash(util.HashToken(key))
if err == nil {
return true, nil
}
}
if password == "" && username != "" {
key = username
user, err = store.GetUserByAPIKeyHash(util.HashToken(key))
if err == nil {
return true, nil
}
}
return false, nil
}
err = auth.ComparePassword(hash, password)
if err != nil {
if password != "" {
key = password
user, err = store.GetUserByAPIKeyHash(util.HashToken(key))
if err == nil {
return true, nil
}
}
return false, nil
}
if user.Disabled {
return false, nil
}
return user.ID != "", nil
}
gitServer, err = git.NewHTTPServer(repoManager.BaseDir, nil, logger)
if err != nil {
log.Fatalf("git server error: %v", err)
}
var rpmServer http.Handler
rpmServer = rpm.NewHTTPServer(rpmBase, nil, logger)
var dockerServer http.Handler
dockerServer = docker.NewHTTPServer(store, dockerBase, nil, logger)
var router *httpx.Router
router = httpx.NewRouter()
router.Handle("GET", "/api/health", api.Health)
router.Handle("POST", "/api/login", api.Login)
router.Handle("POST", "/api/logout", api.Logout)
router.Handle("GET", "/api/auth/oidc/enabled", api.OIDCEnabled)
router.Handle("GET", "/api/auth/oidc/login", api.OIDCLogin)
router.Handle("GET", "/api/auth/oidc/callback", api.OIDCCallback)
router.Handle("GET", "/api/me", api.Me)
router.Handle("PATCH", "/api/me", api.UpdateMe)
router.Handle("GET", "/api/users", api.ListUsers)
router.Handle("POST", "/api/users", api.CreateUser)
router.Handle("PATCH", "/api/users/:id", api.UpdateUser)
router.Handle("DELETE", "/api/users/:id", api.DeleteUser)
router.Handle("POST", "/api/users/:id/disable", api.DisableUser)
router.Handle("POST", "/api/users/:id/enable", api.EnableUser)
router.Handle("GET", "/api/admin/auth", api.GetAuthSettings)
router.Handle("PATCH", "/api/admin/auth", api.UpdateAuthSettings)
router.Handle("POST", "/api/admin/auth/test", api.TestLDAPSettings)
router.Handle("GET", "/api/admin/tls", api.GetTLSSettings)
router.Handle("PATCH", "/api/admin/tls", api.UpdateTLSSettings)
router.Handle("GET", "/api/admin/tls/listeners", api.ListTLSListeners)
router.Handle("GET", "/api/admin/tls/listeners/runtime", api.GetTLSListenerRuntimeStatus)
router.Handle("POST", "/api/admin/tls/listeners", api.CreateTLSListener)
router.Handle("PATCH", "/api/admin/tls/listeners/:id", api.UpdateTLSListener)
router.Handle("DELETE", "/api/admin/tls/listeners/:id", api.DeleteTLSListener)
router.Handle("GET", "/api/admin/service-principals", api.ListServicePrincipals)
router.Handle("POST", "/api/admin/service-principals", api.CreateServicePrincipal)
router.Handle("PATCH", "/api/admin/service-principals/:id", api.UpdateServicePrincipal)
router.Handle("DELETE", "/api/admin/service-principals/:id", api.DeleteServicePrincipal)
router.Handle("GET", "/api/admin/service-principals/:id/roles", api.ListPrincipalProjectRoles)
router.Handle("POST", "/api/admin/service-principals/:id/roles", api.UpsertPrincipalProjectRole)
router.Handle("DELETE", "/api/admin/service-principals/:id/roles/:projectId", api.DeletePrincipalProjectRole)
router.Handle("GET", "/api/admin/cert-principal-bindings", api.ListCertPrincipalBindings)
router.Handle("POST", "/api/admin/cert-principal-bindings", api.UpsertCertPrincipalBinding)
router.Handle("DELETE", "/api/admin/cert-principal-bindings/:fingerprint", api.DeleteCertPrincipalBinding)
router.Handle("GET", "/api/admin/auth/ldap", api.GetAuthSettings)
router.Handle("PATCH", "/api/admin/auth/ldap", api.UpdateAuthSettings)
router.Handle("POST", "/api/admin/auth/ldap/test", api.TestLDAPSettings)
router.Handle("GET", "/api/admin/pki/cas", api.ListPKICAs)
router.Handle("GET", "/api/admin/pki/cas/:id", api.GetPKICA)
router.Handle("PATCH", "/api/admin/pki/cas/:id", api.UpdatePKICA)
router.Handle("GET", "/api/admin/pki/cas/:id/bundle", api.DownloadPKICABundle)
router.Handle("POST", "/api/admin/pki/cas/root", api.CreatePKIRootCA)
router.Handle("POST", "/api/admin/pki/cas/intermediate", api.CreatePKIIntermediateCA)
router.Handle("GET", "/api/admin/pki/cas/:id/crl", api.GetPKICRL)
router.Handle("DELETE", "/api/admin/pki/cas/:id", api.DeletePKICA)
router.Handle("GET", "/api/admin/pki/certs", api.ListPKICerts)
router.Handle("GET", "/api/admin/pki/certs/:id", api.GetPKICert)
router.Handle("GET", "/api/admin/pki/certs/:id/bundle", api.DownloadPKICertBundle)
router.Handle("POST", "/api/admin/pki/certs", api.IssuePKICert)
router.Handle("POST", "/api/admin/pki/certs/import", api.ImportPKICert)
router.Handle("POST", "/api/admin/pki/certs/:id/revoke", api.RevokePKICert)
router.Handle("DELETE", "/api/admin/pki/certs/:id", api.DeletePKICert)
router.Handle("GET", "/api/projects", api.ListProjects)
router.Handle("POST", "/api/projects", api.CreateProject)
router.Handle("GET", "/api/projects/:id", api.GetProject)
router.Handle("PATCH", "/api/projects/:id", api.UpdateProject)
router.Handle("DELETE", "/api/projects/:id", api.DeleteProject)
router.Handle("GET", "/api/projects/:projectId/members", api.ListProjectMembers)
router.Handle("GET", "/api/projects/:projectId/member-candidates", api.ListProjectMemberCandidates)
router.Handle("POST", "/api/projects/:projectId/members", api.AddProjectMember)
router.Handle("PATCH", "/api/projects/:projectId/members", api.UpdateProjectMember)
router.Handle("DELETE", "/api/projects/:projectId/members/:userId", api.RemoveProjectMember)
router.Handle("GET", "/api/projects/:projectId/repos", api.ListRepos)
router.Handle("POST", "/api/projects/:projectId/repos", api.CreateRepo)
router.Handle("GET", "/api/repos/types", api.RepoTypes)
router.Handle("GET", "/api/repos", api.ListAllRepos)
router.Handle("GET", "/api/projects/:projectId/foreign-repos/available", api.ListAvailableRepos)
router.Handle("POST", "/api/projects/:projectId/foreign-repos", api.AttachForeignRepo)
router.Handle("DELETE", "/api/projects/:projectId/foreign-repos/:repoId", api.DetachForeignRepo)
router.Handle("GET", "/api/repos/:id", api.GetRepo)
router.Handle("PATCH", "/api/repos/:id", api.UpdateRepo)
router.Handle("DELETE", "/api/repos/:id", api.DeleteRepo)
router.Handle("GET", "/api/repos/:id/branches", api.RepoBranches)
router.Handle("GET", "/api/repos/:id/branches/info", api.RepoBranchesInfo)
router.Handle("PUT", "/api/repos/:id/branches/default", api.RepoSetDefaultBranch)
router.Handle("POST", "/api/repos/:id/branches/rename", api.RepoRenameBranch)
router.Handle("POST", "/api/repos/:id/branches/delete", api.RepoDeleteBranch)
router.Handle("POST", "/api/repos/:id/branches/create", api.RepoCreateBranch)
router.Handle("GET", "/api/repos/:id/commits", api.RepoCommits)
router.Handle("GET", "/api/repos/:id/tree", api.RepoTree)
router.Handle("GET", "/api/repos/:id/blob", api.RepoBlob)
router.Handle("GET", "/api/repos/:id/blob/raw", api.RepoBlobRaw)
router.Handle("GET", "/api/repos/:id/history", api.RepoFileHistory)
router.Handle("GET", "/api/repos/:id/diff", api.RepoFileDiff)
router.Handle("GET", "/api/repos/:id/commit", api.RepoCommitDetail)
router.Handle("GET", "/api/repos/:id/commit/diff", api.RepoCommitDiff)
router.Handle("GET", "/api/repos/:id/compare", api.RepoCompare)
router.Handle("GET", "/api/repos/:id/stats", api.RepoStats)
router.Handle("GET", "/api/repos/:id/rpm/packages", api.RepoRPMPackages)
router.Handle("GET", "/api/repos/:id/rpm/package", api.RepoRPMPackage)
router.Handle("POST", "/api/repos/:id/rpm/subdirs", api.RepoRPMCreateSubdir)
router.Handle("GET", "/api/repos/:id/rpm/subdir", api.RepoRPMGetSubdir)
router.Handle("POST", "/api/repos/:id/rpm/subdir/update", api.RepoRPMRenameSubdir)
router.Handle("POST", "/api/repos/:id/rpm/subdir/rename", api.RepoRPMRenameSubdir)
router.Handle("POST", "/api/repos/:id/rpm/subdir/sync", api.RepoRPMSyncSubdir)
router.Handle("POST", "/api/repos/:id/rpm/subdir/suspend", api.RepoRPMSuspendSubdir)
router.Handle("POST", "/api/repos/:id/rpm/subdir/resume", api.RepoRPMResumeSubdir)
router.Handle("POST", "/api/repos/:id/rpm/subdir/rebuild-metadata", api.RepoRPMRebuildSubdirMetadata)
router.Handle("POST", "/api/repos/:id/rpm/subdir/cancel", api.RepoRPMCancelSubdirSync)
router.Handle("GET", "/api/repos/:id/rpm/subdir/runs", api.RepoRPMMirrorRuns)
router.Handle("DELETE", "/api/repos/:id/rpm/subdir/runs", api.RepoRPMClearMirrorRuns)
router.Handle("DELETE", "/api/repos/:id/rpm/subdir", api.RepoRPMDeleteSubdir)
router.Handle("DELETE", "/api/repos/:id/rpm/file", api.RepoRPMDeleteFile)
router.Handle("GET", "/api/repos/:id/rpm/file", api.RepoRPMFile)
router.Handle("GET", "/api/repos/:id/rpm/tree", api.RepoRPMTree)
router.Handle("POST", "/api/repos/:id/rpm/upload", api.RepoRPMUpload)
router.Handle("GET", "/api/repos/:id/docker/images", api.RepoDockerImages)
router.Handle("GET", "/api/repos/:id/docker/tags", api.RepoDockerTags)
router.Handle("GET", "/api/repos/:id/docker/manifest", api.RepoDockerManifest)
router.Handle("DELETE", "/api/repos/:id/docker/tag", api.RepoDockerDeleteTag)
router.Handle("DELETE", "/api/repos/:id/docker/image", api.RepoDockerDeleteImage)
router.Handle("POST", "/api/repos/:id/docker/tag/rename", api.RepoDockerRenameTag)
router.Handle("POST", "/api/repos/:id/docker/image/rename", api.RepoDockerRenameImage)
router.Handle("GET", "/api/me/keys", api.ListAPIKeys)
router.Handle("POST", "/api/me/keys", api.CreateAPIKey)
router.Handle("DELETE", "/api/me/keys/:id", api.DeleteAPIKey)
router.Handle("POST", "/api/me/keys/:id/disable", api.DisableAPIKey)
router.Handle("POST", "/api/me/keys/:id/enable", api.EnableAPIKey)
router.Handle("GET", "/api/admin/api-keys", api.ListAdminAPIKeys)
router.Handle("DELETE", "/api/admin/api-keys/:id", api.DeleteAdminAPIKey)
router.Handle("POST", "/api/admin/api-keys/:id/disable", api.DisableAdminAPIKey)
router.Handle("POST", "/api/admin/api-keys/:id/enable", api.EnableAdminAPIKey)
router.Handle("GET", "/api/projects/:projectId/issues", api.ListIssues)
router.Handle("POST", "/api/projects/:projectId/issues", api.CreateIssue)
router.Handle("PATCH", "/api/issues/:id", api.UpdateIssue)
router.Handle("POST", "/api/issues/:id/comments", api.AddIssueComment)
router.Handle("GET", "/api/projects/:projectId/wiki/pages", api.ListWikiPages)
router.Handle("POST", "/api/projects/:projectId/wiki/pages", api.CreateWikiPage)
router.Handle("PATCH", "/api/wiki/pages/:id", api.UpdateWikiPage)
router.Handle("POST", "/api/projects/:projectId/uploads", api.UploadFile)
router.Handle("GET", "/api/projects/:projectId/uploads", api.ListUploads)
router.Handle("GET", "/api/uploads/:id", api.DownloadFile)
var mux *http.ServeMux
mux = http.NewServeMux()
gitRewrite = &gitPathRewriteHandler{next: gitServer, store: store}
mux.Handle(cfg.GitHTTPPrefix+"/", withServiceAuth(serviceGit, http.StripPrefix(cfg.GitHTTPPrefix, gitRewrite), authFunc, store, logger))
gitIDRewrite = &gitIDPathRewriteHandler{next: gitServer, store: store}
mux.Handle(cfg.GitHTTPPrefix+"-id/", withServiceAuth(serviceGit, http.StripPrefix(cfg.GitHTTPPrefix+"-id", gitIDRewrite), authFunc, store, logger))
rpmRewrite = &rpmPathRewriteHandler{next: rpmServer, store: store}
mux.Handle(cfg.RPMHTTPPrefix+"/", withServiceAuth(serviceRPM, http.StripPrefix(cfg.RPMHTTPPrefix, rpmRewrite), authFunc, store, logger))
rpmIDRewrite = &rpmIDPathRewriteHandler{next: rpmServer, store: store}
mux.Handle(cfg.RPMHTTPPrefix+"-id/", withServiceAuth(serviceRPM, http.StripPrefix(cfg.RPMHTTPPrefix+"-id", rpmIDRewrite), authFunc, store, logger))
mux.Handle("/v2", withServiceAuth(serviceV2, dockerServer, authFunc, store, logger))
mux.Handle("/v2/", withServiceAuth(serviceV2, dockerServer, authFunc, store, logger))
mux.HandleFunc("/pki/crl/", api.ServePKICRL)
mux.Handle("/api/graphql", middleware.WithUser(store, middleware.RequireAuth(graphqlHandler)))
mux.Handle("/api/", middleware.WithUser(store, middleware.AccessLog(logger, withAPIAuth(router, authFunc, store, logger))))
mux.Handle("/api/login", middleware.WithUser(store, middleware.AccessLog(logger, router)))
mux.Handle("/api/logout", middleware.WithUser(store, middleware.AccessLog(logger, router)))
mux.Handle("/api/auth/oidc/enabled", middleware.WithUser(store, middleware.AccessLog(logger, router)))
mux.Handle("/api/auth/oidc/login", middleware.WithUser(store, middleware.AccessLog(logger, router)))
mux.Handle("/api/auth/oidc/callback", middleware.WithUser(store, middleware.AccessLog(logger, router)))
mux.Handle("/api/health", middleware.AccessLog(logger, router))
mux.Handle("/", middleware.WithUser(store, spaHandler(cfg.FrontendDir)))
extraListenerManager = newAdditionalListenerManager(store, mux, logger)
api.OnTLSListenersChanged = extraListenerManager.NotifyReload
api.OnTLSListenerRuntimeStatus = extraListenerManager.ListenerEndpointCounts
err = extraListenerManager.Start()
if err != nil {
logger.Write("", util.LOG_ERROR, "additional listener manager error: %v", err)
os.Exit(1)
}
mainEndpoints, err = buildListenerEndpoints("main", tlsSettingsFromConfig(cfg), defaultListenerPolicy(), store)
if err != nil {
log.Fatalf("main listener config error: %v", err)
}
err = serveListeners(mainEndpoints, mux, logger)
if err != nil {
log.Fatalf("server error: %v", err)
}
logger.Close()
}
type listenerEndpoint struct {
Name string
Key string
Addr string
IsHTTPS bool
TLSConfig *tls.Config
Policy listenerAuthPolicy
}
type ctxKey string
const listenerPolicyKey ctxKey = "listener_policy"
type serviceKind string
const (
serviceAPI serviceKind = "api"
serviceGit serviceKind = "git"
serviceRPM serviceKind = "rpm"
serviceV2 serviceKind = "v2"
)
type opKind string
const (
opRead opKind = "read"
opWrite opKind = "write"
)
type authOutcome string
const (
authRequireAuth authOutcome = "require_auth"
authAllow authOutcome = "allow"
authRequireCert authOutcome = "require_cert"
authCertOrAuth authOutcome = "require_cert_or_auth"
authDeny authOutcome = "deny"
)
type listenerAuthPolicy struct {
Mode string
ApplyAPI bool
ApplyGit bool
ApplyRPM bool
ApplyV2 bool
CertAllowlist map[string]bool
}
type runningListener struct {
Endpoint listenerEndpoint
Server *http.Server
}
type additionalListenerManager struct {
Store *db.Store
Handler http.Handler
Logger *util.Logger
mu sync.Mutex
Running map[string]runningListener
Reload chan struct{}
}
func newAdditionalListenerManager(store *db.Store, handler http.Handler, logger *util.Logger) *additionalListenerManager {
var manager *additionalListenerManager
manager = &additionalListenerManager{
Store: store,
Handler: handler,
Logger: logger,
Running: make(map[string]runningListener),
Reload: make(chan struct{}, 1),
}
return manager
}
func (m *additionalListenerManager) NotifyReload() {
select {
case m.Reload <- struct{}{}:
default:
}
}
func (m *additionalListenerManager) ListenerEndpointCounts() map[string]int {
var out map[string]int
var key string
var parts []string
var listenerID string
out = make(map[string]int)
m.mu.Lock()
for key = range m.Running {
parts = strings.SplitN(key, ":", 5)
if len(parts) < 3 {
continue
}
listenerID = strings.TrimSpace(parts[1])
if listenerID == "" {
continue
}
out[listenerID] = out[listenerID] + 1
}
m.mu.Unlock()
return out
}
func (m *additionalListenerManager) Start() error {
var err error
var i int
for i = 0; i < 30; i++ {
err = m.reconcile()
if err == nil {
break
}
if !isSQLiteBusyError(err) {
return err
}
time.Sleep(100 * time.Millisecond)
}
if err != nil {
return err
}
go func() {
var reconcileErr error
for range m.Reload {
reconcileErr = m.reconcile()
if reconcileErr != nil && m.Logger != nil {
m.Logger.Write("", util.LOG_ERROR, "additional listener reconcile error: %v", reconcileErr)
}
}
}()
return nil
}
func isSQLiteBusyError(err error) bool {
var msg string
if err == nil {
return false
}
msg = strings.ToLower(err.Error())
return strings.Contains(msg, "database is locked") || strings.Contains(msg, "sqlite_busy")
}
func (m *additionalListenerManager) reconcile() error {
var listeners []models.TLSListener
var desired map[string]listenerEndpoint
var err error
var i int
var key string
var endpoint listenerEndpoint
var running runningListener
var shutdownCtx context.Context
var cancel context.CancelFunc
var more []listenerEndpoint
var j int
var settings models.TLSSettings
var policy listenerAuthPolicy
var keyPrefix string
listeners, err = m.Store.ListTLSListeners()
if err != nil {
return err
}
desired = make(map[string]listenerEndpoint)
for i = 0; i < len(listeners); i++ {
if !listeners[i].Enabled {
continue
}
settings = models.TLSSettings{
HTTPAddrs: listeners[i].HTTPAddrs,
HTTPSAddrs: listeners[i].HTTPSAddrs,
TLSServerCertSource: listeners[i].TLSServerCertSource,
TLSCertFile: listeners[i].TLSCertFile,
TLSKeyFile: listeners[i].TLSKeyFile,
TLSPKIServerCertID: listeners[i].TLSPKIServerCertID,
TLSClientAuth: listeners[i].TLSClientAuth,
TLSClientCAFile: listeners[i].TLSClientCAFile,
TLSPKIClientCAID: listeners[i].TLSPKIClientCAID,
TLSMinVersion: listeners[i].TLSMinVersion,
}
policy = listenerPolicyFromTLSListener(listeners[i])
more, err = buildListenerEndpoints(listeners[i].Name, settings, policy, m.Store)
if err != nil {
return err
}
keyPrefix = fmt.Sprintf("extra:%s:%d", listeners[i].ID, listeners[i].UpdatedAt)
for j = 0; j < len(more); j++ {
key = keyPrefix + ":" + endpointScheme(more[j]) + ":" + more[j].Addr
more[j].Key = key
desired[key] = more[j]
}
}
m.mu.Lock()
// among the running listeners, stop all disabled/invalid listeners.
for key, running = range m.Running {
var exists bool
_, exists = desired[key]
if exists { continue } // found in configuration
shutdownCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
_ = running.Server.Shutdown(shutdownCtx)
cancel()
delete(m.Running, key)
if m.Logger != nil {
m.Logger.Write("", util.LOG_INFO, "additional listener stopped name=%s addr=%s", running.Endpoint.Name, running.Endpoint.Addr)
}
}
// start the listeners if they are not already running
for key, endpoint = range desired {
var exists bool
_, exists = m.Running[key]
if exists { continue } // already running
running = m.startEndpoint(endpoint)
m.Running[key] = running
}
m.mu.Unlock()
return nil
}
func (m *additionalListenerManager) startEndpoint(endpoint listenerEndpoint) runningListener {
var server *http.Server
var running runningListener
server = &http.Server{
Addr: endpoint.Addr,
Handler: m.Handler,
ConnContext: connContextWithListenerPolicy(endpoint.Policy),
ErrorLog: log.New(&server_http_log_writer{l: m.Logger, id: endpoint.Name, depth: +2}, "", 0),
}
if endpoint.IsHTTPS {
server.TLSConfig = endpoint.TLSConfig
}
running = runningListener{
Endpoint: endpoint,
Server: server,
}
if m.Logger != nil {
if endpoint.IsHTTPS {
m.Logger.Write(endpoint.Name, util.LOG_INFO, "additional listener started https://%s", endpoint.Addr)
} else {
m.Logger.Write(endpoint.Name, util.LOG_INFO, "additional listener started %s", endpoint.Addr)
}
}
go func(ep listenerEndpoint, srv *http.Server) {
var err error
if ep.IsHTTPS {
err = srv.ListenAndServeTLS("", "")
} else {
err = srv.ListenAndServe()
}
if err != nil && !errors.Is(err, http.ErrServerClosed) && m.Logger != nil {
m.Logger.Write(ep.Name, util.LOG_ERROR, "additional listener failed addr=%s err=%v", ep.Addr, err)
}
}(endpoint, server)
return running
}
func endpointScheme(endpoint listenerEndpoint) string {
if endpoint.IsHTTPS {
return "https"
}
return "http"
}
func connContextWithListenerPolicy(policy listenerAuthPolicy) func(ctx context.Context, conn net.Conn) context.Context {
return func(ctx context.Context, _ net.Conn) context.Context {
return context.WithValue(ctx, listenerPolicyKey, policy)
}
}
func listenerPolicyFromRequest(r *http.Request) listenerAuthPolicy {
var policy listenerAuthPolicy
var ok bool
policy = defaultListenerPolicy()
if r == nil {
return policy
}
policy, ok = r.Context().Value(listenerPolicyKey).(listenerAuthPolicy)
if !ok {
return defaultListenerPolicy()
}
return policy
}
func isPolicyApplied(policy listenerAuthPolicy, service serviceKind) bool {
if service == serviceAPI {
return policy.ApplyAPI
}
if service == serviceGit {
return policy.ApplyGit
}
if service == serviceRPM {
return policy.ApplyRPM
}
if service == serviceV2 {
return policy.ApplyV2
}
return false
}
func evaluateAuthOutcome(policy listenerAuthPolicy, service serviceKind, operation opKind) authOutcome {
if !isPolicyApplied(policy, service) {
return authRequireAuth
}
switch strings.ToLower(strings.TrimSpace(policy.Mode)) {
case "read_open_write_cert":
if operation == opRead {
return authAllow
}
return authRequireCert
case "read_open_write_cert_or_auth":
if operation == opRead {
return authAllow
}
return authCertOrAuth
case "cert_only":
return authRequireCert
case "read_only_public":
if operation == opRead {
return authAllow
}
return authDeny
default:
return authRequireAuth
}
}
func requestClientCertFingerprint(r *http.Request) string {
var cert *x509.Certificate
var hash [32]byte
if r == nil || r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
return ""
}
cert = r.TLS.PeerCertificates[0]
hash = sha256.Sum256(cert.Raw)
return strings.ToLower(hex.EncodeToString(hash[:]))
}
func isClientCertAllowed(r *http.Request, policy listenerAuthPolicy) bool {
var fp string
if len(policy.CertAllowlist) == 0 {
return false
}
fp = requestClientCertFingerprint(r)
if fp == "" {
return false
}
return policy.CertAllowlist[fp]
}
func authenticateByBasic(r *http.Request, auth func(username string, password string) (bool, error)) bool {
var username string
var password string
var ok bool
var allowed bool
var err error
if auth == nil {
return false
}
username, password, ok = r.BasicAuth()
if !ok {
return false
}
allowed, err = auth(username, password)
if err != nil {
return false
}
return allowed
}
func withServiceAuth(service serviceKind, next http.Handler, auth func(username string, password string) (bool, error), store *db.Store, logger *util.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var policy listenerAuthPolicy
var operation opKind
var outcome authOutcome
var ok bool
var principal models.ServicePrincipal
var principalOK bool
var err error
policy = listenerPolicyFromRequest(r)
operation = classifyServiceOperation(service, r)
outcome = evaluateAuthOutcome(policy, service, operation)
if outcome == authDeny {
w.WriteHeader(http.StatusForbidden)
return
}
if outcome == authAllow {
next.ServeHTTP(w, r)
return
}
if outcome == authRequireCert {
ok, principal, principalOK, err = authorizeByClientCert(service, r, policy, operation, store)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if ok {
if principalOK {
r = middleware.WithPrincipal(r, principal)
}
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusForbidden)
return
}
if outcome == authCertOrAuth {
ok, principal, principalOK, err = authorizeByClientCert(service, r, policy, operation, store)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if ok {
if principalOK {
r = middleware.WithPrincipal(r, principal)
}
next.ServeHTTP(w, r)
return
}
}
ok = authenticateByBasic(r, auth)
if ok {
next.ServeHTTP(w, r)
return
}
if service == serviceGit {
w.Header().Set("WWW-Authenticate", `Basic realm="git"`)
} else if service == serviceRPM {
w.Header().Set("WWW-Authenticate", `Basic realm="rpm"`)
} else if service == serviceV2 {
w.Header().Set("WWW-Authenticate", `Basic realm="docker"`)
}
if logger != nil {
logger.Write(string(service), util.LOG_INFO, "auth denied policy=%s method=%s path=%s remote=%s", policy.Mode, r.Method, r.URL.Path, r.RemoteAddr)
}
w.WriteHeader(http.StatusUnauthorized)
})
}
func withAPIAuth(next http.Handler, auth func(username string, password string) (bool, error), store *db.Store, logger *util.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var policy listenerAuthPolicy
var operation opKind
var outcome authOutcome
var user models.User
var hasUser bool
var ok bool
var principal models.ServicePrincipal
var principalOK bool
var err error
policy = listenerPolicyFromRequest(r)
operation = classifyServiceOperation(serviceAPI, r)
outcome = evaluateAuthOutcome(policy, serviceAPI, operation)
user, hasUser = middleware.UserFromContext(r.Context())
if user.Disabled {
hasUser = false
}
if outcome == authDeny {
w.WriteHeader(http.StatusForbidden)
return
}
if outcome == authAllow {
next.ServeHTTP(w, r)
return
}
if outcome == authRequireCert {
ok, principal, principalOK, err = authorizeByClientCert(serviceAPI, r, policy, operation, store)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if ok {
if principalOK {
r = middleware.WithPrincipal(r, principal)
}
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusForbidden)
return
}
if outcome == authCertOrAuth {
ok, principal, principalOK, err = authorizeByClientCert(serviceAPI, r, policy, operation, store)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
if ok || hasUser {
if principalOK {
r = middleware.WithPrincipal(r, principal)
}
next.ServeHTTP(w, r)
return
}
if authenticateByBasic(r, auth) {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusUnauthorized)
return
}
if hasUser || authenticateByBasic(r, auth) {
next.ServeHTTP(w, r)
return
}
if logger != nil {
logger.Write("api", util.LOG_INFO, "auth denied policy=%s method=%s path=%s remote=%s", policy.Mode, r.Method, r.URL.Path, r.RemoteAddr)
}
w.WriteHeader(http.StatusUnauthorized)
})
}
func classifyServiceOperation(service serviceKind, r *http.Request) opKind {
var serviceName string
serviceName = ""
if r != nil {
serviceName = r.URL.Query().Get("service")
}
if service == serviceGit {
if strings.Contains(r.URL.Path, "git-receive-pack") || serviceName == "git-receive-pack" {
return opWrite
}
return opRead
}
if service == serviceAPI {
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
return opRead
}
return opWrite
}
if service == serviceRPM {
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
return opRead
}
return opWrite
}
if service == serviceV2 {
if r.Method == http.MethodGet || r.Method == http.MethodHead {
return opRead
}
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch || r.Method == http.MethodDelete {
return opWrite
}
}
return opRead
}
func authorizeByClientCert(service serviceKind, r *http.Request, policy listenerAuthPolicy, operation opKind, store *db.Store) (bool, models.ServicePrincipal, bool, error) {
var fp string
var principal models.ServicePrincipal
var bound bool
var allowed bool
var err error
if !isClientCertAllowed(r, policy) {
return false, principal, false, nil
}
if operation == opRead {
return true, principal, false, nil
}
fp = requestClientCertFingerprint(r)
if fp == "" {
return false, principal, false, nil
}
principal, bound, err = store.GetPrincipalByCertFingerprint(fp)
if err != nil {
return false, principal, false, err
}
if !bound {
return false, principal, false, nil
}
if service == serviceAPI {
return true, principal, true, nil
}
allowed, err = principalCanWriteService(service, principal, r, store)
if err != nil {
return false, principal, false, err
}
if !allowed {
return false, principal, false, nil
}
return true, principal, true, nil
}
func principalCanWriteService(service serviceKind, principal models.ServicePrincipal, r *http.Request, store *db.Store) (bool, error) {
var projectID string
var role string
var err error
if principal.IsAdmin {
return true, nil
}
projectID, err = resolveProjectIDForServiceWrite(service, r, store)
if err != nil || strings.TrimSpace(projectID) == "" {
return false, nil
}
role, err = store.GetPrincipalProjectRole(principal.ID, projectID)
if err != nil {
return false, nil
}
if !roleAllowsWrite(role) {
return false, nil
}
return true, nil
}
func roleAllowsWrite(role string) bool {
var r string
r = strings.ToLower(strings.TrimSpace(role))
return r == "writer" || r == "admin"
}
func resolveProjectIDForServiceWrite(service serviceKind, r *http.Request, store *db.Store) (string, error) {
var path string
var parts []string
var first string
var second string
var repo models.Repo
var project models.Project
var err error
path = strings.Trim(strings.TrimSpace(r.URL.Path), "/")
if service == serviceV2 {
path = strings.TrimPrefix(path, "v2/")
parts = strings.Split(path, "/")
if len(parts) < 2 {
return "", errors.New("invalid v2 path")
}
project, err = store.GetProjectBySlug(parts[0])
if err != nil {
return "", err
}
repo, err = store.GetRepoByProjectNameType(project.ID, parts[1], "docker")
if err != nil {
return "", err
}
return repo.ProjectID, nil
}
if service != serviceGit && service != serviceRPM {
return "", errors.New("unsupported service")
}
parts = strings.Split(path, "/")
if len(parts) < 1 {
return "", errors.New("invalid path")
}
first = strings.TrimSuffix(parts[0], ".git")
repo, err = store.GetRepo(first)
if err == nil {
return repo.ProjectID, nil
}
if len(parts) < 2 {
return "", errors.New("invalid path")
}
second = strings.TrimSuffix(parts[1], ".git")
project, err = store.GetProjectBySlug(parts[0])
if err != nil {
return "", err
}
repo, err = store.GetRepoByProjectNameType(project.ID, second, serviceRepoType(service, parts[1]))
if err != nil {
return "", err
}
return repo.ProjectID, nil
}
func serviceRepoType(service serviceKind, segment string) string {
if service == serviceGit || strings.HasSuffix(segment, ".git") {
return "git"
}
return "rpm"
}
func tlsSettingsFromConfig(cfg config.Config) models.TLSSettings {
var settings models.TLSSettings
settings = models.TLSSettings{
HTTPAddrs: normalizeAddrList(cfg.HTTPAddrs),
HTTPSAddrs: normalizeAddrList(cfg.HTTPSAddrs),
TLSServerCertSource: strings.TrimSpace(cfg.TLSServerCertSource),
TLSCertFile: strings.TrimSpace(cfg.TLSCertFile),
TLSKeyFile: strings.TrimSpace(cfg.TLSKeyFile),
TLSPKIServerCertID: strings.TrimSpace(cfg.TLSPKIServerCertID),
TLSClientAuth: strings.TrimSpace(cfg.TLSClientAuth),
TLSClientCAFile: strings.TrimSpace(cfg.TLSClientCAFile),
TLSPKIClientCAID: strings.TrimSpace(cfg.TLSPKIClientCAID),
TLSMinVersion: strings.TrimSpace(cfg.TLSMinVersion),
}
return settings
}
func defaultListenerPolicy() listenerAuthPolicy {
var policy listenerAuthPolicy
policy = listenerAuthPolicy{
Mode: "default",
ApplyAPI: true,
ApplyGit: true,
ApplyRPM: true,
ApplyV2: true,
CertAllowlist: make(map[string]bool),
}
return policy
}
func listenerPolicyFromTLSListener(item models.TLSListener) listenerAuthPolicy {
var policy listenerAuthPolicy
var i int
var fp string
policy = defaultListenerPolicy()
if strings.TrimSpace(item.AuthPolicy) != "" {
policy.Mode = strings.TrimSpace(item.AuthPolicy)
}
policy.ApplyAPI = item.ApplyPolicyAPI
policy.ApplyGit = item.ApplyPolicyGit
policy.ApplyRPM = item.ApplyPolicyRPM
policy.ApplyV2 = item.ApplyPolicyV2
for i = 0; i < len(item.ClientCertAllowlist); i++ {
fp = strings.ToLower(strings.TrimSpace(item.ClientCertAllowlist[i]))
if fp == "" {
continue
}
policy.CertAllowlist[fp] = true
}
return policy
}
func normalizeAddrList(values []string) []string {
var out []string
var i int
var v string
for i = 0; i < len(values); i++ {
v = strings.TrimSpace(values[i])
if v == "" {
continue
}
out = append(out, v)
}
return out
}
func serveListeners(endpoints []listenerEndpoint, handler http.Handler, logger *util.Logger) error {
var wg sync.WaitGroup
var errs chan error
var i int
var ep listenerEndpoint
var listenErr error
if len(endpoints) == 0 {
return http.ErrServerClosed
}
errs = make(chan error, len(endpoints))
for i = 0; i < len(endpoints); i++ {
ep = endpoints[i]
if logger != nil {
if ep.IsHTTPS {
logger.Write("", util.LOG_INFO, "codit server listener=%s https://%s", ep.Name, ep.Addr)
} else {
logger.Write("", util.LOG_INFO, "codit server listener=%s %s", ep.Name, ep.Addr)
}
}
wg.Add(1)
go func(endpoint listenerEndpoint) {
var err error
var server *http.Server
var l *log.Logger
defer wg.Done()
l = log.New(&server_http_log_writer{l: logger, id: "", depth: +2}, "", 0)
if endpoint.IsHTTPS {
server = &http.Server{
Addr: endpoint.Addr,
Handler: handler,
TLSConfig: endpoint.TLSConfig,
ErrorLog: l,
ConnContext: connContextWithListenerPolicy(endpoint.Policy),
}
err = server.ListenAndServeTLS("", "")
} else {
server = &http.Server{
Addr: endpoint.Addr,
Handler: handler,
ErrorLog: l,
ConnContext: connContextWithListenerPolicy(endpoint.Policy),
}
err = server.ListenAndServe()
}
errs <- err
}(ep)
}
listenErr = <-errs
return listenErr
}
func buildListenerEndpoints(name string, settings models.TLSSettings, policy listenerAuthPolicy, store *db.Store) ([]listenerEndpoint, error) {
var out []listenerEndpoint
var tlsConfig *tls.Config
var i int
var err error
settings.HTTPAddrs = normalizeAddrList(settings.HTTPAddrs)
settings.HTTPSAddrs = normalizeAddrList(settings.HTTPSAddrs)
if len(settings.HTTPAddrs) == 0 && len(settings.HTTPSAddrs) == 0 {
return out, nil
}
if len(settings.HTTPSAddrs) > 0 {
tlsConfig, err = buildServerTLSConfig(settings, store)
if err != nil {
return nil, err
}
}
for i = 0; i < len(settings.HTTPAddrs); i++ {
out = append(out, listenerEndpoint{Name: name, Addr: settings.HTTPAddrs[i], IsHTTPS: false, TLSConfig: nil, Policy: policy})
}
for i = 0; i < len(settings.HTTPSAddrs); i++ {
out = append(out, listenerEndpoint{Name: name, Addr: settings.HTTPSAddrs[i], IsHTTPS: true, TLSConfig: tlsConfig, Policy: policy})
}
return out, nil
}
func buildServerTLSConfig(settings models.TLSSettings, store *db.Store) (*tls.Config, error) {
var cert tls.Certificate
var caPool *x509.CertPool
var clientAuth tls.ClientAuthType
var minVersion uint16
var err error
cert, err = loadServerCertificate(settings, store)
if err != nil {
return nil, err
}
clientAuth = parseTLSClientAuth(settings.TLSClientAuth)
caPool, err = loadClientCAPool(settings, store)
if err != nil {
return nil, err
}
if (clientAuth == tls.RequireAndVerifyClientCert || clientAuth == tls.VerifyClientCertIfGiven) && caPool == nil {
return nil, errors.New("client auth is enabled but no client CA configured")
}
minVersion = parseTLSMinVersion(settings.TLSMinVersion)
return &tls.Config{
MinVersion: minVersion,
Certificates: []tls.Certificate{cert},
ClientAuth: clientAuth,
ClientCAs: caPool,
}, nil
}
func loadServerCertificate(settings models.TLSSettings, store *db.Store) (tls.Certificate, error) {
var cert tls.Certificate
var certData []byte
var keyData []byte
var pkiCert models.PKICert
var err error
if strings.TrimSpace(settings.TLSPKIServerCertID) == "" {
return cert, errors.New("tls_pki_server_cert_id is required for https listener")
}
pkiCert, err = store.GetPKICert(strings.TrimSpace(settings.TLSPKIServerCertID))
if err != nil {
return cert, err
}
certData = []byte(pkiCert.CertPEM)
keyData = []byte(pkiCert.KeyPEM)
cert, err = tls.X509KeyPair(certData, keyData)
if err != nil {
return cert, err
}
return cert, nil
}
func loadClientCAPool(settings models.TLSSettings, store *db.Store) (*x509.CertPool, error) {
var pool *x509.CertPool
var pkiCA models.PKICA
var ok bool
var err error
pool = nil
if strings.TrimSpace(settings.TLSPKIClientCAID) != "" {
pkiCA, err = store.GetPKICA(strings.TrimSpace(settings.TLSPKIClientCAID))
if err != nil {
return nil, err
}
if pool == nil {
pool = x509.NewCertPool()
}
ok = pool.AppendCertsFromPEM([]byte(pkiCA.CertPEM))
if !ok {
return nil, errors.New("failed to parse tls_pki_client_ca_id certificate")
}
}
return pool, nil
}
func parseTLSClientAuth(value string) tls.ClientAuthType {
var v string
v = strings.ToLower(strings.TrimSpace(value))
switch v {
case "request":
return tls.RequestClientCert
case "require":
return tls.RequireAnyClientCert
case "verify_if_given":
return tls.VerifyClientCertIfGiven
case "require_and_verify":
return tls.RequireAndVerifyClientCert
default:
return tls.NoClientCert
}
}
func parseTLSMinVersion(value string) uint16 {
var v string
v = strings.ToLower(strings.TrimSpace(value))
switch v {
case "1.0", "tls1.0":
return tls.VersionTLS10
case "1.1", "tls1.1":
return tls.VersionTLS11
case "1.3", "tls1.3":
return tls.VersionTLS13
default:
return tls.VersionTLS12
}
}
func mergeTLSSettingsFromDB(cfg *config.Config, store *db.Store) error {
var settings models.TLSSettings
var envHTTPAddrs string
var envHTTPSAddrs string
var envServerSource string
var envCertFile string
var envKeyFile string
var envPKIServerCertID string
var envClientAuth string
var envClientCAFile string
var envPKIClientCAID string
var envTLSMinVersion string
var err error
settings, err = store.GetTLSSettings()
if err != nil {
return err
}
envHTTPAddrs = strings.TrimSpace(os.Getenv("CODIT_HTTP_ADDRS"))
envHTTPSAddrs = strings.TrimSpace(os.Getenv("CODIT_HTTPS_ADDRS"))
envServerSource = strings.TrimSpace(os.Getenv("CODIT_TLS_SERVER_CERT_SOURCE"))
envCertFile = strings.TrimSpace(os.Getenv("CODIT_TLS_CERT_FILE"))
envKeyFile = strings.TrimSpace(os.Getenv("CODIT_TLS_KEY_FILE"))
envPKIServerCertID = strings.TrimSpace(os.Getenv("CODIT_TLS_PKI_SERVER_CERT_ID"))
envClientAuth = strings.TrimSpace(os.Getenv("CODIT_TLS_CLIENT_AUTH"))
envClientCAFile = strings.TrimSpace(os.Getenv("CODIT_TLS_CLIENT_CA_FILE"))
envPKIClientCAID = strings.TrimSpace(os.Getenv("CODIT_TLS_PKI_CLIENT_CA_ID"))
envTLSMinVersion = strings.TrimSpace(os.Getenv("CODIT_TLS_MIN_VERSION"))
if len(settings.HTTPAddrs) > 0 && envHTTPAddrs == "" {
cfg.HTTPAddrs = settings.HTTPAddrs
}
if len(settings.HTTPSAddrs) > 0 && envHTTPSAddrs == "" {
cfg.HTTPSAddrs = settings.HTTPSAddrs
}
if strings.TrimSpace(settings.TLSServerCertSource) != "" && envServerSource == "" {
cfg.TLSServerCertSource = settings.TLSServerCertSource
}
if strings.TrimSpace(settings.TLSCertFile) != "" && envCertFile == "" {
cfg.TLSCertFile = settings.TLSCertFile
}
if strings.TrimSpace(settings.TLSKeyFile) != "" && envKeyFile == "" {
cfg.TLSKeyFile = settings.TLSKeyFile
}
if strings.TrimSpace(settings.TLSPKIServerCertID) != "" && envPKIServerCertID == "" {
cfg.TLSPKIServerCertID = settings.TLSPKIServerCertID
}
if strings.TrimSpace(settings.TLSClientAuth) != "" && envClientAuth == "" {
cfg.TLSClientAuth = settings.TLSClientAuth
}
if strings.TrimSpace(settings.TLSClientCAFile) != "" && envClientCAFile == "" {
cfg.TLSClientCAFile = settings.TLSClientCAFile
}
if strings.TrimSpace(settings.TLSPKIClientCAID) != "" && envPKIClientCAID == "" {
cfg.TLSPKIClientCAID = settings.TLSPKIClientCAID
}
if strings.TrimSpace(settings.TLSMinVersion) != "" && envTLSMinVersion == "" {
cfg.TLSMinVersion = settings.TLSMinVersion
}
return nil
}
func spaHandler(root string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var path string
var reqPath string
var cleaned string
reqPath = strings.TrimPrefix(r.URL.Path, "/")
cleaned = filepath.Clean(reqPath)
if cleaned == "." {
cleaned = ""
}
if strings.HasPrefix(cleaned, "..") {
w.WriteHeader(http.StatusNotFound)
return
}
path = filepath.Join(root, cleaned)
var info os.FileInfo
var err error
info, err = os.Stat(path)
if err == nil && !info.IsDir() {
http.ServeFile(w, r, path)
return
}
if filepath.Ext(cleaned) != "" {
w.WriteHeader(http.StatusNotFound)
return
}
var indexPath string
indexPath = filepath.Join(root, "index.html")
_, err = os.Stat(indexPath)
if err == nil {
http.ServeFile(w, r, indexPath)
return
}
w.WriteHeader(http.StatusNotFound)
}
}
func bootstrapAdmin(store *db.Store) error {
var bootstrap string
var parts []string
var username string
var password string
var hash string
var err error
bootstrap = os.Getenv("CODIT_BOOTSTRAP_ADMIN")
if bootstrap == "" {
return nil
}
parts = strings.SplitN(bootstrap, ":", 2)
if len(parts) != 2 {
return nil
}
username = strings.TrimSpace(parts[0])
password = strings.TrimSpace(parts[1])
if username == "" || password == "" {
return nil
}
_, _, err = store.GetUserByUsername(username)
if err == nil {
return nil
}
hash, err = auth.HashPassword(password)
if err != nil {
return err
}
_, err = store.CreateUser(models.User{
Username: username,
DisplayName: username,
Email: username + "@local",
IsAdmin: true,
AuthSource: "db",
}, hash)
return err
}