Files
codit/backend/internal/handlers/ssh_broker_transcript.go

210 lines
5.0 KiB
Go

package handlers
import "database/sql"
import "fmt"
import "io"
import "net/http"
import "os"
import "path/filepath"
import "strings"
import "sync"
import "codit/internal/middleware"
import "codit/internal/models"
import "codit/internal/util"
type sshSessionTranscriptRecorder struct {
mu sync.Mutex
file *os.File
failed bool
}
func (api *API) sshSessionTranscriptDir() string {
return filepath.Join(api.Cfg.DataDir, "ssh-transcripts")
}
func (api *API) sshSessionTranscriptPath(sessionID string) string {
return filepath.Join(api.sshSessionTranscriptDir(), strings.TrimSpace(sessionID)+".log")
}
func (api *API) sshSessionHasTranscript(sessionID string) bool {
var info os.FileInfo
var err error
info, err = os.Stat(api.sshSessionTranscriptPath(sessionID))
if err != nil {
return false
}
return !info.IsDir() && info.Size() > 0
}
func (api *API) setSSHSessionTranscriptAvailability(item *models.SSHSession) {
if item == nil {
return
}
item.TranscriptAvailable = api.sshSessionHasTranscript(item.ID)
}
func (api *API) setSSHSessionsTranscriptAvailability(items []models.SSHSession) {
var i int
for i = 0; i < len(items); i++ {
api.setSSHSessionTranscriptAvailability(&items[i])
}
}
func (api *API) openSSHSessionTranscriptRecorder(sessionID string) (*sshSessionTranscriptRecorder, error) {
var dir string
var path string
var file *os.File
var err error
dir = api.sshSessionTranscriptDir()
err = os.MkdirAll(dir, 0o755)
if err != nil {
return nil, err
}
path = api.sshSessionTranscriptPath(sessionID)
file, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
if err != nil {
return nil, err
}
return &sshSessionTranscriptRecorder{file: file}, nil
}
func (r *sshSessionTranscriptRecorder) Write(data []byte) error {
var err error
if r == nil || len(data) == 0 {
return nil
}
r.mu.Lock()
defer r.mu.Unlock()
if r.file == nil || r.failed {
return nil
}
_, err = r.file.Write(data)
if err != nil {
r.failed = true
return err
}
return nil
}
func (r *sshSessionTranscriptRecorder) Close() error {
var file *os.File
if r == nil {
return nil
}
r.mu.Lock()
file = r.file
r.file = nil
r.mu.Unlock()
if file == nil {
return nil
}
return file.Close()
}
func (api *API) writeSSHSessionTranscript(w http.ResponseWriter, sessionID string) error {
var path string
var file *os.File
var info os.FileInfo
var err error
path = api.sshSessionTranscriptPath(sessionID)
file, err = os.Open(path)
if err != nil {
return err
}
defer file.Close()
info, err = file.Stat()
if err != nil {
return err
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Length", fmt.Sprintf("%d", info.Size()))
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", "ssh-session-"+strings.TrimSpace(sessionID)+".txt"))
_, err = io.Copy(w, file)
return err
}
func (api *API) GetSSHSessionTranscriptForSelf(w http.ResponseWriter, r *http.Request, params map[string]string) {
var user models.User
var ok bool
var item models.SSHSession
var err error
user, ok = middleware.UserFromContext(r.Context())
if !ok || user.Disabled {
w.WriteHeader(http.StatusUnauthorized)
return
}
item, err = api.store(r).GetSSHSession(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if item.UserID != user.ID {
w.WriteHeader(http.StatusForbidden)
return
}
err = api.writeSSHSessionTranscript(w, item.ID)
if err != nil {
if os.IsNotExist(err) {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session transcript not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
}
func (api *API) GetSSHSessionTranscriptAdmin(w http.ResponseWriter, r *http.Request, params map[string]string) {
var item models.SSHSession
var err error
if !api.requireAdmin(w, r) {
return
}
item, err = api.store(r).GetSSHSession(params["id"])
if err != nil {
if err == sql.ErrNoRows {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
err = api.writeSSHSessionTranscript(w, item.ID)
if err != nil {
if os.IsNotExist(err) {
WriteJSON(w, http.StatusNotFound, map[string]string{"error": "ssh session transcript not found"})
return
}
WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
}
func (api *API) writeSSHSessionTranscriptChunk(sessionID string, data []byte, recorder *sshSessionTranscriptRecorder) {
var err error
if recorder == nil {
return
}
err = recorder.Write(data)
if err != nil {
api.Logger.Write(logIDSSHBroker, util.LOG_WARN,
"transcript write failed session=%s err=%v",
sessionID,
err)
}
}