package handlers import codit_logger "codit/logger" import "database/sql" import "encoding/base64" import "errors" import "fmt" import "io" import "net" import "net/http" import "sort" import "strconv" import "strings" import "sync" import "time" import "golang.org/x/crypto/ssh" import "golang.org/x/net/websocket" import "codit/internal/db" import "codit/internal/middleware" import "codit/internal/models" const logIDSSHBroker string = "ssh-broker" const sshWebsocketWriteTimeout time.Duration = 5 * time.Second const sshSessionInputBufferSize int = 64 type sshServerRequest struct { Name string `json:"name"` Host string `json:"host"` Port int `json:"port"` Description string `json:"description"` Tags []string `json:"tags"` Enabled bool `json:"enabled"` HostKeyPolicy string `json:"host_key_policy"` } type sshCredentialRequest struct { Name string `json:"name"` Description string `json:"description"` PrivateKeyPEM string `json:"private_key_pem"` Enabled bool `json:"enabled"` } type sshAccessProfileTargetRequest struct { TargetType string `json:"target_type"` TargetID string `json:"target_id"` } type sshAccessProfileRequest struct { ServerID string `json:"server_id"` ServerTargetType string `json:"server_target_type"` ServerGroupID string `json:"server_group_id"` Name string `json:"name"` Description string `json:"description"` RemoteUsername string `json:"remote_username"` AuthMethod string `json:"auth_method"` SecondFactorMode string `json:"second_factor_mode"` Enabled bool `json:"enabled"` PrivateKeyPEM string `json:"private_key_pem"` SSHCredentialID string `json:"ssh_credential_id"` PasswordText string `json:"password_text"` SSHUserCAID string `json:"ssh_user_ca_id"` SSHPrincipalGrantIDs []string `json:"ssh_principal_grant_ids"` DefaultValidSeconds int64 `json:"default_valid_seconds"` MaxValidSeconds int64 `json:"max_valid_seconds"` Targets []sshAccessProfileTargetRequest `json:"targets"` } type sshServerGroupRequest struct { Name string `json:"name"` Description string `json:"description"` Enabled bool `json:"enabled"` } type sshServerGroupMemberRequest struct { ServerID string `json:"server_id"` } type sshAccessProfileNormalizeOptions struct { OwnerScope string OwnerUserID string AllowUserEdit bool RequireTargets bool SelfService bool } type sshServerHostKeyRequest struct { PublicKey string `json:"public_key"` } type sshSessionConnectRequest struct { Cols int `json:"cols"` Rows int `json:"rows"` ValidSeconds int64 `json:"valid_seconds"` Term string `json:"term"` Password string `json:"password"` OTPCode string `json:"otp_code"` ServerID string `json:"server_id"` } type sshSessionConnectResponse struct { SessionID string `json:"session_id"` Status string `json:"status"` ServerName string `json:"server_name"` Host string `json:"host"` Port int `json:"port"` RemoteUsername string `json:"remote_username"` HostKeyFingerprint string `json:"host_key_fingerprint"` } type sshSessionListResponse struct { Items []models.SSHSession `json:"items"` Limit int `json:"limit"` Offset int `json:"offset"` HasMore bool `json:"has_more"` } type sshSessionStreamMessage struct { SessionID string `json:"session_id"` Type string `json:"type"` Data string `json:"data"` Cols int `json:"cols"` Rows int `json:"rows"` Status string `json:"status"` Message string `json:"message"` } type sshSessionPrepared struct { Session models.SSHSession Profile models.SSHAccessProfile HostKeys []models.SSHServerHostKey AuthMethods []ssh.AuthMethod PinHostKeyOnFirstUse bool } type sshWorkspaceAttachment struct { InputCh chan sshSessionStreamMessage InputErrCh chan error } type sshWorkspaceAttachmentDone struct { SessionID string Attachment *sshWorkspaceAttachment } const sshAuthMethodManagedSSHCert string = "managed_ssh_cert" const sshAuthMethodPromptedPassword string = "prompted_password" const sshAuthMethodStoredPassword string = "stored_password" const sshAuthMethodStoredPrivateKey string = "stored_private_key" const sshSecondFactorNone string = "none" const sshSecondFactorPromptedTOTP string = "prompted_totp" func (api *API) logSSHSessionConnectStart(sessionItem *models.SSHSession, profile *models.SSHAccessProfile, user *models.User) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "connect start session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod) } func (api *API) logSSHSessionPrepare(sessionItem *models.SSHSession, user *models.User, profile *models.SSHAccessProfile) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "connect prepare build session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod) } func (api *API) logSSHSessionPreparedUse(sessionItem *models.SSHSession, user *models.User, profile *models.SSHAccessProfile) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "connect prepare use session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod) } func (api *API) logSSHSessionConnectSuccess(sessionItem *models.SSHSession, profile *models.SSHAccessProfile, user *models.User, hostKeyFingerprint string) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "connect success session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s host_key=%s", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, hostKeyFingerprint) } func (api *API) logSSHSessionConnectFailure(sessionItem *models.SSHSession, profile *models.SSHAccessProfile, user *models.User, stage string, err error) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "connect failed session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s stage=%s err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, stage, err) } func (api *API) logSSHSessionPrepareFailure(sessionItem *models.SSHSession, user *models.User, stage string, err error) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "connect failed session=%s requester=%s user=%s profile=%s stage=%s err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, sessionItem.ProfileID, stage, err) } func (api *API) logSSHSessionClosed(sessionItem *models.SSHSession, profile *models.SSHAccessProfile, user *models.User, status string, reason string, connectedAt int64, endedAt int64) { var durationSeconds int64 durationSeconds = 0 if connectedAt > 0 && endedAt >= connectedAt { durationSeconds = endedAt - connectedAt } api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "session closed session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s status=%s dur_s=%d reason=%q", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, status, durationSeconds, reason) } func (api *API) ListSSHServersAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var items []models.SSHServer var err error if !api.requireAdmin(w, r) { return } items, err = api.store(r).ListSSHServers() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) GetSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) CreateSSHServerAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshServerRequest var item models.SSHServer var actorKind string var actorID string var actorName string var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } actorKind, actorID, actorName = currentSSHBrokerActor(r) item = models.SSHServer{ Name: req.Name, Host: req.Host, Port: req.Port, Description: req.Description, Tags: normalizeSSHBrokerTags(req.Tags), Enabled: req.Enabled, HostKeyPolicy: normalizeSSHServerHostKeyPolicy(req.HostKeyPolicy), CreatedByKind: actorKind, CreatedBySubjectID: actorID, CreatedBySubjectName: actorName, } item, err = api.store(r).CreateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshServerRequest var item models.SSHServer var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } item.Name = req.Name item.Host = req.Host item.Port = req.Port item.Description = req.Description item.Tags = normalizeSSHBrokerTags(req.Tags) item.Enabled = req.Enabled item.HostKeyPolicy = normalizeSSHServerHostKeyPolicy(req.HostKeyPolicy) item, err = api.store(r).UpdateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHServer(params["id"]) if err != nil { if err == db.ErrSSHServerHasSessionHistory { WriteJSON(w, http.StatusConflict, map[string]string{"error": "cannot delete ssh server while session history exists; disable it instead"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHServerGroupsAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var items []models.SSHServerGroup var err error if !api.requireAdmin(w, r) { return } items, err = api.store(r).ListSSHServerGroups() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) CreateSSHServerGroupAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshServerGroupRequest var user models.User var ok bool var item models.SSHServerGroup var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } user, ok = middleware.UserFromContext(r.Context()) if !ok { w.WriteHeader(http.StatusUnauthorized) return } item = models.SSHServerGroup{Name: strings.TrimSpace(req.Name), Description: strings.TrimSpace(req.Description), Enabled: req.Enabled, CreatedByKind: "user", CreatedBySubjectID: user.ID, CreatedBySubjectName: user.Username} item, err = api.store(r).CreateSSHServerGroup(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHServerGroupAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshServerGroupRequest var item models.SSHServerGroup var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServerGroup(params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server group not found"}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } item.Name = strings.TrimSpace(req.Name) item.Description = strings.TrimSpace(req.Description) item.Enabled = req.Enabled item, err = api.store(r).UpdateSSHServerGroup(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHServerGroupAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHServerGroup(params["id"]) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHServerGroupMembersAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var items []models.SSHServer var err error if !api.requireAdmin(w, r) { return } _, err = api.store(r).GetSSHServerGroup(params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server group not found"}) return } items, err = api.store(r).ListSSHServersForGroup(params["id"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) AddSSHServerGroupMemberAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshServerGroupMemberRequest var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } req.ServerID = strings.TrimSpace(req.ServerID) if req.ServerID == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "server_id is required"}) return } _, err = api.store(r).GetSSHServerGroup(params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server group not found"}) return } _, err = api.store(r).GetSSHServer(req.ServerID) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } err = api.store(r).AddSSHServerGroupMember(params["id"], req.ServerID) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) DeleteSSHServerGroupMemberAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHServerGroupMember(params["id"], params["serverId"]) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHCredentialsAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var items []models.SSHCredential var err error if !api.requireAdmin(w, r) { return } items, err = api.store(r).ListSSHCredentials() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) CreateSSHCredentialAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshCredentialRequest var user models.User var ok bool var item models.SSHCredential var secret models.SSHSecret var signer ssh.Signer var encrypted string var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } req.Name = strings.TrimSpace(req.Name) req.PrivateKeyPEM = strings.TrimSpace(req.PrivateKeyPEM) if req.Name == "" || req.PrivateKeyPEM == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name and private_key_pem are required"}) return } signer, err = parseSSHSignerFromPEM(req.PrivateKeyPEM) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } encrypted, err = api.encryptSSHSecretPayload(req.PrivateKeyPEM) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } user, ok = middleware.UserFromContext(r.Context()) if !ok { w.WriteHeader(http.StatusUnauthorized) return } item = models.SSHCredential{ Name: req.Name, Description: strings.TrimSpace(req.Description), Type: "private_key", PublicKey: strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))), Fingerprint: strings.TrimSpace(ssh.FingerprintSHA256(signer.PublicKey())), Enabled: req.Enabled, OwnerScope: "admin", CreatedByKind: "user", CreatedBySubjectID: user.ID, CreatedBySubjectName: user.Username, } secret = models.SSHSecret{Payload: encrypted} item, err = api.store(r).CreateSSHCredential(item, secret) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHCredentialAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshCredentialRequest var item models.SSHCredential var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHCredential(params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh credential not found"}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } item.Name = strings.TrimSpace(req.Name) item.Description = strings.TrimSpace(req.Description) item.Enabled = req.Enabled item, err = api.store(r).UpdateSSHCredential(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHCredentialAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHCredential(params["id"]) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHCredentialsForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var items []models.SSHCredential var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } items, err = api.store(r).ListSSHCredentialsForUser(user.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) CreateSSHCredentialForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshCredentialRequest var user models.User var ok bool var item models.SSHCredential var secret models.SSHSecret var signer ssh.Signer var encrypted string var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } req.Name = strings.TrimSpace(req.Name) req.PrivateKeyPEM = strings.TrimSpace(req.PrivateKeyPEM) if req.Name == "" || req.PrivateKeyPEM == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name and private_key_pem are required"}) return } signer, err = parseSSHSignerFromPEM(req.PrivateKeyPEM) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } encrypted, err = api.encryptSSHSecretPayload(req.PrivateKeyPEM) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } item = models.SSHCredential{ Name: req.Name, Description: strings.TrimSpace(req.Description), Type: "private_key", PublicKey: strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))), Fingerprint: strings.TrimSpace(ssh.FingerprintSHA256(signer.PublicKey())), Enabled: req.Enabled, OwnerScope: "user", OwnerUserID: user.ID, CreatedByKind: "user", CreatedBySubjectID: user.ID, CreatedBySubjectName: user.Username, } secret = models.SSHSecret{Payload: encrypted} item, err = api.store(r).CreateSSHCredential(item, secret) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHCredentialForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshCredentialRequest var user models.User var item models.SSHCredential var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHCredentialForUser(user.ID, params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh credential not found"}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid json"}) return } item.Name = strings.TrimSpace(req.Name) item.Description = strings.TrimSpace(req.Description) item.Enabled = req.Enabled item, err = api.store(r).UpdateSSHCredential(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHCredentialForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } _, err = api.store(r).GetSSHCredentialForUser(user.ID, params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh credential not found"}) return } err = api.store(r).DeleteSSHCredential(params["id"]) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHAccessProfilesAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var items []models.SSHAccessProfile var err error if !api.requireAdmin(w, r) { return } items, err = api.store(r).ListSSHAccessProfiles() if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) GetSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHAccessProfile var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHAccessProfile(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) CreateSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var req sshAccessProfileRequest var item models.SSHAccessProfile var actorKind string var actorID string var actorName string var options sshAccessProfileNormalizeOptions var err error if !api.requireAdmin(w, r) { return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } actorKind, actorID, actorName = currentSSHBrokerActor(r) options = sshAccessProfileNormalizeOptions{ OwnerScope: "admin_shared", OwnerUserID: "", AllowUserEdit: false, RequireTargets: true, SelfService: false, } item, err = api.normalizeSSHAccessProfileRequest(r, req, models.SSHAccessProfile{}, actorKind, actorID, actorName, false, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).CreateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var req sshAccessProfileRequest var existing models.SSHAccessProfile var item models.SSHAccessProfile var actorKind string var actorID string var actorName string var options sshAccessProfileNormalizeOptions var err error if !api.requireAdmin(w, r) { return } existing, err = api.store(r).GetSSHAccessProfile(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } actorKind, actorID, actorName = currentSSHBrokerActor(r) options = sshAccessProfileNormalizeOptions{ OwnerScope: "admin_shared", OwnerUserID: "", AllowUserEdit: false, RequireTargets: true, SelfService: false, } item, err = api.normalizeSSHAccessProfileRequest(r, req, existing, actorKind, actorID, actorName, true, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).UpdateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHAccessProfile(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHAccessProfilesForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var items []models.SSHAccessProfile var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } items, err = api.store(r).ListSSHAccessProfilesForUser(user.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) GetSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHAccessProfile var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) GetSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) GetSSHAccessProfileCredentialForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var profile models.SSHAccessProfile var credential models.SSHCredential var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } if profile.AuthMethod != sshAuthMethodStoredPrivateKey || strings.TrimSpace(profile.SSHCredentialID) == "" { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh credential not found"}) return } credential, err = api.store(r).GetSSHCredential(profile.SSHCredentialID) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh credential not found"}) return } WriteJSON(w, http.StatusOK, credential) } func (api *API) ListSSHAccessProfileServersForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var profile models.SSHAccessProfile var items []models.SSHServer var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } if profile.ServerTargetType == "group" { items, err = api.store(r).ListSSHServersForGroup(profile.ServerGroupID) } else { items = []models.SSHServer{profile.Server} } if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) ListSSHServersForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var items []models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } items, err = api.store(r).ListSSHServersForUser(user.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) CreateSSHServerForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var req sshServerRequest var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } item = models.SSHServer{ Name: req.Name, Host: req.Host, Port: req.Port, Description: req.Description, Tags: normalizeSSHBrokerTags(req.Tags), Enabled: req.Enabled, HostKeyPolicy: normalizeSSHServerHostKeyPolicy(req.HostKeyPolicy), CreatedByKind: "user", CreatedBySubjectID: user.ID, CreatedBySubjectName: user.Username, Editable: true, } item, err = api.store(r).CreateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item.Editable = true WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var req sshServerRequest var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } req.Name = strings.TrimSpace(req.Name) req.Host = strings.TrimSpace(req.Host) req.Description = strings.TrimSpace(req.Description) if req.Name == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"}) return } if req.Host == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"}) return } if req.Port <= 0 || req.Port > 65535 { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"}) return } item.Name = req.Name item.Host = req.Host item.Port = req.Port item.Description = req.Description item.Tags = normalizeSSHBrokerTags(req.Tags) item.Enabled = req.Enabled item.HostKeyPolicy = normalizeSSHServerHostKeyPolicy(req.HostKeyPolicy) item, err = api.store(r).UpdateSSHServer(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item.Editable = true WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } _, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = api.store(r).DeleteSSHServer(params["id"]) if err != nil { if err == db.ErrSSHServerHasSessionHistory { WriteJSON(w, http.StatusConflict, map[string]string{"error": "cannot delete ssh server while session history exists; disable it instead"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHServerHostKeysForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var items []models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } items, err = api.store(r).ListSSHServerHostKeys(item.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) DiscoverSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var discovered models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } discovered, err = discoverSSHServerHostKey(item.Host, item.Port, api.serverId()) if err != nil { WriteJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()}) return } discovered.ServerID = item.ID WriteJSON(w, http.StatusOK, discovered) } func (api *API) CreateSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var req sshServerHostKeyRequest var hostKey models.SSHServerHostKey var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } hostKey, err = parseSSHServerHostKey(strings.TrimSpace(req.PublicKey)) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } hostKey.ServerID = item.ID hostKey, err = api.store(r).CreateSSHServerHostKey(hostKey) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, hostKey) } func (api *API) DeleteSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHServer var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = api.store(r).DeleteSSHServerHostKey(item.ID, params["hostKeyId"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) CreateSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var req sshAccessProfileRequest var item models.SSHAccessProfile var options sshAccessProfileNormalizeOptions var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } options = sshAccessProfileNormalizeOptions{ OwnerScope: "user", OwnerUserID: user.ID, AllowUserEdit: true, RequireTargets: false, SelfService: true, } item, err = api.normalizeSSHAccessProfileRequest(r, req, models.SSHAccessProfile{}, "user", user.ID, user.Username, false, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).CreateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, item) } func (api *API) UpdateSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var req sshAccessProfileRequest var existing models.SSHAccessProfile var item models.SSHAccessProfile var options sshAccessProfileNormalizeOptions var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } existing, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if existing.OwnerScope != "user" || existing.OwnerUserID != user.ID { w.WriteHeader(http.StatusForbidden) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } options = sshAccessProfileNormalizeOptions{ OwnerScope: "user", OwnerUserID: user.ID, AllowUserEdit: true, RequireTargets: false, SelfService: true, } item, err = api.normalizeSSHAccessProfileRequest(r, req, existing, "user", user.ID, user.Username, true, options) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } item, err = api.store(r).UpdateSSHAccessProfile(item) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, item) } func (api *API) DeleteSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHAccessProfile var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.OwnerScope != "user" || item.OwnerUserID != user.ID { w.WriteHeader(http.StatusForbidden) return } err = api.store(r).DeleteSSHAccessProfile(item.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) ListSSHServerHostKeysAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var items []models.SSHServerHostKey var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } items, err = api.store(r).ListSSHServerHostKeys(item.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusOK, items) } func (api *API) DiscoverSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var discovered models.SSHServerHostKey var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } discovered, err = discoverSSHServerHostKey(item.Host, item.Port, api.serverId()) if err != nil { WriteJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()}) return } discovered.ServerID = item.ID WriteJSON(w, http.StatusOK, discovered) } func (api *API) CreateSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHServer var req sshServerHostKeyRequest var hostKey models.SSHServerHostKey var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHServer(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } hostKey, err = parseSSHServerHostKey(strings.TrimSpace(req.PublicKey)) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } hostKey.ServerID = item.ID hostKey, err = api.store(r).CreateSSHServerHostKey(hostKey) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusCreated, hostKey) } func (api *API) DeleteSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var err error if !api.requireAdmin(w, r) { return } err = api.store(r).DeleteSSHServerHostKey(params["id"], params["hostKeyId"]) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } w.WriteHeader(http.StatusNoContent) } func (api *API) CreateSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var req sshSessionConnectRequest var profile models.SSHAccessProfile var selectedServer models.SSHServer var item models.SSHSession var response sshSessionConnectResponse var prepared sshSessionPrepared var prepareStage string var hostKeyFingerprint string var belongs bool var pendingCount int var pendingCutoff int64 var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } err = DecodeJSON(r, &req) if err != nil { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } pendingCutoff = time.Now().UTC().Add(-sshSessionPreparationTTL).Unix() err = api.store(r).ExpireStalePendingSSHSessions(pendingCutoff, "session preparation expired") if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } pendingCount, err = api.store(r).CountPendingSSHSessionsForUser(user.ID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if pendingCount >= sshSessionPendingPerUserLimit { WriteJSON(w, http.StatusTooManyRequests, map[string]string{"error": "too many pending ssh sessions; attach or wait for pending sessions to expire"}) return } profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if profile.AuthMethod == sshAuthMethodPromptedPassword && req.Password == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("password is required for %s", sshAuthMethodPromptedPassword)}) return } if profile.SecondFactorMode == sshSecondFactorPromptedTOTP && strings.TrimSpace(req.OTPCode) == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("otp_code is required for %s", sshSecondFactorPromptedTOTP)}) return } if profile.ServerTargetType == "group" { req.ServerID = strings.TrimSpace(req.ServerID) if req.ServerID == "" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "server_id is required for server group profiles"}) return } belongs, err = api.store(r).SSHServerBelongsToGroup(profile.ServerGroupID, req.ServerID) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if !belongs { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "server_id is not a member of the profile server group"}) return } selectedServer, err = api.store(r).GetSSHServer(req.ServerID) if err != nil { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"}) return } if !selectedServer.Enabled { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "ssh server is disabled"}) return } profile.ServerID = selectedServer.ID profile.Server = selectedServer } else if strings.TrimSpace(req.ServerID) != "" && strings.TrimSpace(req.ServerID) != profile.ServerID { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "server_id does not match access profile"}) return } item = models.SSHSession{ ProfileID: profile.ID, ServerID: profile.ServerID, UserID: user.ID, Username: user.Username, RemoteUsername: profile.RemoteUsername, AuthMethod: profile.AuthMethod, SecondFactorMode: profile.SecondFactorMode, Host: profile.Server.Host, Port: profile.Server.Port, Status: "pending", RequestedTerm: normalizeSSHSessionTerm(req.Term), RequestedCols: normalizeSSHSessionCols(req.Cols), RequestedRows: normalizeSSHSessionRows(req.Rows), RemoteAddr: requestRemoteAddr(r), UserAgent: strings.TrimSpace(r.UserAgent()), } item, err = api.store(r).CreateSSHSession(item) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } prepared, prepareStage, err = api.prepareSSHSession(api.store(r), &user, &item, req.Password, req.OTPCode) if err != nil { if prepareStage == "host_key_missing" || prepareStage == "build_auth" { WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } middleware.RegisterAfterCommit(r.Context(), func() { if api.SSHPromptedAuthStore != nil { api.SSHPromptedAuthStore.Put(item.ID, SSHPromptedAuthInput{ Password: req.Password, OTPCode: req.OTPCode, }) } if api.SSHPreparedSessionStore != nil { api.SSHPreparedSessionStore.Put(item.ID, prepared) } }) if len(prepared.HostKeys) > 0 { hostKeyFingerprint = prepared.HostKeys[0].Fingerprint } response = sshSessionConnectResponse{ SessionID: item.ID, Status: item.Status, ServerName: profile.Server.Name, Host: profile.Server.Host, Port: profile.Server.Port, RemoteUsername: profile.RemoteUsername, HostKeyFingerprint: hostKeyFingerprint, } WriteJSON(w, http.StatusCreated, response) } func (api *API) GetSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHSession var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.UserID != user.ID { w.WriteHeader(http.StatusForbidden) return } api.setSSHSessionTranscriptAvailability(&item) WriteJSON(w, http.StatusOK, item) } func parseSSHSessionListLimit(raw string) int { var value int var err error value = 20 raw = strings.TrimSpace(raw) if raw == "" { return value } value, err = strconv.Atoi(raw) if err != nil || value <= 0 { return 20 } if value > 200 { return 200 } return value } func parseSSHSessionListOffset(raw string) int { var value int var err error value = 0 raw = strings.TrimSpace(raw) if raw == "" { return value } value, err = strconv.Atoi(raw) if err != nil || value < 0 { return 0 } return value } func (api *API) ListSSHSessionsForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) { var user models.User var ok bool var limit int var offset int var query string var status string var items []models.SSHSession var hasMore bool var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } limit = parseSSHSessionListLimit(r.URL.Query().Get("limit")) offset = parseSSHSessionListOffset(r.URL.Query().Get("offset")) query = strings.TrimSpace(r.URL.Query().Get("q")) status = strings.TrimSpace(r.URL.Query().Get("status")) items, hasMore, err = api.store(r).ListSSHSessionsForUserFiltered(user.ID, limit, offset, query, status) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } api.setSSHSessionsTranscriptAvailability(items) WriteJSON(w, http.StatusOK, sshSessionListResponse{Items: items, Limit: limit, Offset: offset, HasMore: hasMore}) } func (api *API) ListSSHSessionsAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) { var limit int var offset int var query string var status string var items []models.SSHSession var hasMore bool var err error if !api.requireAdmin(w, r) { return } limit = parseSSHSessionListLimit(r.URL.Query().Get("limit")) offset = parseSSHSessionListOffset(r.URL.Query().Get("offset")) query = strings.TrimSpace(r.URL.Query().Get("q")) status = strings.TrimSpace(r.URL.Query().Get("status")) items, hasMore, err = api.store(r).ListSSHSessionsFiltered(limit, offset, query, status) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } api.setSSHSessionsTranscriptAvailability(items) WriteJSON(w, http.StatusOK, sshSessionListResponse{Items: items, Limit: limit, Offset: offset, HasMore: hasMore}) } func (api *API) DisconnectSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHSession var closed bool var err error var now int64 user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.UserID != user.ID { w.WriteHeader(http.StatusForbidden) return } if api.SSHPromptedAuthStore != nil { api.SSHPromptedAuthStore.Delete(item.ID) } if api.SSHPreparedSessionStore != nil { api.SSHPreparedSessionStore.Delete(item.ID) } if api.SSHSessionRegistry != nil { closed = api.SSHSessionRegistry.RequestClose(item.ID, SSHSessionCloseUserDisconnect) } now = time.Now().UTC().Unix() if !closed { err = api.store(r).UpdateSSHSessionStatus(item.ID, "closed", item.HostKeyFingerprint, item.ConnectedAt, now, "disconnected by user") if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } } w.WriteHeader(http.StatusNoContent) } func (api *API) StreamSSHWorkspaceForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var store *db.Store // one websocket deals with multiple ssh sessons. user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } store = api.store(r) websocket.Handler(func(ws *websocket.Conn) { api.serveSSHWorkspaceStream(ws, store, user, r.RemoteAddr) }).ServeHTTP(w, r) } func (api *API) normalizeSSHAccessProfileRequest(r *http.Request, req sshAccessProfileRequest, existing models.SSHAccessProfile, actorKind string, actorID string, actorName string, isUpdate bool, options sshAccessProfileNormalizeOptions) (models.SSHAccessProfile, error) { var item models.SSHAccessProfile var err error var targets []models.SSHAccessProfileTarget var privateKeyPEM string var passwordText string var signer ssh.Signer var publicKey string var fingerprint string var generatedKeyPEM string var generatedPublicKey string var generatedFingerprint string var activeGrants []models.SSHPrincipalGrant var credential models.SSHCredential var now int64 var allowedGrantIDs map[string]bool var i int item = existing req.ServerID = strings.TrimSpace(req.ServerID) req.ServerTargetType = strings.TrimSpace(req.ServerTargetType) req.ServerGroupID = strings.TrimSpace(req.ServerGroupID) if req.ServerTargetType == "" { req.ServerTargetType = "server" } req.Name = strings.TrimSpace(req.Name) req.Description = strings.TrimSpace(req.Description) req.RemoteUsername = strings.TrimSpace(req.RemoteUsername) req.AuthMethod = normalizeSSHAccessAuthMethod(req.AuthMethod) req.SecondFactorMode = normalizeSSHAccessSecondFactorMode(req.SecondFactorMode) req.SSHUserCAID = strings.TrimSpace(req.SSHUserCAID) privateKeyPEM = strings.TrimSpace(req.PrivateKeyPEM) passwordText = strings.TrimSpace(req.PasswordText) targets, err = normalizeSSHAccessProfileTargets(req.Targets, !options.RequireTargets) if err != nil { return item, err } if req.Name == "" { return item, errors.New("name is required") } if req.ServerTargetType != "server" && req.ServerTargetType != "group" { return item, errors.New("server_target_type must be server or group") } if req.ServerTargetType == "server" && req.ServerID == "" { return item, errors.New("server_id is required") } if req.ServerTargetType == "group" && req.ServerGroupID == "" { return item, errors.New("server_group_id is required") } if req.RemoteUsername == "" { return item, errors.New("remote_username is required") } if req.AuthMethod == "" { return item, fmt.Errorf("auth_method must be one of %v", [4]string{ sshAuthMethodManagedSSHCert, sshAuthMethodPromptedPassword, sshAuthMethodStoredPassword, sshAuthMethodStoredPrivateKey}) } if options.SelfService && req.AuthMethod == sshAuthMethodManagedSSHCert { return item, fmt.Errorf("%s is not allowed for personal ssh access profiles", sshAuthMethodManagedSSHCert) } if req.AuthMethod == sshAuthMethodManagedSSHCert { item.SSHCredentialID = "" if req.SSHUserCAID == "" { return item, fmt.Errorf("ssh_user_ca_id is required for %s", sshAuthMethodManagedSSHCert) } if options.SelfService { _, err = api.store(r).GetSSHUserCAForUser(req.SSHUserCAID) } else { _, err = api.store(r).GetSSHUserCA(req.SSHUserCAID) } if err != nil { if err == sql.ErrNoRows { return item, errors.New("ssh_user_ca_id not found") } return item, err } req.SSHPrincipalGrantIDs = normalizeSSHBrokerNames(req.SSHPrincipalGrantIDs) if len(req.SSHPrincipalGrantIDs) == 0 { return item, errors.New("at least one ssh_principal_grant_id is required") } if options.SelfService { now = time.Now().UTC().Unix() activeGrants, err = api.store(r).ListActiveSSHPrincipalGrantsForUser(options.OwnerUserID, now) if err != nil { return item, err } allowedGrantIDs = map[string]bool{} for i = 0; i < len(activeGrants); i++ { allowedGrantIDs[activeGrants[i].ID] = true } for i = 0; i < len(req.SSHPrincipalGrantIDs); i++ { if !allowedGrantIDs[req.SSHPrincipalGrantIDs[i]] { return item, errors.New("ssh_principal_grant_id not available to user: " + req.SSHPrincipalGrantIDs[i]) } } } if privateKeyPEM == "" && !isUpdate { generatedKeyPEM, generatedPublicKey, generatedFingerprint, err = generateSSHUserCAKeyPair("ed25519") if err != nil { return item, err } privateKeyPEM = strings.TrimSpace(generatedKeyPEM) publicKey = strings.TrimSpace(generatedPublicKey) fingerprint = strings.TrimSpace(generatedFingerprint) } } else { req.SSHUserCAID = "" req.SSHPrincipalGrantIDs = []string{} if req.AuthMethod == sshAuthMethodStoredPrivateKey { req.SSHCredentialID = strings.TrimSpace(req.SSHCredentialID) if req.SSHCredentialID != "" { credential, err = api.store(r).GetSSHCredential(req.SSHCredentialID) if err != nil { if err == sql.ErrNoRows { return item, errors.New("ssh_credential_id not found") } return item, err } if !credential.Enabled { return item, errors.New("ssh_credential_id is disabled") } if options.SelfService { if credential.OwnerScope != "user" || credential.OwnerUserID != options.OwnerUserID { return item, errors.New("ssh_credential_id not available to user") } } else if credential.OwnerScope == "user" { return item, errors.New("user-owned ssh credential is not allowed for admin access profiles") } item.SSHCredentialID = credential.ID item.SecretID = credential.SecretID item.AuthPublicKey = credential.PublicKey item.AuthPublicKeyFingerprint = credential.Fingerprint privateKeyPEM = "" publicKey = credential.PublicKey fingerprint = credential.Fingerprint } else if privateKeyPEM == "" && !isUpdate { return item, fmt.Errorf("private_key_pem or ssh_credential_id is required for %s", sshAuthMethodStoredPrivateKey) } else if privateKeyPEM != "" { item.SSHCredentialID = "" } } else if req.AuthMethod == sshAuthMethodStoredPassword { item.SSHCredentialID = "" privateKeyPEM = "" publicKey = "" fingerprint = "" item.AuthPublicKey = "" item.AuthPublicKeyFingerprint = "" if passwordText == "" && !isUpdate { return item, fmt.Errorf("password_text is required for %s", sshAuthMethodStoredPassword) } } else if req.AuthMethod == sshAuthMethodPromptedPassword { item.SSHCredentialID = "" privateKeyPEM = "" passwordText = "" publicKey = "" fingerprint = "" item.SecretID = "" item.AuthPublicKey = "" item.AuthPublicKeyFingerprint = "" } else { return item, fmt.Errorf("unsupported auth_method: %s", req.AuthMethod) } } if privateKeyPEM != "" && publicKey == "" { signer, err = parseSSHSignerFromPEM(privateKeyPEM) if err != nil { return item, err } publicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(signer.PublicKey())) } if req.DefaultValidSeconds <= 0 { req.DefaultValidSeconds = 3600 } if req.MaxValidSeconds <= 0 { req.MaxValidSeconds = req.DefaultValidSeconds } if req.MaxValidSeconds < req.DefaultValidSeconds { return item, errors.New("max_valid_seconds must be greater than or equal to default_valid_seconds") } if req.ServerTargetType == "group" { var group models.SSHServerGroup group, err = api.store(r).GetSSHServerGroup(req.ServerGroupID) if err != nil { if err == sql.ErrNoRows { return item, errors.New("server_group_id not found") } return item, err } if !group.Enabled { return item, errors.New("server_group_id is disabled") } if len(group.ServerIDs) == 0 { return item, errors.New("server group has no servers") } req.ServerID = group.ServerIDs[0] item.ServerGroupID = group.ID } else { item.ServerGroupID = "" } if options.SelfService { _, err = api.store(r).GetOwnedSSHServerForUser(options.OwnerUserID, req.ServerID) if err != nil { if err == sql.ErrNoRows { return item, errors.New("server_id not found") } return item, err } } else { _, err = api.store(r).GetSSHServer(req.ServerID) if err != nil { if err == sql.ErrNoRows { return item, errors.New("server_id not found") } return item, err } } err = validateSSHAccessProfileTargets(api.store(r), targets) if err != nil { return item, err } item.ServerID = req.ServerID item.ServerTargetType = req.ServerTargetType item.Name = req.Name item.Description = req.Description item.RemoteUsername = req.RemoteUsername item.AuthMethod = req.AuthMethod item.SecondFactorMode = req.SecondFactorMode item.OwnerScope = options.OwnerScope item.OwnerUserID = options.OwnerUserID item.AllowUserEdit = options.AllowUserEdit item.Enabled = req.Enabled item.SSHUserCAID = req.SSHUserCAID item.SSHPrincipalGrantIDs = req.SSHPrincipalGrantIDs item.DefaultValidSeconds = req.DefaultValidSeconds item.MaxValidSeconds = req.MaxValidSeconds item.Targets = targets if !isUpdate { item.CreatedByKind = actorKind item.CreatedBySubjectID = actorID item.CreatedBySubjectName = actorName } if privateKeyPEM != "" { item.SecretPayload, err = api.encryptSSHSecretPayload(privateKeyPEM) if err != nil { return item, err } item.AuthPublicKey = publicKey item.AuthPublicKeyFingerprint = fingerprint } item.SecretPassword = passwordText return item, nil } func normalizeSSHAccessAuthMethod(raw string) string { var value string value = strings.ToLower(strings.TrimSpace(raw)) if value == sshAuthMethodManagedSSHCert || value == sshAuthMethodPromptedPassword || value == sshAuthMethodStoredPassword || value == sshAuthMethodStoredPrivateKey { return value } return "" } func normalizeSSHAccessSecondFactorMode(raw string) string { var value string value = strings.ToLower(strings.TrimSpace(raw)) if value == sshSecondFactorPromptedTOTP { return value } return sshSecondFactorNone } func normalizeSSHBrokerTags(raw []string) []string { return db.NormalizeSSHServerTags(raw) } func normalizeSSHServerHostKeyPolicy(value string) string { value = strings.TrimSpace(value) if value == "trust_on_first_use" { return value } return "strict" } func normalizeSSHBrokerNames(raw []string) []string { var out []string var seen map[string]bool var value string var i int seen = map[string]bool{} for i = 0; i < len(raw); i++ { value = strings.TrimSpace(raw[i]) if value == "" { continue } if seen[value] { continue } seen[value] = true out = append(out, value) } sort.Strings(out) return out } func normalizeSSHAccessProfileTargets(raw []sshAccessProfileTargetRequest, allowEmpty bool) ([]models.SSHAccessProfileTarget, error) { var out []models.SSHAccessProfileTarget var dedupe map[string]bool var targetType string var targetID string var key string var i int dedupe = map[string]bool{} for i = 0; i < len(raw); i++ { targetType = strings.ToLower(strings.TrimSpace(raw[i].TargetType)) targetID = strings.TrimSpace(raw[i].TargetID) if targetType != "user" && targetType != "group" { return nil, errors.New("target_type must be user or group") } if targetID == "" { return nil, errors.New("target_id is required") } key = targetType + ":" + targetID if dedupe[key] { continue } dedupe[key] = true out = append(out, models.SSHAccessProfileTarget{TargetType: targetType, TargetID: targetID}) } if len(out) == 0 && !allowEmpty { return nil, errors.New("at least one target is required") } return out, nil } func validateSSHAccessProfileTargets(store *db.Store, targets []models.SSHAccessProfileTarget) error { var i int var err error if len(targets) == 0 { return nil } for i = 0; i < len(targets); i++ { if targets[i].TargetType == "user" { _, err = store.GetUserByID(targets[i].TargetID) } else { _, err = store.GetUserGroup(targets[i].TargetID) } if err != nil { if err == sql.ErrNoRows { return errors.New(targets[i].TargetType + " target not found: " + targets[i].TargetID) } return err } } return nil } func currentSSHBrokerActor(r *http.Request) (string, string, string) { var user models.User var principal models.ServicePrincipal var ok bool user, ok = middleware.UserFromContext(r.Context()) if ok { return "user", user.ID, user.Username } principal, ok = middleware.PrincipalFromContext(r.Context()) if ok { return "service_principal", principal.ID, principal.Name } return "system", "", "system" } func normalizeSSHSessionTerm(raw string) string { var value string value = strings.TrimSpace(raw) if value == "" { return "xterm-256color" } return value } func normalizeSSHSessionCols(value int) int { if value <= 0 { return 120 } if value > 400 { return 400 } return value } func normalizeSSHSessionRows(value int) int { if value <= 0 { return 36 } if value > 200 { return 200 } return value } func parseSSHServerHostKey(raw string) (models.SSHServerHostKey, error) { var item models.SSHServerHostKey var key ssh.PublicKey var err error key, _, _, _, err = ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(raw))) if err != nil { return item, errors.New("invalid ssh public key") } item.Algorithm = key.Type() item.PublicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key))) item.Fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(key)) return item, nil } func discoverSSHServerHostKey(host string, port int, defaultUser string) (models.SSHServerHostKey, error) { var item models.SSHServerHostKey var addr string var conn net.Conn var clientConn ssh.Conn var reqs <-chan *ssh.Request var config *ssh.ClientConfig var discovered ssh.PublicKey var err error addr = net.JoinHostPort(strings.TrimSpace(host), fmt.Sprintf("%d", port)) defaultUser = strings.TrimSpace(defaultUser) config = &ssh.ClientConfig{ User: defaultUser, Auth: []ssh.AuthMethod{}, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { discovered = key return nil }, Timeout: 10 * time.Second, } conn, err = net.DialTimeout("tcp", addr, 10*time.Second) if err != nil { return item, err } defer conn.Close() clientConn, _, reqs, err = ssh.NewClientConn(conn, addr, config) if clientConn != nil { clientConn.Close() } if reqs != nil { go ssh.DiscardRequests(reqs) } if discovered == nil { return item, errors.New("failed to discover ssh host key") } item.Algorithm = discovered.Type() item.PublicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(discovered))) item.Fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(discovered)) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "unable to authenticate") { return item, nil } if strings.Contains(strings.ToLower(err.Error()), "no auth passed yet") { return item, nil } } return item, nil } func (api *API) prepareSSHSession(store *db.Store, user *models.User, sessionItem *models.SSHSession, promptedPassword string, otpCode string) (sshSessionPrepared, string, error) { var prepared sshSessionPrepared var err error prepared.Session = *sessionItem prepared.Profile, err = store.GetSSHAccessProfileForUser(user.ID, sessionItem.ProfileID) if err != nil { return prepared, "load_profile", err } if strings.TrimSpace(sessionItem.ServerID) != "" && sessionItem.ServerID != prepared.Profile.ServerID { prepared.Profile.Server, err = store.GetSSHServer(sessionItem.ServerID) if err != nil { return prepared, "load_server", err } prepared.Profile.ServerID = prepared.Profile.Server.ID } api.logSSHSessionPrepare(sessionItem, user, &prepared.Profile) prepared.HostKeys, err = store.ListSSHServerHostKeys(sessionItem.ServerID) if err != nil { return prepared, "load_host_keys", err } if len(prepared.HostKeys) == 0 { if normalizeSSHServerHostKeyPolicy(prepared.Profile.Server.HostKeyPolicy) != "trust_on_first_use" { return prepared, "host_key_missing", errors.New("no pinned host key for " + prepared.Profile.Server.Name) } prepared.PinHostKeyOnFirstUse = true } prepared.AuthMethods, err = api.buildSSHSessionAuth(store, prepared.Profile, promptedPassword, otpCode) if err != nil { return prepared, "build_auth", err } return prepared, "", nil } func (api *API) validatePreparedSSHSessionForUse(sessionItem *models.SSHSession, prepared *sshSessionPrepared) error { var i int if strings.TrimSpace(prepared.Session.ID) == "" { return errors.New("prepared ssh session is missing session id") } if prepared.Session.ID != sessionItem.ID { return fmt.Errorf("prepared ssh session id mismatch: prepared=%s session=%s", prepared.Session.ID, sessionItem.ID) } if strings.TrimSpace(sessionItem.ServerID) == "" { return errors.New("ssh session server_id is empty") } if strings.TrimSpace(prepared.Session.ServerID) == "" { return errors.New("prepared ssh session server_id is empty") } if prepared.Session.ServerID != sessionItem.ServerID { return fmt.Errorf("prepared ssh session server mismatch: prepared=%s session=%s", prepared.Session.ServerID, sessionItem.ServerID) } if strings.TrimSpace(prepared.Profile.ID) == "" { return errors.New("prepared ssh profile is missing profile id") } if prepared.Profile.ID != sessionItem.ProfileID { return fmt.Errorf("prepared ssh profile mismatch: prepared=%s session=%s", prepared.Profile.ID, sessionItem.ProfileID) } if prepared.Profile.ServerID != sessionItem.ServerID { return fmt.Errorf("prepared ssh profile server mismatch: prepared=%s session=%s", prepared.Profile.ServerID, sessionItem.ServerID) } if prepared.Profile.Server.ID != sessionItem.ServerID { return fmt.Errorf("prepared ssh server mismatch: prepared=%s session=%s", prepared.Profile.Server.ID, sessionItem.ServerID) } if strings.TrimSpace(sessionItem.Host) != "" && prepared.Profile.Server.Host != sessionItem.Host { return fmt.Errorf("prepared ssh target host mismatch: prepared=%s session=%s", prepared.Profile.Server.Host, sessionItem.Host) } if sessionItem.Port > 0 && prepared.Profile.Server.Port != sessionItem.Port { return fmt.Errorf("prepared ssh target port mismatch: prepared=%d session=%d", prepared.Profile.Server.Port, sessionItem.Port) } for i = 0; i < len(prepared.HostKeys); i++ { if prepared.HostKeys[i].ServerID != sessionItem.ServerID { return fmt.Errorf("prepared ssh host key server mismatch: prepared=%s session=%s", prepared.HostKeys[i].ServerID, sessionItem.ServerID) } } return nil } func (api *API) serveSSHWorkspaceStream(ws *websocket.Conn, store *db.Store, user models.User, remoteAddr string) { var sendMu sync.Mutex var controlCh chan sshSessionStreamMessage var controlErrCh chan error var attachments map[string]*sshWorkspaceAttachment var msg sshSessionStreamMessage var item models.SSHSession var attachment *sshWorkspaceAttachment var prepared sshSessionPrepared var preparedOK bool var err error var ok bool var attachedID string var attachmentDoneCh chan sshWorkspaceAttachmentDone var attachmentDone sshWorkspaceAttachmentDone var websocketCloseCode SSHSessionCloseCode var closing bool var closeWaitCh <-chan time.Time var closingSessionIDs map[string]bool var closeReason string var closeID string var closedItem models.SSHSession var now int64 defer ws.Close() controlCh = make(chan sshSessionStreamMessage, sshSessionInputBufferSize) controlErrCh = make(chan error, 1) attachmentDoneCh = make(chan sshWorkspaceAttachmentDone, sshSessionInputBufferSize) attachments = map[string]*sshWorkspaceAttachment{} closingSessionIDs = map[string]bool{} go api.readSSHSessionWebsocket(ws, controlCh, controlErrCh) event_loop: for { select { case <-closeWaitCh: api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "workspace stream close wait timeout requester=%s user=%s remaining_sessions=%d", remoteAddr, user.Username, len(attachments)) break event_loop case err = <-controlErrCh: if closing { continue } closing = true closeWaitCh = time.After(10 * time.Second) websocketCloseCode = sshSessionWebsocketCloseCode(err) closeReason = sshSessionCloseMessage(websocketCloseCode) api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "workspace stream closed requester=%s user=%s err=%v", remoteAddr, user.Username, err) for attachedID, attachment = range attachments { closingSessionIDs[attachedID] = true if api.SSHSessionRegistry != nil { api.SSHSessionRegistry.RequestClose(attachedID, websocketCloseCode) } select { case attachment.InputErrCh <- err: default: // make it non-blocking } } if len(attachments) == 0 { break event_loop } continue case attachmentDone = <-attachmentDoneCh: if attachments[attachmentDone.SessionID] == attachmentDone.Attachment { delete(attachments, attachmentDone.SessionID) } if closing && len(attachments) == 0 { break event_loop } case msg = <-controlCh: if closing { continue } if msg.Type == "attach" { var attachedItem models.SSHSession var attachedSessionID string var attachedPrepared sshSessionPrepared var attachedStatus string var attachedMessage string if strings.TrimSpace(msg.SessionID) == "" { api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{ SessionID: msg.SessionID, Type: "status", Status: "error", Message: "session_id is required", }) continue } if attachments[msg.SessionID] != nil { continue } // session exists item, err = store.GetSSHSession(msg.SessionID) if err != nil { api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{ SessionID: msg.SessionID, Type: "status", Status: "error", Message: "ssh session not found", }) continue } if item.UserID != user.ID { api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{ SessionID: msg.SessionID, Type: "status", Status: "error", Message: "ssh session not allowed", }) continue } prepared, preparedOK = api.SSHPreparedSessionStore.Get(msg.SessionID) if !preparedOK { attachedStatus = "closed" attachedMessage = item.Error if item.Status == "error" { attachedStatus = "error" } if attachedMessage == "" { if item.Status == "closed" { attachedMessage = "SSH session is closed" } else if item.Status == "error" { attachedMessage = "SSH session failed" } else { attachedMessage = "SSH session cannot be resumed. Connect again to start a new session." } } api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{ SessionID: msg.SessionID, Type: "status", Status: attachedStatus, Message: attachedMessage, }) continue } prepared, preparedOK = api.SSHPreparedSessionStore.Take(msg.SessionID) if !preparedOK { api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{ SessionID: msg.SessionID, Type: "status", Status: "error", Message: "ssh session preparation not found", }) continue } attachedPrepared = prepared attachment = &sshWorkspaceAttachment{ InputCh: make(chan sshSessionStreamMessage, sshSessionInputBufferSize), InputErrCh: make(chan error, 1), } attachments[msg.SessionID] = attachment api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "workspace attach session=%s requester=%s user=%s", msg.SessionID, remoteAddr, user.Username) attachedItem = item attachedSessionID = attachedItem.ID go func(runUser models.User, runItem models.SSHSession, runPrepared sshSessionPrepared, runSessionID string, runAttachment *sshWorkspaceAttachment) { // accept runUser, runItem, runPrepared by value because the values can mutate // in the main event loop if we reference them directly(closure or pointer) defer func() { select { case attachmentDoneCh <- sshWorkspaceAttachmentDone{SessionID: runSessionID, Attachment: runAttachment}: default: // non-block } }() api.runSSHSessionStream(store, &runUser, &runItem, nil, runAttachment.InputCh, runAttachment.InputErrCh, &runPrepared, func(status sshSessionStreamMessage) { var sendErr error status.SessionID = runSessionID sendErr = api.sendSSHSessionWS(&sendMu, ws, status) if sendErr != nil && api.SSHSessionRegistry != nil { api.SSHSessionRegistry.RequestClose(runSessionID, SSHSessionCloseWebsocketSendFailed) } }, func(data []byte) { var sendErr error sendErr = api.sendSSHWorkspaceOutput(&sendMu, ws, runSessionID, data) if sendErr != nil && api.SSHSessionRegistry != nil { api.SSHSessionRegistry.RequestClose(runSessionID, SSHSessionCloseWebsocketSendFailed) } }) }(user, attachedItem, attachedPrepared, attachedSessionID, attachment) continue } attachment, ok = attachments[msg.SessionID] if !ok { if msg.Type == "input" || msg.Type == "resize" || msg.Type == "detach" { continue } api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{ SessionID: msg.SessionID, Type: "status", Status: "error", Message: "ssh session not attached", }) continue } if msg.Type == "detach" { if api.SSHSessionRegistry != nil { api.SSHSessionRegistry.RequestClose(msg.SessionID, SSHSessionCloseDetached) } select { case attachment.InputErrCh <- io.EOF: default: // non-block } delete(attachments, msg.SessionID) continue } /* other message types. write to the channel */ select { case attachment.InputCh <- msg: default: api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "workspace input queue full session=%s requester=%s user=%s type=%s", msg.SessionID, remoteAddr, user.Username, msg.Type) if api.SSHSessionRegistry != nil { api.SSHSessionRegistry.RequestClose(msg.SessionID, SSHSessionCloseWebsocketInputQueueFull) } select { case attachment.InputErrCh <- errors.New("websocket input queue full"): default: // non-block } } } } if closing { now = time.Now().UTC().Unix() for closeID = range closingSessionIDs { closedItem, err = api.Store.GetSSHSession(closeID) if err != nil { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "workspace stream close status check failed session=%s requester=%s user=%s err=%v", closeID, remoteAddr, user.Username, err) continue } if closedItem.Status == "connected" || closedItem.Status == "connecting" { err = api.Store.UpdateSSHSessionStatus(closedItem.ID, "closed", closedItem.HostKeyFingerprint, closedItem.ConnectedAt, now, closeReason) if err != nil { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "workspace stream close status update failed session=%s requester=%s user=%s err=%v", closedItem.ID, remoteAddr, user.Username, err) } } } } } func (api *API) runSSHSessionStream(store *db.Store, user *models.User, sessionItem *models.SSHSession, transportWS *websocket.Conn, inputCh <-chan sshSessionStreamMessage, inputErrCh <-chan error, prepared *sshSessionPrepared, sendStatus func(sshSessionStreamMessage), sendBinary func([]byte)) { var profile models.SSHAccessProfile var authMethods []ssh.AuthMethod var hostKeys []models.SSHServerHostKey var sshConfig *ssh.ClientConfig var sshClient *ssh.Client var sshSession *ssh.Session var stdin io.WriteCloser var stdout io.Reader var stderr io.Reader var runtimeSession *sshActiveSession var closeReason string var closeCode SSHSessionCloseCode var err error var connectedFingerprint string var connectedHostKey ssh.PublicKey var connectedAt int64 var now int64 var pinnedHostKey models.SSHServerHostKey var reloadedHostKeys []models.SSHServerHostKey var i int var shellErrCh chan error var done bool var waitErr error var msg sshSessionStreamMessage var inputOk bool var loopExitSource string var promptedPassword string var promptedOTPCode string var prepareStage string var transcriptRecorder *sshSessionTranscriptRecorder var transcriptSendBinary func([]byte) var promptedAuthInput SSHPromptedAuthInput var runtimeStartedAt time.Time runtimeStartedAt = time.Now() defer func() { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "session runtime exited session=%s requester=%s user=%s dur_ms=%d", sessionItem.ID, sessionItem.RemoteAddr, user.Username, time.Since(runtimeStartedAt).Milliseconds()) }() runtimeSession = &sshActiveSession{} runtimeSession.SetResources(transportWS, nil, nil, nil) if api.SSHSessionRegistry != nil { err = api.SSHSessionRegistry.Register(sessionItem.ID, runtimeSession) if err != nil { if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error()) return } defer api.SSHSessionRegistry.Unregister(sessionItem.ID, runtimeSession) } if api.SSHPromptedAuthStore != nil { promptedAuthInput = api.SSHPromptedAuthStore.Take(sessionItem.ID) api.SSHPromptedAuthStore.Delete(sessionItem.ID) promptedPassword = promptedAuthInput.Password promptedOTPCode = promptedAuthInput.OTPCode } if prepared != nil { err = api.validatePreparedSSHSessionForUse(sessionItem, prepared) if err != nil { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "prepared session rejected session=%s requester=%s user=%s err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error()) return } profile = prepared.Profile hostKeys = prepared.HostKeys authMethods = prepared.AuthMethods api.logSSHSessionPreparedUse(sessionItem, user, &profile) } else { prepared = &sshSessionPrepared{} *prepared, prepareStage, err = api.prepareSSHSession(store, user, sessionItem, promptedPassword, promptedOTPCode) if err != nil { if prepareStage == "load_profile" { api.logSSHSessionPrepareFailure(sessionItem, user, prepareStage, err) } else if prepared.Profile.ID != "" { api.logSSHSessionConnectFailure(sessionItem, &prepared.Profile, user, prepareStage, err) } else { api.logSSHSessionPrepareFailure(sessionItem, user, prepareStage, err) } if sendStatus != nil { if prepareStage == "load_profile" { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: "ssh access profile not found"}) } else { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } } now = time.Now().UTC().Unix() if prepareStage == "load_profile" { _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, "ssh access profile not found") } else { _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error()) } return } profile = prepared.Profile hostKeys = prepared.HostKeys authMethods = prepared.AuthMethods } closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", "", 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectStart(sessionItem, &profile, user) now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "connecting", "", 0, 0, "") if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "connecting", Message: "connecting"}) } sshConfig = &ssh.ClientConfig{ User: profile.RemoteUsername, Auth: authMethods, HostKeyCallback: sshHostKeyCallback(hostKeys, &connectedFingerprint, &connectedHostKey, prepared.PinHostKeyOnFirstUse), Timeout: 15 * time.Second, } sshClient, err = ssh.Dial("tcp", net.JoinHostPort(profile.Server.Host, fmt.Sprintf("%d", profile.Server.Port)), sshConfig) if err != nil { api.logSSHSessionConnectFailure(sessionItem, &profile, user, "dial", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } defer sshClient.Close() if prepared.PinHostKeyOnFirstUse { if connectedHostKey == nil { err = errors.New("failed to capture ssh host key") api.logSSHSessionConnectFailure(sessionItem, &profile, user, "pin_host_key", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } pinnedHostKey = models.SSHServerHostKey{ ServerID: sessionItem.ServerID, Algorithm: connectedHostKey.Type(), PublicKey: strings.TrimSpace(string(ssh.MarshalAuthorizedKey(connectedHostKey))), Fingerprint: connectedFingerprint, } pinnedHostKey, err = store.CreateSSHServerHostKey(pinnedHostKey) if err != nil { pinnedHostKey = models.SSHServerHostKey{} reloadedHostKeys, err = store.ListSSHServerHostKeys(sessionItem.ServerID) if err == nil { for i = 0; i < len(reloadedHostKeys); i++ { if strings.TrimSpace(reloadedHostKeys[i].Fingerprint) == connectedFingerprint { pinnedHostKey = reloadedHostKeys[i] break } } } if pinnedHostKey.ID == "" { api.logSSHSessionConnectFailure(sessionItem, &profile, user, "pin_host_key", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: "failed to pin host key: " + err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, "failed to pin host key: " + err.Error()) return } } api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "host key pinned session=%s server=%s server_name=%q fingerprint=%s algorithm=%s", sessionItem.ID, profile.Server.ID, profile.Server.Name, pinnedHostKey.Fingerprint, pinnedHostKey.Algorithm) } runtimeSession.SetResources(transportWS, sshClient, nil, nil) closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } sshSession, err = sshClient.NewSession() if err != nil { closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectFailure(sessionItem, &profile, user, "new_session", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } defer sshSession.Close() runtimeSession.SetResources(transportWS, sshClient, sshSession, nil) closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } stdin, err = sshSession.StdinPipe() if err != nil { closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectFailure(sessionItem, &profile, user, "stdin_pipe", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } stdout, err = sshSession.StdoutPipe() if err != nil { closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectFailure(sessionItem, &profile, user, "stdout_pipe", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } stderr, err = sshSession.StderrPipe() if err != nil { closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectFailure(sessionItem, &profile, user, "stderr_pipe", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } runtimeSession.SetResources(transportWS, sshClient, sshSession, stdin) closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } err = sshSession.RequestPty(sessionItem.RequestedTerm, sessionItem.RequestedRows, sessionItem.RequestedCols, ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, }) if err != nil { closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectFailure(sessionItem, &profile, user, "request_pty", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } err = sshSession.Shell() if err != nil { closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } api.logSSHSessionConnectFailure(sessionItem, &profile, user, "shell", err) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()}) } now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error()) return } closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { now = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, 0, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } connectedAt = time.Now().UTC().Unix() _ = store.UpdateSSHSessionStatus(sessionItem.ID, "connected", connectedFingerprint, connectedAt, 0, "") api.logSSHSessionConnectSuccess(sessionItem, &profile, user, connectedFingerprint) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "connected", Message: connectedFingerprint}) } transcriptRecorder, err = api.openSSHSessionTranscriptRecorder(sessionItem.ID) if err != nil { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "transcript open failed session=%s err=%v", sessionItem.ID, err) } if transcriptRecorder != nil { defer transcriptRecorder.Close() } transcriptSendBinary = func(data []byte) { api.writeSSHSessionTranscriptChunk(sessionItem.ID, data, transcriptRecorder) if sendBinary != nil { sendBinary(data) } } shellErrCh = make(chan error, 1) go api.pipeSSHSessionOutput(stdout, transcriptSendBinary) // read stdout and write to websocket go api.pipeSSHSessionOutput(stderr, transcriptSendBinary) // read stdout and write to websocket go func() { var shellErr error shellErr = sshSession.Wait() shellErrCh <- shellErr }() done = false waitErr = nil loopExitSource = "" for !done { select { case err = <-inputErrCh: waitErr = err loopExitSource = "input_err" done = true case waitErr = <-shellErrCh: loopExitSource = "shell_wait" done = true case msg, inputOk = <-inputCh: if !inputOk { loopExitSource = "input_closed" done = true break } if msg.Type == "input" { _, err = io.WriteString(stdin, msg.Data) if err != nil { waitErr = err loopExitSource = "stdin_write" done = true } } else if msg.Type == "resize" { _ = sshSession.WindowChange(normalizeSSHSessionRows(msg.Rows), normalizeSSHSessionCols(msg.Cols)) } } } api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "session loop ended session=%s requester=%s user=%s profile=%s profile_name=%q source=%s wait_err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, loopExitSource, waitErr) _ = stdin.Close() now = time.Now().UTC().Unix() closeReason = runtimeSession.CloseReason() closeCode = runtimeSession.CloseCode() if closeReason != "" { _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, connectedAt, now, closeReason) api.logSSHSessionClosed(sessionItem, &profile, user, "closed", closeReason, connectedAt, now) if sendStatus != nil && !isSSHSessionTransportCloseCode(closeCode) { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason}) } return } if waitErr != nil && !isSSHSessionTerminalExit(waitErr) { _ = store.UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, connectedAt, now, waitErr.Error()) api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "session error session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s dur_s=%d err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, profile.ServerID, profile.Server.Name, profile.Server.Host, profile.Server.Port, profile.RemoteUsername, profile.AuthMethod, now-connectedAt, waitErr) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "error", Message: waitErr.Error()}) } return } if isSSHSessionTerminalExit(waitErr) { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO, "session ended session=%s requester=%s user=%s profile=%s profile_name=%q err=%v", sessionItem.ID, sessionItem.RemoteAddr, user.Username, profile.ID, profile.Name, waitErr) } _ = store.UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, connectedAt, now, "") api.logSSHSessionClosed(sessionItem, &profile, user, "closed", "", connectedAt, now) if sendStatus != nil { sendStatus(sshSessionStreamMessage{Type: "status", Status: "closed", Message: ""}) } } func buildSSHKeyboardInteractiveAnswers(password string, otpCode string, questions []string) []string { var answers []string var prompt string var i int for i = 0; i < len(questions); i++ { prompt = strings.ToLower(strings.TrimSpace(questions[i])) if strings.Contains(prompt, "password") { answers = append(answers, password) continue } if strings.Contains(prompt, "verification") || strings.Contains(prompt, "authenticator") || strings.Contains(prompt, "otp") || strings.Contains(prompt, "token") || strings.Contains(prompt, "code") { answers = append(answers, otpCode) continue } if otpCode != "" && len(questions) == 1 { answers = append(answers, otpCode) continue } if password != "" && len(questions) == 1 { answers = append(answers, password) continue } answers = append(answers, "") } return answers } func isSSHSessionTerminalExit(err error) bool { var exitErr *ssh.ExitError var missingErr *ssh.ExitMissingError if err == nil { return false } if errors.Is(err, io.EOF) { return true } if errors.As(err, &exitErr) { return true } if errors.As(err, &missingErr) { return true } return false } func buildSSHPromptedTOTPAuthMethod(password string, otpCode string) ssh.AuthMethod { var challenge ssh.KeyboardInteractiveChallenge challenge = func(user string, instruction string, questions []string, echos []bool) ([]string, error) { return buildSSHKeyboardInteractiveAnswers(password, otpCode, questions), nil } return ssh.KeyboardInteractive(challenge) } func appendSSHSecondFactorAuthMethods(methods []ssh.AuthMethod, profile models.SSHAccessProfile, password string, otpCode string) ([]ssh.AuthMethod, error) { if profile.SecondFactorMode == sshSecondFactorNone { return methods, nil } if profile.SecondFactorMode != sshSecondFactorPromptedTOTP { return nil, fmt.Errorf("unsupported second_factor_mode: %s", profile.SecondFactorMode) } if strings.TrimSpace(otpCode) == "" { return nil, fmt.Errorf("otp_code is required for %s", sshSecondFactorPromptedTOTP) } methods = append(methods, buildSSHPromptedTOTPAuthMethod(password, otpCode)) return methods, nil } func (api *API) buildSSHSessionAuth(store *db.Store, profile models.SSHAccessProfile, promptedPassword string, otpCode string) ([]ssh.AuthMethod, error) { var methods []ssh.AuthMethod var secret models.SSHSecret var credential models.SSHCredential var privateKeyPEM string var signer ssh.Signer var ca models.SSHUserCA var signed sshSignUserKeyResponse var certKey ssh.PublicKey var cert *ssh.Certificate var certSigner ssh.Signer var principals []string var err error if profile.AuthMethod == sshAuthMethodPromptedPassword { if promptedPassword == "" { return nil, fmt.Errorf("password is required for %s", sshAuthMethodPromptedPassword) } methods = []ssh.AuthMethod{ssh.Password(promptedPassword)} return appendSSHSecondFactorAuthMethods(methods, profile, promptedPassword, otpCode) } if strings.TrimSpace(profile.SSHCredentialID) != "" { credential, err = store.GetSSHCredential(profile.SSHCredentialID) if err != nil { return nil, err } if !credential.Enabled { return nil, errors.New("ssh credential is disabled") } profile.SecretID = credential.SecretID } if strings.TrimSpace(profile.SecretID) == "" { return nil, errors.New("ssh access profile has no secret") } secret, err = store.GetSSHSecret(profile.SecretID) if err != nil { return nil, err } if profile.AuthMethod == sshAuthMethodStoredPassword { methods = []ssh.AuthMethod{ssh.Password(secret.Password)} return appendSSHSecondFactorAuthMethods(methods, profile, secret.Password, otpCode) } privateKeyPEM, err = api.decryptSSHSecretPayload(secret.Payload) if err != nil { return nil, err } signer, err = parseSSHSignerFromPEM(privateKeyPEM) if err != nil { return nil, err } if profile.AuthMethod == sshAuthMethodStoredPrivateKey { methods = []ssh.AuthMethod{ssh.PublicKeys(signer)} return appendSSHSecondFactorAuthMethods(methods, profile, promptedPassword, otpCode) } ca, err = store.GetSSHUserCA(profile.SSHUserCAID) if err != nil { return nil, err } principals, err = api.resolveSSHAccessProfilePrincipals(store, profile) if err != nil { return nil, err } signed, err = api.signSSHUserKeyWithCA(store, ca, profile.AuthPublicKey, "", principals, profile.DefaultValidSeconds, profile.MaxValidSeconds) if err != nil { return nil, err } certKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(signed.Certificate))) if err != nil { return nil, err } cert, _ = certKey.(*ssh.Certificate) if cert == nil { return nil, errors.New("signed ssh certificate is invalid") } certSigner, err = ssh.NewCertSigner(cert, signer) if err != nil { return nil, err } methods = []ssh.AuthMethod{ssh.PublicKeys(certSigner)} return appendSSHSecondFactorAuthMethods(methods, profile, promptedPassword, otpCode) } func (api *API) resolveSSHAccessProfilePrincipals(store *db.Store, profile models.SSHAccessProfile) ([]string, error) { var principals []string var grant models.SSHPrincipalGrant var seen map[string]bool var i int var j int var err error seen = map[string]bool{} for i = 0; i < len(profile.SSHPrincipalGrantIDs); i++ { grant, err = store.GetSSHPrincipalGrant(profile.SSHPrincipalGrantIDs[i]) if err != nil { return nil, err } for j = 0; j < len(grant.Principals); j++ { if seen[grant.Principals[j]] { continue} seen[grant.Principals[j]] = true principals = append(principals, grant.Principals[j]) } } sort.Strings(principals) if len(principals) == 0 { return nil, errors.New("no ssh principals resolved for access profile") } return principals, nil } func sshHostKeyCallback(items []models.SSHServerHostKey, connectedFingerprint *string, connectedHostKey *ssh.PublicKey, trustOnFirstUse bool) ssh.HostKeyCallback { var allowed map[string]bool var i int allowed = map[string]bool{} for i = 0; i < len(items); i++ { allowed[strings.TrimSpace(items[i].Fingerprint)] = true } return func(hostname string, remote net.Addr, key ssh.PublicKey) error { var fingerprint string fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(key)) if connectedFingerprint != nil { *connectedFingerprint = fingerprint } if connectedHostKey != nil { *connectedHostKey = key } if allowed[fingerprint] { return nil} if trustOnFirstUse && len(items) == 0 { return nil } return errors.New("host key mismatch: " + fingerprint) } } func (api *API) pipeSSHSessionOutput(reader io.Reader, sendBinary func([]byte)) { var buffer []byte var chunk []byte var n int var err error buffer = make([]byte, 4096) for { n, err = reader.Read(buffer) if n > 0 { chunk = make([]byte, n) copy(chunk, buffer[:n]) if sendBinary != nil { sendBinary(chunk) } } if err != nil { return } } } func (api *API) readSSHSessionWebsocket(ws *websocket.Conn, inputCh chan<- sshSessionStreamMessage, errCh chan<- error) { var msg sshSessionStreamMessage var err error for { msg = sshSessionStreamMessage{} err = websocket.JSON.Receive(ws, &msg) if err != nil { errCh <- err return } select { case inputCh <- msg: default: errCh <- errors.New("websocket input queue full") return } } } func sshSessionWebsocketCloseCode(err error) SSHSessionCloseCode { if err != nil && strings.Contains(err.Error(), "input queue full") { return SSHSessionCloseWebsocketInputQueueFull } return SSHSessionCloseWebsocketClosed } func isSSHSessionTransportCloseCode(code SSHSessionCloseCode) bool { if code == SSHSessionCloseWebsocketClosed { return true } if code == SSHSessionCloseWebsocketSendFailed { return true } if code == SSHSessionCloseWebsocketInputQueueFull { return true } if code == SSHSessionCloseDetached { return true } return false } func (api *API) sendSSHSessionWS(mu *sync.Mutex, ws *websocket.Conn, msg sshSessionStreamMessage) error { var err error if mu == nil || ws == nil { return nil } mu.Lock() defer mu.Unlock() _ = ws.SetWriteDeadline(time.Now().Add(sshWebsocketWriteTimeout)) err = websocket.JSON.Send(ws, msg) _ = ws.SetWriteDeadline(time.Time{}) if err != nil { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "websocket json send failed session=%s type=%s err=%v", msg.SessionID, msg.Type, err) } return err } func (api *API) sendSSHSessionBinary(mu *sync.Mutex, ws *websocket.Conn, data []byte) error { var err error if mu == nil || ws == nil { return nil } mu.Lock() defer mu.Unlock() _ = ws.SetWriteDeadline(time.Now().Add(sshWebsocketWriteTimeout)) err = websocket.Message.Send(ws, data) _ = ws.SetWriteDeadline(time.Time{}) if err != nil { api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "websocket binary send failed bytes=%d err=%v", len(data), err) } return err } func (api *API) sendSSHWorkspaceOutput(mu *sync.Mutex, ws *websocket.Conn, sessionID string, data []byte) error { var err error err = api.sendSSHSessionWS(mu, ws, sshSessionStreamMessage{ SessionID: sessionID, Type: "output", Data: base64.StdEncoding.EncodeToString(data), }) return err }