Files
codit/backend/internal/handlers/ssh-broker-files.go

850 lines
26 KiB
Go

package handlers
import codit_logger "codit/logger"
import "archive/zip"
import "context"
import "database/sql"
import "errors"
import "fmt"
import "io"
import "mime"
import "mime/multipart"
import "net"
import "net/http"
import "path"
import "strings"
import "os"
import "time"
import "github.com/pkg/sftp"
import "golang.org/x/crypto/ssh"
import "codit/internal/db"
import "codit/internal/middleware"
import "codit/internal/models"
// TODO: make these configurable??
const sshFileTransferMaxUploadBytes int64 = 5 * 1024 * 1024 * 1024
const sshFileTransferMaxDownloadBytes int64 = 10 * 1024 * 1024 * 1024
const sshFileTransferMaxFiles int = 256
const sshFileTransferCopyBufferSize int = 1024 * 1024
type sshSessionFileDownloadRequest struct {
Paths []string `json:"paths"`
Password string `json:"password"`
OTPCode string `json:"otp_code"`
}
type sshSessionFileCopyRequest struct {
TargetSessionID string `json:"target_session_id"`
Paths []string `json:"paths"`
TargetDir string `json:"target_dir"`
Overwrite bool `json:"overwrite"`
SourcePassword string `json:"source_password"`
SourceOTPCode string `json:"source_otp_code"`
TargetPassword string `json:"target_password"`
TargetOTPCode string `json:"target_otp_code"`
}
type sshSessionFileUploadItem struct {
Name string `json:"name"`
Path string `json:"path"`
Size int64 `json:"size"`
Error string `json:"error,omitempty"`
}
type sshSessionFileUploadResponse struct {
Items []sshSessionFileUploadItem `json:"items"`
Uploaded int `json:"uploaded"`
Failed int `json:"failed"`
}
type sshSessionFileDownloadItem struct {
RemotePath string
Name string
Size int64
}
type sshSessionFileCopyItem struct {
SourcePath string `json:"source_path"`
TargetPath string `json:"target_path"`
Size int64 `json:"size"`
Error string `json:"error,omitempty"`
}
type sshSessionFileCopyResponse struct {
Items []sshSessionFileCopyItem `json:"items"`
Copied int `json:"copied"`
Failed int `json:"failed"`
}
type sshFileTransferContextReader struct {
Context context.Context
Reader io.Reader
}
type sshFileTransferContextWriter struct {
Context context.Context
Writer io.Writer
}
func (r *sshFileTransferContextReader) Read(p []byte) (int, error) {
if r.Context.Err() != nil {
return 0, r.Context.Err()
}
return r.Reader.Read(p)
}
func (w *sshFileTransferContextWriter) Write(p []byte) (int, error) {
if w.Context.Err() != nil {
return 0, w.Context.Err()
}
return w.Writer.Write(p)
}
func (api *API) UploadSSHSessionFilesForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var store *db.Store
var sessionItem models.SSHSession
var multipartReader *multipart.Reader
var part *multipart.Part
var sshClient *ssh.Client
var sftpClient *sftp.Client
var stopContextClose func()
var targetDir string
var overwrite bool
var response sshSessionFileUploadResponse
var item sshSessionFileUploadItem
var remotePath string
var fieldName string
var fieldValue string
var password string
var otpCode string
var seenUploadPaths map[string]bool
var uploadCount int
var opened bool
var copied int64
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
r.Body = http.MaxBytesReader(w, r.Body, sshFileTransferMaxUploadBytes)
multipartReader, err = r.MultipartReader()
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid multipart upload"})
return
}
targetDir = "."
seenUploadPaths = make(map[string]bool)
for {
if r.Context().Err() != nil {
return
}
part, err = multipartReader.NextPart()
if err == io.EOF {
break
}
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid multipart upload"})
return
}
fieldName = part.FormName()
if fieldName != "files" {
fieldValue, err = readMultipartFieldValue(part, 64 * 1024)
_ = part.Close()
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
if fieldName == "target_dir" {
targetDir = strings.TrimSpace(fieldValue)
if targetDir == "" {
targetDir = "."
}
} else if fieldName == "overwrite" {
overwrite = parseSSHFileTransferBool(fieldValue)
} else if fieldName == "password" {
password = fieldValue
} else if fieldName == "otp_code" {
otpCode = fieldValue
}
continue
}
item = sshSessionFileUploadItem{Name: path.Base(part.FileName())}
if item.Name == "." || item.Name == "/" || strings.Contains(item.Name, "\x00") {
item.Error = "invalid file name"
response.Failed += 1
response.Items = append(response.Items, item)
_ = part.Close()
continue
}
remotePath = path.Join(targetDir, item.Name)
if seenUploadPaths[remotePath] {
_ = part.Close()
continue
}
seenUploadPaths[remotePath] = true
uploadCount += 1
if uploadCount > sshFileTransferMaxFiles {
_ = part.Close()
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("too many files; max is %d", sshFileTransferMaxFiles)})
return
}
if !opened {
store = api.store(r)
sessionItem, err = api.getSSHSessionFileTransferItem(store, user, params["id"])
if err != nil {
_ = part.Close()
api.writeSSHSessionFileTransferError(w, err)
return
}
sshClient, err = api.openSSHSessionFileTransferClient(store, user, sessionItem, password, otpCode)
if err != nil {
_ = part.Close()
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
stopContextClose = closeSSHClientOnContext(r.Context(), sshClient)
defer stopContextClose()
defer sshClient.Close()
sftpClient, err = sftp.NewClient(sshClient)
if err != nil {
_ = part.Close()
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "sftp subsystem unavailable: " + err.Error()})
return
}
defer sftpClient.Close()
opened = true
}
item.Path = remotePath
copied, err = api.sftpUploadFile(r.Context(), sftpClient, targetDir, item.Name, part, overwrite)
_ = part.Close()
item.Size = copied
if err != nil {
if r.Context().Err() != nil {
return
}
item.Error = err.Error()
response.Failed += 1
} else {
response.Uploaded += 1
}
response.Items = append(response.Items, item)
}
if uploadCount == 0 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "at least one file is required"})
return
}
WriteJSON(w, http.StatusOK, response)
}
func (api *API) DownloadSSHSessionFilesForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var store *db.Store
var req sshSessionFileDownloadRequest
var sessionItem models.SSHSession
var sshClient *ssh.Client
var sftpClient *sftp.Client
var stopContextClose func()
var items []sshSessionFileDownloadItem
var item sshSessionFileDownloadItem
var remotePath string
var seenPaths map[string]bool
var pathCount int
var totalSize int64
var contentDisposition string
var archiveName string
var err error
var i int
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
if len(req.Paths) == 0 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "at least one path is required"})
return
}
seenPaths = make(map[string]bool)
store = api.store(r)
sessionItem, err = api.getSSHSessionFileTransferItem(store, user, params["id"])
if err != nil {
api.writeSSHSessionFileTransferError(w, err)
return
}
sshClient, err = api.openSSHSessionFileTransferClient(store, user, sessionItem, req.Password, req.OTPCode)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}
stopContextClose = closeSSHClientOnContext(r.Context(), sshClient)
defer stopContextClose()
defer sshClient.Close()
sftpClient, err = sftp.NewClient(sshClient)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "sftp subsystem unavailable: " + err.Error()})
return
}
defer sftpClient.Close()
for i = 0; i < len(req.Paths); i++ {
if r.Context().Err() != nil {
return
}
remotePath = strings.TrimSpace(req.Paths[i])
if remotePath == "" || strings.Contains(remotePath, "\x00") {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid download path"})
return
}
if seenPaths[remotePath] {
continue
}
seenPaths[remotePath] = true
pathCount += 1
if pathCount > sshFileTransferMaxFiles {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("too many files; max is %d", sshFileTransferMaxFiles)})
return
}
item, err = api.sftpDownloadMetadata(sftpClient, remotePath)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": fmt.Sprintf("%s: %s", remotePath, err.Error())})
return
}
totalSize += item.Size
if totalSize > sshFileTransferMaxDownloadBytes {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "download size exceeds limit"})
return
}
items = append(items, item)
}
if len(items) == 1 {
contentDisposition = mime.FormatMediaType("attachment", map[string]string{"filename": items[0].Name})
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", contentDisposition)
w.Header().Set("Content-Length", fmt.Sprintf("%d", items[0].Size))
err = api.sftpDownloadFile(r.Context(), sftpClient, items[0].RemotePath, w)
if err != nil {
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "ssh file download failed session=%s path=%q err=%s", sessionItem.ID, items[0].RemotePath, err.Error())
}
return
}
archiveName = sshSessionFilesArchiveName(store, sessionItem)
contentDisposition = mime.FormatMediaType("attachment", map[string]string{"filename": archiveName})
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", contentDisposition)
err = api.writeSSHSessionFileZip(r.Context(), sftpClient, items, w)
if err != nil {
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN, "ssh file zip download failed session=%s err=%s", sessionItem.ID, err.Error())
}
}
func (api *API) CopySSHSessionFilesForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var store *db.Store
var req sshSessionFileCopyRequest
var sourceSession models.SSHSession
var targetSession models.SSHSession
var sourceSSHClient *ssh.Client
var targetSSHClient *ssh.Client
var sourceSFTPClient *sftp.Client
var targetSFTPClient *sftp.Client
var stopSourceClose func()
var stopTargetClose func()
var response sshSessionFileCopyResponse
var item sshSessionFileCopyItem
var metadata sshSessionFileDownloadItem
var remotePath string
var targetDir string
var seenPaths map[string]bool
var pathCount int
var totalSize int64
var startedAt time.Time
var duration time.Duration
var err error
var i int
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
err = DecodeJSON(r, &req)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"})
return
}
if strings.TrimSpace(req.TargetSessionID) == "" {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target_session_id is required"})
return
}
if strings.TrimSpace(req.TargetSessionID) == strings.TrimSpace(params["id"]) {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target session must be different from source session"})
return
}
if len(req.Paths) == 0 {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "at least one path is required"})
return
}
targetDir = strings.TrimSpace(req.TargetDir)
if targetDir == "" {
targetDir = "."
}
if strings.Contains(targetDir, "\x00") {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid target directory"})
return
}
seenPaths = make(map[string]bool)
store = api.store(r)
sourceSession, err = api.getSSHSessionFileTransferItem(store, user, params["id"])
if err != nil {
api.writeSSHSessionFileTransferError(w, err)
return
}
targetSession, err = api.getSSHSessionFileTransferItem(store, user, req.TargetSessionID)
if err != nil {
api.writeSSHSessionFileTransferError(w, err)
return
}
sourceSSHClient, err = api.openSSHSessionFileTransferClient(store, user, sourceSession, req.SourcePassword, req.SourceOTPCode)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "source session: " + err.Error()})
return
}
stopSourceClose = closeSSHClientOnContext(r.Context(), sourceSSHClient)
defer stopSourceClose()
defer sourceSSHClient.Close()
targetSSHClient, err = api.openSSHSessionFileTransferClient(store, user, targetSession, req.TargetPassword, req.TargetOTPCode)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target session: " + err.Error()})
return
}
stopTargetClose = closeSSHClientOnContext(r.Context(), targetSSHClient)
defer stopTargetClose()
defer targetSSHClient.Close()
sourceSFTPClient, err = sftp.NewClient(sourceSSHClient)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "source sftp subsystem unavailable: " + err.Error()})
return
}
defer sourceSFTPClient.Close()
targetSFTPClient, err = sftp.NewClient(targetSSHClient)
if err != nil {
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "target sftp subsystem unavailable: " + err.Error()})
return
}
defer targetSFTPClient.Close()
for i = 0; i < len(req.Paths); i++ {
if r.Context().Err() != nil {
return
}
remotePath = strings.TrimSpace(req.Paths[i])
item = sshSessionFileCopyItem{SourcePath: remotePath}
if remotePath == "" || strings.Contains(remotePath, "\x00") {
item.Error = "invalid source path"
response.Failed += 1
response.Items = append(response.Items, item)
continue
}
if seenPaths[remotePath] {
continue
}
seenPaths[remotePath] = true
pathCount += 1
if pathCount > sshFileTransferMaxFiles {
item.Error = fmt.Sprintf("too many files; max is %d", sshFileTransferMaxFiles)
response.Failed += 1
response.Items = append(response.Items, item)
break
}
metadata, err = api.sftpDownloadMetadata(sourceSFTPClient, remotePath)
if err != nil {
item.Error = err.Error()
response.Failed += 1
response.Items = append(response.Items, item)
continue
}
totalSize += metadata.Size
if totalSize > sshFileTransferMaxDownloadBytes {
item.Error = "copy size exceeds limit"
response.Failed += 1
response.Items = append(response.Items, item)
break
}
item.TargetPath = path.Join(targetDir, metadata.Name)
startedAt = time.Now()
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO,
"ssh file copy start source_session=%s source_server=%s target_session=%s target_server=%s source_path=%q target_path=%q size=%d",
sourceSession.ID,
sourceSession.ServerName,
targetSession.ID,
targetSession.ServerName,
metadata.RemotePath,
item.TargetPath,
metadata.Size)
item.Size, err = api.sftpCopyFile(r.Context(), sourceSFTPClient, targetSFTPClient, metadata.RemotePath, targetDir, metadata.Name, req.Overwrite)
duration = time.Since(startedAt)
if err != nil {
if r.Context().Err() != nil {
return
}
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_WARN,
"ssh file copy failed source_session=%s source_server=%s target_session=%s target_server=%s source_path=%q target_path=%q size=%d dur_ms=%d err=%s",
sourceSession.ID,
sourceSession.ServerName,
targetSession.ID,
targetSession.ServerName,
metadata.RemotePath,
item.TargetPath,
item.Size,
duration.Milliseconds(),
err.Error())
item.Error = err.Error()
response.Failed += 1
} else {
api.Logger.Write(logIDSSHBroker, codit_logger.LOG_INFO,
"ssh file copy done source_session=%s source_server=%s target_session=%s target_server=%s source_path=%q target_path=%q size=%d dur_ms=%d",
sourceSession.ID,
sourceSession.ServerName,
targetSession.ID,
targetSession.ServerName,
metadata.RemotePath,
item.TargetPath,
item.Size,
duration.Milliseconds())
response.Copied += 1
}
response.Items = append(response.Items, item)
}
WriteJSON(w, http.StatusOK, response)
}
func (api *API) getSSHSessionFileTransferItem(store *db.Store, user models.User, sessionID string) (models.SSHSession, error) {
var item models.SSHSession
var err error
item, err = store.GetSSHSession(sessionID)
if err != nil {
if err == sql.ErrNoRows {
return item, errors.New("ssh session not found")
}
return item, err
}
if item.UserID != user.ID {
return item, errors.New("ssh session not found")
}
if item.Status != "connected" {
return item, errors.New("ssh session is not connected")
}
return item, nil
}
func (api *API) writeSSHSessionFileTransferError(w http.ResponseWriter, err error) {
if err.Error() == "ssh session not found" {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()})
return
}
WriteJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
}
func (api *API) openSSHSessionFileTransferClient(store *db.Store, user models.User, sessionItem models.SSHSession, promptedPassword string, otpCode string) (*ssh.Client, error) {
var prepared sshSessionPrepared
var prepareStage string
var sshConfig *ssh.ClientConfig
var sshClient *ssh.Client
var connectedFingerprint string
var connectedHostKey ssh.PublicKey
var err error
prepared, prepareStage, err = api.prepareSSHSession(store, &user, &sessionItem, promptedPassword, otpCode)
if err != nil {
if prepareStage == "load_profile" {
return nil, errors.New("ssh access profile not found")
}
return nil, err
}
sshConfig = &ssh.ClientConfig{
User: prepared.Profile.RemoteUsername,
Auth: prepared.AuthMethods,
HostKeyCallback: sshHostKeyCallback(prepared.HostKeys, &connectedFingerprint, &connectedHostKey, prepared.PinHostKeyOnFirstUse),
Timeout: 15 * time.Second,
}
sshClient, err = ssh.Dial("tcp", net.JoinHostPort(prepared.Profile.Server.Host, fmt.Sprintf("%d", prepared.Profile.Server.Port)), sshConfig)
if err != nil {
return nil, err
}
if prepared.PinHostKeyOnFirstUse {
err = api.pinSSHSessionFileTransferHostKey(store, sessionItem, prepared.Profile, connectedFingerprint, connectedHostKey)
if err != nil {
_ = sshClient.Close()
return nil, err
}
}
return sshClient, nil
}
func (api *API) pinSSHSessionFileTransferHostKey(store *db.Store, sessionItem models.SSHSession, profile models.SSHAccessProfile, fingerprint string, hostKey ssh.PublicKey) error {
var item models.SSHServerHostKey
var reloaded []models.SSHServerHostKey
var err error
var i int
if hostKey == nil {
return errors.New("failed to capture ssh host key")
}
item = models.SSHServerHostKey{
ServerID: sessionItem.ServerID,
Algorithm: hostKey.Type(),
PublicKey: strings.TrimSpace(string(ssh.MarshalAuthorizedKey(hostKey))),
Fingerprint: fingerprint,
}
item, err = store.CreateSSHServerHostKey(item)
if err == nil {
return nil
}
reloaded, err = store.ListSSHServerHostKeys(sessionItem.ServerID)
if err != nil {
return err
}
for i = 0; i < len(reloaded); i++ {
if strings.TrimSpace(reloaded[i].Fingerprint) == fingerprint {
return nil
}
}
return fmt.Errorf("failed to pin host key for %s: %s", profile.Server.Name, err.Error())
}
func closeSSHClientOnContext(ctx context.Context, client *ssh.Client) func() {
var done chan struct{}
var stop func()
done = make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = client.Close()
case <-done:
}
}()
stop = func() {
close(done)
}
return stop
}
func (api *API) sftpUploadFile(ctx context.Context, client *sftp.Client, targetDir string, name string, reader io.Reader, overwrite bool) (int64, error) {
var remotePath string
var flags int
var remoteFile *sftp.File
var copied int64
var statErr error
var err error
if ctx.Err() != nil { return 0, ctx.Err() }
err = client.MkdirAll(targetDir)
if err != nil { return 0, err }
if ctx.Err() != nil { return 0, ctx.Err() }
remotePath = path.Join(targetDir, name)
if !overwrite {
_, statErr = client.Stat(remotePath)
if statErr == nil {
return 0, fmt.Errorf("remote file already exists: %s", remotePath)
}
if !isSFTPNoSuchFile(statErr) {
return 0, statErr
}
}
flags = os.O_WRONLY | os.O_CREATE
if overwrite {
flags = flags | os.O_TRUNC
} else {
flags = flags | os.O_EXCL
}
remoteFile, err = client.OpenFile(remotePath, flags)
if err != nil { return 0, err }
defer remoteFile.Close()
copied, err = copyWithContext(ctx, remoteFile, reader)
return copied, err
}
func isSFTPNoSuchFile(err error) bool {
var statusErr *sftp.StatusError
if err == nil {
return false
}
if errors.Is(err, os.ErrNotExist) {
return true
}
statusErr = nil
if errors.As(err, &statusErr) && statusErr.FxCode() == sftp.ErrSSHFxNoSuchFile {
return true
}
return false
}
func (api *API) sftpDownloadMetadata(client *sftp.Client, remotePath string) (sshSessionFileDownloadItem, error) {
var info os.FileInfo
var item sshSessionFileDownloadItem
var err error
info, err = client.Stat(remotePath)
if err != nil { return item, err }
if info.IsDir() {
return item, errors.New("directory downloads are not supported")
}
item = sshSessionFileDownloadItem{
RemotePath: remotePath,
Name: path.Base(remotePath),
Size: info.Size(),
}
if item.Name == "." || item.Name == "/" || strings.TrimSpace(item.Name) == "" {
item.Name = info.Name()
}
return item, nil
}
func sshSessionFilesArchiveName(store *db.Store, sessionItem models.SSHSession) string {
var profile models.SSHAccessProfile
var server models.SSHServer
var profileName string
var serverName string
var timestamp string
var err error
profileName = strings.TrimSpace(sessionItem.ProfileID)
serverName = strings.TrimSpace(sessionItem.Host)
if sessionItem.Port > 0 && serverName != "" {
serverName = fmt.Sprintf("%s-%d", serverName, sessionItem.Port)
}
profile, err = store.GetSSHAccessProfile(sessionItem.ProfileID)
if err == nil && strings.TrimSpace(profile.Name) != "" {
profileName = strings.TrimSpace(profile.Name)
}
server, err = store.GetSSHServer(sessionItem.ServerID)
if err == nil && strings.TrimSpace(server.Name) != "" {
serverName = strings.TrimSpace(server.Name)
}
if profileName == "" {
profileName = "ssh-profile"
}
if serverName == "" {
serverName = "ssh-server"
}
timestamp = time.Now().Format("20060102-150405")
return fmt.Sprintf("%s-%s-%s-files.zip", sanitizeBundleName(profileName), sanitizeBundleName(serverName), timestamp)
}
func (api *API) sftpDownloadFile(ctx context.Context, client *sftp.Client, remotePath string, writer io.Writer) error {
var remoteFile *sftp.File
var err error
if ctx.Err() != nil { return ctx.Err() }
remoteFile, err = client.Open(remotePath)
if err != nil { return err }
defer remoteFile.Close()
_, err = copyWithContext(ctx, writer, remoteFile)
return err
}
func (api *API) sftpCopyFile(ctx context.Context, sourceClient *sftp.Client, targetClient *sftp.Client, sourcePath string, targetDir string, name string, overwrite bool) (int64, error) {
var sourceFile *sftp.File
var copied int64
var err error
if ctx.Err() != nil { return 0, ctx.Err() }
sourceFile, err = sourceClient.Open(sourcePath)
if err != nil { return 0, err }
defer sourceFile.Close()
copied, err = api.sftpUploadFile(ctx, targetClient, targetDir, name, sourceFile, overwrite)
return copied, err
}
func (api *API) writeSSHSessionFileZip(ctx context.Context, client *sftp.Client, items []sshSessionFileDownloadItem, writer io.Writer) error {
var zipWriter *zip.Writer
var zipFile io.Writer
var closeErr error
var err error
var i int
zipWriter = zip.NewWriter(writer)
for i = 0; i < len(items); i++ {
if ctx.Err() != nil {
err = ctx.Err()
break
}
zipFile, err = zipWriter.Create(sshDownloadEntryName(items[i], i))
if err != nil { break }
err = api.sftpDownloadFile(ctx, client, items[i].RemotePath, zipFile)
if err != nil { break }
}
closeErr = zipWriter.Close()
if err == nil && closeErr != nil {
err = closeErr
}
return err
}
func copyWithContext(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) {
var contextWriter *sshFileTransferContextWriter
var contextReader *sshFileTransferContextReader
var buffer []byte
var writerTo io.WriterTo
var readerFrom io.ReaderFrom
if ctx.Err() != nil {
return 0, ctx.Err()
}
contextWriter = &sshFileTransferContextWriter{Context: ctx, Writer: writer}
contextReader = &sshFileTransferContextReader{Context: ctx, Reader: reader}
writerTo, _ = reader.(io.WriterTo)
if writerTo != nil {
return writerTo.WriteTo(contextWriter)
}
readerFrom, _ = writer.(io.ReaderFrom)
if readerFrom != nil {
return readerFrom.ReadFrom(contextReader)
}
buffer = make([]byte, sshFileTransferCopyBufferSize)
return io.CopyBuffer(contextWriter, contextReader, buffer)
}
func readMultipartFieldValue(part *multipart.Part, maxBytes int64) (string, error) {
var data []byte
var err error
data, err = io.ReadAll(io.LimitReader(part, maxBytes + 1))
if err != nil { return "", err }
if int64(len(data)) > maxBytes {
return "", errors.New("multipart field is too large")
}
return string(data), nil
}
func sshDownloadEntryName(item sshSessionFileDownloadItem, index int) string {
var name string
name = strings.TrimSpace(item.Name)
if name == "" || name == "." || name == "/" {
name = fmt.Sprintf("file-%d", index + 1)
}
return name
}
func parseSSHFileTransferBool(value string) bool {
value = strings.ToLower(strings.TrimSpace(value))
return value == "1" || value == "true" || value == "yes" || value == "on"
}