package handlers import "database/sql" import "fmt" import "io" import "net/http" import "os" import "path/filepath" import "strings" import "sync" import "codit/internal/middleware" import "codit/internal/models" import "codit/internal/util" type sshSessionTranscriptRecorder struct { mu sync.Mutex file *os.File failed bool } func (api *API) sshSessionTranscriptDir() string { return filepath.Join(api.Cfg.DataDir, "ssh-transcripts") } func (api *API) sshSessionTranscriptPath(sessionID string) string { return filepath.Join(api.sshSessionTranscriptDir(), strings.TrimSpace(sessionID)+".log") } func (api *API) sshSessionHasTranscript(sessionID string) bool { var info os.FileInfo var err error info, err = os.Stat(api.sshSessionTranscriptPath(sessionID)) if err != nil { return false } return !info.IsDir() && info.Size() > 0 } func (api *API) setSSHSessionTranscriptAvailability(item *models.SSHSession) { if item == nil { return } item.TranscriptAvailable = api.sshSessionHasTranscript(item.ID) } func (api *API) setSSHSessionsTranscriptAvailability(items []models.SSHSession) { var i int for i = 0; i < len(items); i++ { api.setSSHSessionTranscriptAvailability(&items[i]) } } func (api *API) openSSHSessionTranscriptRecorder(sessionID string) (*sshSessionTranscriptRecorder, error) { var dir string var path string var file *os.File var err error dir = api.sshSessionTranscriptDir() err = os.MkdirAll(dir, 0o755) if err != nil { return nil, err } path = api.sshSessionTranscriptPath(sessionID) file, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) if err != nil { return nil, err } return &sshSessionTranscriptRecorder{file: file}, nil } func (r *sshSessionTranscriptRecorder) Write(data []byte) error { var err error if r == nil || len(data) == 0 { return nil } r.mu.Lock() defer r.mu.Unlock() if r.file == nil || r.failed { return nil } _, err = r.file.Write(data) if err != nil { r.failed = true return err } return nil } func (r *sshSessionTranscriptRecorder) Close() error { var file *os.File if r == nil { return nil } r.mu.Lock() file = r.file r.file = nil r.mu.Unlock() if file == nil { return nil } return file.Close() } func (api *API) writeSSHSessionTranscript(w http.ResponseWriter, sessionID string) error { var path string var file *os.File var info os.FileInfo var err error path = api.sshSessionTranscriptPath(sessionID) file, err = os.Open(path) if err != nil { return err } defer file.Close() info, err = file.Stat() if err != nil { return err } w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Length", fmt.Sprintf("%d", info.Size())) w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", "ssh-session-"+strings.TrimSpace(sessionID)+".txt")) _, err = io.Copy(w, file) return err } func (api *API) GetSSHSessionTranscriptForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) { var user models.User var ok bool var item models.SSHSession var err error user, ok = middleware.UserFromContext(r.Context()) if !ok || user.Disabled { w.WriteHeader(http.StatusUnauthorized) return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } if item.UserID != user.ID { w.WriteHeader(http.StatusForbidden) return } err = api.writeSSHSessionTranscript(w, item.ID) if err != nil { if os.IsNotExist(err) { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session transcript not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } } func (api *API) GetSSHSessionTranscriptAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) { var item models.SSHSession var err error if !api.requireAdmin(w, r) { return } item, err = api.store(r).GetSSHSession(params["id"]) if err != nil { if err == sql.ErrNoRows { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } err = api.writeSSHSessionTranscript(w, item.ID) if err != nil { if os.IsNotExist(err) { WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session transcript not found"}) return } WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } } func (api *API) writeSSHSessionTranscriptChunk(sessionID string, data []byte, recorder *sshSessionTranscriptRecorder) { var err error if recorder == nil { return } err = recorder.Write(data) if err != nil { api.Logger.Write(logIDSSHBroker, util.LOG_WARN, "transcript write failed session=%s err=%v", sessionID, err) } }