Files
codit/backend/internal/handlers/ssh_broker.go

1999 lines
62 KiB
Go

package handlers
import "database/sql"
import "errors"
import "fmt"
import "io"
import "net"
import "net/http"
import "sort"
import "strings"
import "sync"
import "time"
import "golang.org/x/crypto/ssh"
import "golang.org/x/net/websocket"
import "codit/internal/db"
import "codit/internal/middleware"
import "codit/internal/models"
import "codit/internal/util"
const logIDSSHBroker string = "ssh-broker"
type sshServerRequest struct {
Name string `json:"name"`
Host string `json:"host"`
Port int `json:"port"`
Description string `json:"description"`
Tags []string `json:"tags"`
Enabled bool `json:"enabled"`
}
type sshAccessProfileTargetRequest struct {
TargetType string `json:"target_type"`
TargetID string `json:"target_id"`
}
type sshAccessProfileRequest struct {
ServerID string `json:"server_id"`
Name string `json:"name"`
Description string `json:"description"`
RemoteUsername string `json:"remote_username"`
AuthMethod string `json:"auth_method"`
Enabled bool `json:"enabled"`
PrivateKeyPEM string `json:"private_key_pem"`
SSHUserCAID string `json:"ssh_user_ca_id"`
SSHPrincipalMode string `json:"ssh_principal_mode"`
SSHPrincipals []string `json:"ssh_principals"`
SSHPrincipalGrantIDs []string `json:"ssh_principal_grant_ids"`
DefaultValidSeconds int64 `json:"default_valid_seconds"`
MaxValidSeconds int64 `json:"max_valid_seconds"`
Targets []sshAccessProfileTargetRequest `json:"targets"`
}
type sshAccessProfileNormalizeOptions struct {
OwnerScope string
OwnerUserID string
AllowUserEdit bool
RequireTargets bool
SelfService bool
}
type sshServerHostKeyRequest struct {
PublicKey string `json:"public_key"`
}
type sshSessionConnectRequest struct {
Cols int `json:"cols"`
Rows int `json:"rows"`
ValidSeconds int64 `json:"valid_seconds"`
Term string `json:"term"`
}
type sshSessionConnectResponse struct {
SessionID string `json:"session_id"`
Status string `json:"status"`
WebSocketPath string `json:"websocket_path"`
ServerName string `json:"server_name"`
Host string `json:"host"`
Port int `json:"port"`
RemoteUsername string `json:"remote_username"`
HostKeyFingerprint string `json:"host_key_fingerprint"`
}
type sshSessionStreamMessage struct {
Type string `json:"type"`
Data string `json:"data"`
Cols int `json:"cols"`
Rows int `json:"rows"`
Status string `json:"status"`
Message string `json:"message"`
}
func (api *API) logSSHSessionConnectStart(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User) {
api.Logger.Write(logIDSSHBroker, util.LOG_INFO,
"connect start session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s",
sessionItem.ID,
sessionItem.RemoteAddr,
user.Username,
profile.ID,
profile.Name,
profile.ServerID,
profile.Server.Name,
profile.Server.Host,
profile.Server.Port,
profile.RemoteUsername,
profile.AuthMethod)
}
func (api *API) logSSHSessionPrepare(sessionID string, remoteAddr string, user models.User, profile models.SSHAccessProfile) {
api.Logger.Write(logIDSSHBroker, util.LOG_INFO,
"connect prepare session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s",
sessionID,
remoteAddr,
user.Username,
profile.ID,
profile.Name,
profile.ServerID,
profile.Server.Name,
profile.Server.Host,
profile.Server.Port,
profile.RemoteUsername,
profile.AuthMethod)
}
func (api *API) logSSHSessionConnectSuccess(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User, hostKeyFingerprint string) {
api.Logger.Write(logIDSSHBroker, util.LOG_INFO,
"connect success session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s host_key=%s",
sessionItem.ID,
sessionItem.RemoteAddr,
user.Username,
profile.ID,
profile.Name,
profile.ServerID,
profile.Server.Name,
profile.Server.Host,
profile.Server.Port,
profile.RemoteUsername,
profile.AuthMethod,
hostKeyFingerprint)
}
func (api *API) logSSHSessionConnectFailure(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User, stage string, err error) {
api.Logger.Write(logIDSSHBroker, util.LOG_WARN,
"connect failed session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s stage=%s err=%v",
sessionItem.ID,
sessionItem.RemoteAddr,
user.Username,
profile.ID,
profile.Name,
profile.ServerID,
profile.Server.Name,
profile.Server.Host,
profile.Server.Port,
profile.RemoteUsername,
profile.AuthMethod,
stage,
err)
}
func (api *API) logSSHSessionPrepareFailure(sessionID string, remoteAddr string, username string, profileID string, stage string, err error) {
api.Logger.Write(logIDSSHBroker, util.LOG_WARN,
"connect failed session=%s requester=%s user=%s profile=%s stage=%s err=%v",
sessionID,
remoteAddr,
username,
profileID,
stage,
err)
}
func (api *API) logSSHSessionClosed(sessionItem models.SSHSession, profile models.SSHAccessProfile, user models.User, status string, reason string, connectedAt int64, endedAt int64) {
var durationSeconds int64
durationSeconds = 0
if connectedAt > 0 && endedAt >= connectedAt {
durationSeconds = endedAt - connectedAt
}
api.Logger.Write(logIDSSHBroker, util.LOG_INFO,
"session closed session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s status=%s dur_s=%d reason=%q",
sessionItem.ID,
sessionItem.RemoteAddr,
user.Username,
profile.ID,
profile.Name,
profile.ServerID,
profile.Server.Name,
profile.Server.Host,
profile.Server.Port,
profile.RemoteUsername,
profile.AuthMethod,
status,
durationSeconds,
reason)
}
func (api *API) ListSSHServersAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var items []models.SSHServer
var err error
if !api.requireAdmin(w, r) {
return
}
items, err = api.store(r).ListSSHServers()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) GetSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var item models.SSHServer
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHServer(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) CreateSSHServerAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var req sshServerRequest
var item models.SSHServer
var actorKind string
var actorID string
var actorName string
var err error
if !api.requireAdmin(w, r) {
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
req.Name = strings.TrimSpace(req.Name)
req.Host = strings.TrimSpace(req.Host)
req.Description = strings.TrimSpace(req.Description)
if req.Name == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
return
}
if req.Host == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"})
return
}
if req.Port <= 0 || req.Port > 65535 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"})
return
}
actorKind, actorID, actorName = currentSSHBrokerActor(r)
item = models.SSHServer{
Name: req.Name,
Host: req.Host,
Port: req.Port,
Description: req.Description,
Tags: normalizeSSHBrokerTags(req.Tags),
Enabled: req.Enabled,
CreatedByKind: actorKind,
CreatedBySubjectID: actorID,
CreatedBySubjectName: actorName,
}
item, err = api.store(r).CreateSSHServer(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusCreated, item)
}
func (api *API) UpdateSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var req sshServerRequest
var item models.SSHServer
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHServer(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
req.Name = strings.TrimSpace(req.Name)
req.Host = strings.TrimSpace(req.Host)
req.Description = strings.TrimSpace(req.Description)
if req.Name == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
return
}
if req.Host == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"})
return
}
if req.Port <= 0 || req.Port > 65535 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"})
return
}
item.Name = req.Name
item.Host = req.Host
item.Port = req.Port
item.Description = req.Description
item.Tags = normalizeSSHBrokerTags(req.Tags)
item.Enabled = req.Enabled
item, err = api.store(r).UpdateSSHServer(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) DeleteSSHServerAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var err error
if !api.requireAdmin(w, r) {
return
}
err = api.store(r).DeleteSSHServer(params["id"])
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) ListSSHAccessProfilesAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var items []models.SSHAccessProfile
var err error
if !api.requireAdmin(w, r) {
return
}
items, err = api.store(r).ListSSHAccessProfiles()
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) GetSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var item models.SSHAccessProfile
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHAccessProfile(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) CreateSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var req sshAccessProfileRequest
var item models.SSHAccessProfile
var actorKind string
var actorID string
var actorName string
var options sshAccessProfileNormalizeOptions
var err error
if !api.requireAdmin(w, r) {
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
actorKind, actorID, actorName = currentSSHBrokerActor(r)
options = sshAccessProfileNormalizeOptions{
OwnerScope: "admin_shared",
OwnerUserID: "",
AllowUserEdit: false,
RequireTargets: true,
SelfService: false,
}
item, err = api.normalizeSSHAccessProfileRequest(r, req, models.SSHAccessProfile{}, actorKind, actorID, actorName, false, options)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
item, err = api.store(r).CreateSSHAccessProfile(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusCreated, item)
}
func (api *API) UpdateSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var req sshAccessProfileRequest
var existing models.SSHAccessProfile
var item models.SSHAccessProfile
var actorKind string
var actorID string
var actorName string
var options sshAccessProfileNormalizeOptions
var err error
if !api.requireAdmin(w, r) {
return
}
existing, err = api.store(r).GetSSHAccessProfile(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
actorKind, actorID, actorName = currentSSHBrokerActor(r)
options = sshAccessProfileNormalizeOptions{
OwnerScope: "admin_shared",
OwnerUserID: "",
AllowUserEdit: false,
RequireTargets: true,
SelfService: false,
}
item, err = api.normalizeSSHAccessProfileRequest(r, req, existing, actorKind, actorID, actorName, true, options)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
item, err = api.store(r).UpdateSSHAccessProfile(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) DeleteSSHAccessProfileAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var err error
if !api.requireAdmin(w, r) {
return
}
err = api.store(r).DeleteSSHAccessProfile(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) ListSSHAccessProfilesForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var user models.User
var ok bool
var items []models.SSHAccessProfile
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
items, err = api.store(r).ListSSHAccessProfilesForUser(user.ID)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) GetSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHAccessProfile
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) ListSSHServersForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var user models.User
var ok bool
var items []models.SSHServer
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
items, err = api.store(r).ListSSHServersForUser(user.ID)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) CreateSSHServerForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var user models.User
var ok bool
var req sshServerRequest
var item models.SSHServer
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
req.Name = strings.TrimSpace(req.Name)
req.Host = strings.TrimSpace(req.Host)
req.Description = strings.TrimSpace(req.Description)
if req.Name == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
return
}
if req.Host == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"})
return
}
if req.Port <= 0 || req.Port > 65535 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"})
return
}
item = models.SSHServer{
Name: req.Name,
Host: req.Host,
Port: req.Port,
Description: req.Description,
Tags: normalizeSSHBrokerTags(req.Tags),
Enabled: req.Enabled,
CreatedByKind: "user",
CreatedBySubjectID: user.ID,
CreatedBySubjectName: user.Username,
Editable: true,
}
item, err = api.store(r).CreateSSHServer(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
item.Editable = true
WriteJSON(w, http.StatusCreated, item)
}
func (api *API) UpdateSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var req sshServerRequest
var item models.SSHServer
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
req.Name = strings.TrimSpace(req.Name)
req.Host = strings.TrimSpace(req.Host)
req.Description = strings.TrimSpace(req.Description)
if req.Name == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "name is required"})
return
}
if req.Host == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "host is required"})
return
}
if req.Port <= 0 || req.Port > 65535 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "port must be between 1 and 65535"})
return
}
item.Name = req.Name
item.Host = req.Host
item.Port = req.Port
item.Description = req.Description
item.Tags = normalizeSSHBrokerTags(req.Tags)
item.Enabled = req.Enabled
item, err = api.store(r).UpdateSSHServer(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
item.Editable = true
WriteJSON(w, http.StatusOK, item)
}
func (api *API) DeleteSSHServerForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
_, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = api.store(r).DeleteSSHServer(params["id"])
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) ListSSHServerHostKeysForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHServer
var items []models.SSHServerHostKey
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
items, err = api.store(r).ListSSHServerHostKeys(item.ID)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) DiscoverSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHServer
var discovered models.SSHServerHostKey
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
discovered, err = discoverSSHServerHostKey(item.Host, item.Port)
if err != nil {
WriteJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
return
}
discovered.ServerID = item.ID
WriteJSON(w, http.StatusOK, discovered)
}
func (api *API) CreateSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHServer
var req sshServerHostKeyRequest
var hostKey models.SSHServerHostKey
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
hostKey, err = parseSSHServerHostKey(strings.TrimSpace(req.PublicKey))
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
hostKey.ServerID = item.ID
hostKey, err = api.store(r).CreateSSHServerHostKey(hostKey)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusCreated, hostKey)
}
func (api *API) DeleteSSHServerHostKeyForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHServer
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetOwnedSSHServerForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = api.store(r).DeleteSSHServerHostKey(item.ID, params["hostKeyId"])
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) CreateSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var user models.User
var ok bool
var req sshAccessProfileRequest
var item models.SSHAccessProfile
var options sshAccessProfileNormalizeOptions
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
options = sshAccessProfileNormalizeOptions{
OwnerScope: "user",
OwnerUserID: user.ID,
AllowUserEdit: true,
RequireTargets: false,
SelfService: true,
}
item, err = api.normalizeSSHAccessProfileRequest(r, req, models.SSHAccessProfile{}, "user", user.ID, user.Username, false, options)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
item, err = api.store(r).CreateSSHAccessProfile(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusCreated, item)
}
func (api *API) UpdateSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var req sshAccessProfileRequest
var existing models.SSHAccessProfile
var item models.SSHAccessProfile
var options sshAccessProfileNormalizeOptions
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
existing, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if existing.OwnerScope != "user" || existing.OwnerUserID != user.ID {
w.WriteHeader(http.StatusForbidden)
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
options = sshAccessProfileNormalizeOptions{
OwnerScope: "user",
OwnerUserID: user.ID,
AllowUserEdit: true,
RequireTargets: false,
SelfService: true,
}
item, err = api.normalizeSSHAccessProfileRequest(r, req, existing, "user", user.ID, user.Username, true, options)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
item, err = api.store(r).UpdateSSHAccessProfile(item)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) DeleteSSHAccessProfileForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHAccessProfile
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if item.OwnerScope != "user" || item.OwnerUserID != user.ID {
w.WriteHeader(http.StatusForbidden)
return
}
err = api.store(r).DeleteSSHAccessProfile(item.ID)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) ListSSHServerHostKeysAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var item models.SSHServer
var items []models.SSHServerHostKey
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHServer(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
items, err = api.store(r).ListSSHServerHostKeys(item.ID)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) DiscoverSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var item models.SSHServer
var discovered models.SSHServerHostKey
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHServer(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
discovered, err = discoverSSHServerHostKey(item.Host, item.Port)
if err != nil {
WriteJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
return
}
discovered.ServerID = item.ID
WriteJSON(w, http.StatusOK, discovered)
}
func (api *API) CreateSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var item models.SSHServer
var req sshServerHostKeyRequest
var hostKey models.SSHServerHostKey
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHServer(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh server not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
hostKey, err = parseSSHServerHostKey(strings.TrimSpace(req.PublicKey))
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
hostKey.ServerID = item.ID
hostKey, err = api.store(r).CreateSSHServerHostKey(hostKey)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusCreated, hostKey)
}
func (api *API) DeleteSSHServerHostKeyAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var err error
if !api.requireAdmin(w, r) {
return
}
err = api.store(r).DeleteSSHServerHostKey(params["id"], params["hostKeyId"])
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) CreateSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var req sshSessionConnectRequest
var profile models.SSHAccessProfile
var item models.SSHSession
var response sshSessionConnectResponse
var hostKeys []models.SSHServerHostKey
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh access profile not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
hostKeys, err = api.store(r).ListSSHServerHostKeys(profile.ServerID)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if len(hostKeys) == 0 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "no pinned host key for this server"})
return
}
item = models.SSHSession{
ProfileID: profile.ID,
ServerID: profile.ServerID,
UserID: user.ID,
Username: user.Username,
RemoteUsername: profile.RemoteUsername,
AuthMethod: profile.AuthMethod,
Host: profile.Server.Host,
Port: profile.Server.Port,
Status: "pending",
RequestedTerm: normalizeSSHSessionTerm(req.Term),
RequestedCols: normalizeSSHSessionCols(req.Cols),
RequestedRows: normalizeSSHSessionRows(req.Rows),
RemoteAddr: requestRemoteAddr(r),
UserAgent: strings.TrimSpace(r.UserAgent()),
}
item, err = api.store(r).CreateSSHSession(item)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
response = sshSessionConnectResponse{
SessionID: item.ID,
Status: item.Status,
WebSocketPath: "/api/ssh/sessions/" + item.ID + "/stream",
ServerName: profile.Server.Name,
Host: profile.Server.Host,
Port: profile.Server.Port,
RemoteUsername: profile.RemoteUsername,
HostKeyFingerprint: hostKeys[0].Fingerprint,
}
WriteJSON(w, http.StatusCreated, response)
}
func (api *API) GetSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHSession
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetSSHSession(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if item.UserID != user.ID {
w.WriteHeader(http.StatusForbidden)
return
}
WriteJSON(w, http.StatusOK, item)
}
func (api *API) ListSSHSessionsForSelf(w http.ResponseWriter, r *http.Request, _ map[string]string) {
var user models.User
var ok bool
var items []models.SSHSession
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
items, err = api.store(r).ListSSHSessionsForUser(user.ID, 50)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusOK, items)
}
func (api *API) DisconnectSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHSession
var closed bool
var err error
var now int64
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetSSHSession(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if item.UserID != user.ID {
w.WriteHeader(http.StatusForbidden)
return
}
if api.SSHSessionRegistry != nil {
closed = api.SSHSessionRegistry.RequestClose(item.ID, "disconnected by user")
}
now = time.Now().UTC().Unix()
if !closed {
err = api.store(r).UpdateSSHSessionStatus(item.ID, "closed", item.HostKeyFingerprint, item.ConnectedAt, now, "disconnected by user")
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
}
w.WriteHeader(http.StatusNoContent)
}
func (api *API) StreamSSHSessionForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHSession
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetSSHSession(params["id"])
if err != nil {
if err == sql.ErrNoRows {
w.WriteHeader(http.StatusNotFound)
return
}
w.WriteHeader(http.StatusInternalServerError)
return
}
if item.UserID != user.ID {
w.WriteHeader(http.StatusForbidden)
return
}
websocket.Handler(func(ws *websocket.Conn) {
api.serveSSHSessionStream(ws, r, user, item)
}).ServeHTTP(w, r)
}
func (api *API) normalizeSSHAccessProfileRequest(r *http.Request, req sshAccessProfileRequest, existing models.SSHAccessProfile, actorKind string, actorID string, actorName string, isUpdate bool, options sshAccessProfileNormalizeOptions) (models.SSHAccessProfile, error) {
var item models.SSHAccessProfile
var err error
var targets []models.SSHAccessProfileTarget
var privateKeyPEM string
var signer ssh.Signer
var publicKey string
var fingerprint string
var generatedKeyPEM string
var generatedPublicKey string
var generatedFingerprint string
var activeGrants []models.SSHPrincipalGrant
var now int64
var allowedGrantIDs map[string]bool
var i int
item = existing
req.ServerID = strings.TrimSpace(req.ServerID)
req.Name = strings.TrimSpace(req.Name)
req.Description = strings.TrimSpace(req.Description)
req.RemoteUsername = strings.TrimSpace(req.RemoteUsername)
req.AuthMethod = normalizeSSHAccessAuthMethod(req.AuthMethod)
req.SSHUserCAID = strings.TrimSpace(req.SSHUserCAID)
req.SSHPrincipalMode = normalizeSSHAccessPrincipalMode(req.SSHPrincipalMode)
privateKeyPEM = strings.TrimSpace(req.PrivateKeyPEM)
targets, err = normalizeSSHAccessProfileTargets(req.Targets, !options.RequireTargets)
if err != nil {
return item, err
}
if req.Name == "" {
return item, errors.New("name is required")
}
if req.ServerID == "" {
return item, errors.New("server_id is required")
}
if req.RemoteUsername == "" {
return item, errors.New("remote_username is required")
}
if req.AuthMethod == "" {
return item, errors.New("auth_method must be managed_ssh_cert or stored_private_key")
}
if req.AuthMethod == "managed_ssh_cert" {
if req.SSHUserCAID == "" {
return item, errors.New("ssh_user_ca_id is required for managed_ssh_cert")
}
if options.SelfService {
_, err = api.store(r).GetSSHUserCAForUser(req.SSHUserCAID)
} else {
_, err = api.store(r).GetSSHUserCA(req.SSHUserCAID)
}
if err != nil {
if err == sql.ErrNoRows {
return item, errors.New("ssh_user_ca_id not found")
}
return item, err
}
if req.SSHPrincipalMode == "grant" {
req.SSHPrincipalGrantIDs = normalizeSSHBrokerNames(req.SSHPrincipalGrantIDs)
if len(req.SSHPrincipalGrantIDs) == 0 {
return item, errors.New("at least one ssh_principal_grant_id is required in grant mode")
}
if options.SelfService {
now = time.Now().UTC().Unix()
activeGrants, err = api.store(r).ListActiveSSHPrincipalGrantsForUser(options.OwnerUserID, now)
if err != nil {
return item, err
}
allowedGrantIDs = map[string]bool{}
for i = 0; i < len(activeGrants); i++ {
allowedGrantIDs[activeGrants[i].ID] = true
}
for i = 0; i < len(req.SSHPrincipalGrantIDs); i++ {
if !allowedGrantIDs[req.SSHPrincipalGrantIDs[i]] {
return item, errors.New("ssh_principal_grant_id not available to user: " + req.SSHPrincipalGrantIDs[i])
}
}
}
} else {
req.SSHPrincipals = normalizeSSHBrokerNames(req.SSHPrincipals)
if len(req.SSHPrincipals) == 0 {
return item, errors.New("at least one ssh_principal is required in explicit mode")
}
}
if privateKeyPEM == "" && !isUpdate {
generatedKeyPEM, generatedPublicKey, generatedFingerprint, err = generateSSHUserCAKeyPair("ed25519")
if err != nil {
return item, err
}
privateKeyPEM = strings.TrimSpace(generatedKeyPEM)
publicKey = strings.TrimSpace(generatedPublicKey)
fingerprint = strings.TrimSpace(generatedFingerprint)
}
} else {
req.SSHUserCAID = ""
req.SSHPrincipalMode = "explicit"
req.SSHPrincipals = []string{}
req.SSHPrincipalGrantIDs = []string{}
if privateKeyPEM == "" && !isUpdate {
return item, errors.New("private_key_pem is required for stored_private_key")
}
}
if privateKeyPEM != "" && publicKey == "" {
signer, err = parseSSHSignerFromPEM(privateKeyPEM)
if err != nil {
return item, err
}
publicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(signer.PublicKey()))
}
if req.DefaultValidSeconds <= 0 {
req.DefaultValidSeconds = 3600
}
if req.MaxValidSeconds <= 0 {
req.MaxValidSeconds = req.DefaultValidSeconds
}
if req.MaxValidSeconds < req.DefaultValidSeconds {
return item, errors.New("max_valid_seconds must be greater than or equal to default_valid_seconds")
}
if options.SelfService {
_, err = api.store(r).GetOwnedSSHServerForUser(options.OwnerUserID, req.ServerID)
if err != nil {
if err == sql.ErrNoRows {
return item, errors.New("server_id not found")
}
return item, err
}
} else {
_, err = api.store(r).GetSSHServer(req.ServerID)
if err != nil {
if err == sql.ErrNoRows {
return item, errors.New("server_id not found")
}
return item, err
}
}
err = validateSSHAccessProfileTargets(api.store(r), targets)
if err != nil {
return item, err
}
item.ServerID = req.ServerID
item.Name = req.Name
item.Description = req.Description
item.RemoteUsername = req.RemoteUsername
item.AuthMethod = req.AuthMethod
item.OwnerScope = options.OwnerScope
item.OwnerUserID = options.OwnerUserID
item.AllowUserEdit = options.AllowUserEdit
item.Enabled = req.Enabled
item.SSHUserCAID = req.SSHUserCAID
item.SSHPrincipalMode = req.SSHPrincipalMode
item.SSHPrincipals = req.SSHPrincipals
item.SSHPrincipalGrantIDs = req.SSHPrincipalGrantIDs
item.DefaultValidSeconds = req.DefaultValidSeconds
item.MaxValidSeconds = req.MaxValidSeconds
item.Targets = targets
if !isUpdate {
item.CreatedByKind = actorKind
item.CreatedBySubjectID = actorID
item.CreatedBySubjectName = actorName
}
if privateKeyPEM != "" {
item.SecretPayload, err = api.encryptSSHSecretPayload(privateKeyPEM)
if err != nil {
return item, err
}
item.AuthPublicKey = publicKey
item.AuthPublicKeyFingerprint = fingerprint
}
return item, nil
}
func normalizeSSHAccessAuthMethod(raw string) string {
var value string
value = strings.ToLower(strings.TrimSpace(raw))
if value == "managed_ssh_cert" || value == "stored_private_key" {
return value
}
return ""
}
func normalizeSSHAccessPrincipalMode(raw string) string {
var value string
value = strings.ToLower(strings.TrimSpace(raw))
if value == "grant" {
return "grant"
}
return "explicit"
}
func normalizeSSHBrokerTags(raw []string) []string {
var out []string
out = db.NormalizeSSHServerTags(raw)
return out
}
func normalizeSSHBrokerNames(raw []string) []string {
var out []string
var seen map[string]bool
var value string
var i int
seen = map[string]bool{}
for i = 0; i < len(raw); i++ {
value = strings.TrimSpace(raw[i])
if value == "" {
continue
}
if seen[value] {
continue
}
seen[value] = true
out = append(out, value)
}
sort.Strings(out)
return out
}
func normalizeSSHAccessProfileTargets(raw []sshAccessProfileTargetRequest, allowEmpty bool) ([]models.SSHAccessProfileTarget, error) {
var out []models.SSHAccessProfileTarget
var dedupe map[string]bool
var targetType string
var targetID string
var key string
var i int
dedupe = map[string]bool{}
for i = 0; i < len(raw); i++ {
targetType = strings.ToLower(strings.TrimSpace(raw[i].TargetType))
targetID = strings.TrimSpace(raw[i].TargetID)
if targetType != "user" && targetType != "group" {
return nil, errors.New("target_type must be user or group")
}
if targetID == "" {
return nil, errors.New("target_id is required")
}
key = targetType + ":" + targetID
if dedupe[key] {
continue
}
dedupe[key] = true
out = append(out, models.SSHAccessProfileTarget{TargetType: targetType, TargetID: targetID})
}
if len(out) == 0 && !allowEmpty {
return nil, errors.New("at least one target is required")
}
return out, nil
}
func validateSSHAccessProfileTargets(store *db.Store, targets []models.SSHAccessProfileTarget) error {
var i int
var err error
if len(targets) == 0 {
return nil
}
for i = 0; i < len(targets); i++ {
if targets[i].TargetType == "user" {
_, err = store.GetUserByID(targets[i].TargetID)
} else {
_, err = store.GetUserGroup(targets[i].TargetID)
}
if err != nil {
if err == sql.ErrNoRows {
return errors.New(targets[i].TargetType + " target not found: " + targets[i].TargetID)
}
return err
}
}
return nil
}
func currentSSHBrokerActor(r *http.Request) (string, string, string) {
var user models.User
var principal models.ServicePrincipal
var ok bool
user, ok = middleware.UserFromContext(r.Context())
if ok {
return "user", user.ID, user.Username
}
principal, ok = middleware.PrincipalFromContext(r.Context())
if ok {
return "service_principal", principal.ID, principal.Name
}
return "system", "", "system"
}
func normalizeSSHSessionTerm(raw string) string {
var value string
value = strings.TrimSpace(raw)
if value == "" {
return "xterm-256color"
}
return value
}
func normalizeSSHSessionCols(value int) int {
if value <= 0 {
return 120
}
if value > 400 {
return 400
}
return value
}
func normalizeSSHSessionRows(value int) int {
if value <= 0 {
return 36
}
if value > 200 {
return 200
}
return value
}
func parseSSHServerHostKey(raw string) (models.SSHServerHostKey, error) {
var item models.SSHServerHostKey
var key ssh.PublicKey
var err error
key, _, _, _, err = ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(raw)))
if err != nil {
return item, errors.New("invalid ssh public key")
}
item.Algorithm = key.Type()
item.PublicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
item.Fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(key))
return item, nil
}
func discoverSSHServerHostKey(host string, port int) (models.SSHServerHostKey, error) {
var item models.SSHServerHostKey
var addr string
var conn net.Conn
var clientConn ssh.Conn
var reqs <-chan *ssh.Request
var config *ssh.ClientConfig
var discovered ssh.PublicKey
var err error
addr = net.JoinHostPort(strings.TrimSpace(host), fmt.Sprintf("%d", port))
config = &ssh.ClientConfig{
User: "codit",
Auth: []ssh.AuthMethod{},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
discovered = key
return nil
},
Timeout: 10 * time.Second,
}
conn, err = net.DialTimeout("tcp", addr, 10*time.Second)
if err != nil {
return item, err
}
defer conn.Close()
clientConn, _, reqs, err = ssh.NewClientConn(conn, addr, config)
if clientConn != nil {
clientConn.Close()
}
if reqs != nil {
go ssh.DiscardRequests(reqs)
}
if discovered == nil {
return item, errors.New("failed to discover ssh host key")
}
item.Algorithm = discovered.Type()
item.PublicKey = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(discovered)))
item.Fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(discovered))
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "unable to authenticate") {
return item, nil
}
if strings.Contains(strings.ToLower(err.Error()), "no auth passed yet") {
return item, nil
}
}
return item, nil
}
func (api *API) serveSSHSessionStream(ws *websocket.Conn, r *http.Request, user models.User, sessionItem models.SSHSession) {
var profile models.SSHAccessProfile
var authMethods []ssh.AuthMethod
var hostKeys []models.SSHServerHostKey
var sshConfig *ssh.ClientConfig
var sshClient *ssh.Client
var sshSession *ssh.Session
var stdin io.WriteCloser
var stdout io.Reader
var stderr io.Reader
var runtimeSession *sshActiveSession
var closeReason string
var err error
var connectedFingerprint string
var connectedAt int64
var now int64
var sendMu sync.Mutex
var inputCh chan sshSessionStreamMessage
var inputErrCh chan error
var shellErrCh chan error
var done bool
var waitErr error
defer ws.Close()
runtimeSession = &sshActiveSession{}
runtimeSession.SetResources(ws, nil, nil, nil)
if api.SSHSessionRegistry != nil {
err = api.SSHSessionRegistry.Register(sessionItem.ID, runtimeSession)
if err != nil {
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error())
return
}
defer api.SSHSessionRegistry.Unregister(sessionItem.ID, runtimeSession)
}
profile, err = api.store(r).GetSSHAccessProfileForUser(user.ID, sessionItem.ProfileID)
if err != nil {
api.logSSHSessionPrepareFailure(sessionItem.ID, sessionItem.RemoteAddr, user.Username, sessionItem.ProfileID, "load_profile", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: "ssh access profile not found"})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, "ssh access profile not found")
return
}
api.logSSHSessionPrepare(sessionItem.ID, sessionItem.RemoteAddr, user, profile)
hostKeys, err = api.store(r).ListSSHServerHostKeys(profile.ServerID)
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "load_host_keys", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error())
return
}
if len(hostKeys) == 0 {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "host_key_missing", errors.New("no pinned host key for this server"))
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: "no pinned host key for this server"})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, "no pinned host key")
return
}
authMethods, err = api.buildSSHSessionAuth(profile)
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "build_auth", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", sessionItem.HostKeyFingerprint, sessionItem.ConnectedAt, now, err.Error())
return
}
closeReason = runtimeSession.CloseReason()
if closeReason != "" {
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", "", 0, now, closeReason)
api.logSSHSessionClosed(sessionItem, profile, user, "closed", closeReason, 0, now)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason})
return
}
api.logSSHSessionConnectStart(sessionItem, profile, user)
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "connecting", "", 0, 0, "")
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "connecting", Message: "connecting"})
sshConfig = &ssh.ClientConfig{
User: profile.RemoteUsername,
Auth: authMethods,
HostKeyCallback: sshHostKeyCallback(hostKeys, &connectedFingerprint),
Timeout: 15 * time.Second,
}
sshClient, err = ssh.Dial("tcp", net.JoinHostPort(profile.Server.Host, fmt.Sprintf("%d", profile.Server.Port)), sshConfig)
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "dial", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
defer sshClient.Close()
runtimeSession.SetResources(ws, sshClient, nil, nil)
sshSession, err = sshClient.NewSession()
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "new_session", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
defer sshSession.Close()
runtimeSession.SetResources(ws, sshClient, sshSession, nil)
stdin, err = sshSession.StdinPipe()
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "stdin_pipe", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
stdout, err = sshSession.StdoutPipe()
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "stdout_pipe", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
stderr, err = sshSession.StderrPipe()
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "stderr_pipe", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
runtimeSession.SetResources(ws, sshClient, sshSession, stdin)
closeReason = runtimeSession.CloseReason()
if closeReason != "" {
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, 0, now, closeReason)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason})
return
}
err = sshSession.RequestPty(sessionItem.RequestedTerm, sessionItem.RequestedRows, sessionItem.RequestedCols, ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
})
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "request_pty", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
err = sshSession.Shell()
if err != nil {
api.logSSHSessionConnectFailure(sessionItem, profile, user, "shell", err)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: err.Error()})
now = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, 0, now, err.Error())
return
}
connectedAt = time.Now().UTC().Unix()
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "connected", connectedFingerprint, connectedAt, 0, "")
api.logSSHSessionConnectSuccess(sessionItem, profile, user, connectedFingerprint)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "connected", Message: connectedFingerprint})
inputCh = make(chan sshSessionStreamMessage)
inputErrCh = make(chan error, 1)
shellErrCh = make(chan error, 1)
go api.readSSHSessionWebsocket(ws, inputCh, inputErrCh)
go api.pipeSSHSessionOutput(stdout, ws, &sendMu)
go api.pipeSSHSessionOutput(stderr, ws, &sendMu)
go func() {
var shellErr error
shellErr = sshSession.Wait()
shellErrCh <- shellErr
}()
done = false
waitErr = nil
for !done {
select {
case err = <-inputErrCh:
if err != nil {
waitErr = nil
}
done = true
case waitErr = <-shellErrCh:
done = true
case msg := <-inputCh:
if msg.Type == "input" {
_, err = io.WriteString(stdin, msg.Data)
if err != nil {
waitErr = err
done = true
}
} else if msg.Type == "resize" {
_ = sshSession.WindowChange(normalizeSSHSessionRows(msg.Rows), normalizeSSHSessionCols(msg.Cols))
}
}
}
_ = stdin.Close()
now = time.Now().UTC().Unix()
closeReason = runtimeSession.CloseReason()
if closeReason != "" {
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, connectedAt, now, closeReason)
api.logSSHSessionClosed(sessionItem, profile, user, "closed", closeReason, connectedAt, now)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: closeReason})
return
}
if waitErr != nil && !errors.Is(waitErr, io.EOF) {
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "error", connectedFingerprint, connectedAt, now, waitErr.Error())
api.Logger.Write(logIDSSHBroker, util.LOG_WARN,
"session error session=%s requester=%s user=%s profile=%s profile_name=%q server=%s server_name=%q target=%s:%d remote_user=%s auth=%s dur_s=%d err=%v",
sessionItem.ID,
sessionItem.RemoteAddr,
user.Username,
profile.ID,
profile.Name,
profile.ServerID,
profile.Server.Name,
profile.Server.Host,
profile.Server.Port,
profile.RemoteUsername,
profile.AuthMethod,
now-connectedAt,
waitErr)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "error", Message: waitErr.Error()})
return
}
_ = api.store(r).UpdateSSHSessionStatus(sessionItem.ID, "closed", connectedFingerprint, connectedAt, now, "")
api.logSSHSessionClosed(sessionItem, profile, user, "closed", "", connectedAt, now)
api.sendSSHSessionWS(&sendMu, ws, sshSessionStreamMessage{Type: "status", Status: "closed", Message: ""})
}
func (api *API) buildSSHSessionAuth(profile models.SSHAccessProfile) ([]ssh.AuthMethod, error) {
var secret models.SSHSecret
var privateKeyPEM string
var signer ssh.Signer
var ca models.SSHUserCA
var signed sshSignUserKeyResponse
var certKey ssh.PublicKey
var cert *ssh.Certificate
var certSigner ssh.Signer
var principals []string
var err error
if strings.TrimSpace(profile.SecretID) == "" {
return nil, errors.New("ssh access profile has no secret")
}
secret, err = api.Store.GetSSHSecret(profile.SecretID)
if err != nil {
return nil, err
}
privateKeyPEM, err = api.decryptSSHSecretPayload(secret.Payload)
if err != nil {
return nil, err
}
signer, err = parseSSHSignerFromPEM(privateKeyPEM)
if err != nil {
return nil, err
}
if profile.AuthMethod == "stored_private_key" {
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
}
ca, err = api.Store.GetSSHUserCA(profile.SSHUserCAID)
if err != nil {
return nil, err
}
principals, err = api.resolveSSHAccessProfilePrincipals(profile)
if err != nil {
return nil, err
}
signed, err = api.signSSHUserKeyWithCA(api.Store, ca, profile.AuthPublicKey, "", principals, profile.DefaultValidSeconds, profile.MaxValidSeconds)
if err != nil {
return nil, err
}
certKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(signed.Certificate)))
if err != nil {
return nil, err
}
cert, _ = certKey.(*ssh.Certificate)
if cert == nil {
return nil, errors.New("signed ssh certificate is invalid")
}
certSigner, err = ssh.NewCertSigner(cert, signer)
if err != nil {
return nil, err
}
return []ssh.AuthMethod{ssh.PublicKeys(certSigner)}, nil
}
func (api *API) resolveSSHAccessProfilePrincipals(profile models.SSHAccessProfile) ([]string, error) {
var principals []string
var grant models.SSHPrincipalGrant
var seen map[string]bool
var i int
var j int
var err error
if profile.SSHPrincipalMode != "grant" {
return normalizeSSHBrokerNames(profile.SSHPrincipals), nil
}
seen = map[string]bool{}
for i = 0; i < len(profile.SSHPrincipalGrantIDs); i++ {
grant, err = api.Store.GetSSHPrincipalGrant(profile.SSHPrincipalGrantIDs[i])
if err != nil {
return nil, err
}
for j = 0; j < len(grant.Principals); j++ {
if seen[grant.Principals[j]] {
continue
}
seen[grant.Principals[j]] = true
principals = append(principals, grant.Principals[j])
}
}
sort.Strings(principals)
if len(principals) == 0 {
return nil, errors.New("no ssh principals resolved for access profile")
}
return principals, nil
}
func sshHostKeyCallback(items []models.SSHServerHostKey, connectedFingerprint *string) ssh.HostKeyCallback {
var allowed map[string]bool
var i int
allowed = map[string]bool{}
for i = 0; i < len(items); i++ {
allowed[strings.TrimSpace(items[i].Fingerprint)] = true
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
var fingerprint string
fingerprint = strings.TrimSpace(ssh.FingerprintSHA256(key))
if connectedFingerprint != nil {
*connectedFingerprint = fingerprint
}
if allowed[fingerprint] {
return nil
}
return errors.New("host key mismatch: " + fingerprint)
}
}
func (api *API) pipeSSHSessionOutput(reader io.Reader, ws *websocket.Conn, sendMu *sync.Mutex) {
var buffer []byte
var chunk []byte
var n int
var err error
buffer = make([]byte, 4096)
for {
n, err = reader.Read(buffer)
if n > 0 {
chunk = make([]byte, n)
copy(chunk, buffer[:n])
api.sendSSHSessionBinary(sendMu, ws, chunk)
}
if err != nil {
return
}
}
}
func (api *API) readSSHSessionWebsocket(ws *websocket.Conn, inputCh chan<- sshSessionStreamMessage, errCh chan<- error) {
var msg sshSessionStreamMessage
var err error
for {
msg = sshSessionStreamMessage{}
err = websocket.JSON.Receive(ws, &msg)
if err != nil {
errCh <- err
return
}
inputCh <- msg
}
}
func (api *API) sendSSHSessionWS(mu *sync.Mutex, ws *websocket.Conn, msg sshSessionStreamMessage) {
if mu == nil {
return
}
mu.Lock()
defer mu.Unlock()
_ = websocket.JSON.Send(ws, msg)
}
func (api *API) sendSSHSessionBinary(mu *sync.Mutex, ws *websocket.Conn, data []byte) {
if mu == nil {
return
}
mu.Lock()
defer mu.Unlock()
_ = websocket.Message.Send(ws, data)
}