package handlers import "database/sql" import "errors" import "fmt" import "io" import "net" import "net/http" import "sort" import "strings" import "sync" import "time" import "golang.org/x/crypto/ssh" import "golang.org/x/net/websocket" import "codit/internal/db" import "codit/internal/middleware" import "codit/internal/models" import "codit/internal/util" const logIDSSHBroker string = "ssh-broker" type sshServerRequest struct { Name string `json:"name"` Host string `json:"host"` Port int `json:"port"` Description string `json:"description"` Tags []string `json:"tags"` Enabled bool `json:"enabled"` } type sshAccessProfileTargetRequest struct { TargetType string `json:"target_type"` TargetID string `json:"target_id"` } type sshAccessProfileRequest struct { ServerID string `json:"server_id"` Name string `json:"name"` Description string `json:"description"` RemoteUsername string `json:"remote_username"` AuthMethod string `json:"auth_method"` Enabled bool `json:"enabled"` PrivateKeyPEM string `json:"private_key_pem"` SSHUserCAID string `json:"ssh_user_ca_id"` SSHPrincipalMode string `json:"ssh_principal_mode"` SSHPrincipals []string `json:"ssh_principals"` SSHPrincipalGrantIDs []string `json:"ssh_principal_grant_ids"` DefaultValidSeconds int64 `json:"default_valid_seconds"` MaxValidSeconds int64 `json:"max_valid_seconds"` Targets []sshAccessProfileTargetRequest `json:"targets"` } type sshAccessProfileNormalizeOptions struct { OwnerScope string OwnerUserID string AllowUserEdit bool RequireTargets bool SelfService bool } type sshServerHostKeyRequest struct { PublicKey string `json:"public_key"` } type sshSessionConnectRequest struct { Cols int `json:"cols"` Rows int `json:"rows"` ValidSeconds int64 `json:"valid_seconds"` Term string `json:"term"` } type sshSessionConnectResponse struct { SessionID string `json:"session_id"` Status string `json:"status"` WebSocketPath string `json:"websocket_path"` ServerName string `json:"server_name"` Host string `json:"host"` Port int `json:"port"` RemoteUsername string `json:"remote_username"` HostKeyFingerprint string `json:"host_key_fingerprint"` } type sshSessionStreamMessage struct { Type string `json:"type"` Data string `json:"data"` Cols int `json:"cols"` Rows int `json:"rows"` Status string `json:"status"` Message string `json:"message"` } func (api *API) logSSHSessionConnectStart(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User) { api.Logger.Write(logIDSSHBroker, util.LOG_INFO, "connect start session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod) } func (api *API) logSSHSessionPrepare(sessionID string, remoteAddr string, user models.User, profile models.SSHAccessProfile) { api.Logger.Write(logIDSSHBroker, util.LOG_INFO, "connect prepare session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s", sessionID, remoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod) } func (api *API) logSSHSessionConnectSuccess(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User, hostKeyFingerprint string) { api.Logger.Write(logIDSSHBroker, util.LOG_INFO, "connect success session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s host_key=%s", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, hostKeyFingerprint) } func (api *API) logSSHSessionConnectFailure(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User, stage string, err error) { api.Logger.Write(logIDSSHBroker, util.LOG_WARN, "connect failed session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s stage=%s err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, stage, err) } func (api *API) logSSHSessionPrepareFailure(sessionID string, remoteAddr string, username string, profileID string, stage string, err error) { api.Logger.Write(logIDSSHBroker, util.LOG_WARN, "connect failed session=%s requester=%s user=%s profile=%s stage=%s err=%v", sessionID, remoteAddr, username, profileID, stage, err) } func (api *API) logSSHSessionClosed(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User, status string, reason string, connectedAt int64, endedAt int64) { var durationSeconds int64 durationSeconds = 0 if connectedAt > 0 && endedAt >= connectedAt { durationSeconds = endedAt - connectedAt } api.Logger.Write(logIDSSHBroker, util.LOG_INFO, "session closed session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s status=%s dur_s=%d reason=%q", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, status, durationSeconds, reason) } func (api *API) ListSSHServersAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var items []models.SSHServer var err error if !api.requireAdmin(w, r) { return } items, err = api.store(r).ListSSHServers() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) GetSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) CreateSSHServerAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshServerRequest var item models.SSHServer var actorKind string var actorID string var actorName string var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } actorKind, actorID, actorName = currentSSHBrokerActor(r) item = models.SSHServer{ Name: req.Name, Host: req.Host, Port: req.Port, Description: req.Description, Tags: normalizeSSHBrokerTags(req.Tags), Enabled: req.Enabled, CreatedByKind: actorKind, CreatedBySubjectID: actorID, CreatedBySubjectName: actorName, } item, err = api.store(r).CreateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshServerRequest var item models.SSHServer var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } item.Name = req.Name item.Host = req.Host item.Port = req.Port item.Description = req.Description item.Tags = normalizeSSHBrokerTags(req.Tags) item.Enabled = req.Enabled item, err = api.store(r).UpdateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHServer(params["id"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHAccessProfilesAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var items []models.SSHAccessProfile var err error if !api.requireAdmin(w, r) { return } items, err = api.store(r).ListSSHAccessProfiles() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) GetSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHAccessProfile var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHAccessProfile(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) CreateSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshAccessProfileRequest var item models.SSHAccessProfile var actorKind string var actorID string var actorName string var options sshAccessProfileNormalizeOptions var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } actorKind, actorID, actorName = currentSSHBrokerActor(r) options = sshAccessProfileNormalizeOptions{ OwnerScope: "admin_shared", OwnerUserID: "", AllowUserEdit: false, RequireTargets: true, SelfService: false, } item, err = api.normalizeSSHAccessProfileRequest(r, req, models.SSHAccessProfile{}, actorKind, actorID, actorName, false, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).CreateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshAccessProfileRequest var existing models.SSHAccessProfile var item models.SSHAccessProfile var actorKind string var actorID string var actorName string var options sshAccessProfileNormalizeOptions var err error if !api.requireAdmin(w, r) { return } existing, err = api.store(r).GetSSHAccessProfile(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } actorKind, actorID, actorName = currentSSHBrokerActor(r) options = sshAccessProfileNormalizeOptions{ OwnerScope: "admin_shared", OwnerUserID: "", AllowUserEdit: false, RequireTargets: true, SelfService: false, } item, err = api.normalizeSSHAccessProfileRequest(r, req, existing, actorKind, actorID, actorName, true, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).UpdateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHAccessProfile(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHAccessProfilesForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var items []models.SSHAccessProfile var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } items, err = api.store(r).ListSSHAccessProfilesForUser(user.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) GetSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHAccessProfile var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) ListSSHServersForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var items []models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } items, err = api.store(r).ListSSHServersForUser(user.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) CreateSSHServerForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var req sshServerRequest var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } item = models.SSHServer{ Name: req.Name, Host: req.Host, Port: req.Port, Description: req.Description, Tags: normalizeSSHBrokerTags(req.Tags), Enabled: req.Enabled, CreatedByKind: "user", CreatedBySubjectID: user.ID, CreatedBySubjectName: user.Username, Editable: true, } item, err = api.store(r).CreateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item.Editable = true WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var req sshServerRequest var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } item.Name = req.Name item.Host = req.Host item.Port = req.Port item.Description = req.Description item.Tags = normalizeSSHBrokerTags(req.Tags) item.Enabled = req.Enabled item, err = api.store(r).UpdateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item.Editable = true WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } _, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = api.store(r).DeleteSSHServer(params["id"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHServerHostKeysForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var items []models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } items, err = api.store(r).ListSSHServerHostKeys(item.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) DiscoverSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var discovered models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } discovered, err = discoverSSHServerHostKey(item.Host, item.Port) if err != nil { WriteJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()}) return } discovered.ServerID = item.ID WriteJSON(w, http.StatusOK, discovered) } func (api *API) CreateSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var req sshServerHostKeyRequest var hostKey models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } hostKey, err = parseSSHServerHostKey(strings.TrimSpace(req.PublicKey)) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } hostKey.ServerID = item.ID hostKey, err = api.store(r).CreateSSHServerHostKey(hostKey) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, hostKey) } func (api *API) DeleteSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = api.store(r).DeleteSSHServerHostKey(item.ID, params["hostKeyId"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) CreateSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var req sshAccessProfileRequest var item models.SSHAccessProfile var options sshAccessProfileNormalizeOptions var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } options = sshAccessProfileNormalizeOptions{ OwnerScope: "user", OwnerUserID: user.ID, AllowUserEdit: true, RequireTargets: false, SelfService: true, } item, err = api.normalizeSSHAccessProfileRequest(r, req, models.SSHAccessProfile{}, "user", user.ID, user.Username, false, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).CreateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var req sshAccessProfileRequest var existing models.SSHAccessProfile var item models.SSHAccessProfile var options sshAccessProfileNormalizeOptions var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } existing, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if existing.OwnerScope != "user" || existing.OwnerUserID != user.ID { w.WriteHeader(http.StatusForbidden) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } options = sshAccessProfileNormalizeOptions{ OwnerScope: "user", OwnerUserID: user.ID, AllowUserEdit: true, RequireTargets: false, SelfService: true, } item, err = api.normalizeSSHAccessProfileRequest(r, req, existing, "user", user.ID, user.Username, true, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).UpdateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHAccessProfile var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.OwnerScope != "user" || item.OwnerUserID != user.ID { w.WriteHeader(http.StatusForbidden) return } err = api.store(r).DeleteSSHAccessProfile(item.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHServerHostKeysAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var items []models.SSHServerHostKey var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } items, err = api.store(r).ListSSHServerHostKeys(item.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) DiscoverSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var discovered models.SSHServerHostKey var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } discovered, err = discoverSSHServerHostKey(item.Host, item.Port) if err != nil { WriteJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()}) return } discovered.ServerID = item.ID WriteJSON(w, http.StatusOK, discovered) } func (api *API) CreateSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var req sshServerHostKeyRequest var hostKey models.SSHServerHostKey var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } hostKey, err = parseSSHServerHostKey(strings.TrimSpace(req.PublicKey)) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } hostKey.ServerID = item.ID hostKey, err = api.store(r).CreateSSHServerHostKey(hostKey) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, hostKey) } func (api *API) DeleteSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHServerHostKey(params["id"], params["hostKeyId"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) CreateSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var req sshSessionConnectRequest var profile models.SSHAccessProfile var item models.SSHSession var response sshSessionConnectResponse var hostKeys []models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } hostKeys, err = api.store(r).ListSSHServerHostKeys(profile.ServerID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if len(hostKeys) == 0 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "no pinned host key for this server"}) return } item = models.SSHSession{ ProfileID: profile.ID, ServerID: profile.ServerID, UserID: user.ID, Username: user.Username, RemoteUsername: profile.RemoteUsername, AuthMethod: profile.AuthMethod, Host: profile.Server.Host, Port: profile.Server.Port, Status: "pending", RequestedTerm: normalizeSSHSessionTerm(req.Term), RequestedCols: normalizeSSHSessionCols(req.Cols), RequestedRows: normalizeSSHSessionRows(req.Rows), RemoteAddr: requestRemoteAddr(r), UserAgent: strings.TrimSpace(r.UserAgent()), } item, err = api.store(r).CreateSSHSession(item) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } response = sshSessionConnectResponse{ SessionID: item.ID, Status: item.Status, WebSocketPath: "/api/ssh/sessions/" + item.ID + "/stream", ServerName: profile.Server.Name, Host: profile.Server.Host, Port: profile.Server.Port, RemoteUsername: profile.RemoteUsername, HostKeyFingerprint: hostKeys[0].Fingerprint, } WriteJSON(w, http.StatusCreated, response) } func (api *API) GetSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHSession var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.UserID != user.ID { w.WriteHeader(http.StatusForbidden) return } WriteJSON(w, http.StatusOK, item) } func (api *API) ListSSHSessionsForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var items []models.SSHSession var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } items, err = api.store(r).ListSSHSessionsForUser(user.ID, 50) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) DisconnectSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHSession var closed bool var err error var now int64 user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.UserID != user.ID { w.WriteHeader(http.StatusForbidden) return } if api.SSHSessionRegistry != nil { closed = api.SSHSessionRegistry.RequestClose(item.ID, "disconnected by user") } now = time.Now().UTC().Unix() if !closed { err = api.store(r).UpdateSSHSessionStatus(item.ID, "closed", item.HostKeyFingerprint, item.ConnectedAt, now, "disconnected by user") if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } } w.WriteHeader(http.StatusNoContent) } func (api *API) StreamSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHSession var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { w.WriteHeader(http.StatusNotFound) return } w.WriteHeader(http.StatusInternalServerError) return } if item.UserID != user.ID { w.WriteHeader(http.StatusForbidden) return } websocket.Handler(func(ws *websocket.Conn) { api.serveSSHSessionStream(ws, r, user, item) }).ServeHTTP(w, r) } func (api *API) normalizeSSHAccessProfileRequest(r *http.Request, req sshAccessProfileRequest, existing models.SSHAccessProfile, actorKind string, actorID string, actorName string, isUpdate bool, options sshAccessProfileNormalizeOptions) (models.SSHAccessProfile, error) { var item models.SSHAccessProfile var err error var targets []models.SSHAccessProfileTarget var privateKeyPEM string var signer ssh.Signer var publicKey string var fingerprint string var generatedKeyPEM string var generatedPublicKey string var generatedFingerprint string var activeGrants []models.SSHPrincipalGrant var now int64 var allowedGrantIDs map[string]bool var i int item = existing req.ServerID = strings.TrimSpace(req.ServerID) req.Name = strings.TrimSpace(req.Name) req.Description = strings.TrimSpace(req.Description) req.RemoteUsername = strings.TrimSpace(req.RemoteUsername) req.AuthMethod = normalizeSSHAccessAuthMethod(req.AuthMethod) req.SSHUserCAID = strings.TrimSpace(req.SSHUserCAID) req.SSHPrincipalMode = normalizeSSHAccessPrincipalMode(req.SSHPrincipalMode) privateKeyPEM = strings.TrimSpace(req.PrivateKeyPEM) targets, err = normalizeSSHAccessProfileTargets(req.Targets, !options.RequireTargets) if err != nil { return item, err } if req.Name == "" { return item, errors.New("name is required") } if req.ServerID == "" { return item, errors.New("server_id is required") } if req.RemoteUsername == "" { return item, errors.New("remote_username is required") } if req.AuthMethod == "" { return item, errors.New("auth_method must be managed_ssh_cert or stored_private_key") } if req.AuthMethod == "managed_ssh_cert" { if req.SSHUserCAID == "" { return item, errors.New("ssh_user_ca_id is required for managed_ssh_cert") } if options.SelfService { _, err = api.store(r).GetSSHUserCAForUser(req.SSHUserCAID) } else { _, err = api.store(r).GetSSHUserCA(req.SSHUserCAID) } if err != nil { if err == sql.ErrNoRows { return item, errors.New("ssh_user_ca_id not found") } return item, err } if req.SSHPrincipalMode == "grant" { req.SSHPrincipalGrantIDs = normalizeSSHBrokerNames(req.SSHPrincipalGrantIDs) if len(req.SSHPrincipalGrantIDs) == 0 { return item, errors.New("at least one ssh_principal_grant_id is required in grant mode") } if options.SelfService { now = time.Now().UTC().Unix() activeGrants, err = api.store(r).ListActiveSSHPrincipalGrantsForUser(options.OwnerUserID, now) if err != nil { return item, err } allowedGrantIDs = map[string]bool{} for i = 0; i < len(activeGrants); i++ { allowedGrantIDs[activeGrants[i].ID] = true } for i = 0; i < len(req.SSHPrincipalGrantIDs); i++ { if !allowedGrantIDs[req.SSHPrincipalGrantIDs[i]] { return item, errors.New("ssh_principal_grant_id not available to user: " + req.SSHPrincipalGrantIDs[i]) } } } } else { req.SSHPrincipals = normalizeSSHBrokerNames(req.SSHPrincipals) if len(req.SSHPrincipals) == 0 { return item, errors.New("at least one ssh_principal is required in explicit mode") } } if privateKeyPEM == "" && !isUpdate { generatedKeyPEM, generatedPublicKey, generatedFingerprint, err = generateSSHUserCAKeyPair("ed25519") if err != nil { return item, err } privateKeyPEM = strings.TrimSpace(generatedKeyPEM) publicKey = strings.TrimSpace(generatedPublicKey) fingerprint = strings.TrimSpace(generatedFingerprint) } } else { req.SSHUserCAID = "" req.SSHPrincipalMode = "explicit" req.SSHPrincipals = []string{} req.SSHPrincipalGrantIDs = []string{} if privateKeyPEM == "" && !isUpdate { return item, errors.New("private_key_pem is required for stored_private_key") } } if privateKeyPEM != "" && publicKey == "" { signer, err = parseSSHSignerFromPEM(privateKeyPEM) if err != nil { return item, err } publicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(signer.PublicKey())) } if req.DefaultValidSeconds <= 0 { req.DefaultValidSeconds = 3600 } if req.MaxValidSeconds <= 0 { req.MaxValidSeconds = req.DefaultValidSeconds } if req.MaxValidSeconds < req.DefaultValidSeconds { return item, errors.New("max_valid_seconds must be greater than or equal to default_valid_seconds") } if options.SelfService { _, err = api.store(r).GetOwnedSSHServerForUser(options.OwnerUserID, req.ServerID) if err != nil { if err == sql.ErrNoRows { return item, errors.New("server_id not found") } return item, err } } else { _, err = api.store(r).GetSSHServer(req.ServerID) if err != nil { if err == sql.ErrNoRows { return item, errors.New("server_id not found") } return item, err } } err = validateSSHAccessProfileTargets(api.store(r), targets) if err != nil { return item, err } item.ServerID = req.ServerID item.Name = req.Name item.Description = req.Description item.RemoteUsername = req.RemoteUsername item.AuthMethod = req.AuthMethod item.OwnerScope = options.OwnerScope item.OwnerUserID = options.OwnerUserID item.AllowUserEdit = options.AllowUserEdit item.Enabled = req.Enabled item.SSHUserCAID = req.SSHUserCAID item.SSHPrincipalMode = req.SSHPrincipalMode item.SSHPrincipals = req.SSHPrincipals item.SSHPrincipalGrantIDs = req.SSHPrincipalGrantIDs item.DefaultValidSeconds = req.DefaultValidSeconds item.MaxValidSeconds = req.MaxValidSeconds item.Targets = targets if !isUpdate { item.CreatedByKind = actorKind item.CreatedBySubjectID = actorID item.CreatedBySubjectName = actorName } if privateKeyPEM != "" { item.SecretPayload, err = api.encryptSSHSecretPayload(privateKeyPEM) if err != nil { return item, err } item.AuthPublicKey = publicKey item.AuthPublicKeyFingerprint = fingerprint } return item, nil } func normalizeSSHAccessAuthMethod(raw string) string { var value string value = strings.ToLower(strings.TrimSpace(raw)) if value == "managed_ssh_cert" || value == "stored_private_key" { return value } return "" } func normalizeSSHAccessPrincipalMode(raw string) string { var value string value = strings.ToLower(strings.TrimSpace(raw)) if value == "grant" { return "grant" } return "explicit" } func normalizeSSHBrokerTags(raw []string) []string { var out []string out = db.NormalizeSSHServerTags(raw) return out } func normalizeSSHBrokerNames(raw []string) []string { var out []string var seen map[string]bool var value string var i int seen = map[string]bool{} for i = 0; i < len(raw); i++ { value = strings.TrimSpace(raw[i]) if value == "" { continue } if seen[value] { continue } seen[value] = true out = append(out, value) } sort.Strings(out) return out } func normalizeSSHAccessProfileTargets(raw []sshAccessProfileTargetRequest, allowEmpty bool) ([]models.SSHAccessProfileTarget, error) { var out []models.SSHAccessProfileTarget var dedupe map[string]bool var targetType string var targetID string var key string var i int dedupe = map[string]bool{} for i = 0; i < len(raw); i++ { targetType = strings.ToLower(strings.TrimSpace(raw[i].TargetType)) targetID = strings.TrimSpace(raw[i].TargetID) if targetType != "user" && targetType != "group" { return nil, errors.New("target_type must be user or group") } if targetID == "" { return nil, errors.New("target_id is required") } key = targetType + ":" + targetID if dedupe[key] { continue } dedupe[key] = true out = append(out, models.SSHAccessProfileTarget{TargetType: targetType, TargetID: targetID}) } if len(out) == 0 && !allowEmpty { return nil, errors.New("at least one target is required") } return out, nil } func validateSSHAccessProfileTargets(store *db.Store, targets []models.SSHAccessProfileTarget) error { var i int var err error if len(targets) == 0 { return nil } for i = 0; i < len(targets); i++ { if targets[i].TargetType == "user" { _, err = store.GetUserByID(targets[i].TargetID) } else { _, err = store.GetUserGroup(targets[i].TargetID) } if err != nil { if err == sql.ErrNoRows { return errors.New(targets[i].TargetType + " target not found: " + targets[i].TargetID) } return err } } return nil } func currentSSHBrokerActor(r *http.Request) (string, string, string) { var user models.User var principal models.ServicePrincipal var ok bool user, ok = middleware.UserFromContext(r.Context()) if ok { return "user", user.ID, user.Username } principal, ok = middleware.PrincipalFromContext(r.Context()) if ok { return "service_principal", principal.ID, principal.Name } return "system", "", "system" } func normalizeSSHSessionTerm(raw string) string { var value string value = strings.TrimSpace(raw) if value == "" { return "xterm-256color" } return value } func normalizeSSHSessionCols(value int) int { if value <= 0 { return 120 } if value > 400 { return 400 } return value } func normalizeSSHSessionRows(value int) int { if value <= 0 { return 36 } if value > 200 { return 200 } return value } func parseSSHServerHostKey(raw string) (models.SSHServerHostKey, error) { var item models.SSHServerHostKey var key ssh.PublicKey var err error key, _, _, _, err = ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(raw))) if err != nil { return item, errors.New("invalid ssh public key") } item.Algorithm = key.Type() item.PublicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key))) item.Fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(key)) return item, nil } func discoverSSHServerHostKey(host string, port int) (models.SSHServerHostKey, error) { var item models.SSHServerHostKey var addr string var conn net.Conn var clientConn ssh.Conn var reqs <-chan *ssh.Request var config *ssh.ClientConfig var discovered ssh.PublicKey var err error addr = net.JoinHostPort(strings.TrimSpace(host), fmt.Sprintf("%d", port)) config = &ssh.ClientConfig{ User: "codit", Auth: []ssh.AuthMethod{}, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { discovered = key return nil }, Timeout: 10 * time.Second, } conn, err = net.DialTimeout("tcp", addr, 10*time.Second) if err != nil { return item, err } defer conn.Close() clientConn, _, reqs, err = ssh.NewClientConn(conn, addr, config) if clientConn != nil { clientConn.Close() } if reqs != nil { go ssh.DiscardRequests(reqs) } if discovered == nil { return item, errors.New("failed to discover ssh host key") } item.Algorithm = discovered.Type() item.PublicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(discovered))) item.Fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(discovered)) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "unable to authenticate") { return item, nil } if strings.Contains(strings.ToLower(err.Error()), "no auth passed yet") { return item, nil } } return item, nil } func (api *API) serveSSHSessionStream(ws *websocket.Conn, r *http.Request, user models.User, sessionItem models.SSHSession) { var profile models.SSHAccessProfile var authMethods []ssh.AuthMethod var hostKeys []models.SSHServerHostKey var sshConfig *ssh.ClientConfig var sshClient *ssh.Client var sshSession *ssh.Session var stdin io.WriteCloser var stdout io.Reader var stderr io.Reader var runtimeSession *sshActiveSession var closeReason string var err error var connectedFingerprint string var connectedAt int64 var now int64 var sendMu sync.Mutex var inputCh chan sshSessionStreamMessage var inputErrCh chan error var shellErrCh chan error var done bool var waitErr error defer ws.Close() runtimeSession = &sshActiveSession{} runtimeSession.SetResources(ws, nil, nil, nil) if api.SSHSessionRegistry != nil { err = api.SSHSessionRegistry.Register(sessionItem.ID, runtimeSession) if err != nil { api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error()) return } defer api.SSHSessionRegistry.Unregister(sessionItem.ID, runtimeSession) } profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, sessionItem.ProfileID) if err != nil { api.logSSHSessionPrepareFailure(sessionItem.ID, sessionItem.RemoteAddr, user.Username, sessionItem.ProfileID, "load_profile", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: "ssh access profile not found"}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, "ssh access profile not found") return } api.logSSHSessionPrepare(sessionItem.ID, sessionItem.RemoteAddr, user, profile) hostKeys, err = api.store(r).ListSSHServerHostKeys(profile.ServerID) if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "load_host_keys", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error()) return } if len(hostKeys) == 0 { api.logSSHSessionConnectFailure(sessionItem, profile, user, "host_key_missing", errors.New("no pinned host key for this server")) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: "no pinned host key for this server"}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, "no pinned host key") return } authMethods, err = api.buildSSHSessionAuth(profile) if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "build_auth", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error()) return } closeReason = runtimeSession.CloseReason() if closeReason != "" { now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", "", 0, now, closeReason) api.logSSHSessionClosed(sessionItem, profile, user, "closed", closeReason, 0, now) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) return } api.logSSHSessionConnectStart(sessionItem, profile, user) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "connecting", "", 0, 0, "") api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "connecting", Message: "connecting"}) sshConfig = &ssh.ClientConfig{ User: profile.RemoteUsername, Auth: authMethods, HostKeyCallback: sshHostKeyCallback(hostKeys, &connectedFingerprint), Timeout: 15 * time.Second, } sshClient, err = ssh.Dial("tcp", net.JoinHostPort(profile.Server.Host, fmt.Sprintf("%d", profile.Server.Port)), sshConfig) if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "dial", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } defer sshClient.Close() runtimeSession.SetResources(ws, sshClient, nil, nil) sshSession, err = sshClient.NewSession() if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "new_session", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } defer sshSession.Close() runtimeSession.SetResources(ws, sshClient, sshSession, nil) stdin, err = sshSession.StdinPipe() if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "stdin_pipe", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } stdout, err = sshSession.StdoutPipe() if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "stdout_pipe", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } stderr, err = sshSession.StderrPipe() if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "stderr_pipe", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } runtimeSession.SetResources(ws, sshClient, sshSession, stdin) closeReason = runtimeSession.CloseReason() if closeReason != "" { now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) return } err = sshSession.RequestPty(sessionItem.RequestedTerm, sessionItem.RequestedRows, sessionItem.RequestedCols, ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, }) if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "request_pty", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } err = sshSession.Shell() if err != nil { api.logSSHSessionConnectFailure(sessionItem, profile, user, "shell", err) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) now = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } connectedAt = time.Now().UTC().Unix() _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "connected", connectedFingerprint, connectedAt, 0, "") api.logSSHSessionConnectSuccess(sessionItem, profile, user, connectedFingerprint) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "connected", Message: connectedFingerprint}) inputCh = make(chan sshSessionStreamMessage) inputErrCh = make(chan error, 1) shellErrCh = make(chan error, 1) go api.readSSHSessionWebsocket(ws, inputCh, inputErrCh) go api.pipeSSHSessionOutput(stdout, ws, &sendMu) go api.pipeSSHSessionOutput(stderr, ws, &sendMu) go func() { var shellErr error shellErr = sshSession.Wait() shellErrCh <- shellErr }() done = false waitErr = nil for !done { select { case err = <-inputErrCh: if err != nil { waitErr = nil } done = true case waitErr = <-shellErrCh: done = true case msg := <-inputCh: if msg.Type == "input" { _, err = io.WriteString(stdin, msg.Data) if err != nil { waitErr = err done = true } } else if msg.Type == "resize" { _ = sshSession.WindowChange(normalizeSSHSessionRows(msg.Rows), normalizeSSHSessionCols(msg.Cols)) } } } _ = stdin.Close() now = time.Now().UTC().Unix() closeReason = runtimeSession.CloseReason() if closeReason != "" { _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, connectedAt, now, closeReason) api.logSSHSessionClosed(sessionItem, profile, user, "closed", closeReason, connectedAt, now) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) return } if waitErr != nil && !errors.Is(waitErr, io.EOF) { _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, connectedAt, now, waitErr.Error()) api.Logger.Write(logIDSSHBroker, util.LOG_WARN, "session error session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s dur_s=%d err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, now-connectedAt, waitErr) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: waitErr.Error()}) return } _ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, connectedAt, now, "") api.logSSHSessionClosed(sessionItem, profile, user, "closed", "", connectedAt, now) api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: ""}) } func (api *API) buildSSHSessionAuth(profile models.SSHAccessProfile) ([]ssh.AuthMethod, error) { var secret models.SSHSecret var privateKeyPEM string var signer ssh.Signer var ca models.SSHUserCA var signed sshSignUserKeyResponse var certKey ssh.PublicKey var cert *ssh.Certificate var certSigner ssh.Signer var principals []string var err error if strings.TrimSpace(profile.SecretID) == "" { return nil, errors.New("ssh access profile has no secret") } secret, err = api.Store.GetSSHSecret(profile.SecretID) if err != nil { return nil, err } privateKeyPEM, err = api.decryptSSHSecretPayload(secret.Payload) if err != nil { return nil, err } signer, err = parseSSHSignerFromPEM(privateKeyPEM) if err != nil { return nil, err } if profile.AuthMethod == "stored_private_key" { return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil } ca, err = api.Store.GetSSHUserCA(profile.SSHUserCAID) if err != nil { return nil, err } principals, err = api.resolveSSHAccessProfilePrincipals(profile) if err != nil { return nil, err } signed, err = api.signSSHUserKeyWithCA(api.Store, ca, profile.AuthPublicKey, "", principals, profile.DefaultValidSeconds, profile.MaxValidSeconds) if err != nil { return nil, err } certKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(signed.Certificate))) if err != nil { return nil, err } cert, _ = certKey.(*ssh.Certificate) if cert == nil { return nil, errors.New("signed ssh certificate is invalid") } certSigner, err = ssh.NewCertSigner(cert, signer) if err != nil { return nil, err } return []ssh.AuthMethod{ssh.PublicKeys(certSigner)}, nil } func (api *API) resolveSSHAccessProfilePrincipals(profile models.SSHAccessProfile) ([]string, error) { var principals []string var grant models.SSHPrincipalGrant var seen map[string]bool var i int var j int var err error if profile.SSHPrincipalMode != "grant" { return normalizeSSHBrokerNames(profile.SSHPrincipals), nil } seen = map[string]bool{} for i = 0; i < len(profile.SSHPrincipalGrantIDs); i++ { grant, err = api.Store.GetSSHPrincipalGrant(profile.SSHPrincipalGrantIDs[i]) if err != nil { return nil, err } for j = 0; j < len(grant.Principals); j++ { if seen[grant.Principals[j]] { continue } seen[grant.Principals[j]] = true principals = append(principals, grant.Principals[j]) } } sort.Strings(principals) if len(principals) == 0 { return nil, errors.New("no ssh principals resolved for access profile") } return principals, nil } func sshHostKeyCallback(items []models.SSHServerHostKey, connectedFingerprint *string) ssh.HostKeyCallback { var allowed map[string]bool var i int allowed = map[string]bool{} for i = 0; i < len(items); i++ { allowed[strings.TrimSpace(items[i].Fingerprint)] = true } return func(hostname string, remote net.Addr, key ssh.PublicKey) error { var fingerprint string fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(key)) if connectedFingerprint != nil { *connectedFingerprint = fingerprint } if allowed[fingerprint] { return nil } return errors.New("host key mismatch: " + fingerprint) } } func (api *API) pipeSSHSessionOutput(reader io.Reader, ws *websocket.Conn, sendMu *sync.Mutex) { var buffer []byte var chunk []byte var n int var err error buffer = make([]byte, 4096) for { n, err = reader.Read(buffer) if n > 0 { chunk = make([]byte, n) copy(chunk, buffer[:n]) api.sendSSHSessionBinary(sendMu, ws, chunk) } if err != nil { return } } } func (api *API) readSSHSessionWebsocket(ws *websocket.Conn, inputCh chan<- sshSessionStreamMessage, errCh chan<- error) { var msg sshSessionStreamMessage var err error for { msg = sshSessionStreamMessage{} err = websocket.JSON.Receive(ws, &msg) if err != nil { errCh <- err return } inputCh <- msg } } func (api *API) sendSSHSessionWS(mu *sync.Mutex, ws *websocket.Conn, msg sshSessionStreamMessage) { if mu == nil { return } mu.Lock() defer mu.Unlock() _ = websocket.JSON.Send(ws, msg) } func (api *API) sendSSHSessionBinary(mu *sync.Mutex, ws *websocket.Conn, data []byte) { if mu == nil { return } mu.Lock() defer mu.Unlock() _ = websocket.Message.Send(ws, data) }