850 lines
26 KiB
Go
850 lines
26 KiB
Go
package handlers
|
|
|
|
import codit_logger "codit/logger"
|
|
|
|
import "archive/zip"
|
|
import "context"
|
|
import "database/sql"
|
|
import "errors"
|
|
import "fmt"
|
|
import "io"
|
|
import "mime"
|
|
import "mime/multipart"
|
|
import "net"
|
|
import "net/http"
|
|
import "path"
|
|
import "strings"
|
|
import "os"
|
|
import "time"
|
|
|
|
import "github.com/pkg/sftp"
|
|
|
|
import "golang.org/x/crypto/ssh"
|
|
|
|
import "codit/internal/db"
|
|
import "codit/internal/middleware"
|
|
import "codit/internal/models"
|
|
|
|
// TODO: make these configurable??
|
|
const sshFileTransferMaxUploadBytes int64 = 5 * 1024 * 1024 * 1024
|
|
const sshFileTransferMaxDownloadBytes int64 = 10 * 1024 * 1024 * 1024
|
|
const sshFileTransferMaxFiles int = 256
|
|
const sshFileTransferCopyBufferSize int = 1024 * 1024
|
|
|
|
type sshSessionFileDownloadRequest struct {
|
|
Paths []string `json:"paths"`
|
|
Password string `json:"password"`
|
|
OTPCode string `json:"otp_code"`
|
|
}
|
|
|
|
type sshSessionFileCopyRequest struct {
|
|
TargetSessionID string `json:"target_session_id"`
|
|
Paths []string `json:"paths"`
|
|
TargetDir string `json:"target_dir"`
|
|
Overwrite bool `json:"overwrite"`
|
|
SourcePassword string `json:"source_password"`
|
|
SourceOTPCode string `json:"source_otp_code"`
|
|
TargetPassword string `json:"target_password"`
|
|
TargetOTPCode string `json:"target_otp_code"`
|
|
}
|
|
|
|
type sshSessionFileUploadItem struct {
|
|
Name string `json:"name"`
|
|
Path string `json:"path"`
|
|
Size int64 `json:"size"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type sshSessionFileUploadResponse struct {
|
|
Items []sshSessionFileUploadItem `json:"items"`
|
|
Uploaded int `json:"uploaded"`
|
|
Failed int `json:"failed"`
|
|
}
|
|
|
|
type sshSessionFileDownloadItem struct {
|
|
RemotePath string
|
|
Name string
|
|
Size int64
|
|
}
|
|
|
|
type sshSessionFileCopyItem struct {
|
|
SourcePath string `json:"source_path"`
|
|
TargetPath string `json:"target_path"`
|
|
Size int64 `json:"size"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type sshSessionFileCopyResponse struct {
|
|
Items []sshSessionFileCopyItem `json:"items"`
|
|
Copied int `json:"copied"`
|
|
Failed int `json:"failed"`
|
|
}
|
|
|
|
type sshFileTransferContextReader struct {
|
|
Context context.Context
|
|
Reader io.Reader
|
|
}
|
|
|
|
type sshFileTransferContextWriter struct {
|
|
Context context.Context
|
|
Writer io.Writer
|
|
}
|
|
|
|
func (r *sshFileTransferContextReader) Read(p []byte) (int, error) {
|
|
if r.Context.Err() != nil {
|
|
return 0, r.Context.Err()
|
|
}
|
|
return r.Reader.Read(p)
|
|
}
|
|
|
|
func (w *sshFileTransferContextWriter) Write(p []byte) (int, error) {
|
|
if w.Context.Err() != nil {
|
|
return 0, w.Context.Err()
|
|
}
|
|
return w.Writer.Write(p)
|
|
}
|
|
|
|
func (api *API) UploadSSHSessionFilesForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
|
var user models.User
|
|
var ok bool
|
|
var store *db.Store
|
|
var sessionItem models.SSHSession
|
|
var multipartReader *multipart.Reader
|
|
var part *multipart.Part
|
|
var sshClient *ssh.Client
|
|
var sftpClient *sftp.Client
|
|
var stopContextClose func()
|
|
var targetDir string
|
|
var overwrite bool
|
|
var response sshSessionFileUploadResponse
|
|
var item sshSessionFileUploadItem
|
|
var remotePath string
|
|
var fieldName string
|
|
var fieldValue string
|
|
var password string
|
|
var otpCode string
|
|
var seenUploadPaths map[string]bool
|
|
var uploadCount int
|
|
var opened bool
|
|
var copied int64
|
|
var err error
|
|
|
|
user, ok = middleware.UserFromContext(r.Context())
|
|
if !ok || user.Disabled {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
r.Body = http.MaxBytesReader(w, r.Body, sshFileTransferMaxUploadBytes)
|
|
multipartReader, err = r.MultipartReader()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid multipart upload"})
|
|
return
|
|
}
|
|
targetDir = "."
|
|
seenUploadPaths = make(map[string]bool)
|
|
for {
|
|
if r.Context().Err() != nil {
|
|
return
|
|
}
|
|
part, err = multipartReader.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid multipart upload"})
|
|
return
|
|
}
|
|
fieldName = part.FormName()
|
|
if fieldName != "files" {
|
|
fieldValue, err = readMultipartFieldValue(part, 64 * 1024)
|
|
_ = part.Close()
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
if fieldName == "target_dir" {
|
|
targetDir = strings.TrimSpace(fieldValue)
|
|
if targetDir == "" {
|
|
targetDir = "."
|
|
}
|
|
} else if fieldName == "overwrite" {
|
|
overwrite = parseSSHFileTransferBool(fieldValue)
|
|
} else if fieldName == "password" {
|
|
password = fieldValue
|
|
} else if fieldName == "otp_code" {
|
|
otpCode = fieldValue
|
|
}
|
|
continue
|
|
}
|
|
item = sshSessionFileUploadItem{Name: path.Base(part.FileName())}
|
|
if item.Name == "." || item.Name == "/" || strings.Contains(item.Name, "\x00") {
|
|
item.Error = "invalid file name"
|
|
response.Failed += 1
|
|
response.Items = append(response.Items, item)
|
|
_ = part.Close()
|
|
continue
|
|
}
|
|
remotePath = path.Join(targetDir, item.Name)
|
|
if seenUploadPaths[remotePath] {
|
|
_ = part.Close()
|
|
continue
|
|
}
|
|
seenUploadPaths[remotePath] = true
|
|
uploadCount += 1
|
|
if uploadCount > sshFileTransferMaxFiles {
|
|
_ = part.Close()
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("too many files; max is %d", sshFileTransferMaxFiles)})
|
|
return
|
|
}
|
|
if !opened {
|
|
store = api.store(r)
|
|
sessionItem, err = api.getSSHSessionFileTransferItem(store, user, params["id"])
|
|
if err != nil {
|
|
_ = part.Close()
|
|
api.writeSSHSessionFileTransferError(w, err)
|
|
return
|
|
}
|
|
sshClient, err = api.openSSHSessionFileTransferClient(store, user, sessionItem, password, otpCode)
|
|
if err != nil {
|
|
_ = part.Close()
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
stopContextClose = closeSSHClientOnContext(r.Context(), sshClient)
|
|
defer stopContextClose()
|
|
defer sshClient.Close()
|
|
sftpClient, err = sftp.NewClient(sshClient)
|
|
if err != nil {
|
|
_ = part.Close()
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "sftp subsystem unavailable: " + err.Error()})
|
|
return
|
|
}
|
|
defer sftpClient.Close()
|
|
opened = true
|
|
}
|
|
item.Path = remotePath
|
|
copied, err = api.sftpUploadFile(r.Context(), sftpClient, targetDir, item.Name, part, overwrite)
|
|
_ = part.Close()
|
|
item.Size = copied
|
|
if err != nil {
|
|
if r.Context().Err() != nil {
|
|
return
|
|
}
|
|
item.Error = err.Error()
|
|
response.Failed += 1
|
|
} else {
|
|
response.Uploaded += 1
|
|
}
|
|
response.Items = append(response.Items, item)
|
|
}
|
|
if uploadCount == 0 {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "at least one file is required"})
|
|
return
|
|
}
|
|
WriteJSON(w, http.StatusOK, response)
|
|
}
|
|
|
|
func (api *API) DownloadSSHSessionFilesForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
|
var user models.User
|
|
var ok bool
|
|
var store *db.Store
|
|
var req sshSessionFileDownloadRequest
|
|
var sessionItem models.SSHSession
|
|
var sshClient *ssh.Client
|
|
var sftpClient *sftp.Client
|
|
var stopContextClose func()
|
|
var items []sshSessionFileDownloadItem
|
|
var item sshSessionFileDownloadItem
|
|
var remotePath string
|
|
var seenPaths map[string]bool
|
|
var pathCount int
|
|
var totalSize int64
|
|
var contentDisposition string
|
|
var archiveName string
|
|
var err error
|
|
var i int
|
|
|
|
user, ok = middleware.UserFromContext(r.Context())
|
|
if !ok || user.Disabled {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
err = DecodeJSON(r, &req)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
|
|
return
|
|
}
|
|
if len(req.Paths) == 0 {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "at least one path is required"})
|
|
return
|
|
}
|
|
seenPaths = make(map[string]bool)
|
|
store = api.store(r)
|
|
sessionItem, err = api.getSSHSessionFileTransferItem(store, user, params["id"])
|
|
if err != nil {
|
|
api.writeSSHSessionFileTransferError(w, err)
|
|
return
|
|
}
|
|
sshClient, err = api.openSSHSessionFileTransferClient(store, user, sessionItem, req.Password, req.OTPCode)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
stopContextClose = closeSSHClientOnContext(r.Context(), sshClient)
|
|
defer stopContextClose()
|
|
defer sshClient.Close()
|
|
sftpClient, err = sftp.NewClient(sshClient)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "sftp subsystem unavailable: " + err.Error()})
|
|
return
|
|
}
|
|
defer sftpClient.Close()
|
|
for i = 0; i < len(req.Paths); i++ {
|
|
if r.Context().Err() != nil {
|
|
return
|
|
}
|
|
remotePath = strings.TrimSpace(req.Paths[i])
|
|
if remotePath == "" || strings.Contains(remotePath, "\x00") {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid download path"})
|
|
return
|
|
}
|
|
if seenPaths[remotePath] {
|
|
continue
|
|
}
|
|
seenPaths[remotePath] = true
|
|
pathCount += 1
|
|
if pathCount > sshFileTransferMaxFiles {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("too many files; max is %d", sshFileTransferMaxFiles)})
|
|
return
|
|
}
|
|
item, err = api.sftpDownloadMetadata(sftpClient, remotePath)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("%s: %s", remotePath, err.Error())})
|
|
return
|
|
}
|
|
totalSize += item.Size
|
|
if totalSize > sshFileTransferMaxDownloadBytes {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "download size exceeds limit"})
|
|
return
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
if len(items) == 1 {
|
|
contentDisposition = mime.FormatMediaType("attachment", map[string]string{"filename": items[0].Name})
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Header().Set("Content-Disposition", contentDisposition)
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", items[0].Size))
|
|
err = api.sftpDownloadFile(r.Context(), sftpClient, items[0].RemotePath, w)
|
|
if err != nil {
|
|
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "ssh file download failed session=%s path=%q err=%s", sessionItem.ID, items[0].RemotePath, err.Error())
|
|
}
|
|
return
|
|
}
|
|
archiveName = sshSessionFilesArchiveName(store, sessionItem)
|
|
contentDisposition = mime.FormatMediaType("attachment", map[string]string{"filename": archiveName})
|
|
w.Header().Set("Content-Type", "application/zip")
|
|
w.Header().Set("Content-Disposition", contentDisposition)
|
|
err = api.writeSSHSessionFileZip(r.Context(), sftpClient, items, w)
|
|
if err != nil {
|
|
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "ssh file zip download failed session=%s err=%s", sessionItem.ID, err.Error())
|
|
}
|
|
}
|
|
|
|
func (api *API) CopySSHSessionFilesForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
|
var user models.User
|
|
var ok bool
|
|
var store *db.Store
|
|
var req sshSessionFileCopyRequest
|
|
var sourceSession models.SSHSession
|
|
var targetSession models.SSHSession
|
|
var sourceSSHClient *ssh.Client
|
|
var targetSSHClient *ssh.Client
|
|
var sourceSFTPClient *sftp.Client
|
|
var targetSFTPClient *sftp.Client
|
|
var stopSourceClose func()
|
|
var stopTargetClose func()
|
|
var response sshSessionFileCopyResponse
|
|
var item sshSessionFileCopyItem
|
|
var metadata sshSessionFileDownloadItem
|
|
var remotePath string
|
|
var targetDir string
|
|
var seenPaths map[string]bool
|
|
var pathCount int
|
|
var totalSize int64
|
|
var startedAt time.Time
|
|
var duration time.Duration
|
|
var err error
|
|
var i int
|
|
|
|
user, ok = middleware.UserFromContext(r.Context())
|
|
if !ok || user.Disabled {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
err = DecodeJSON(r, &req)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.TargetSessionID) == "" {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target_session_id is required"})
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.TargetSessionID) == strings.TrimSpace(params["id"]) {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target session must be different from source session"})
|
|
return
|
|
}
|
|
if len(req.Paths) == 0 {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "at least one path is required"})
|
|
return
|
|
}
|
|
targetDir = strings.TrimSpace(req.TargetDir)
|
|
if targetDir == "" {
|
|
targetDir = "."
|
|
}
|
|
if strings.Contains(targetDir, "\x00") {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid target directory"})
|
|
return
|
|
}
|
|
seenPaths = make(map[string]bool)
|
|
store = api.store(r)
|
|
sourceSession, err = api.getSSHSessionFileTransferItem(store, user, params["id"])
|
|
if err != nil {
|
|
api.writeSSHSessionFileTransferError(w, err)
|
|
return
|
|
}
|
|
targetSession, err = api.getSSHSessionFileTransferItem(store, user, req.TargetSessionID)
|
|
if err != nil {
|
|
api.writeSSHSessionFileTransferError(w, err)
|
|
return
|
|
}
|
|
sourceSSHClient, err = api.openSSHSessionFileTransferClient(store, user, sourceSession, req.SourcePassword, req.SourceOTPCode)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "source session: " + err.Error()})
|
|
return
|
|
}
|
|
stopSourceClose = closeSSHClientOnContext(r.Context(), sourceSSHClient)
|
|
defer stopSourceClose()
|
|
defer sourceSSHClient.Close()
|
|
targetSSHClient, err = api.openSSHSessionFileTransferClient(store, user, targetSession, req.TargetPassword, req.TargetOTPCode)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target session: " + err.Error()})
|
|
return
|
|
}
|
|
stopTargetClose = closeSSHClientOnContext(r.Context(), targetSSHClient)
|
|
defer stopTargetClose()
|
|
defer targetSSHClient.Close()
|
|
sourceSFTPClient, err = sftp.NewClient(sourceSSHClient)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "source sftp subsystem unavailable: " + err.Error()})
|
|
return
|
|
}
|
|
defer sourceSFTPClient.Close()
|
|
targetSFTPClient, err = sftp.NewClient(targetSSHClient)
|
|
if err != nil {
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target sftp subsystem unavailable: " + err.Error()})
|
|
return
|
|
}
|
|
defer targetSFTPClient.Close()
|
|
for i = 0; i < len(req.Paths); i++ {
|
|
if r.Context().Err() != nil {
|
|
return
|
|
}
|
|
remotePath = strings.TrimSpace(req.Paths[i])
|
|
item = sshSessionFileCopyItem{SourcePath: remotePath}
|
|
if remotePath == "" || strings.Contains(remotePath, "\x00") {
|
|
item.Error = "invalid source path"
|
|
response.Failed += 1
|
|
response.Items = append(response.Items, item)
|
|
continue
|
|
}
|
|
if seenPaths[remotePath] {
|
|
continue
|
|
}
|
|
seenPaths[remotePath] = true
|
|
pathCount += 1
|
|
if pathCount > sshFileTransferMaxFiles {
|
|
item.Error = fmt.Sprintf("too many files; max is %d", sshFileTransferMaxFiles)
|
|
response.Failed += 1
|
|
response.Items = append(response.Items, item)
|
|
break
|
|
}
|
|
metadata, err = api.sftpDownloadMetadata(sourceSFTPClient, remotePath)
|
|
if err != nil {
|
|
item.Error = err.Error()
|
|
response.Failed += 1
|
|
response.Items = append(response.Items, item)
|
|
continue
|
|
}
|
|
totalSize += metadata.Size
|
|
if totalSize > sshFileTransferMaxDownloadBytes {
|
|
item.Error = "copy size exceeds limit"
|
|
response.Failed += 1
|
|
response.Items = append(response.Items, item)
|
|
break
|
|
}
|
|
item.TargetPath = path.Join(targetDir, metadata.Name)
|
|
startedAt = time.Now()
|
|
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO,
|
|
"ssh file copy start source_session=%s source_server=%s target_session=%s target_server=%s source_path=%q target_path=%q size=%d",
|
|
sourceSession.ID,
|
|
sourceSession.ServerName,
|
|
targetSession.ID,
|
|
targetSession.ServerName,
|
|
metadata.RemotePath,
|
|
item.TargetPath,
|
|
metadata.Size)
|
|
item.Size, err = api.sftpCopyFile(r.Context(), sourceSFTPClient, targetSFTPClient, metadata.RemotePath, targetDir, metadata.Name, req.Overwrite)
|
|
duration = time.Since(startedAt)
|
|
if err != nil {
|
|
if r.Context().Err() != nil {
|
|
return
|
|
}
|
|
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN,
|
|
"ssh file copy failed source_session=%s source_server=%s target_session=%s target_server=%s source_path=%q target_path=%q size=%d dur_ms=%d err=%s",
|
|
sourceSession.ID,
|
|
sourceSession.ServerName,
|
|
targetSession.ID,
|
|
targetSession.ServerName,
|
|
metadata.RemotePath,
|
|
item.TargetPath,
|
|
item.Size,
|
|
duration.Milliseconds(),
|
|
err.Error())
|
|
item.Error = err.Error()
|
|
response.Failed += 1
|
|
} else {
|
|
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO,
|
|
"ssh file copy done source_session=%s source_server=%s target_session=%s target_server=%s source_path=%q target_path=%q size=%d dur_ms=%d",
|
|
sourceSession.ID,
|
|
sourceSession.ServerName,
|
|
targetSession.ID,
|
|
targetSession.ServerName,
|
|
metadata.RemotePath,
|
|
item.TargetPath,
|
|
item.Size,
|
|
duration.Milliseconds())
|
|
response.Copied += 1
|
|
}
|
|
response.Items = append(response.Items, item)
|
|
}
|
|
WriteJSON(w, http.StatusOK, response)
|
|
}
|
|
|
|
func (api *API) getSSHSessionFileTransferItem(store *db.Store, user models.User, sessionID string) (models.SSHSession, error) {
|
|
var item models.SSHSession
|
|
var err error
|
|
|
|
item, err = store.GetSSHSession(sessionID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return item, errors.New("ssh session not found")
|
|
}
|
|
return item, err
|
|
}
|
|
if item.UserID != user.ID {
|
|
return item, errors.New("ssh session not found")
|
|
}
|
|
if item.Status != "connected" {
|
|
return item, errors.New("ssh session is not connected")
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (api *API) writeSSHSessionFileTransferError(w http.ResponseWriter, err error) {
|
|
if err.Error() == "ssh session not found" {
|
|
WriteJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
|
|
}
|
|
|
|
func (api *API) openSSHSessionFileTransferClient(store *db.Store, user models.User, sessionItem models.SSHSession, promptedPassword string, otpCode string) (*ssh.Client, error) {
|
|
var prepared sshSessionPrepared
|
|
var prepareStage string
|
|
var sshConfig *ssh.ClientConfig
|
|
var sshClient *ssh.Client
|
|
var connectedFingerprint string
|
|
var connectedHostKey ssh.PublicKey
|
|
var err error
|
|
|
|
prepared, prepareStage, err = api.prepareSSHSession(store, &user, &sessionItem, promptedPassword, otpCode)
|
|
if err != nil {
|
|
if prepareStage == "load_profile" {
|
|
return nil, errors.New("ssh access profile not found")
|
|
}
|
|
return nil, err
|
|
}
|
|
sshConfig = &ssh.ClientConfig{
|
|
User: prepared.Profile.RemoteUsername,
|
|
Auth: prepared.AuthMethods,
|
|
HostKeyCallback: sshHostKeyCallback(prepared.HostKeys, &connectedFingerprint, &connectedHostKey, prepared.PinHostKeyOnFirstUse),
|
|
Timeout: 15 * time.Second,
|
|
}
|
|
sshClient, err = ssh.Dial("tcp", net.JoinHostPort(prepared.Profile.Server.Host, fmt.Sprintf("%d", prepared.Profile.Server.Port)), sshConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if prepared.PinHostKeyOnFirstUse {
|
|
err = api.pinSSHSessionFileTransferHostKey(store, sessionItem, prepared.Profile, connectedFingerprint, connectedHostKey)
|
|
if err != nil {
|
|
_ = sshClient.Close()
|
|
return nil, err
|
|
}
|
|
}
|
|
return sshClient, nil
|
|
}
|
|
|
|
func (api *API) pinSSHSessionFileTransferHostKey(store *db.Store, sessionItem models.SSHSession, profile models.SSHAccessProfile, fingerprint string, hostKey ssh.PublicKey) error {
|
|
var item models.SSHServerHostKey
|
|
var reloaded []models.SSHServerHostKey
|
|
var err error
|
|
var i int
|
|
|
|
if hostKey == nil {
|
|
return errors.New("failed to capture ssh host key")
|
|
}
|
|
item = models.SSHServerHostKey{
|
|
ServerID: sessionItem.ServerID,
|
|
Algorithm: hostKey.Type(),
|
|
PublicKey: strings.TrimSpace(string(ssh.MarshalAuthorizedKey(hostKey))),
|
|
Fingerprint: fingerprint,
|
|
}
|
|
item, err = store.CreateSSHServerHostKey(item)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
reloaded, err = store.ListSSHServerHostKeys(sessionItem.ServerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i = 0; i < len(reloaded); i++ {
|
|
if strings.TrimSpace(reloaded[i].Fingerprint) == fingerprint {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("failed to pin host key for %s: %s", profile.Server.Name, err.Error())
|
|
}
|
|
|
|
func closeSSHClientOnContext(ctx context.Context, client *ssh.Client) func() {
|
|
var done chan struct{}
|
|
var stop func()
|
|
|
|
done = make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = client.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
stop = func() {
|
|
close(done)
|
|
}
|
|
return stop
|
|
}
|
|
|
|
func (api *API) sftpUploadFile(ctx context.Context, client *sftp.Client, targetDir string, name string, reader io.Reader, overwrite bool) (int64, error) {
|
|
var remotePath string
|
|
var flags int
|
|
var remoteFile *sftp.File
|
|
var copied int64
|
|
var statErr error
|
|
var err error
|
|
|
|
if ctx.Err() != nil { return 0, ctx.Err() }
|
|
err = client.MkdirAll(targetDir)
|
|
if err != nil { return 0, err }
|
|
if ctx.Err() != nil { return 0, ctx.Err() }
|
|
remotePath = path.Join(targetDir, name)
|
|
if !overwrite {
|
|
_, statErr = client.Stat(remotePath)
|
|
if statErr == nil {
|
|
return 0, fmt.Errorf("remote file already exists: %s", remotePath)
|
|
}
|
|
if !isSFTPNoSuchFile(statErr) {
|
|
return 0, statErr
|
|
}
|
|
}
|
|
flags = os.O_WRONLY | os.O_CREATE
|
|
if overwrite {
|
|
flags = flags | os.O_TRUNC
|
|
} else {
|
|
flags = flags | os.O_EXCL
|
|
}
|
|
remoteFile, err = client.OpenFile(remotePath, flags)
|
|
if err != nil { return 0, err }
|
|
defer remoteFile.Close()
|
|
copied, err = copyWithContext(ctx, remoteFile, reader)
|
|
return copied, err
|
|
}
|
|
|
|
func isSFTPNoSuchFile(err error) bool {
|
|
var statusErr *sftp.StatusError
|
|
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return true
|
|
}
|
|
statusErr = nil
|
|
if errors.As(err, &statusErr) && statusErr.FxCode() == sftp.ErrSSHFxNoSuchFile {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (api *API) sftpDownloadMetadata(client *sftp.Client, remotePath string) (sshSessionFileDownloadItem, error) {
|
|
var info os.FileInfo
|
|
var item sshSessionFileDownloadItem
|
|
var err error
|
|
|
|
info, err = client.Stat(remotePath)
|
|
if err != nil { return item, err }
|
|
if info.IsDir() {
|
|
return item, errors.New("directory downloads are not supported")
|
|
}
|
|
item = sshSessionFileDownloadItem{
|
|
RemotePath: remotePath,
|
|
Name: path.Base(remotePath),
|
|
Size: info.Size(),
|
|
}
|
|
if item.Name == "." || item.Name == "/" || strings.TrimSpace(item.Name) == "" {
|
|
item.Name = info.Name()
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func sshSessionFilesArchiveName(store *db.Store, sessionItem models.SSHSession) string {
|
|
var profile models.SSHAccessProfile
|
|
var server models.SSHServer
|
|
var profileName string
|
|
var serverName string
|
|
var timestamp string
|
|
var err error
|
|
|
|
profileName = strings.TrimSpace(sessionItem.ProfileID)
|
|
serverName = strings.TrimSpace(sessionItem.Host)
|
|
if sessionItem.Port > 0 && serverName != "" {
|
|
serverName = fmt.Sprintf("%s-%d", serverName, sessionItem.Port)
|
|
}
|
|
profile, err = store.GetSSHAccessProfile(sessionItem.ProfileID)
|
|
if err == nil && strings.TrimSpace(profile.Name) != "" {
|
|
profileName = strings.TrimSpace(profile.Name)
|
|
}
|
|
server, err = store.GetSSHServer(sessionItem.ServerID)
|
|
if err == nil && strings.TrimSpace(server.Name) != "" {
|
|
serverName = strings.TrimSpace(server.Name)
|
|
}
|
|
if profileName == "" {
|
|
profileName = "ssh-profile"
|
|
}
|
|
if serverName == "" {
|
|
serverName = "ssh-server"
|
|
}
|
|
timestamp = time.Now().Format("20060102-150405")
|
|
return fmt.Sprintf("%s-%s-%s-files.zip", sanitizeBundleName(profileName), sanitizeBundleName(serverName), timestamp)
|
|
}
|
|
|
|
func (api *API) sftpDownloadFile(ctx context.Context, client *sftp.Client, remotePath string, writer io.Writer) error {
|
|
var remoteFile *sftp.File
|
|
var err error
|
|
|
|
if ctx.Err() != nil { return ctx.Err() }
|
|
remoteFile, err = client.Open(remotePath)
|
|
if err != nil { return err }
|
|
defer remoteFile.Close()
|
|
_, err = copyWithContext(ctx, writer, remoteFile)
|
|
return err
|
|
}
|
|
|
|
func (api *API) sftpCopyFile(ctx context.Context, sourceClient *sftp.Client, targetClient *sftp.Client, sourcePath string, targetDir string, name string, overwrite bool) (int64, error) {
|
|
var sourceFile *sftp.File
|
|
var copied int64
|
|
var err error
|
|
|
|
if ctx.Err() != nil { return 0, ctx.Err() }
|
|
sourceFile, err = sourceClient.Open(sourcePath)
|
|
if err != nil { return 0, err }
|
|
defer sourceFile.Close()
|
|
copied, err = api.sftpUploadFile(ctx, targetClient, targetDir, name, sourceFile, overwrite)
|
|
return copied, err
|
|
}
|
|
|
|
func (api *API) writeSSHSessionFileZip(ctx context.Context, client *sftp.Client, items []sshSessionFileDownloadItem, writer io.Writer) error {
|
|
var zipWriter *zip.Writer
|
|
var zipFile io.Writer
|
|
var closeErr error
|
|
var err error
|
|
var i int
|
|
|
|
zipWriter = zip.NewWriter(writer)
|
|
for i = 0; i < len(items); i++ {
|
|
if ctx.Err() != nil {
|
|
err = ctx.Err()
|
|
break
|
|
}
|
|
zipFile, err = zipWriter.Create(sshDownloadEntryName(items[i], i))
|
|
if err != nil { break }
|
|
err = api.sftpDownloadFile(ctx, client, items[i].RemotePath, zipFile)
|
|
if err != nil { break }
|
|
}
|
|
closeErr = zipWriter.Close()
|
|
if err == nil && closeErr != nil {
|
|
err = closeErr
|
|
}
|
|
return err
|
|
}
|
|
|
|
func copyWithContext(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) {
|
|
var contextWriter *sshFileTransferContextWriter
|
|
var contextReader *sshFileTransferContextReader
|
|
var buffer []byte
|
|
var writerTo io.WriterTo
|
|
var readerFrom io.ReaderFrom
|
|
|
|
if ctx.Err() != nil {
|
|
return 0, ctx.Err()
|
|
}
|
|
contextWriter = &sshFileTransferContextWriter{Context: ctx, Writer: writer}
|
|
contextReader = &sshFileTransferContextReader{Context: ctx, Reader: reader}
|
|
writerTo, _ = reader.(io.WriterTo)
|
|
if writerTo != nil {
|
|
return writerTo.WriteTo(contextWriter)
|
|
}
|
|
readerFrom, _ = writer.(io.ReaderFrom)
|
|
if readerFrom != nil {
|
|
return readerFrom.ReadFrom(contextReader)
|
|
}
|
|
buffer = make([]byte, sshFileTransferCopyBufferSize)
|
|
return io.CopyBuffer(contextWriter, contextReader, buffer)
|
|
}
|
|
|
|
func readMultipartFieldValue(part *multipart.Part, maxBytes int64) (string, error) {
|
|
var data []byte
|
|
var err error
|
|
|
|
data, err = io.ReadAll(io.LimitReader(part, maxBytes + 1))
|
|
if err != nil { return "", err }
|
|
if int64(len(data)) > maxBytes {
|
|
return "", errors.New("multipart field is too large")
|
|
}
|
|
return string(data), nil
|
|
}
|
|
|
|
func sshDownloadEntryName(item sshSessionFileDownloadItem, index int) string {
|
|
var name string
|
|
|
|
name = strings.TrimSpace(item.Name)
|
|
if name == "" || name == "." || name == "/" {
|
|
name = fmt.Sprintf("file-%d", index + 1)
|
|
}
|
|
return name
|
|
}
|
|
|
|
func parseSSHFileTransferBool(value string) bool {
|
|
value = strings.ToLower(strings.TrimSpace(value))
|
|
return value == "1" || value == "true" || value == "yes" || value == "on"
|
|
}
|