Files

2382 lines
78 KiB
Go

package db
import "context"
import "database/sql"
import "encoding/json"
import "errors"
import "fmt"
import "sort"
import "strings"
import "time"
import "codit/internal/models"
import "codit/internal/util"
var ErrSSHServerHasSessionHistory error = errors.New("ssh server has session history")
const SSHOwnerScopeUser = "user"
const SSHOwnerScopeAdmin = "admin"
const SSHOwnerScopeAdminShared = "admin_shared"
const SSHAccessProfileServerTargetTypeServer = "server"
const SSHAccessProfileServerTargetTypeGroup = "group"
const SSHAccessProfileTargetTypeUser = "user"
const SSHAccessProfileTargetTypeGroup = "group"
type SSHServerDeleteReferenceError struct {
Kind string
Count int64
}
func (e *SSHServerDeleteReferenceError) Error() string {
var suffix string
suffix = ""
if e.Count != 1 {
suffix = "s"
}
return fmt.Sprintf("cannot delete ssh server because it is used by %d %s%s", e.Count, e.Kind, suffix)
}
func (s *Store) ListSSHServers() ([]models.SSHServer, error) {
var rows *sql.Rows
var err error
var items []models.SSHServer
var item models.SSHServer
var tagsJSON string
rows, err = s.Query(`SELECT srv.public_id, srv.name, srv.host, srv.port, srv.description, srv.tags_json, srv.enabled, srv.host_key_policy, srv.owner_scope, COALESCE(ou.public_id, ''), srv.created_by_kind, COALESCE(srv_cbu.public_id, srv_cbp.public_id, ''), srv.created_by_subject_name, srv.created_at, srv.updated_at
FROM ssh_servers srv
LEFT JOIN users ou ON ou.id = srv.owner_user_id
LEFT JOIN users srv_cbu ON srv_cbu.id = srv.created_by_user_id
LEFT JOIN service_principals srv_cbp ON srv_cbp.id = srv.created_by_principal_id
ORDER BY srv.name, srv.host, srv.port`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
item.Tags, err = decodeStringList(tagsJSON)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetSSHServer(id string) (models.SSHServer, error) {
var row *sql.Row
var item models.SSHServer
var tagsJSON string
var err error
row = s.QueryRow(`SELECT srv.public_id, srv.name, srv.host, srv.port, srv.description, srv.tags_json, srv.enabled, srv.host_key_policy, srv.owner_scope, COALESCE(ou.public_id, ''), srv.created_by_kind, COALESCE(srv_cbu.public_id, srv_cbp.public_id, ''), srv.created_by_subject_name, srv.created_at, srv.updated_at
FROM ssh_servers srv
LEFT JOIN users ou ON ou.id = srv.owner_user_id
LEFT JOIN users srv_cbu ON srv_cbu.id = srv.created_by_user_id
LEFT JOIN service_principals srv_cbp ON srv_cbp.id = srv.created_by_principal_id
WHERE srv.public_id = ?`, strings.TrimSpace(id))
err = row.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil { return item, err }
item.Tags, err = decodeStringList(tagsJSON)
if err != nil { return item, err }
return item, nil
}
func (s *Store) ListSSHServersForUser(userID string) ([]models.SSHServer, error) {
var rows *sql.Rows
var err error
var items []models.SSHServer
var item models.SSHServer
var tagsJSON string
var trimmedUserID string
trimmedUserID = strings.TrimSpace(userID)
rows, err = s.Query(`SELECT srv.public_id, srv.name, srv.host, srv.port, srv.description, srv.tags_json, srv.enabled, srv.host_key_policy, srv.owner_scope, COALESCE(ou.public_id, ''), srv.created_by_kind, COALESCE(srv_cbu.public_id, srv_cbp.public_id, ''), srv.created_by_subject_name, srv.created_at, srv.updated_at
FROM ssh_servers srv
LEFT JOIN users ou ON ou.id = srv.owner_user_id
LEFT JOIN users srv_cbu ON srv_cbu.id = srv.created_by_user_id
LEFT JOIN service_principals srv_cbp ON srv_cbp.id = srv.created_by_principal_id
WHERE srv.owner_scope = ? AND ou.public_id = ?
ORDER BY srv.name, srv.host, srv.port`, SSHOwnerScopeUser, trimmedUserID)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return nil, err
}
item.Tags, err = decodeStringList(tagsJSON)
if err != nil {
return nil, err
}
item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetSSHServerForUser(userID string, id string) (models.SSHServer, error) {
var item models.SSHServer
var items []models.SSHServer
var trimmedID string
var i int
var err error
trimmedID = strings.TrimSpace(id)
items, err = s.ListSSHServersForUser(userID)
if err != nil {
return item, err
}
for i = 0; i < len(items); i++ {
if items[i].ID == trimmedID {
return items[i], nil
}
}
return item, sql.ErrNoRows
}
func (s *Store) GetOwnedSSHServerForUser(userID string, id string) (models.SSHServer, error) {
var item models.SSHServer
var trimmedUserID string
var err error
trimmedUserID = strings.TrimSpace(userID)
item, err = s.GetSSHServer(strings.TrimSpace(id))
if err != nil {
return item, err
}
item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID
if !item.Editable {
return item, sql.ErrNoRows
}
return item, nil
}
func (s *Store) CreateSSHServer(item models.SSHServer) (models.SSHServer, error) {
var err error
var tagsJSON string
var now int64
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
if item.Port <= 0 { item.Port = 22 }
item.HostKeyPolicy = normalizeSSHServerHostKeyPolicy(item.HostKeyPolicy)
item.Tags = normalizeStringList(item.Tags)
item.OwnerScope = strings.TrimSpace(item.OwnerScope)
if item.OwnerScope != SSHOwnerScopeUser {
item.OwnerScope = SSHOwnerScopeAdmin
item.OwnerUserID = ""
}
if item.OwnerScope == SSHOwnerScopeUser && strings.TrimSpace(item.OwnerUserID) == "" {
return item, errors.New("owner_user_id is required")
}
tagsJSON, err = encodeStringList(item.Tags)
if err != nil { return item, err }
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
_, err = s.Exec(`INSERT INTO ssh_servers (
public_id,
name,
host,
port,
description,
tags_json,
enabled,
host_key_policy,
owner_scope,
owner_user_id,
created_by_kind,
created_by_user_id,
created_by_principal_id,
created_by_subject_name,
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM users WHERE public_id = ?) END, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Host),
item.Port,
strings.TrimSpace(item.Description),
tagsJSON,
item.Enabled,
item.HostKeyPolicy,
strings.TrimSpace(item.OwnerScope),
strings.TrimSpace(item.OwnerUserID),
strings.TrimSpace(item.OwnerUserID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectName),
item.CreatedAt,
item.UpdatedAt,
)
return item, err
}
func (s *Store) UpdateSSHServer(item models.SSHServer) (models.SSHServer, error) {
var err error
var tagsJSON string
if item.Port <= 0 { item.Port = 22 }
item.HostKeyPolicy = normalizeSSHServerHostKeyPolicy(item.HostKeyPolicy)
item.Tags = normalizeStringList(item.Tags)
tagsJSON, err = encodeStringList(item.Tags)
if err != nil { return item, err }
item.UpdatedAt = time.Now().UTC().Unix()
// TODO: need to protect it with tx.
_, err = s.Exec(`UPDATE ssh_servers
SET name = ?, host = ?, port = ?, description = ?, tags_json = ?, enabled = ?, host_key_policy = ?, updated_at = ?
WHERE public_id = ?`,
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Host),
item.Port,
strings.TrimSpace(item.Description),
tagsJSON,
item.Enabled,
item.HostKeyPolicy,
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
if err != nil { return item, err }
return s.GetSSHServer(item.ID)
}
func (s *Store) DeleteSSHServer(ctx context.Context, id string) error {
var err error
var tx txExecutor
var owned bool
var count int64
var trimmedID string
trimmedID = strings.TrimSpace(id)
tx, owned, err = s.beginImmediateContext(ctx)
if err != nil { return err }
err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, trimmedID).Scan(&count)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
if count > 0 {
rollbackIfOwned(tx, owned)
return &SSHServerDeleteReferenceError{Kind: "SSH access profile", Count: count}
}
err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, trimmedID).Scan(&count)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
if count > 0 {
rollbackIfOwned(tx, owned)
return &SSHServerDeleteReferenceError{Kind: "SSH server group membership", Count: count}
}
err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_sessions WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, trimmedID).Scan(&count)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
if count > 0 {
rollbackIfOwned(tx, owned)
return &SSHServerDeleteReferenceError{Kind: "SSH session history record", Count: count}
}
_, err = tx.Exec(`DELETE FROM ssh_servers WHERE public_id = ?`, trimmedID)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
return commitIfOwned(tx, owned)
}
func (s *Store) ListSSHAccessProfiles() ([]models.SSHAccessProfile, error) {
var rows *sql.Rows
var err error
var items []models.SSHAccessProfile
var item models.SSHAccessProfile
var grantIDsJSON string
var serverTagsJSON string
var groupTagsJSON string
rows, err = s.Query(`SELECT
p.public_id,
COALESCE(s.public_id, ''),
COALESCE(p.server_target_type, 'server'),
COALESCE(g.public_id, ''),
COALESCE(g.name, ''),
COALESCE(g.description, ''),
COALESCE(g.enabled, 0),
COALESCE(g.created_by_kind, ''),
COALESCE(g_cbu.public_id, g_cbp.public_id, ''),
COALESCE(g.created_by_subject_name, ''),
COALESCE(g.created_at, 0),
COALESCE(g.updated_at, 0),
COALESCE(g.tags_json, '[]'),
p.name,
p.description,
p.remote_username,
p.auth_method,
p.second_factor_mode,
p.owner_scope,
COALESCE(pou.public_id, ''),
p.allow_user_edit,
p.enabled,
COALESCE(sec.public_id, ''),
COALESCE(c.public_id, ''),
COALESCE(c.name, ''),
p.auth_public_key,
p.auth_public_key_fingerprint,
COALESCE(ca.public_id, ''),
p.ssh_principal_grant_ids_json,
p.default_cert_valid_seconds,
p.max_cert_valid_seconds,
p.valid_after,
p.valid_before,
p.created_by_kind,
COALESCE(p_cbu.public_id, p_cbp.public_id, ''),
p.created_by_subject_name,
p.created_at,
p.updated_at,
s.public_id,
s.name,
s.host,
s.port,
s.description,
s.tags_json,
s.enabled,
s.host_key_policy,
s.owner_scope,
COALESCE(sou.public_id, ''),
s.created_by_kind,
COALESCE(s_cbu.public_id, s_cbp.public_id, ''),
s.created_by_subject_name,
s.created_at,
s.updated_at
FROM ssh_access_profiles p
JOIN ssh_servers s ON s.id = p.server_id
LEFT JOIN users sou ON sou.id = s.owner_user_id
LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id
LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id
LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id
LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id
LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id
LEFT JOIN users pou ON pou.id = p.owner_user_id
LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id
LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id
LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id
LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id
LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id
ORDER BY s.name, p.name`)
if err != nil { return nil, err }
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.ServerID,
&item.ServerTargetType,
&item.ServerGroupID,
&item.ServerGroup.Name,
&item.ServerGroup.Description,
&item.ServerGroup.Enabled,
&item.ServerGroup.CreatedByKind,
&item.ServerGroup.CreatedBySubjectID,
&item.ServerGroup.CreatedBySubjectName,
&item.ServerGroup.CreatedAt,
&item.ServerGroup.UpdatedAt,
&groupTagsJSON,
&item.Name,
&item.Description,
&item.RemoteUsername,
&item.AuthMethod,
&item.SecondFactorMode,
&item.OwnerScope,
&item.OwnerUserID,
&item.AllowUserEdit,
&item.Enabled,
&item.SecretID,
&item.SSHCredentialID,
&item.SSHCredentialName,
&item.AuthPublicKey,
&item.AuthPublicKeyFingerprint,
&item.SSHUserCAID,
&grantIDsJSON,
&item.DefaultCertValidSeconds,
&item.MaxCertValidSeconds,
&item.ValidAfter,
&item.ValidBefore,
&item.CreatedByKind,
&item.CreatedBySubjectID,
&item.CreatedBySubjectName,
&item.CreatedAt,
&item.UpdatedAt,
&item.Server.ID,
&item.Server.Name,
&item.Server.Host,
&item.Server.Port,
&item.Server.Description,
&serverTagsJSON,
&item.Server.Enabled,
&item.Server.HostKeyPolicy,
&item.Server.OwnerScope,
&item.Server.OwnerUserID,
&item.Server.CreatedByKind,
&item.Server.CreatedBySubjectID,
&item.Server.CreatedBySubjectName,
&item.Server.CreatedAt,
&item.Server.UpdatedAt,
)
if err != nil { return nil, err }
if item.ServerGroupID != "" {
item.ServerGroup.ID = item.ServerGroupID
}
item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON)
if err != nil { return nil, err }
item.Server.Tags, err = decodeStringList(serverTagsJSON)
if err != nil { return nil, err }
item.ServerGroup.Tags, err = decodeStringList(groupTagsJSON)
if err != nil { return nil, err }
item.Targets, err = s.listSSHAccessProfileTargets(item.ID)
if err != nil { return nil, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) GetSSHAccessProfile(id string) (models.SSHAccessProfile, error) {
var row *sql.Row
var item models.SSHAccessProfile
var err error
var grantIDsJSON string
var serverTagsJSON string
var groupTagsJSON string
row = s.QueryRow(`SELECT
p.public_id,
COALESCE(s.public_id, ''),
COALESCE(p.server_target_type, 'server'),
COALESCE(g.public_id, ''),
COALESCE(g.name, ''),
COALESCE(g.description, ''),
COALESCE(g.enabled, 0),
COALESCE(g.created_by_kind, ''),
COALESCE(g_cbu.public_id, g_cbp.public_id, ''),
COALESCE(g.created_by_subject_name, ''),
COALESCE(g.created_at, 0),
COALESCE(g.updated_at, 0),
COALESCE(g.tags_json, '[]'),
p.name,
p.description,
p.remote_username,
p.auth_method,
p.second_factor_mode,
p.owner_scope,
COALESCE(pou.public_id, ''),
p.allow_user_edit,
p.enabled,
COALESCE(sec.public_id, ''),
COALESCE(c.public_id, ''),
COALESCE(c.name, ''),
p.auth_public_key,
p.auth_public_key_fingerprint,
COALESCE(ca.public_id, ''),
p.ssh_principal_grant_ids_json,
p.default_cert_valid_seconds,
p.max_cert_valid_seconds,
p.valid_after,
p.valid_before,
p.created_by_kind,
COALESCE(p_cbu.public_id, p_cbp.public_id, ''),
p.created_by_subject_name,
p.created_at,
p.updated_at,
s.public_id,
s.name,
s.host,
s.port,
s.description,
s.tags_json,
s.enabled,
s.host_key_policy,
s.owner_scope,
COALESCE(sou.public_id, ''),
s.created_by_kind,
COALESCE(s_cbu.public_id, s_cbp.public_id, ''),
s.created_by_subject_name,
s.created_at,
s.updated_at
FROM ssh_access_profiles p
JOIN ssh_servers s ON s.id = p.server_id
LEFT JOIN users sou ON sou.id = s.owner_user_id
LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id
LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id
LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id
LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id
LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id
LEFT JOIN users pou ON pou.id = p.owner_user_id
LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id
LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id
LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id
LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id
LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id
WHERE p.public_id = ?`, strings.TrimSpace(id))
err = row.Scan(
&item.ID,
&item.ServerID,
&item.ServerTargetType,
&item.ServerGroupID,
&item.ServerGroup.Name,
&item.ServerGroup.Description,
&item.ServerGroup.Enabled,
&item.ServerGroup.CreatedByKind,
&item.ServerGroup.CreatedBySubjectID,
&item.ServerGroup.CreatedBySubjectName,
&item.ServerGroup.CreatedAt,
&item.ServerGroup.UpdatedAt,
&groupTagsJSON,
&item.Name,
&item.Description,
&item.RemoteUsername,
&item.AuthMethod,
&item.SecondFactorMode,
&item.OwnerScope,
&item.OwnerUserID,
&item.AllowUserEdit,
&item.Enabled,
&item.SecretID,
&item.SSHCredentialID,
&item.SSHCredentialName,
&item.AuthPublicKey,
&item.AuthPublicKeyFingerprint,
&item.SSHUserCAID,
&grantIDsJSON,
&item.DefaultCertValidSeconds,
&item.MaxCertValidSeconds,
&item.ValidAfter,
&item.ValidBefore,
&item.CreatedByKind,
&item.CreatedBySubjectID,
&item.CreatedBySubjectName,
&item.CreatedAt,
&item.UpdatedAt,
&item.Server.ID,
&item.Server.Name,
&item.Server.Host,
&item.Server.Port,
&item.Server.Description,
&serverTagsJSON,
&item.Server.Enabled,
&item.Server.HostKeyPolicy,
&item.Server.OwnerScope,
&item.Server.OwnerUserID,
&item.Server.CreatedByKind,
&item.Server.CreatedBySubjectID,
&item.Server.CreatedBySubjectName,
&item.Server.CreatedAt,
&item.Server.UpdatedAt,
)
if err != nil {
return item, err
}
if item.ServerGroupID != "" {
item.ServerGroup.ID = item.ServerGroupID
}
item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON)
if err != nil {
return item, err
}
item.Server.Tags, err = decodeStringList(serverTagsJSON)
if err != nil {
return item, err
}
item.ServerGroup.Tags, err = decodeStringList(groupTagsJSON)
if err != nil {
return item, err
}
item.Targets, err = s.listSSHAccessProfileTargets(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) CreateSSHAccessProfile(ctx context.Context, item models.SSHAccessProfile) (models.SSHAccessProfile, error) {
var tx txExecutor
var owned bool
var err error
var now int64
var grantIDsJSON string
var i int
var target models.SSHAccessProfileTarget
var secret models.SSHSecret
var credential models.SSHCredential
if strings.TrimSpace(item.Name) == "" {
return item, errors.New("name is required")
}
if strings.TrimSpace(item.ServerID) == "" {
return item, errors.New("server_id is required")
}
if strings.TrimSpace(item.RemoteUsername) == "" {
return item, errors.New("remote_username is required")
}
if strings.TrimSpace(item.AuthMethod) == "" {
return item, errors.New("auth_method is required")
}
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
if item.DefaultCertValidSeconds <= 0 { item.DefaultCertValidSeconds = 3600 }
if item.MaxCertValidSeconds <= 0 { item.MaxCertValidSeconds = item.DefaultCertValidSeconds }
if item.ValidAfter < 0 { item.ValidAfter = 0 }
if item.ValidBefore < 0 { item.ValidBefore = 0 }
item.SSHPrincipalGrantIDs = normalizeStringList(item.SSHPrincipalGrantIDs)
grantIDsJSON, err = encodeStringList(item.SSHPrincipalGrantIDs)
if err != nil { return item, err }
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
tx, owned, err = s.beginContext(ctx)
if err != nil { return item, err }
item.SecretPayload = strings.TrimSpace(item.SecretPayload)
item.SecretPassword = strings.TrimSpace(item.SecretPassword)
item.SSHCredentialID = strings.TrimSpace(item.SSHCredentialID)
if item.SSHCredentialID != "" {
credential, err = s.GetSSHCredential(item.SSHCredentialID)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
item.SecretID = credential.SecretID
item.AuthPublicKey = credential.PublicKey
item.AuthPublicKeyFingerprint = credential.Fingerprint
} else if item.SecretPayload != "" || item.SecretPassword != "" {
secret = models.SSHSecret{
ID: item.SecretID,
Kind: "private_key",
Payload: item.SecretPayload,
Password: item.SecretPassword,
CreatedByKind: item.CreatedByKind,
CreatedBySubjectID: item.CreatedBySubjectID,
CreatedBySubjectName: item.CreatedBySubjectName,
}
secret, err = createSSHSecretTx(tx, secret)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
item.SecretID = secret.ID
}
_, err = tx.Exec(`INSERT INTO ssh_access_profiles (
public_id,
server_id,
server_target_type,
server_group_id,
name,
description,
remote_username,
auth_method,
second_factor_mode,
owner_scope,
owner_user_id,
allow_user_edit,
enabled,
secret_id,
ssh_credential_id,
auth_public_key,
auth_public_key_fingerprint,
ssh_user_ca_id,
ssh_principal_grant_ids_json,
default_cert_valid_seconds,
max_cert_valid_seconds,
valid_after,
valid_before,
created_by_kind,
created_by_user_id,
created_by_principal_id,
created_by_subject_name,
created_at,
updated_at
) VALUES (?, (SELECT id FROM ssh_servers WHERE public_id = ?), ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_server_groups WHERE public_id = ?) END, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM users WHERE public_id = ?) END, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_secrets WHERE public_id = ?) END, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_credentials WHERE public_id = ?) END, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_user_cas WHERE public_id = ?) END, ?, ?, ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.ServerID),
strings.TrimSpace(item.ServerTargetType),
strings.TrimSpace(item.ServerGroupID),
emptyStringToNil(item.ServerGroupID),
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Description),
strings.TrimSpace(item.RemoteUsername),
strings.TrimSpace(item.AuthMethod),
strings.TrimSpace(item.SecondFactorMode),
strings.TrimSpace(item.OwnerScope),
strings.TrimSpace(item.OwnerUserID),
strings.TrimSpace(item.OwnerUserID),
item.AllowUserEdit,
item.Enabled,
strings.TrimSpace(item.SecretID),
emptyStringToNil(item.SecretID),
strings.TrimSpace(item.SSHCredentialID),
emptyStringToNil(item.SSHCredentialID),
strings.TrimSpace(item.AuthPublicKey),
strings.TrimSpace(item.AuthPublicKeyFingerprint),
strings.TrimSpace(item.SSHUserCAID),
emptyStringToNil(item.SSHUserCAID),
grantIDsJSON,
item.DefaultCertValidSeconds,
item.MaxCertValidSeconds,
item.ValidAfter,
item.ValidBefore,
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectName),
item.CreatedAt,
item.UpdatedAt,
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
for i = 0; i < len(item.Targets); i++ {
target = item.Targets[i]
err = insertSSHAccessProfileTargetTx(tx, item.ID, target.TargetType, target.TargetID, now)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
}
err = commitIfOwned(tx, owned)
if err != nil {
return item, err
}
item, err = s.GetSSHAccessProfile(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) UpdateSSHAccessProfile(ctx context.Context, item models.SSHAccessProfile) (models.SSHAccessProfile, error) {
var tx txExecutor
var owned bool
var err error
var grantIDsJSON string
var i int
var target models.SSHAccessProfileTarget
var secret models.SSHSecret
var credential models.SSHCredential
if strings.TrimSpace(item.Name) == "" {
return item, errors.New("name is required")
}
if strings.TrimSpace(item.ServerID) == "" {
return item, errors.New("server_id is required")
}
if strings.TrimSpace(item.RemoteUsername) == "" {
return item, errors.New("remote_username is required")
}
if strings.TrimSpace(item.AuthMethod) == "" {
return item, errors.New("auth_method is required")
}
if item.DefaultCertValidSeconds <= 0 { item.DefaultCertValidSeconds = 3600 }
if item.MaxCertValidSeconds <= 0 { item.MaxCertValidSeconds = item.DefaultCertValidSeconds }
if item.ValidAfter < 0 { item.ValidAfter = 0 }
if item.ValidBefore < 0 { item.ValidBefore = 0 }
item.SSHPrincipalGrantIDs = normalizeStringList(item.SSHPrincipalGrantIDs)
grantIDsJSON, err = encodeStringList(item.SSHPrincipalGrantIDs)
if err != nil { return item, err }
item.UpdatedAt = time.Now().UTC().Unix()
tx, owned, err = s.beginContext(ctx)
if err != nil { return item, err }
item.SecretPayload = strings.TrimSpace(item.SecretPayload)
item.SecretPassword = strings.TrimSpace(item.SecretPassword)
item.SSHCredentialID = strings.TrimSpace(item.SSHCredentialID)
if item.SSHCredentialID != "" {
credential, err = s.GetSSHCredential(item.SSHCredentialID)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
item.SecretID = credential.SecretID
item.AuthPublicKey = credential.PublicKey
item.AuthPublicKeyFingerprint = credential.Fingerprint
} else if item.SecretPayload != "" || item.SecretPassword != "" {
secret = models.SSHSecret{
ID: item.SecretID,
Kind: "private_key",
Payload: item.SecretPayload,
Password: item.SecretPassword,
CreatedByKind: item.CreatedByKind,
CreatedBySubjectID: item.CreatedBySubjectID,
CreatedBySubjectName: item.CreatedBySubjectName,
}
if strings.TrimSpace(item.SecretID) == "" {
secret, err = createSSHSecretTx(tx, secret)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
} else {
secret, err = updateSSHSecretTx(tx, secret)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
}
item.SecretID = secret.ID
}
_, err = tx.Exec(`UPDATE ssh_access_profiles
SET server_id = (SELECT id FROM ssh_servers WHERE public_id = ?), server_target_type = ?, server_group_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_server_groups WHERE public_id = ?) END, name = ?, description = ?, remote_username = ?, auth_method = ?, second_factor_mode = ?, owner_scope = ?, owner_user_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM users WHERE public_id = ?) END, allow_user_edit = ?, enabled = ?, secret_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_secrets WHERE public_id = ?) END, ssh_credential_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_credentials WHERE public_id = ?) END, auth_public_key = ?, auth_public_key_fingerprint = ?, ssh_user_ca_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_user_cas WHERE public_id = ?) END, ssh_principal_grant_ids_json = ?, default_cert_valid_seconds = ?, max_cert_valid_seconds = ?, valid_after = ?, valid_before = ?, updated_at = ?
WHERE public_id = ?`,
strings.TrimSpace(item.ServerID),
strings.TrimSpace(item.ServerTargetType),
strings.TrimSpace(item.ServerGroupID),
emptyStringToNil(item.ServerGroupID),
strings.TrimSpace(item.Name),
strings.TrimSpace(item.Description),
strings.TrimSpace(item.RemoteUsername),
strings.TrimSpace(item.AuthMethod),
strings.TrimSpace(item.SecondFactorMode),
strings.TrimSpace(item.OwnerScope),
strings.TrimSpace(item.OwnerUserID),
strings.TrimSpace(item.OwnerUserID),
item.AllowUserEdit,
item.Enabled,
strings.TrimSpace(item.SecretID),
emptyStringToNil(item.SecretID),
strings.TrimSpace(item.SSHCredentialID),
emptyStringToNil(item.SSHCredentialID),
strings.TrimSpace(item.AuthPublicKey),
strings.TrimSpace(item.AuthPublicKeyFingerprint),
strings.TrimSpace(item.SSHUserCAID),
emptyStringToNil(item.SSHUserCAID),
grantIDsJSON,
item.DefaultCertValidSeconds,
item.MaxCertValidSeconds,
item.ValidAfter,
item.ValidBefore,
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
_, err = tx.Exec(`DELETE FROM ssh_access_profile_targets WHERE profile_id = (SELECT id FROM ssh_access_profiles WHERE public_id = ?)`, strings.TrimSpace(item.ID))
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
for i = 0; i < len(item.Targets); i++ {
target = item.Targets[i]
err = insertSSHAccessProfileTargetTx(tx, item.ID, target.TargetType, target.TargetID, item.UpdatedAt)
if err != nil {
rollbackIfOwned(tx, owned)
return item, err
}
}
err = commitIfOwned(tx, owned)
if err != nil {
return item, err
}
item, err = s.GetSSHAccessProfile(item.ID)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) DeleteSSHAccessProfile(ctx context.Context, id string) error {
var tx txExecutor
var owned bool
var row *sql.Row
var secretID sql.NullString
var credentialID sql.NullString
var err error
tx, owned, err = s.beginContext(ctx)
if err != nil {
return err
}
row = tx.QueryRow(`SELECT COALESCE(sec.public_id, ''), COALESCE(c.public_id, '')
FROM ssh_access_profiles p
LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id
LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id
WHERE p.public_id = ?`, strings.TrimSpace(id))
err = row.Scan(&secretID, &credentialID)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
_, err = tx.Exec(`DELETE FROM ssh_access_profiles WHERE public_id = ?`, strings.TrimSpace(id))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
if secretID.Valid && strings.TrimSpace(secretID.String) != "" && (!credentialID.Valid || strings.TrimSpace(credentialID.String) == "") {
_, err = tx.Exec(`DELETE FROM ssh_secrets WHERE public_id = ?`, strings.TrimSpace(secretID.String))
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
}
err = commitIfOwned(tx, owned)
if err != nil {
return err
}
return nil
}
func (s *Store) ListSSHAccessProfilesForUser(userID string) ([]models.SSHAccessProfile, error) {
var rows *sql.Rows
var err error
var items []models.SSHAccessProfile
var item models.SSHAccessProfile
var grantIDsJSON string
var serverTagsJSON string
var groupTagsJSON string
var trimmedUserID string
trimmedUserID = strings.TrimSpace(userID)
rows, err = s.Query(`SELECT DISTINCT
p.public_id,
COALESCE(s.public_id, ''),
COALESCE(p.server_target_type, 'server'),
COALESCE(g.public_id, ''),
COALESCE(g.name, ''),
COALESCE(g.description, ''),
COALESCE(g.enabled, 0),
COALESCE(g.created_by_kind, ''),
COALESCE(g_cbu.public_id, g_cbp.public_id, ''),
COALESCE(g.created_by_subject_name, ''),
COALESCE(g.created_at, 0),
COALESCE(g.updated_at, 0),
COALESCE(g.tags_json, '[]'),
p.name,
p.description,
p.remote_username,
p.auth_method,
p.second_factor_mode,
p.owner_scope,
COALESCE(pou.public_id, ''),
p.allow_user_edit,
p.enabled,
COALESCE(sec.public_id, ''),
COALESCE(c.public_id, ''),
COALESCE(c.name, ''),
p.auth_public_key,
p.auth_public_key_fingerprint,
COALESCE(ca.public_id, ''),
p.ssh_principal_grant_ids_json,
p.default_cert_valid_seconds,
p.max_cert_valid_seconds,
p.valid_after,
p.valid_before,
p.created_by_kind,
COALESCE(p_cbu.public_id, p_cbp.public_id, ''),
p.created_by_subject_name,
p.created_at,
p.updated_at,
s.public_id,
s.name,
s.host,
s.port,
s.description,
s.tags_json,
s.enabled,
s.host_key_policy,
s.owner_scope,
COALESCE(sou.public_id, ''),
s.created_by_kind,
COALESCE(s_cbu.public_id, s_cbp.public_id, ''),
s.created_by_subject_name,
s.created_at,
s.updated_at
FROM ssh_access_profiles p
JOIN ssh_servers s ON s.id = p.server_id
LEFT JOIN users sou ON sou.id = s.owner_user_id
LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id
LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id
LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id
LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id
LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id
LEFT JOIN users pou ON pou.id = p.owner_user_id
LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id
LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id
LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id
LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id
LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id
LEFT JOIN ssh_access_profile_targets t ON t.profile_id = p.id
LEFT JOIN user_groups ug ON t.target_type = 'group' AND ug.id = t.target_id
LEFT JOIN user_group_members gm ON t.target_type = 'group' AND gm.group_id = ug.id
LEFT JOIN users gu ON gm.user_id = gu.id
LEFT JOIN users tu ON t.target_type = 'user' AND tu.id = t.target_id
WHERE (
(p.owner_scope = ? AND pou.public_id = ?)
OR
(
p.owner_scope = ?
AND p.enabled = 1
AND s.enabled = 1
AND (p.valid_after = 0 OR p.valid_after <= CAST(strftime('%s','now') AS INTEGER))
AND (p.valid_before = 0 OR p.valid_before >= CAST(strftime('%s','now') AS INTEGER))
AND (
(t.target_type = 'user' AND tu.public_id = ?)
OR
(t.target_type = 'group' AND ug.disabled = 0 AND (ug.scope = 'all_users' OR gu.public_id = ?))
)
)
)
ORDER BY s.name, p.name`, SSHOwnerScopeUser, trimmedUserID, SSHOwnerScopeAdminShared, trimmedUserID, trimmedUserID)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.ServerID,
&item.ServerTargetType,
&item.ServerGroupID,
&item.ServerGroup.Name,
&item.ServerGroup.Description,
&item.ServerGroup.Enabled,
&item.ServerGroup.CreatedByKind,
&item.ServerGroup.CreatedBySubjectID,
&item.ServerGroup.CreatedBySubjectName,
&item.ServerGroup.CreatedAt,
&item.ServerGroup.UpdatedAt,
&groupTagsJSON,
&item.Name,
&item.Description,
&item.RemoteUsername,
&item.AuthMethod,
&item.SecondFactorMode,
&item.OwnerScope,
&item.OwnerUserID,
&item.AllowUserEdit,
&item.Enabled,
&item.SecretID,
&item.SSHCredentialID,
&item.SSHCredentialName,
&item.AuthPublicKey,
&item.AuthPublicKeyFingerprint,
&item.SSHUserCAID,
&grantIDsJSON,
&item.DefaultCertValidSeconds,
&item.MaxCertValidSeconds,
&item.ValidAfter,
&item.ValidBefore,
&item.CreatedByKind,
&item.CreatedBySubjectID,
&item.CreatedBySubjectName,
&item.CreatedAt,
&item.UpdatedAt,
&item.Server.ID,
&item.Server.Name,
&item.Server.Host,
&item.Server.Port,
&item.Server.Description,
&serverTagsJSON,
&item.Server.Enabled,
&item.Server.HostKeyPolicy,
&item.Server.OwnerScope,
&item.Server.OwnerUserID,
&item.Server.CreatedByKind,
&item.Server.CreatedBySubjectID,
&item.Server.CreatedBySubjectName,
&item.Server.CreatedAt,
&item.Server.UpdatedAt,
)
if err != nil { return nil, err }
if item.ServerGroupID != "" {
item.ServerGroup.ID = item.ServerGroupID
}
item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON)
if err != nil { return nil, err }
item.Server.Tags, err = decodeStringList(serverTagsJSON)
if err != nil { return nil, err }
item.ServerGroup.Tags, err = decodeStringList(groupTagsJSON)
if err != nil { return nil, err }
item.Targets, err = s.listSSHAccessProfileTargets(item.ID)
if err != nil { return nil, err }
item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, err }
return items, nil
}
func (s *Store) GetSSHAccessProfileForUser(userID string, id string) (models.SSHAccessProfile, error) {
var row *sql.Row
var item models.SSHAccessProfile
var grantIDsJSON string
var serverTagsJSON string
var groupTagsJSON string
var trimmedUserID string
var trimmedID string
var err error
// this function gates the profiles with s.enabled=1.
// if we want all profiles, we should remove the condition
trimmedUserID = strings.TrimSpace(userID)
trimmedID = strings.TrimSpace(id)
row = s.QueryRow(`SELECT DISTINCT
p.public_id,
COALESCE(s.public_id, ''),
COALESCE(p.server_target_type, 'server'),
COALESCE(g.public_id, ''),
COALESCE(g.name, ''),
COALESCE(g.description, ''),
COALESCE(g.enabled, 0),
COALESCE(g.created_by_kind, ''),
COALESCE(g_cbu.public_id, g_cbp.public_id, ''),
COALESCE(g.created_by_subject_name, ''),
COALESCE(g.created_at, 0),
COALESCE(g.updated_at, 0),
COALESCE(g.tags_json, '[]'),
p.name,
p.description,
p.remote_username,
p.auth_method,
p.second_factor_mode,
p.owner_scope,
COALESCE(pou.public_id, ''),
p.allow_user_edit,
p.enabled,
COALESCE(sec.public_id, ''),
COALESCE(c.public_id, ''),
COALESCE(c.name, ''),
p.auth_public_key,
p.auth_public_key_fingerprint,
COALESCE(ca.public_id, ''),
p.ssh_principal_grant_ids_json,
p.default_cert_valid_seconds,
p.max_cert_valid_seconds,
p.valid_after,
p.valid_before,
p.created_by_kind,
COALESCE(p_cbu.public_id, p_cbp.public_id, ''),
p.created_by_subject_name,
p.created_at,
p.updated_at,
s.public_id,
s.name,
s.host,
s.port,
s.description,
s.tags_json,
s.enabled,
s.host_key_policy,
s.owner_scope,
COALESCE(sou.public_id, ''),
s.created_by_kind,
COALESCE(s_cbu.public_id, s_cbp.public_id, ''),
s.created_by_subject_name,
s.created_at,
s.updated_at
FROM ssh_access_profiles p
JOIN ssh_servers s ON s.id = p.server_id
LEFT JOIN users sou ON sou.id = s.owner_user_id
LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id
LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id
LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id
LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id
LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id
LEFT JOIN users pou ON pou.id = p.owner_user_id
LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id
LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id
LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id
LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id
LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id
LEFT JOIN ssh_access_profile_targets t ON t.profile_id = p.id
LEFT JOIN user_groups ug ON t.target_type = 'group' AND ug.id = t.target_id
LEFT JOIN user_group_members gm ON t.target_type = 'group' AND gm.group_id = ug.id
LEFT JOIN users gu ON gm.user_id = gu.id
LEFT JOIN users tu ON t.target_type = 'user' AND tu.id = t.target_id
WHERE p.public_id = ?
AND (
(p.owner_scope = ? AND pou.public_id = ?)
OR
(
p.owner_scope = ?
AND p.enabled = 1
AND s.enabled = 1
AND (p.valid_after = 0 OR p.valid_after <= CAST(strftime('%s','now') AS INTEGER))
AND (p.valid_before = 0 OR p.valid_before >= CAST(strftime('%s','now') AS INTEGER))
AND (
(t.target_type = 'user' AND tu.public_id = ?)
OR
(t.target_type = 'group' AND ug.disabled = 0 AND (ug.scope = 'all_users' OR gu.public_id = ?))
)
)
)
LIMIT 1`, trimmedID, SSHOwnerScopeUser, trimmedUserID, SSHOwnerScopeAdminShared, trimmedUserID, trimmedUserID)
err = row.Scan(
&item.ID,
&item.ServerID,
&item.ServerTargetType,
&item.ServerGroupID,
&item.ServerGroup.Name,
&item.ServerGroup.Description,
&item.ServerGroup.Enabled,
&item.ServerGroup.CreatedByKind,
&item.ServerGroup.CreatedBySubjectID,
&item.ServerGroup.CreatedBySubjectName,
&item.ServerGroup.CreatedAt,
&item.ServerGroup.UpdatedAt,
&groupTagsJSON,
&item.Name,
&item.Description,
&item.RemoteUsername,
&item.AuthMethod,
&item.SecondFactorMode,
&item.OwnerScope,
&item.OwnerUserID,
&item.AllowUserEdit,
&item.Enabled,
&item.SecretID,
&item.SSHCredentialID,
&item.SSHCredentialName,
&item.AuthPublicKey,
&item.AuthPublicKeyFingerprint,
&item.SSHUserCAID,
&grantIDsJSON,
&item.DefaultCertValidSeconds,
&item.MaxCertValidSeconds,
&item.ValidAfter,
&item.ValidBefore,
&item.CreatedByKind,
&item.CreatedBySubjectID,
&item.CreatedBySubjectName,
&item.CreatedAt,
&item.UpdatedAt,
&item.Server.ID,
&item.Server.Name,
&item.Server.Host,
&item.Server.Port,
&item.Server.Description,
&serverTagsJSON,
&item.Server.Enabled,
&item.Server.HostKeyPolicy,
&item.Server.OwnerScope,
&item.Server.OwnerUserID,
&item.Server.CreatedByKind,
&item.Server.CreatedBySubjectID,
&item.Server.CreatedBySubjectName,
&item.Server.CreatedAt,
&item.Server.UpdatedAt,
)
if err != nil {
return item, err
}
if item.ServerGroupID != "" {
item.ServerGroup.ID = item.ServerGroupID
}
item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON)
if err != nil {
return item, err
}
item.Server.Tags, err = decodeStringList(serverTagsJSON)
if err != nil {
return item, err
}
item.Targets, err = s.listSSHAccessProfileTargets(item.ID)
if err != nil {
return item, err
}
item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID
return item, nil
}
func (s *Store) GetSSHAccessProfileServerForUser(userID string, id string, serverID string) (models.SSHServer, error) {
var profile models.SSHAccessProfile
var item models.SSHServer
var trimmedServerID string
var err error
// the ssh access profile must be found(id)
// the ssh access profile must be visible to the user(userID)
// the access profile must be bound to a server (serverID)
profile, err = s.GetSSHAccessProfileForUser(userID, id)
if err != nil { return item, err }
trimmedServerID = strings.TrimSpace(serverID)
if profile.ServerTargetType != SSHAccessProfileServerTargetTypeServer || profile.ServerID != trimmedServerID { return item, sql.ErrNoRows }
return profile.Server, nil
}
func (s *Store) ListSSHServersForAccessProfileServerGroupForUser(userID string, id string, serverGroupID string) ([]models.SSHServer, error) {
var profile models.SSHAccessProfile
var trimmedGroupID string
var err error
// the ssh access profile must be found(id)
// the ssh access profile must be visible to the user(userID)
// the access profile must be bound to a server group (serverGroupID)
profile, err = s.GetSSHAccessProfileForUser(userID, id)
if err != nil { return nil, err }
trimmedGroupID = strings.TrimSpace(serverGroupID)
if profile.ServerTargetType != SSHAccessProfileServerTargetTypeGroup || profile.ServerGroupID != trimmedGroupID { return nil, sql.ErrNoRows }
return s.ListSSHServersForGroup(trimmedGroupID)
}
func (s *Store) listSSHAccessProfileTargets(profileID string) ([]models.SSHAccessProfileTarget, error) {
var rows *sql.Rows
var err error
var items []models.SSHAccessProfileTarget
var item models.SSHAccessProfileTarget
rows, err = s.Query(`SELECT
p.public_id,
t.target_type,
CASE WHEN t.target_type = 'user' THEN COALESCE(u.public_id, '') ELSE COALESCE(g.public_id, '') END,
CASE
WHEN t.target_type = 'user' THEN COALESCE(CASE WHEN u.display_name != '' THEN u.display_name || ' (' || u.username || ')' ELSE u.username END, '')
ELSE COALESCE(g.name, '')
END,
CASE
WHEN t.target_type = 'user' THEN CASE WHEN u.public_id IS NOT NULL AND u.disabled = 0 THEN 1 ELSE 0 END
ELSE CASE WHEN g.public_id IS NOT NULL AND g.disabled = 0 THEN 1 ELSE 0 END
END,
t.created_at
FROM ssh_access_profile_targets t
JOIN ssh_access_profiles p ON p.id = t.profile_id
LEFT JOIN users u ON t.target_type = 'user' AND u.id = t.target_id
LEFT JOIN user_groups g ON t.target_type = 'group' AND g.id = t.target_id
WHERE p.public_id = ?
ORDER BY t.target_type, target_id`, strings.TrimSpace(profileID))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ProfileID, &item.TargetType, &item.TargetID, &item.TargetName, &item.TargetActive, &item.CreatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) GetSSHSecret(id string) (models.SSHSecret, error) {
var row *sql.Row
var item models.SSHSecret
var err error
row = s.QueryRow(`SELECT sec.public_id, sec.kind, sec.payload, sec.password, sec.metadata_json, sec.created_by_kind, COALESCE(sec_cbu.public_id, sec_cbp.public_id, ''), sec.created_by_subject_name, sec.created_at, sec.updated_at
FROM ssh_secrets sec
LEFT JOIN users sec_cbu ON sec_cbu.id = sec.created_by_user_id
LEFT JOIN service_principals sec_cbp ON sec_cbp.id = sec.created_by_principal_id
WHERE sec.public_id = ?`, strings.TrimSpace(id))
err = row.Scan(&item.ID, &item.Kind, &item.Payload, &item.Password, &item.MetadataJSON, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) ListSSHServerHostKeys(serverID string) ([]models.SSHServerHostKey, error) {
var rows *sql.Rows
var items []models.SSHServerHostKey
var item models.SSHServerHostKey
var err error
rows, err = s.Query(`SELECT hk.public_id, s.public_id, hk.algorithm, hk.public_key, hk.fingerprint, hk.created_at
FROM ssh_server_host_keys hk
JOIN ssh_servers s ON s.id = hk.server_id
WHERE s.public_id = ?
ORDER BY hk.created_at, hk.fingerprint`, strings.TrimSpace(serverID))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.ServerID, &item.Algorithm, &item.PublicKey, &item.Fingerprint, &item.CreatedAt)
if err != nil {
return nil, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, err
}
return items, nil
}
func (s *Store) CreateSSHServerHostKey(item models.SSHServerHostKey) (models.SSHServerHostKey, error) {
var err error
var now int64
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil {
return item, err
}
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
_, err = s.Exec(`INSERT INTO ssh_server_host_keys (public_id, server_id, algorithm, public_key, fingerprint, created_at)
VALUES (?, (SELECT id FROM ssh_servers WHERE public_id = ?), ?, ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.ServerID),
strings.TrimSpace(item.Algorithm),
strings.TrimSpace(item.PublicKey),
strings.TrimSpace(item.Fingerprint),
item.CreatedAt,
)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) DeleteSSHServerHostKey(serverID string, hostKeyID string) error {
var err error
_, err = s.Exec(`DELETE FROM ssh_server_host_keys WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?) AND public_id = ?`, strings.TrimSpace(serverID), strings.TrimSpace(hostKeyID))
return err
}
func (s *Store) CreateSSHSession(item models.SSHSession) (models.SSHSession, error) {
var err error
var now int64
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil {
return item, err
}
}
now = time.Now().UTC().Unix()
if item.Status == "" {
item.Status = "pending"
}
if item.RequestedTerm == "" {
item.RequestedTerm = "xterm-256color"
}
if item.RequestedCols <= 0 {
item.RequestedCols = 80
}
if item.RequestedRows <= 0 {
item.RequestedRows = 24
}
item.StartedAt = now
_, err = s.Exec(`INSERT INTO ssh_sessions (
public_id,
profile_id,
server_id,
user_id,
username,
remote_username,
auth_method,
second_factor_mode,
host,
port,
status,
host_key_fingerprint,
requested_term,
requested_cols,
requested_rows,
started_at,
connected_at,
ended_at,
remote_addr,
user_agent,
error
) VALUES (?, (SELECT id FROM ssh_access_profiles WHERE public_id = ?), (SELECT id FROM ssh_servers WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.ProfileID),
strings.TrimSpace(item.ServerID),
strings.TrimSpace(item.UserID),
strings.TrimSpace(item.Username),
strings.TrimSpace(item.RemoteUsername),
strings.TrimSpace(item.AuthMethod),
strings.TrimSpace(item.SecondFactorMode),
strings.TrimSpace(item.Host),
item.Port,
strings.TrimSpace(item.Status),
strings.TrimSpace(item.HostKeyFingerprint),
strings.TrimSpace(item.RequestedTerm),
item.RequestedCols,
item.RequestedRows,
item.StartedAt,
item.ConnectedAt,
item.EndedAt,
strings.TrimSpace(item.RemoteAddr),
strings.TrimSpace(item.UserAgent),
strings.TrimSpace(item.Error),
)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) ExpireStalePendingSSHSessions(cutoff int64, reason string) error {
var err error
_, err = s.Exec(`UPDATE ssh_sessions
SET status = 'closed', ended_at = ?, error = ?
WHERE status = 'pending' AND started_at < ?`,
cutoff,
strings.TrimSpace(reason),
cutoff,
)
return err
}
func (s *Store) CountPendingSSHSessionsForUser(userID string) (int, error) {
var row *sql.Row
var count int
var err error
row = s.QueryRow(`SELECT COUNT(*)
FROM ssh_sessions
WHERE user_id = (SELECT id FROM users WHERE public_id = ?) AND status = 'pending'`, strings.TrimSpace(userID))
err = row.Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
func (s *Store) GetSSHSession(id string) (models.SSHSession, error) {
var row *sql.Row
var item models.SSHSession
var err error
row = s.QueryRow(`SELECT ss.public_id, COALESCE(p.public_id, ''), COALESCE(p.name, ''), COALESCE(s.public_id, ''), COALESCE(s.name, ''), COALESCE(u.public_id, ''), ss.username, ss.remote_username, ss.auth_method, ss.second_factor_mode, ss.host, ss.port, ss.status, ss.host_key_fingerprint, ss.requested_term, ss.requested_cols, ss.requested_rows, ss.started_at, ss.connected_at, ss.ended_at, ss.remote_addr, ss.user_agent, ss.error
FROM ssh_sessions ss
LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id
LEFT JOIN ssh_servers s ON s.id = ss.server_id
LEFT JOIN users u ON u.id = ss.user_id
WHERE ss.public_id = ?`, strings.TrimSpace(id))
err = row.Scan(&item.ID, &item.ProfileID, &item.ProfileName, &item.ServerID, &item.ServerName, &item.UserID, &item.Username, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.Host, &item.Port, &item.Status, &item.HostKeyFingerprint, &item.RequestedTerm, &item.RequestedCols, &item.RequestedRows, &item.StartedAt, &item.ConnectedAt, &item.EndedAt, &item.RemoteAddr, &item.UserAgent, &item.Error)
if err != nil {
return item, err
}
return item, nil
}
func (s *Store) ListSSHSessionsForUser(userID string, limit int) ([]models.SSHSession, error) {
var items []models.SSHSession
var hasMore bool
var err error
items, hasMore, err = s.ListSSHSessionsForUserFiltered(userID, limit, 0, "", "")
if err != nil {
return nil, err
}
_ = hasMore
return items, nil
}
func (s *Store) ListSSHSessionsForUserFiltered(userID string, limit int, offset int, query string, status string) ([]models.SSHSession, bool, error) {
var rows *sql.Rows
var items []models.SSHSession
var item models.SSHSession
var whereParts []string
var args []any
var sqlQuery string
var like string
var err error
if limit <= 0 {
limit = 20
}
if limit > 200 {
limit = 200
}
if offset < 0 {
offset = 0
}
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
whereParts = append(whereParts, "u.public_id = ?")
args = append(args, strings.TrimSpace(userID))
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ss.status = ?")
args = append(args, status)
}
sqlQuery = fmt.Sprintf(`SELECT ss.public_id, COALESCE(p.public_id, ''), COALESCE(p.name, ''), COALESCE(s.public_id, ''), COALESCE(s.name, ''), COALESCE(u.public_id, ''), ss.username, ss.remote_username, ss.auth_method, ss.second_factor_mode, ss.host, ss.port, ss.status, ss.host_key_fingerprint, ss.requested_term, ss.requested_cols, ss.requested_rows, ss.started_at, ss.connected_at, ss.ended_at, ss.remote_addr, ss.user_agent, ss.error
FROM ssh_sessions ss
LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id
LEFT JOIN ssh_servers s ON s.id = ss.server_id
LEFT JOIN users u ON u.id = ss.user_id
WHERE %s
ORDER BY ss.started_at DESC
LIMIT ? OFFSET ?`, strings.Join(whereParts, " AND "))
args = append(args, limit + 1, offset)
rows, err = s.Query(sqlQuery, args...)
if err != nil {
return nil, false, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.ProfileID, &item.ProfileName, &item.ServerID, &item.ServerName, &item.UserID, &item.Username, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.Host, &item.Port, &item.Status, &item.HostKeyFingerprint, &item.RequestedTerm, &item.RequestedCols, &item.RequestedRows, &item.StartedAt, &item.ConnectedAt, &item.EndedAt, &item.RemoteAddr, &item.UserAgent, &item.Error)
if err != nil {
return nil, false, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, false, err
}
if len(items) > limit {
items = items[:limit]
return items, true, nil
}
return items, false, nil
}
func (s *Store) CountSSHSessionsForUserFiltered(userID string, query string, status string) (int, error) {
var whereParts []string
var args []any
var sqlQuery string
var like string
var total int
var err error
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
whereParts = append(whereParts, "u.public_id = ?")
args = append(args, strings.TrimSpace(userID))
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ss.status = ?")
args = append(args, status)
}
sqlQuery = fmt.Sprintf(`SELECT COUNT(*)
FROM ssh_sessions ss
LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id
LEFT JOIN ssh_servers s ON s.id = ss.server_id
LEFT JOIN users u ON u.id = ss.user_id
WHERE %s`, strings.Join(whereParts, " AND "))
err = s.QueryRow(sqlQuery, args...).Scan(&total)
if err != nil {
return 0, err
}
return total, nil
}
func (s *Store) ListSSHSessionsFiltered(limit int, offset int, query string, status string) ([]models.SSHSession, bool, error) {
var rows *sql.Rows
var items []models.SSHSession
var item models.SSHSession
var whereParts []string
var args []any
var sqlQuery string
var like string
var err error
if limit <= 0 {
limit = 20
}
if limit > 200 {
limit = 200
}
if offset < 0 {
offset = 0
}
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ss.status = ?")
args = append(args, status)
}
sqlQuery = `SELECT ss.public_id, COALESCE(p.public_id, ''), COALESCE(p.name, ''), COALESCE(s.public_id, ''), COALESCE(s.name, ''), COALESCE(u.public_id, ''), ss.username, ss.remote_username, ss.auth_method, ss.second_factor_mode, ss.host, ss.port, ss.status, ss.host_key_fingerprint, ss.requested_term, ss.requested_cols, ss.requested_rows, ss.started_at, ss.connected_at, ss.ended_at, ss.remote_addr, ss.user_agent, ss.error
FROM ssh_sessions ss
LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id
LEFT JOIN ssh_servers s ON s.id = ss.server_id
LEFT JOIN users u ON u.id = ss.user_id`
if len(whereParts) > 0 {
sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ")
}
sqlQuery += "\n\t\tORDER BY ss.started_at DESC\n\t\tLIMIT ? OFFSET ?"
args = append(args, limit + 1, offset)
rows, err = s.Query(sqlQuery, args...)
if err != nil {
return nil, false, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&item.ID, &item.ProfileID, &item.ProfileName, &item.ServerID, &item.ServerName, &item.UserID, &item.Username, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.Host, &item.Port, &item.Status, &item.HostKeyFingerprint, &item.RequestedTerm, &item.RequestedCols, &item.RequestedRows, &item.StartedAt, &item.ConnectedAt, &item.EndedAt, &item.RemoteAddr, &item.UserAgent, &item.Error)
if err != nil {
return nil, false, err
}
items = append(items, item)
}
err = rows.Err()
if err != nil {
return nil, false, err
}
if len(items) > limit {
items = items[:limit]
return items, true, nil
}
return items, false, nil
}
func (s *Store) CountSSHSessionsFiltered(query string, status string) (int, error) {
var whereParts []string
var args []any
var sqlQuery string
var like string
var total int
var err error
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ss.status = ?")
args = append(args, status)
}
sqlQuery = `SELECT COUNT(*)
FROM ssh_sessions ss
LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id
LEFT JOIN ssh_servers s ON s.id = ss.server_id
LEFT JOIN users u ON u.id = ss.user_id`
if len(whereParts) > 0 {
sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ")
}
err = s.QueryRow(sqlQuery, args...).Scan(&total)
if err != nil {
return 0, err
}
return total, nil
}
func (s *Store) UpdateSSHSessionStatus(id string, status string, hostKeyFingerprint string, connectedAt int64, endedAt int64, errorText string) error {
var err error
_, err = s.Exec(`UPDATE ssh_sessions
SET status = ?, host_key_fingerprint = ?, connected_at = ?, ended_at = ?, error = ?
WHERE public_id = ?`,
strings.TrimSpace(status),
strings.TrimSpace(hostKeyFingerprint),
connectedAt,
endedAt,
strings.TrimSpace(errorText),
strings.TrimSpace(id),
)
return err
}
func (s *Store) CreateSSHFileTransfer(item models.SSHFileTransfer) (models.SSHFileTransfer, error) {
var pathsJSON string
var now int64
var err error
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil { return item, err }
}
pathsJSON, err = encodeStringList(item.Paths)
if err != nil { return item, err }
now = time.Now().UTC().Unix()
if item.StartedAt <= 0 {
item.StartedAt = now
}
if strings.TrimSpace(item.Status) == "" {
item.Status = "running"
}
_, err = s.Exec(`INSERT INTO ssh_file_transfers (
public_id,
session_id,
user_id,
username,
profile_id,
server_id,
server_name,
remote_username,
operation,
source_session_id,
target_session_id,
target_server_id,
target_server_name,
target_dir,
paths_json,
bytes_transferred,
status,
error,
started_at,
finished_at,
duration_ms,
remote_addr,
user_agent
) VALUES (?, (SELECT id FROM ssh_sessions WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?, (SELECT id FROM ssh_access_profiles WHERE public_id = ?), (SELECT id FROM ssh_servers WHERE public_id = ?), ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_sessions WHERE public_id = ?) END, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_sessions WHERE public_id = ?) END, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_servers WHERE public_id = ?) END, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
strings.TrimSpace(item.ID),
strings.TrimSpace(item.SessionID),
strings.TrimSpace(item.UserID),
strings.TrimSpace(item.Username),
strings.TrimSpace(item.ProfileID),
strings.TrimSpace(item.ServerID),
strings.TrimSpace(item.ServerName),
strings.TrimSpace(item.RemoteUsername),
strings.TrimSpace(item.Operation),
strings.TrimSpace(item.SourceSessionID),
strings.TrimSpace(item.SourceSessionID),
strings.TrimSpace(item.TargetSessionID),
strings.TrimSpace(item.TargetSessionID),
strings.TrimSpace(item.TargetServerID),
strings.TrimSpace(item.TargetServerID),
strings.TrimSpace(item.TargetServerName),
strings.TrimSpace(item.TargetDir),
pathsJSON,
item.BytesTransferred,
strings.TrimSpace(item.Status),
strings.TrimSpace(item.Error),
item.StartedAt,
item.FinishedAt,
item.DurationMS,
strings.TrimSpace(item.RemoteAddr),
strings.TrimSpace(item.UserAgent),
)
return item, err
}
func (s *Store) FinishSSHFileTransfer(ctx context.Context, id string, status string, paths []string, bytesTransferred int64, errText string) error {
var startedAt int64
var finishedAt int64
var durationMS int64
var pathsJSON string
var err error
var tx txExecutor
var owned bool
pathsJSON, err = encodeStringList(paths)
if err != nil { return err }
finishedAt = time.Now().UTC().Unix()
tx, owned, err = s.beginImmediateContext(ctx)
if err != nil { return err }
err = tx.QueryRow(`SELECT started_at FROM ssh_file_transfers WHERE public_id = ?`, strings.TrimSpace(id)).Scan(&startedAt)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
durationMS = 0
if startedAt > 0 && finishedAt >= startedAt {
durationMS = (finishedAt - startedAt) * 1000
}
_, err = tx.Exec(`UPDATE ssh_file_transfers
SET status = ?, paths_json = ?, bytes_transferred = ?, error = ?, finished_at = ?, duration_ms = ?
WHERE public_id = ?`,
strings.TrimSpace(status),
pathsJSON,
bytesTransferred,
strings.TrimSpace(errText),
finishedAt,
durationMS,
strings.TrimSpace(id),
)
if err != nil {
rollbackIfOwned(tx, owned)
return err
}
return commitIfOwned(tx, owned)
}
func scanSSHFileTransferRows(rows *sql.Rows, limit int) ([]models.SSHFileTransfer, bool, error) {
var items []models.SSHFileTransfer
var item models.SSHFileTransfer
var pathsJSON string
var err error
for rows.Next() {
err = rows.Scan(
&item.ID,
&item.SessionID,
&item.UserID,
&item.Username,
&item.ProfileID,
&item.ServerID,
&item.ServerName,
&item.RemoteUsername,
&item.Operation,
&item.SourceSessionID,
&item.TargetSessionID,
&item.TargetServerID,
&item.TargetServerName,
&item.TargetDir,
&pathsJSON,
&item.BytesTransferred,
&item.Status,
&item.Error,
&item.StartedAt,
&item.FinishedAt,
&item.DurationMS,
&item.RemoteAddr,
&item.UserAgent,
)
if err != nil { return nil, false, err }
item.Paths, err = decodeStringList(pathsJSON)
if err != nil { return nil, false, err }
items = append(items, item)
}
err = rows.Err()
if err != nil { return nil, false, err }
if len(items) > limit {
items = items[:limit]
return items, true, nil
}
return items, false, nil
}
func (s *Store) ListSSHFileTransfersForUserFiltered(userID string, limit int, offset int, query string, status string, sessionID string) ([]models.SSHFileTransfer, bool, error) {
var rows *sql.Rows
var whereParts []string
var args []any
var sqlQuery string
var like string
var err error
if limit <= 0 { limit = 20 }
if limit > 200 { limit = 200 }
if offset < 0 { offset = 0 }
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
sessionID = strings.TrimSpace(sessionID)
whereParts = append(whereParts, "u.public_id = ?")
args = append(args, strings.TrimSpace(userID))
if sessionID != "" {
whereParts = append(whereParts, "sess.public_id = ?")
args = append(args, sessionID)
}
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.username LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ft.status = ?")
args = append(args, status)
}
sqlQuery = fmt.Sprintf(`SELECT ft.public_id, COALESCE(sess.public_id, ''), COALESCE(u.public_id, ''), ft.username, COALESCE(p.public_id, ''), COALESCE(s.public_id, ''), ft.server_name, ft.remote_username, ft.operation, COALESCE(ssess.public_id, ''), COALESCE(tsess.public_id, ''), COALESCE(ts.public_id, ''), ft.target_server_name, ft.target_dir, ft.paths_json, ft.bytes_transferred, ft.status, ft.error, ft.started_at, ft.finished_at, ft.duration_ms, ft.remote_addr, ft.user_agent
FROM ssh_file_transfers ft
LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id
LEFT JOIN users u ON u.id = ft.user_id
LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id
LEFT JOIN ssh_servers s ON s.id = ft.server_id
LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id
LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id
LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id
WHERE %s
ORDER BY ft.started_at DESC, ft.id DESC
LIMIT ? OFFSET ?`, strings.Join(whereParts, " AND "))
args = append(args, limit + 1, offset)
rows, err = s.Query(sqlQuery, args...)
if err != nil { return nil, false, err }
defer rows.Close()
return scanSSHFileTransferRows(rows, limit)
}
func (s *Store) CountSSHFileTransfersForUserFiltered(userID string, query string, status string, sessionID string) (int, error) {
var whereParts []string
var args []any
var sqlQuery string
var like string
var total int
var err error
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
sessionID = strings.TrimSpace(sessionID)
whereParts = append(whereParts, "u.public_id = ?")
args = append(args, strings.TrimSpace(userID))
if sessionID != "" {
whereParts = append(whereParts, "sess.public_id = ?")
args = append(args, sessionID)
}
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.username LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ft.status = ?")
args = append(args, status)
}
sqlQuery = fmt.Sprintf(`SELECT COUNT(*)
FROM ssh_file_transfers ft
LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id
LEFT JOIN users u ON u.id = ft.user_id
LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id
LEFT JOIN ssh_servers s ON s.id = ft.server_id
LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id
LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id
LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id
WHERE %s`, strings.Join(whereParts, " AND "))
err = s.QueryRow(sqlQuery, args...).Scan(&total)
if err != nil {
return 0, err
}
return total, nil
}
func (s *Store) ListSSHFileTransfersFiltered(limit int, offset int, query string, status string, sessionID string, userID string) ([]models.SSHFileTransfer, bool, error) {
var rows *sql.Rows
var whereParts []string
var args []any
var sqlQuery string
var like string
var err error
if limit <= 0 { limit = 20 }
if limit > 200 { limit = 200 }
if offset < 0 { offset = 0 }
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
sessionID = strings.TrimSpace(sessionID)
userID = strings.TrimSpace(userID)
if sessionID != "" {
whereParts = append(whereParts, "sess.public_id = ?")
args = append(args, sessionID)
}
if userID != "" {
whereParts = append(whereParts, "u.public_id = ?")
args = append(args, userID)
}
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ft.username LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ft.status = ?")
args = append(args, status)
}
sqlQuery = `SELECT ft.public_id, COALESCE(sess.public_id, ''), COALESCE(u.public_id, ''), ft.username, COALESCE(p.public_id, ''), COALESCE(s.public_id, ''), ft.server_name, ft.remote_username, ft.operation, COALESCE(ssess.public_id, ''), COALESCE(tsess.public_id, ''), COALESCE(ts.public_id, ''), ft.target_server_name, ft.target_dir, ft.paths_json, ft.bytes_transferred, ft.status, ft.error, ft.started_at, ft.finished_at, ft.duration_ms, ft.remote_addr, ft.user_agent
FROM ssh_file_transfers ft
LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id
LEFT JOIN users u ON u.id = ft.user_id
LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id
LEFT JOIN ssh_servers s ON s.id = ft.server_id
LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id
LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id
LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id`
if len(whereParts) > 0 {
sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ")
}
sqlQuery += "\n\t\tORDER BY ft.started_at DESC, ft.id DESC\n\t\tLIMIT ? OFFSET ?"
args = append(args, limit + 1, offset)
rows, err = s.Query(sqlQuery, args...)
if err != nil { return nil, false, err }
defer rows.Close()
return scanSSHFileTransferRows(rows, limit)
}
func (s *Store) CountSSHFileTransfersFiltered(query string, status string, sessionID string, userID string) (int, error) {
var whereParts []string
var args []any
var sqlQuery string
var like string
var total int
var err error
query = strings.TrimSpace(query)
status = strings.ToLower(strings.TrimSpace(status))
sessionID = strings.TrimSpace(sessionID)
userID = strings.TrimSpace(userID)
if sessionID != "" {
whereParts = append(whereParts, "sess.public_id = ?")
args = append(args, sessionID)
}
if userID != "" {
whereParts = append(whereParts, "u.public_id = ?")
args = append(args, userID)
}
if query != "" {
like = "%" + query + "%"
whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ft.username LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)")
args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like)
}
if status != "" {
whereParts = append(whereParts, "ft.status = ?")
args = append(args, status)
}
sqlQuery = `SELECT COUNT(*)
FROM ssh_file_transfers ft
LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id
LEFT JOIN users u ON u.id = ft.user_id
LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id
LEFT JOIN ssh_servers s ON s.id = ft.server_id
LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id
LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id
LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id`
if len(whereParts) > 0 {
sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ")
}
err = s.QueryRow(sqlQuery, args...).Scan(&total)
if err != nil {
return 0, err
}
return total, nil
}
func createSSHSecretTx(tx txExecutor, item models.SSHSecret) (models.SSHSecret, error) {
var err error
var now int64
if strings.TrimSpace(item.ID) == "" {
item.ID, err = util.NewID()
if err != nil {
return item, err
}
}
now = time.Now().UTC().Unix()
item.CreatedAt = now
item.UpdatedAt = now
_, err = tx.Exec(`INSERT INTO ssh_secrets (
public_id,
kind,
payload,
password,
metadata_json,
created_by_kind,
created_by_user_id,
created_by_principal_id,
created_by_subject_name,
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`,
item.ID,
strings.TrimSpace(item.Kind),
item.Payload,
item.Password,
normalizeJSONText(item.MetadataJSON, "{}"),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectID),
strings.TrimSpace(item.CreatedByKind),
strings.TrimSpace(item.CreatedBySubjectName),
item.CreatedAt,
item.UpdatedAt,
)
return item, err
}
func updateSSHSecretTx(tx txExecutor, item models.SSHSecret) (models.SSHSecret, error) {
var err error
item.UpdatedAt = time.Now().UTC().Unix()
_, err = tx.Exec(`UPDATE ssh_secrets SET kind = ?, payload = ?, metadata_json = ?, updated_at = ? WHERE public_id = ?`,
strings.TrimSpace(item.Kind),
item.Payload,
normalizeJSONText(item.MetadataJSON, "{}"),
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
if err == nil && item.Password != "" {
// TODO: how can i update the password without security breach?
// currently, i perform the password update only if the given password is not blank
_, err = tx.Exec(`UPDATE ssh_secrets SET password = ?, updated_at = ? WHERE public_id = ?`,
item.Password,
item.UpdatedAt,
strings.TrimSpace(item.ID),
)
}
return item, err
}
func insertSSHAccessProfileTargetTx(tx txExecutor, profileID string, targetType string, targetID string, createdAt int64) error {
var err error
var id string
targetType = strings.ToLower(strings.TrimSpace(targetType))
targetID = strings.TrimSpace(targetID)
if targetType != SSHAccessProfileTargetTypeUser && targetType != SSHAccessProfileTargetTypeGroup {
return errors.New("target_type must be user or group")
}
if targetID == "" {
return errors.New("target_id is required")
}
id, err = util.NewID()
if err != nil {
return err
}
_, err = tx.Exec(`INSERT INTO ssh_access_profile_targets (
public_id,
profile_id,
target_type,
target_id,
created_at
) VALUES (?, (SELECT id FROM ssh_access_profiles WHERE public_id = ?), ?, CASE WHEN ? = 'user' THEN (SELECT id FROM users WHERE public_id = ?) WHEN ? = 'group' THEN (SELECT id FROM user_groups WHERE public_id = ?) ELSE NULL END, ?)`,
id,
strings.TrimSpace(profileID),
targetType,
targetType,
targetID,
targetType,
targetID,
createdAt,
)
return err
}
func normalizeStringList(items []string) []string {
var out []string
var seen map[string]bool
var value string
var i int
seen = map[string]bool{}
for i = 0; i < len(items); i++ {
value = strings.TrimSpace(items[i])
if value == "" {
continue
}
if seen[value] {
continue
}
seen[value] = true
out = append(out, value)
}
return out
}
func encodeStringList(items []string) (string, error) {
var raw []byte
var err error
raw, err = json.Marshal(items)
if err != nil {
return "", err
}
return string(raw), nil
}
func decodeStringList(raw string) ([]string, error) {
var out []string
var err error
raw = strings.TrimSpace(raw)
if raw == "" {
return []string{}, nil
}
err = json.Unmarshal([]byte(raw), &out)
if err != nil {
return nil, err
}
return normalizeStringList(out), nil
}
func normalizeJSONText(raw string, fallback string) string {
var value string
value = strings.TrimSpace(raw)
if value == "" {
return fallback
}
return value
}
func emptyStringToNil(raw string) any {
var value string
value = strings.TrimSpace(raw)
if value == "" {
return nil
}
return value
}
func normalizeSSHServerHostKeyPolicy(value string) string {
value = strings.TrimSpace(value)
if value == "trust_on_first_use" {
return value
}
return "strict"
}
func NormalizeSSHServerTags(items []string) []string {
var out []string
out = normalizeStringList(items)
sort.Strings(out)
return out
}