460 lines
10 KiB
Go
460 lines
10 KiB
Go
package handlers
|
|
|
|
import "errors"
|
|
import "io"
|
|
import "strings"
|
|
import "sync"
|
|
import "time"
|
|
|
|
import "golang.org/x/crypto/ssh"
|
|
import "golang.org/x/net/websocket"
|
|
|
|
// TODO: do these need to tbe configurable?
|
|
const sshSessionPreparationTTL time.Duration = 60 * time.Second
|
|
const sshSessionPreparationMaxEntries int = 16384
|
|
const sshSessionPendingPerUserLimit int = 32
|
|
|
|
type sshActiveSession struct {
|
|
mu sync.Mutex
|
|
ws *websocket.Conn
|
|
client *ssh.Client
|
|
session *ssh.Session
|
|
stdin io.WriteCloser
|
|
closeCode SSHSessionCloseCode
|
|
closeReason string
|
|
closed bool
|
|
}
|
|
|
|
type SSHSessionCloseCode string
|
|
|
|
const SSHSessionCloseNone SSHSessionCloseCode = ""
|
|
const SSHSessionCloseWebsocketClosed SSHSessionCloseCode = "websocket_closed"
|
|
const SSHSessionCloseWebsocketSendFailed SSHSessionCloseCode = "websocket_send_failed"
|
|
const SSHSessionCloseWebsocketInputQueueFull SSHSessionCloseCode = "websocket_input_queue_full"
|
|
const SSHSessionCloseDetached SSHSessionCloseCode = "detached"
|
|
const SSHSessionCloseUserDisconnect SSHSessionCloseCode = "user_disconnect"
|
|
const SSHSessionCloseUnknown SSHSessionCloseCode = "unknown"
|
|
|
|
type SSHSessionRegistry struct {
|
|
mu sync.Mutex
|
|
items map[string]*sshActiveSession
|
|
}
|
|
|
|
type SSHPromptedAuthInput struct {
|
|
Password string
|
|
OTPCode string
|
|
}
|
|
|
|
type sshPromptedAuthStoreItem struct {
|
|
Input SSHPromptedAuthInput
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
type sshPreparedSessionStoreItem struct {
|
|
Session sshSessionPrepared
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
type SSHPromptedAuthStore struct {
|
|
mu sync.Mutex
|
|
items map[string]sshPromptedAuthStoreItem
|
|
ttl time.Duration
|
|
maxEntries int
|
|
}
|
|
|
|
type SSHPreparedSessionStore struct {
|
|
mu sync.Mutex
|
|
items map[string]sshPreparedSessionStoreItem
|
|
ttl time.Duration
|
|
maxEntries int
|
|
}
|
|
|
|
func NewSSHSessionRegistry() *SSHSessionRegistry {
|
|
var registry *SSHSessionRegistry
|
|
|
|
registry = &SSHSessionRegistry{
|
|
items: map[string]*sshActiveSession{},
|
|
}
|
|
return registry
|
|
}
|
|
|
|
func NewSSHPromptedAuthStore() *SSHPromptedAuthStore {
|
|
var store *SSHPromptedAuthStore
|
|
|
|
store = &SSHPromptedAuthStore{
|
|
items: map[string]sshPromptedAuthStoreItem{},
|
|
ttl: sshSessionPreparationTTL,
|
|
maxEntries: sshSessionPreparationMaxEntries,
|
|
}
|
|
return store
|
|
}
|
|
|
|
func NewSSHPreparedSessionStore() *SSHPreparedSessionStore {
|
|
var store *SSHPreparedSessionStore
|
|
|
|
store = &SSHPreparedSessionStore{
|
|
items: map[string]sshPreparedSessionStoreItem{},
|
|
ttl: sshSessionPreparationTTL,
|
|
maxEntries: sshSessionPreparationMaxEntries,
|
|
}
|
|
return store
|
|
}
|
|
|
|
func sshPreparationEntryExpired(now time.Time, expiresAt time.Time) bool {
|
|
return !expiresAt.IsZero() && !now.Before(expiresAt)
|
|
}
|
|
|
|
func (s *SSHPromptedAuthStore) purgeExpiredLocked(now time.Time) {
|
|
var id string
|
|
var item sshPromptedAuthStoreItem
|
|
|
|
for id, item = range s.items {
|
|
if sshPreparationEntryExpired(now, item.ExpiresAt) {
|
|
delete(s.items, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SSHPromptedAuthStore) enforceMaxEntriesLocked() {
|
|
var id string
|
|
var oldestID string
|
|
var item sshPromptedAuthStoreItem
|
|
var oldestAt time.Time
|
|
|
|
if s.maxEntries <= 0 {
|
|
return
|
|
}
|
|
for len(s.items) > s.maxEntries {
|
|
oldestID = ""
|
|
oldestAt = time.Time{}
|
|
for id, item = range s.items {
|
|
if oldestID == "" || item.CreatedAt.Before(oldestAt) {
|
|
oldestID = id
|
|
oldestAt = item.CreatedAt
|
|
}
|
|
}
|
|
if oldestID == "" {
|
|
return
|
|
}
|
|
delete(s.items, oldestID)
|
|
}
|
|
}
|
|
|
|
func (s *SSHPromptedAuthStore) Put(id string, input SSHPromptedAuthInput) {
|
|
var trimmedID string
|
|
var now time.Time
|
|
|
|
if s == nil { return }
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" { return }
|
|
if input.Password == "" && input.OTPCode == "" { return }
|
|
now = time.Now().UTC()
|
|
s.mu.Lock()
|
|
s.purgeExpiredLocked(now)
|
|
s.items[trimmedID] = sshPromptedAuthStoreItem{
|
|
Input: input,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(s.ttl),
|
|
}
|
|
s.enforceMaxEntriesLocked()
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *SSHPromptedAuthStore) Take(id string) SSHPromptedAuthInput {
|
|
var trimmedID string
|
|
var input SSHPromptedAuthInput
|
|
var item sshPromptedAuthStoreItem
|
|
var ok bool
|
|
var now time.Time
|
|
|
|
if s == nil { return input }
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" { return input }
|
|
now = time.Now().UTC()
|
|
s.mu.Lock()
|
|
s.purgeExpiredLocked(now)
|
|
item, ok = s.items[trimmedID]
|
|
if ok {
|
|
input = item.Input
|
|
delete(s.items, trimmedID)
|
|
}
|
|
s.mu.Unlock()
|
|
return input
|
|
}
|
|
|
|
func (s *SSHPromptedAuthStore) Delete(id string) {
|
|
var trimmedID string
|
|
|
|
if s == nil { return }
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" { return }
|
|
s.mu.Lock()
|
|
delete(s.items, trimmedID)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *SSHPreparedSessionStore) purgeExpiredLocked(now time.Time) {
|
|
var id string
|
|
var item sshPreparedSessionStoreItem
|
|
|
|
for id, item = range s.items {
|
|
if sshPreparationEntryExpired(now, item.ExpiresAt) {
|
|
delete(s.items, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SSHPreparedSessionStore) enforceMaxEntriesLocked() {
|
|
var id string
|
|
var oldestID string
|
|
var item sshPreparedSessionStoreItem
|
|
var oldestAt time.Time
|
|
|
|
if s.maxEntries <= 0 {
|
|
return
|
|
}
|
|
for len(s.items) > s.maxEntries {
|
|
oldestID = ""
|
|
oldestAt = time.Time{}
|
|
for id, item = range s.items {
|
|
if oldestID == "" || item.CreatedAt.Before(oldestAt) {
|
|
oldestID = id
|
|
oldestAt = item.CreatedAt
|
|
}
|
|
}
|
|
if oldestID == "" {
|
|
return
|
|
}
|
|
delete(s.items, oldestID)
|
|
}
|
|
}
|
|
|
|
func (s *SSHPreparedSessionStore) Put(id string, item sshSessionPrepared) {
|
|
var trimmedID string
|
|
var now time.Time
|
|
|
|
if s == nil {
|
|
return
|
|
}
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" {
|
|
return
|
|
}
|
|
now = time.Now().UTC()
|
|
s.mu.Lock()
|
|
s.purgeExpiredLocked(now)
|
|
s.items[trimmedID] = sshPreparedSessionStoreItem{
|
|
Session: item,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(s.ttl),
|
|
}
|
|
s.enforceMaxEntriesLocked()
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *SSHPreparedSessionStore) Get(id string) (sshSessionPrepared, bool) {
|
|
var trimmedID string
|
|
var item sshSessionPrepared
|
|
var storeItem sshPreparedSessionStoreItem
|
|
var ok bool
|
|
var now time.Time
|
|
|
|
if s == nil {
|
|
return item, false
|
|
}
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" {
|
|
return item, false
|
|
}
|
|
now = time.Now().UTC()
|
|
s.mu.Lock()
|
|
s.purgeExpiredLocked(now)
|
|
storeItem, ok = s.items[trimmedID]
|
|
if ok {
|
|
item = storeItem.Session
|
|
}
|
|
s.mu.Unlock()
|
|
return item, ok
|
|
}
|
|
|
|
func (s *SSHPreparedSessionStore) Take(id string) (sshSessionPrepared, bool) {
|
|
var trimmedID string
|
|
var item sshSessionPrepared
|
|
var storeItem sshPreparedSessionStoreItem
|
|
var ok bool
|
|
var now time.Time
|
|
|
|
if s == nil {
|
|
return item, false
|
|
}
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" {
|
|
return item, false
|
|
}
|
|
now = time.Now().UTC()
|
|
s.mu.Lock()
|
|
s.purgeExpiredLocked(now)
|
|
storeItem, ok = s.items[trimmedID]
|
|
if ok {
|
|
item = storeItem.Session
|
|
delete(s.items, trimmedID)
|
|
}
|
|
s.mu.Unlock()
|
|
return item, ok
|
|
}
|
|
|
|
func (s *SSHPreparedSessionStore) Delete(id string) {
|
|
var trimmedID string
|
|
|
|
if s == nil {
|
|
return
|
|
}
|
|
trimmedID = strings.TrimSpace(id)
|
|
if trimmedID == "" {
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
delete(s.items, trimmedID)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (r *SSHSessionRegistry) Register(id string, item *sshActiveSession) error {
|
|
if r == nil || item == nil {
|
|
return nil
|
|
}
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.items[id] != nil {
|
|
return errors.New("ssh session already active")
|
|
}
|
|
r.items[id] = item
|
|
return nil
|
|
}
|
|
|
|
func (r *SSHSessionRegistry) Unregister(id string, item *sshActiveSession) {
|
|
if r == nil || item == nil {
|
|
return
|
|
}
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.items[id] == item {
|
|
delete(r.items, id)
|
|
}
|
|
}
|
|
|
|
func sshSessionCloseMessage(code SSHSessionCloseCode) string {
|
|
if code == SSHSessionCloseWebsocketClosed { return "websocket closed" }
|
|
if code == SSHSessionCloseWebsocketSendFailed { return "websocket send failed" }
|
|
if code == SSHSessionCloseWebsocketInputQueueFull { return "websocket input queue full" }
|
|
if code == SSHSessionCloseDetached { return "detached" }
|
|
if code == SSHSessionCloseUserDisconnect { return "disconnected by user" }
|
|
return strings.TrimSpace(string(code))
|
|
}
|
|
|
|
func (r *SSHSessionRegistry) RequestClose(id string, code SSHSessionCloseCode) bool {
|
|
return r.RequestCloseWithReason(id, code, sshSessionCloseMessage(code))
|
|
}
|
|
|
|
func (r *SSHSessionRegistry) RequestCloseWithReason(id string, code SSHSessionCloseCode, reason string) bool {
|
|
var item *sshActiveSession
|
|
|
|
if r == nil { return false }
|
|
r.mu.Lock()
|
|
item = r.items[id]
|
|
r.mu.Unlock()
|
|
if item == nil { return false }
|
|
item.RequestCloseWithReason(code, reason)
|
|
return true
|
|
}
|
|
|
|
func (s *sshActiveSession) SetResources(ws *websocket.Conn, client *ssh.Client, session *ssh.Session, stdin io.WriteCloser) {
|
|
var closeWS *websocket.Conn
|
|
var closeClient *ssh.Client
|
|
var closeSession *ssh.Session
|
|
var closeStdin io.WriteCloser
|
|
|
|
if s == nil { return }
|
|
|
|
s.mu.Lock()
|
|
|
|
if s.closed {
|
|
closeWS = ws
|
|
closeClient = client
|
|
closeSession = session
|
|
closeStdin = stdin
|
|
s.mu.Unlock()
|
|
if closeStdin != nil { _ = closeStdin.Close() }
|
|
if closeSession != nil { _ = closeSession.Close() }
|
|
if closeClient != nil { _ = closeClient.Close() }
|
|
if closeWS != nil { _ = closeWS.Close() }
|
|
return
|
|
}
|
|
|
|
if ws != nil { s.ws = ws }
|
|
if client != nil { s.client = client }
|
|
if session != nil { s.session = session }
|
|
if stdin != nil { s.stdin = stdin }
|
|
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *sshActiveSession) CloseReason() string {
|
|
var reason string
|
|
|
|
if s == nil { return "" }
|
|
|
|
s.mu.Lock()
|
|
reason = s.closeReason
|
|
s.mu.Unlock()
|
|
return reason
|
|
}
|
|
|
|
func (s *sshActiveSession) CloseCode() SSHSessionCloseCode {
|
|
var code SSHSessionCloseCode
|
|
|
|
if s == nil { return SSHSessionCloseNone }
|
|
|
|
s.mu.Lock()
|
|
code = s.closeCode
|
|
s.mu.Unlock()
|
|
return code
|
|
}
|
|
|
|
func (s *sshActiveSession) RequestCloseWithReason(code SSHSessionCloseCode, reason string) {
|
|
var ws *websocket.Conn
|
|
var client *ssh.Client
|
|
var session *ssh.Session
|
|
var stdin io.WriteCloser
|
|
|
|
if s == nil { return }
|
|
|
|
s.mu.Lock()
|
|
if code == SSHSessionCloseNone {
|
|
code = SSHSessionCloseUnknown
|
|
}
|
|
if strings.TrimSpace(reason) == "" {
|
|
reason = sshSessionCloseMessage(code)
|
|
}
|
|
if s.closeReason == "" {
|
|
s.closeCode = code
|
|
s.closeReason = reason
|
|
}
|
|
if s.closed {
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
s.closed = true
|
|
ws = s.ws
|
|
client = s.client
|
|
session = s.session
|
|
stdin = s.stdin
|
|
s.mu.Unlock()
|
|
|
|
if stdin != nil { _ = stdin.Close() }
|
|
if session != nil { _ = session.Close() }
|
|
if client != nil { _ = client.Close() }
|
|
if ws != nil { _ = ws.Close() }
|
|
}
|