package db import "context" import "database/sql" import "encoding/json" import "errors" import "fmt" import "sort" import "strings" import "time" import "codit/internal/models" import "codit/internal/util" var ErrSSHServerHasSessionHistory error = errors.New("ssh server has session history") const SSHOwnerScopeUser = "user" const SSHOwnerScopeAdmin = "admin" const SSHOwnerScopeAdminShared = "admin_shared" const SSHAccessProfileServerTargetTypeServer = "server" const SSHAccessProfileServerTargetTypeGroup = "group" const SSHAccessProfileTargetTypeUser = "user" const SSHAccessProfileTargetTypeGroup = "group" type SSHServerDeleteReferenceError struct { Kind string Count int64 } func (e *SSHServerDeleteReferenceError) Error() string { var suffix string suffix = "" if e.Count != 1 { suffix = "s" } return fmt.Sprintf("cannot delete ssh server because it is used by %d %s%s", e.Count, e.Kind, suffix) } func (s *Store) ListSSHServers() ([]models.SSHServer, error) { var rows *sql.Rows var err error var items []models.SSHServer var item models.SSHServer var tagsJSON string rows, err = s.Query(`SELECT srv.public_id, srv.name, srv.host, srv.port, srv.description, srv.tags_json, srv.enabled, srv.host_key_policy, srv.owner_scope, COALESCE(ou.public_id, ''), srv.created_by_kind, COALESCE(srv_cbu.public_id, srv_cbp.public_id, ''), srv.created_by_subject_name, srv.created_at, srv.updated_at FROM ssh_servers srv LEFT JOIN users ou ON ou.id = srv.owner_user_id LEFT JOIN users srv_cbu ON srv_cbu.id = srv.created_by_user_id LEFT JOIN service_principals srv_cbp ON srv_cbp.id = srv.created_by_principal_id ORDER BY srv.name, srv.host, srv.port`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } item.Tags, err = decodeStringList(tagsJSON) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHServer(id string) (models.SSHServer, error) { var row *sql.Row var item models.SSHServer var tagsJSON string var err error row = s.QueryRow(`SELECT srv.public_id, srv.name, srv.host, srv.port, srv.description, srv.tags_json, srv.enabled, srv.host_key_policy, srv.owner_scope, COALESCE(ou.public_id, ''), srv.created_by_kind, COALESCE(srv_cbu.public_id, srv_cbp.public_id, ''), srv.created_by_subject_name, srv.created_at, srv.updated_at FROM ssh_servers srv LEFT JOIN users ou ON ou.id = srv.owner_user_id LEFT JOIN users srv_cbu ON srv_cbu.id = srv.created_by_user_id LEFT JOIN service_principals srv_cbp ON srv_cbp.id = srv.created_by_principal_id WHERE srv.public_id = ?`, strings.TrimSpace(id)) err = row.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return item, err } item.Tags, err = decodeStringList(tagsJSON) if err != nil { return item, err } return item, nil } func (s *Store) ListSSHServersForUser(userID string) ([]models.SSHServer, error) { var rows *sql.Rows var err error var items []models.SSHServer var item models.SSHServer var tagsJSON string var trimmedUserID string trimmedUserID = strings.TrimSpace(userID) rows, err = s.Query(`SELECT srv.public_id, srv.name, srv.host, srv.port, srv.description, srv.tags_json, srv.enabled, srv.host_key_policy, srv.owner_scope, COALESCE(ou.public_id, ''), srv.created_by_kind, COALESCE(srv_cbu.public_id, srv_cbp.public_id, ''), srv.created_by_subject_name, srv.created_at, srv.updated_at FROM ssh_servers srv LEFT JOIN users ou ON ou.id = srv.owner_user_id LEFT JOIN users srv_cbu ON srv_cbu.id = srv.created_by_user_id LEFT JOIN service_principals srv_cbp ON srv_cbp.id = srv.created_by_principal_id WHERE srv.owner_scope = ? AND ou.public_id = ? ORDER BY srv.name, srv.host, srv.port`, SSHOwnerScopeUser, trimmedUserID) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.Name, &item.Host, &item.Port, &item.Description, &tagsJSON, &item.Enabled, &item.HostKeyPolicy, &item.OwnerScope, &item.OwnerUserID, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return nil, err } item.Tags, err = decodeStringList(tagsJSON) if err != nil { return nil, err } item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHServerForUser(userID string, id string) (models.SSHServer, error) { var item models.SSHServer var items []models.SSHServer var trimmedID string var i int var err error trimmedID = strings.TrimSpace(id) items, err = s.ListSSHServersForUser(userID) if err != nil { return item, err } for i = 0; i < len(items); i++ { if items[i].ID == trimmedID { return items[i], nil } } return item, sql.ErrNoRows } func (s *Store) GetOwnedSSHServerForUser(userID string, id string) (models.SSHServer, error) { var item models.SSHServer var trimmedUserID string var err error trimmedUserID = strings.TrimSpace(userID) item, err = s.GetSSHServer(strings.TrimSpace(id)) if err != nil { return item, err } item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID if !item.Editable { return item, sql.ErrNoRows } return item, nil } func (s *Store) CreateSSHServer(item models.SSHServer) (models.SSHServer, error) { var err error var tagsJSON string var now int64 if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } if item.Port <= 0 { item.Port = 22 } item.HostKeyPolicy = normalizeSSHServerHostKeyPolicy(item.HostKeyPolicy) item.Tags = normalizeStringList(item.Tags) item.OwnerScope = strings.TrimSpace(item.OwnerScope) if item.OwnerScope != SSHOwnerScopeUser { item.OwnerScope = SSHOwnerScopeAdmin item.OwnerUserID = "" } if item.OwnerScope == SSHOwnerScopeUser && strings.TrimSpace(item.OwnerUserID) == "" { return item, errors.New("owner_user_id is required") } tagsJSON, err = encodeStringList(item.Tags) if err != nil { return item, err } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now _, err = s.Exec(`INSERT INTO ssh_servers ( public_id, name, host, port, description, tags_json, enabled, host_key_policy, owner_scope, owner_user_id, created_by_kind, created_by_user_id, created_by_principal_id, created_by_subject_name, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM users WHERE public_id = ?) END, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`, item.ID, strings.TrimSpace(item.Name), strings.TrimSpace(item.Host), item.Port, strings.TrimSpace(item.Description), tagsJSON, item.Enabled, item.HostKeyPolicy, strings.TrimSpace(item.OwnerScope), strings.TrimSpace(item.OwnerUserID), strings.TrimSpace(item.OwnerUserID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt, ) return item, err } func (s *Store) UpdateSSHServer(item models.SSHServer) (models.SSHServer, error) { var err error var tagsJSON string if item.Port <= 0 { item.Port = 22 } item.HostKeyPolicy = normalizeSSHServerHostKeyPolicy(item.HostKeyPolicy) item.Tags = normalizeStringList(item.Tags) tagsJSON, err = encodeStringList(item.Tags) if err != nil { return item, err } item.UpdatedAt = time.Now().UTC().Unix() // TODO: need to protect it with tx. _, err = s.Exec(`UPDATE ssh_servers SET name = ?, host = ?, port = ?, description = ?, tags_json = ?, enabled = ?, host_key_policy = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Name), strings.TrimSpace(item.Host), item.Port, strings.TrimSpace(item.Description), tagsJSON, item.Enabled, item.HostKeyPolicy, item.UpdatedAt, strings.TrimSpace(item.ID), ) if err != nil { return item, err } return s.GetSSHServer(item.ID) } func (s *Store) DeleteSSHServer(ctx context.Context, id string) error { var err error var tx txExecutor var owned bool var count int64 var trimmedID string trimmedID = strings.TrimSpace(id) tx, owned, err = s.beginImmediateContext(ctx) if err != nil { return err } err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_access_profiles WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, trimmedID).Scan(&count) if err != nil { rollbackIfOwned(tx, owned) return err } if count > 0 { rollbackIfOwned(tx, owned) return &SSHServerDeleteReferenceError{Kind: "SSH access profile", Count: count} } err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_server_group_members WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, trimmedID).Scan(&count) if err != nil { rollbackIfOwned(tx, owned) return err } if count > 0 { rollbackIfOwned(tx, owned) return &SSHServerDeleteReferenceError{Kind: "SSH server group membership", Count: count} } err = tx.QueryRow(`SELECT COUNT(*) FROM ssh_sessions WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?)`, trimmedID).Scan(&count) if err != nil { rollbackIfOwned(tx, owned) return err } if count > 0 { rollbackIfOwned(tx, owned) return &SSHServerDeleteReferenceError{Kind: "SSH session history record", Count: count} } _, err = tx.Exec(`DELETE FROM ssh_servers WHERE public_id = ?`, trimmedID) if err != nil { rollbackIfOwned(tx, owned) return err } return commitIfOwned(tx, owned) } func (s *Store) ListSSHAccessProfiles() ([]models.SSHAccessProfile, error) { var rows *sql.Rows var err error var items []models.SSHAccessProfile var item models.SSHAccessProfile var grantIDsJSON string var serverTagsJSON string var groupTagsJSON string rows, err = s.Query(`SELECT p.public_id, COALESCE(s.public_id, ''), COALESCE(p.server_target_type, 'server'), COALESCE(g.public_id, ''), COALESCE(g.name, ''), COALESCE(g.description, ''), COALESCE(g.enabled, 0), COALESCE(g.created_by_kind, ''), COALESCE(g_cbu.public_id, g_cbp.public_id, ''), COALESCE(g.created_by_subject_name, ''), COALESCE(g.created_at, 0), COALESCE(g.updated_at, 0), COALESCE(g.tags_json, '[]'), p.name, p.description, p.remote_username, p.auth_method, p.second_factor_mode, p.owner_scope, COALESCE(pou.public_id, ''), p.allow_user_edit, p.enabled, COALESCE(sec.public_id, ''), COALESCE(c.public_id, ''), COALESCE(c.name, ''), p.auth_public_key, p.auth_public_key_fingerprint, COALESCE(ca.public_id, ''), p.ssh_principal_grant_ids_json, p.default_cert_valid_seconds, p.max_cert_valid_seconds, p.valid_after, p.valid_before, p.created_by_kind, COALESCE(p_cbu.public_id, p_cbp.public_id, ''), p.created_by_subject_name, p.created_at, p.updated_at, s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.owner_scope, COALESCE(sou.public_id, ''), s.created_by_kind, COALESCE(s_cbu.public_id, s_cbp.public_id, ''), s.created_by_subject_name, s.created_at, s.updated_at FROM ssh_access_profiles p JOIN ssh_servers s ON s.id = p.server_id LEFT JOIN users sou ON sou.id = s.owner_user_id LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id LEFT JOIN users pou ON pou.id = p.owner_user_id LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id ORDER BY s.name, p.name`) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.ServerID, &item.ServerTargetType, &item.ServerGroupID, &item.ServerGroup.Name, &item.ServerGroup.Description, &item.ServerGroup.Enabled, &item.ServerGroup.CreatedByKind, &item.ServerGroup.CreatedBySubjectID, &item.ServerGroup.CreatedBySubjectName, &item.ServerGroup.CreatedAt, &item.ServerGroup.UpdatedAt, &groupTagsJSON, &item.Name, &item.Description, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.OwnerScope, &item.OwnerUserID, &item.AllowUserEdit, &item.Enabled, &item.SecretID, &item.SSHCredentialID, &item.SSHCredentialName, &item.AuthPublicKey, &item.AuthPublicKeyFingerprint, &item.SSHUserCAID, &grantIDsJSON, &item.DefaultCertValidSeconds, &item.MaxCertValidSeconds, &item.ValidAfter, &item.ValidBefore, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt, &item.Server.ID, &item.Server.Name, &item.Server.Host, &item.Server.Port, &item.Server.Description, &serverTagsJSON, &item.Server.Enabled, &item.Server.HostKeyPolicy, &item.Server.OwnerScope, &item.Server.OwnerUserID, &item.Server.CreatedByKind, &item.Server.CreatedBySubjectID, &item.Server.CreatedBySubjectName, &item.Server.CreatedAt, &item.Server.UpdatedAt, ) if err != nil { return nil, err } if item.ServerGroupID != "" { item.ServerGroup.ID = item.ServerGroupID } item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON) if err != nil { return nil, err } item.Server.Tags, err = decodeStringList(serverTagsJSON) if err != nil { return nil, err } item.ServerGroup.Tags, err = decodeStringList(groupTagsJSON) if err != nil { return nil, err } item.Targets, err = s.listSSHAccessProfileTargets(item.ID) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHAccessProfile(id string) (models.SSHAccessProfile, error) { var row *sql.Row var item models.SSHAccessProfile var err error var grantIDsJSON string var serverTagsJSON string var groupTagsJSON string row = s.QueryRow(`SELECT p.public_id, COALESCE(s.public_id, ''), COALESCE(p.server_target_type, 'server'), COALESCE(g.public_id, ''), COALESCE(g.name, ''), COALESCE(g.description, ''), COALESCE(g.enabled, 0), COALESCE(g.created_by_kind, ''), COALESCE(g_cbu.public_id, g_cbp.public_id, ''), COALESCE(g.created_by_subject_name, ''), COALESCE(g.created_at, 0), COALESCE(g.updated_at, 0), COALESCE(g.tags_json, '[]'), p.name, p.description, p.remote_username, p.auth_method, p.second_factor_mode, p.owner_scope, COALESCE(pou.public_id, ''), p.allow_user_edit, p.enabled, COALESCE(sec.public_id, ''), COALESCE(c.public_id, ''), COALESCE(c.name, ''), p.auth_public_key, p.auth_public_key_fingerprint, COALESCE(ca.public_id, ''), p.ssh_principal_grant_ids_json, p.default_cert_valid_seconds, p.max_cert_valid_seconds, p.valid_after, p.valid_before, p.created_by_kind, COALESCE(p_cbu.public_id, p_cbp.public_id, ''), p.created_by_subject_name, p.created_at, p.updated_at, s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.owner_scope, COALESCE(sou.public_id, ''), s.created_by_kind, COALESCE(s_cbu.public_id, s_cbp.public_id, ''), s.created_by_subject_name, s.created_at, s.updated_at FROM ssh_access_profiles p JOIN ssh_servers s ON s.id = p.server_id LEFT JOIN users sou ON sou.id = s.owner_user_id LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id LEFT JOIN users pou ON pou.id = p.owner_user_id LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id WHERE p.public_id = ?`, strings.TrimSpace(id)) err = row.Scan( &item.ID, &item.ServerID, &item.ServerTargetType, &item.ServerGroupID, &item.ServerGroup.Name, &item.ServerGroup.Description, &item.ServerGroup.Enabled, &item.ServerGroup.CreatedByKind, &item.ServerGroup.CreatedBySubjectID, &item.ServerGroup.CreatedBySubjectName, &item.ServerGroup.CreatedAt, &item.ServerGroup.UpdatedAt, &groupTagsJSON, &item.Name, &item.Description, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.OwnerScope, &item.OwnerUserID, &item.AllowUserEdit, &item.Enabled, &item.SecretID, &item.SSHCredentialID, &item.SSHCredentialName, &item.AuthPublicKey, &item.AuthPublicKeyFingerprint, &item.SSHUserCAID, &grantIDsJSON, &item.DefaultCertValidSeconds, &item.MaxCertValidSeconds, &item.ValidAfter, &item.ValidBefore, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt, &item.Server.ID, &item.Server.Name, &item.Server.Host, &item.Server.Port, &item.Server.Description, &serverTagsJSON, &item.Server.Enabled, &item.Server.HostKeyPolicy, &item.Server.OwnerScope, &item.Server.OwnerUserID, &item.Server.CreatedByKind, &item.Server.CreatedBySubjectID, &item.Server.CreatedBySubjectName, &item.Server.CreatedAt, &item.Server.UpdatedAt, ) if err != nil { return item, err } if item.ServerGroupID != "" { item.ServerGroup.ID = item.ServerGroupID } item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON) if err != nil { return item, err } item.Server.Tags, err = decodeStringList(serverTagsJSON) if err != nil { return item, err } item.ServerGroup.Tags, err = decodeStringList(groupTagsJSON) if err != nil { return item, err } item.Targets, err = s.listSSHAccessProfileTargets(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) CreateSSHAccessProfile(ctx context.Context, item models.SSHAccessProfile) (models.SSHAccessProfile, error) { var tx txExecutor var owned bool var err error var now int64 var grantIDsJSON string var i int var target models.SSHAccessProfileTarget var secret models.SSHSecret var credential models.SSHCredential if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if strings.TrimSpace(item.ServerID) == "" { return item, errors.New("server_id is required") } if strings.TrimSpace(item.RemoteUsername) == "" { return item, errors.New("remote_username is required") } if strings.TrimSpace(item.AuthMethod) == "" { return item, errors.New("auth_method is required") } if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } if item.DefaultCertValidSeconds <= 0 { item.DefaultCertValidSeconds = 3600 } if item.MaxCertValidSeconds <= 0 { item.MaxCertValidSeconds = item.DefaultCertValidSeconds } if item.ValidAfter < 0 { item.ValidAfter = 0 } if item.ValidBefore < 0 { item.ValidBefore = 0 } item.SSHPrincipalGrantIDs = normalizeStringList(item.SSHPrincipalGrantIDs) grantIDsJSON, err = encodeStringList(item.SSHPrincipalGrantIDs) if err != nil { return item, err } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now tx, owned, err = s.beginContext(ctx) if err != nil { return item, err } item.SecretPayload = strings.TrimSpace(item.SecretPayload) item.SecretPassword = strings.TrimSpace(item.SecretPassword) item.SSHCredentialID = strings.TrimSpace(item.SSHCredentialID) if item.SSHCredentialID != "" { credential, err = s.GetSSHCredential(item.SSHCredentialID) if err != nil { rollbackIfOwned(tx, owned) return item, err } item.SecretID = credential.SecretID item.AuthPublicKey = credential.PublicKey item.AuthPublicKeyFingerprint = credential.Fingerprint } else if item.SecretPayload != "" || item.SecretPassword != "" { secret = models.SSHSecret{ ID: item.SecretID, Kind: "private_key", Payload: item.SecretPayload, Password: item.SecretPassword, CreatedByKind: item.CreatedByKind, CreatedBySubjectID: item.CreatedBySubjectID, CreatedBySubjectName: item.CreatedBySubjectName, } secret, err = createSSHSecretTx(tx, secret) if err != nil { rollbackIfOwned(tx, owned) return item, err } item.SecretID = secret.ID } _, err = tx.Exec(`INSERT INTO ssh_access_profiles ( public_id, server_id, server_target_type, server_group_id, name, description, remote_username, auth_method, second_factor_mode, owner_scope, owner_user_id, allow_user_edit, enabled, secret_id, ssh_credential_id, auth_public_key, auth_public_key_fingerprint, ssh_user_ca_id, ssh_principal_grant_ids_json, default_cert_valid_seconds, max_cert_valid_seconds, valid_after, valid_before, created_by_kind, created_by_user_id, created_by_principal_id, created_by_subject_name, created_at, updated_at ) VALUES (?, (SELECT id FROM ssh_servers WHERE public_id = ?), ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_server_groups WHERE public_id = ?) END, ?, ?, ?, ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM users WHERE public_id = ?) END, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_secrets WHERE public_id = ?) END, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_credentials WHERE public_id = ?) END, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_user_cas WHERE public_id = ?) END, ?, ?, ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`, item.ID, strings.TrimSpace(item.ServerID), strings.TrimSpace(item.ServerTargetType), strings.TrimSpace(item.ServerGroupID), emptyStringToNil(item.ServerGroupID), strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), strings.TrimSpace(item.RemoteUsername), strings.TrimSpace(item.AuthMethod), strings.TrimSpace(item.SecondFactorMode), strings.TrimSpace(item.OwnerScope), strings.TrimSpace(item.OwnerUserID), strings.TrimSpace(item.OwnerUserID), item.AllowUserEdit, item.Enabled, strings.TrimSpace(item.SecretID), emptyStringToNil(item.SecretID), strings.TrimSpace(item.SSHCredentialID), emptyStringToNil(item.SSHCredentialID), strings.TrimSpace(item.AuthPublicKey), strings.TrimSpace(item.AuthPublicKeyFingerprint), strings.TrimSpace(item.SSHUserCAID), emptyStringToNil(item.SSHUserCAID), grantIDsJSON, item.DefaultCertValidSeconds, item.MaxCertValidSeconds, item.ValidAfter, item.ValidBefore, strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt, ) if err != nil { rollbackIfOwned(tx, owned) return item, err } for i = 0; i < len(item.Targets); i++ { target = item.Targets[i] err = insertSSHAccessProfileTargetTx(tx, item.ID, target.TargetType, target.TargetID, now) if err != nil { rollbackIfOwned(tx, owned) return item, err } } err = commitIfOwned(tx, owned) if err != nil { return item, err } item, err = s.GetSSHAccessProfile(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) UpdateSSHAccessProfile(ctx context.Context, item models.SSHAccessProfile) (models.SSHAccessProfile, error) { var tx txExecutor var owned bool var err error var grantIDsJSON string var i int var target models.SSHAccessProfileTarget var secret models.SSHSecret var credential models.SSHCredential if strings.TrimSpace(item.Name) == "" { return item, errors.New("name is required") } if strings.TrimSpace(item.ServerID) == "" { return item, errors.New("server_id is required") } if strings.TrimSpace(item.RemoteUsername) == "" { return item, errors.New("remote_username is required") } if strings.TrimSpace(item.AuthMethod) == "" { return item, errors.New("auth_method is required") } if item.DefaultCertValidSeconds <= 0 { item.DefaultCertValidSeconds = 3600 } if item.MaxCertValidSeconds <= 0 { item.MaxCertValidSeconds = item.DefaultCertValidSeconds } if item.ValidAfter < 0 { item.ValidAfter = 0 } if item.ValidBefore < 0 { item.ValidBefore = 0 } item.SSHPrincipalGrantIDs = normalizeStringList(item.SSHPrincipalGrantIDs) grantIDsJSON, err = encodeStringList(item.SSHPrincipalGrantIDs) if err != nil { return item, err } item.UpdatedAt = time.Now().UTC().Unix() tx, owned, err = s.beginContext(ctx) if err != nil { return item, err } item.SecretPayload = strings.TrimSpace(item.SecretPayload) item.SecretPassword = strings.TrimSpace(item.SecretPassword) item.SSHCredentialID = strings.TrimSpace(item.SSHCredentialID) if item.SSHCredentialID != "" { credential, err = s.GetSSHCredential(item.SSHCredentialID) if err != nil { rollbackIfOwned(tx, owned) return item, err } item.SecretID = credential.SecretID item.AuthPublicKey = credential.PublicKey item.AuthPublicKeyFingerprint = credential.Fingerprint } else if item.SecretPayload != "" || item.SecretPassword != "" { secret = models.SSHSecret{ ID: item.SecretID, Kind: "private_key", Payload: item.SecretPayload, Password: item.SecretPassword, CreatedByKind: item.CreatedByKind, CreatedBySubjectID: item.CreatedBySubjectID, CreatedBySubjectName: item.CreatedBySubjectName, } if strings.TrimSpace(item.SecretID) == "" { secret, err = createSSHSecretTx(tx, secret) if err != nil { rollbackIfOwned(tx, owned) return item, err } } else { secret, err = updateSSHSecretTx(tx, secret) if err != nil { rollbackIfOwned(tx, owned) return item, err } } item.SecretID = secret.ID } _, err = tx.Exec(`UPDATE ssh_access_profiles SET server_id = (SELECT id FROM ssh_servers WHERE public_id = ?), server_target_type = ?, server_group_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_server_groups WHERE public_id = ?) END, name = ?, description = ?, remote_username = ?, auth_method = ?, second_factor_mode = ?, owner_scope = ?, owner_user_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM users WHERE public_id = ?) END, allow_user_edit = ?, enabled = ?, secret_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_secrets WHERE public_id = ?) END, ssh_credential_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_credentials WHERE public_id = ?) END, auth_public_key = ?, auth_public_key_fingerprint = ?, ssh_user_ca_id = CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_user_cas WHERE public_id = ?) END, ssh_principal_grant_ids_json = ?, default_cert_valid_seconds = ?, max_cert_valid_seconds = ?, valid_after = ?, valid_before = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.ServerID), strings.TrimSpace(item.ServerTargetType), strings.TrimSpace(item.ServerGroupID), emptyStringToNil(item.ServerGroupID), strings.TrimSpace(item.Name), strings.TrimSpace(item.Description), strings.TrimSpace(item.RemoteUsername), strings.TrimSpace(item.AuthMethod), strings.TrimSpace(item.SecondFactorMode), strings.TrimSpace(item.OwnerScope), strings.TrimSpace(item.OwnerUserID), strings.TrimSpace(item.OwnerUserID), item.AllowUserEdit, item.Enabled, strings.TrimSpace(item.SecretID), emptyStringToNil(item.SecretID), strings.TrimSpace(item.SSHCredentialID), emptyStringToNil(item.SSHCredentialID), strings.TrimSpace(item.AuthPublicKey), strings.TrimSpace(item.AuthPublicKeyFingerprint), strings.TrimSpace(item.SSHUserCAID), emptyStringToNil(item.SSHUserCAID), grantIDsJSON, item.DefaultCertValidSeconds, item.MaxCertValidSeconds, item.ValidAfter, item.ValidBefore, item.UpdatedAt, strings.TrimSpace(item.ID), ) if err != nil { rollbackIfOwned(tx, owned) return item, err } _, err = tx.Exec(`DELETE FROM ssh_access_profile_targets WHERE profile_id = (SELECT id FROM ssh_access_profiles WHERE public_id = ?)`, strings.TrimSpace(item.ID)) if err != nil { rollbackIfOwned(tx, owned) return item, err } for i = 0; i < len(item.Targets); i++ { target = item.Targets[i] err = insertSSHAccessProfileTargetTx(tx, item.ID, target.TargetType, target.TargetID, item.UpdatedAt) if err != nil { rollbackIfOwned(tx, owned) return item, err } } err = commitIfOwned(tx, owned) if err != nil { return item, err } item, err = s.GetSSHAccessProfile(item.ID) if err != nil { return item, err } return item, nil } func (s *Store) DeleteSSHAccessProfile(ctx context.Context, id string) error { var tx txExecutor var owned bool var row *sql.Row var secretID sql.NullString var credentialID sql.NullString var err error tx, owned, err = s.beginContext(ctx) if err != nil { return err } row = tx.QueryRow(`SELECT COALESCE(sec.public_id, ''), COALESCE(c.public_id, '') FROM ssh_access_profiles p LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id WHERE p.public_id = ?`, strings.TrimSpace(id)) err = row.Scan(&secretID, &credentialID) if err != nil { rollbackIfOwned(tx, owned) return err } _, err = tx.Exec(`DELETE FROM ssh_access_profiles WHERE public_id = ?`, strings.TrimSpace(id)) if err != nil { rollbackIfOwned(tx, owned) return err } if secretID.Valid && strings.TrimSpace(secretID.String) != "" && (!credentialID.Valid || strings.TrimSpace(credentialID.String) == "") { _, err = tx.Exec(`DELETE FROM ssh_secrets WHERE public_id = ?`, strings.TrimSpace(secretID.String)) if err != nil { rollbackIfOwned(tx, owned) return err } } err = commitIfOwned(tx, owned) if err != nil { return err } return nil } func (s *Store) ListSSHAccessProfilesForUser(userID string) ([]models.SSHAccessProfile, error) { var rows *sql.Rows var err error var items []models.SSHAccessProfile var item models.SSHAccessProfile var grantIDsJSON string var serverTagsJSON string var groupTagsJSON string var trimmedUserID string trimmedUserID = strings.TrimSpace(userID) rows, err = s.Query(`SELECT DISTINCT p.public_id, COALESCE(s.public_id, ''), COALESCE(p.server_target_type, 'server'), COALESCE(g.public_id, ''), COALESCE(g.name, ''), COALESCE(g.description, ''), COALESCE(g.enabled, 0), COALESCE(g.created_by_kind, ''), COALESCE(g_cbu.public_id, g_cbp.public_id, ''), COALESCE(g.created_by_subject_name, ''), COALESCE(g.created_at, 0), COALESCE(g.updated_at, 0), COALESCE(g.tags_json, '[]'), p.name, p.description, p.remote_username, p.auth_method, p.second_factor_mode, p.owner_scope, COALESCE(pou.public_id, ''), p.allow_user_edit, p.enabled, COALESCE(sec.public_id, ''), COALESCE(c.public_id, ''), COALESCE(c.name, ''), p.auth_public_key, p.auth_public_key_fingerprint, COALESCE(ca.public_id, ''), p.ssh_principal_grant_ids_json, p.default_cert_valid_seconds, p.max_cert_valid_seconds, p.valid_after, p.valid_before, p.created_by_kind, COALESCE(p_cbu.public_id, p_cbp.public_id, ''), p.created_by_subject_name, p.created_at, p.updated_at, s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.owner_scope, COALESCE(sou.public_id, ''), s.created_by_kind, COALESCE(s_cbu.public_id, s_cbp.public_id, ''), s.created_by_subject_name, s.created_at, s.updated_at FROM ssh_access_profiles p JOIN ssh_servers s ON s.id = p.server_id LEFT JOIN users sou ON sou.id = s.owner_user_id LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id LEFT JOIN users pou ON pou.id = p.owner_user_id LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id LEFT JOIN ssh_access_profile_targets t ON t.profile_id = p.id LEFT JOIN user_groups ug ON t.target_type = 'group' AND ug.id = t.target_id LEFT JOIN user_group_members gm ON t.target_type = 'group' AND gm.group_id = ug.id LEFT JOIN users gu ON gm.user_id = gu.id LEFT JOIN users tu ON t.target_type = 'user' AND tu.id = t.target_id WHERE ( (p.owner_scope = ? AND pou.public_id = ?) OR ( p.owner_scope = ? AND p.enabled = 1 AND s.enabled = 1 AND (p.valid_after = 0 OR p.valid_after <= CAST(strftime('%s','now') AS INTEGER)) AND (p.valid_before = 0 OR p.valid_before >= CAST(strftime('%s','now') AS INTEGER)) AND ( (t.target_type = 'user' AND tu.public_id = ?) OR (t.target_type = 'group' AND ug.disabled = 0 AND (ug.scope = 'all_users' OR gu.public_id = ?)) ) ) ) ORDER BY s.name, p.name`, SSHOwnerScopeUser, trimmedUserID, SSHOwnerScopeAdminShared, trimmedUserID, trimmedUserID) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan( &item.ID, &item.ServerID, &item.ServerTargetType, &item.ServerGroupID, &item.ServerGroup.Name, &item.ServerGroup.Description, &item.ServerGroup.Enabled, &item.ServerGroup.CreatedByKind, &item.ServerGroup.CreatedBySubjectID, &item.ServerGroup.CreatedBySubjectName, &item.ServerGroup.CreatedAt, &item.ServerGroup.UpdatedAt, &groupTagsJSON, &item.Name, &item.Description, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.OwnerScope, &item.OwnerUserID, &item.AllowUserEdit, &item.Enabled, &item.SecretID, &item.SSHCredentialID, &item.SSHCredentialName, &item.AuthPublicKey, &item.AuthPublicKeyFingerprint, &item.SSHUserCAID, &grantIDsJSON, &item.DefaultCertValidSeconds, &item.MaxCertValidSeconds, &item.ValidAfter, &item.ValidBefore, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt, &item.Server.ID, &item.Server.Name, &item.Server.Host, &item.Server.Port, &item.Server.Description, &serverTagsJSON, &item.Server.Enabled, &item.Server.HostKeyPolicy, &item.Server.OwnerScope, &item.Server.OwnerUserID, &item.Server.CreatedByKind, &item.Server.CreatedBySubjectID, &item.Server.CreatedBySubjectName, &item.Server.CreatedAt, &item.Server.UpdatedAt, ) if err != nil { return nil, err } if item.ServerGroupID != "" { item.ServerGroup.ID = item.ServerGroupID } item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON) if err != nil { return nil, err } item.Server.Tags, err = decodeStringList(serverTagsJSON) if err != nil { return nil, err } item.ServerGroup.Tags, err = decodeStringList(groupTagsJSON) if err != nil { return nil, err } item.Targets, err = s.listSSHAccessProfileTargets(item.ID) if err != nil { return nil, err } item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHAccessProfileForUser(userID string, id string) (models.SSHAccessProfile, error) { var row *sql.Row var item models.SSHAccessProfile var grantIDsJSON string var serverTagsJSON string var groupTagsJSON string var trimmedUserID string var trimmedID string var err error // this function gates the profiles with s.enabled=1. // if we want all profiles, we should remove the condition trimmedUserID = strings.TrimSpace(userID) trimmedID = strings.TrimSpace(id) row = s.QueryRow(`SELECT DISTINCT p.public_id, COALESCE(s.public_id, ''), COALESCE(p.server_target_type, 'server'), COALESCE(g.public_id, ''), COALESCE(g.name, ''), COALESCE(g.description, ''), COALESCE(g.enabled, 0), COALESCE(g.created_by_kind, ''), COALESCE(g_cbu.public_id, g_cbp.public_id, ''), COALESCE(g.created_by_subject_name, ''), COALESCE(g.created_at, 0), COALESCE(g.updated_at, 0), COALESCE(g.tags_json, '[]'), p.name, p.description, p.remote_username, p.auth_method, p.second_factor_mode, p.owner_scope, COALESCE(pou.public_id, ''), p.allow_user_edit, p.enabled, COALESCE(sec.public_id, ''), COALESCE(c.public_id, ''), COALESCE(c.name, ''), p.auth_public_key, p.auth_public_key_fingerprint, COALESCE(ca.public_id, ''), p.ssh_principal_grant_ids_json, p.default_cert_valid_seconds, p.max_cert_valid_seconds, p.valid_after, p.valid_before, p.created_by_kind, COALESCE(p_cbu.public_id, p_cbp.public_id, ''), p.created_by_subject_name, p.created_at, p.updated_at, s.public_id, s.name, s.host, s.port, s.description, s.tags_json, s.enabled, s.host_key_policy, s.owner_scope, COALESCE(sou.public_id, ''), s.created_by_kind, COALESCE(s_cbu.public_id, s_cbp.public_id, ''), s.created_by_subject_name, s.created_at, s.updated_at FROM ssh_access_profiles p JOIN ssh_servers s ON s.id = p.server_id LEFT JOIN users sou ON sou.id = s.owner_user_id LEFT JOIN users s_cbu ON s_cbu.id = s.created_by_user_id LEFT JOIN service_principals s_cbp ON s_cbp.id = s.created_by_principal_id LEFT JOIN ssh_server_groups g ON g.id = p.server_group_id LEFT JOIN users g_cbu ON g_cbu.id = g.created_by_user_id LEFT JOIN service_principals g_cbp ON g_cbp.id = g.created_by_principal_id LEFT JOIN users pou ON pou.id = p.owner_user_id LEFT JOIN users p_cbu ON p_cbu.id = p.created_by_user_id LEFT JOIN service_principals p_cbp ON p_cbp.id = p.created_by_principal_id LEFT JOIN ssh_secrets sec ON sec.id = p.secret_id LEFT JOIN ssh_credentials c ON c.id = p.ssh_credential_id LEFT JOIN ssh_user_cas ca ON ca.id = p.ssh_user_ca_id LEFT JOIN ssh_access_profile_targets t ON t.profile_id = p.id LEFT JOIN user_groups ug ON t.target_type = 'group' AND ug.id = t.target_id LEFT JOIN user_group_members gm ON t.target_type = 'group' AND gm.group_id = ug.id LEFT JOIN users gu ON gm.user_id = gu.id LEFT JOIN users tu ON t.target_type = 'user' AND tu.id = t.target_id WHERE p.public_id = ? AND ( (p.owner_scope = ? AND pou.public_id = ?) OR ( p.owner_scope = ? AND p.enabled = 1 AND s.enabled = 1 AND (p.valid_after = 0 OR p.valid_after <= CAST(strftime('%s','now') AS INTEGER)) AND (p.valid_before = 0 OR p.valid_before >= CAST(strftime('%s','now') AS INTEGER)) AND ( (t.target_type = 'user' AND tu.public_id = ?) OR (t.target_type = 'group' AND ug.disabled = 0 AND (ug.scope = 'all_users' OR gu.public_id = ?)) ) ) ) LIMIT 1`, trimmedID, SSHOwnerScopeUser, trimmedUserID, SSHOwnerScopeAdminShared, trimmedUserID, trimmedUserID) err = row.Scan( &item.ID, &item.ServerID, &item.ServerTargetType, &item.ServerGroupID, &item.ServerGroup.Name, &item.ServerGroup.Description, &item.ServerGroup.Enabled, &item.ServerGroup.CreatedByKind, &item.ServerGroup.CreatedBySubjectID, &item.ServerGroup.CreatedBySubjectName, &item.ServerGroup.CreatedAt, &item.ServerGroup.UpdatedAt, &groupTagsJSON, &item.Name, &item.Description, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.OwnerScope, &item.OwnerUserID, &item.AllowUserEdit, &item.Enabled, &item.SecretID, &item.SSHCredentialID, &item.SSHCredentialName, &item.AuthPublicKey, &item.AuthPublicKeyFingerprint, &item.SSHUserCAID, &grantIDsJSON, &item.DefaultCertValidSeconds, &item.MaxCertValidSeconds, &item.ValidAfter, &item.ValidBefore, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt, &item.Server.ID, &item.Server.Name, &item.Server.Host, &item.Server.Port, &item.Server.Description, &serverTagsJSON, &item.Server.Enabled, &item.Server.HostKeyPolicy, &item.Server.OwnerScope, &item.Server.OwnerUserID, &item.Server.CreatedByKind, &item.Server.CreatedBySubjectID, &item.Server.CreatedBySubjectName, &item.Server.CreatedAt, &item.Server.UpdatedAt, ) if err != nil { return item, err } if item.ServerGroupID != "" { item.ServerGroup.ID = item.ServerGroupID } item.SSHPrincipalGrantIDs, err = decodeStringList(grantIDsJSON) if err != nil { return item, err } item.Server.Tags, err = decodeStringList(serverTagsJSON) if err != nil { return item, err } item.Targets, err = s.listSSHAccessProfileTargets(item.ID) if err != nil { return item, err } item.Editable = item.OwnerScope == SSHOwnerScopeUser && item.OwnerUserID == trimmedUserID return item, nil } func (s *Store) GetSSHAccessProfileServerForUser(userID string, id string, serverID string) (models.SSHServer, error) { var profile models.SSHAccessProfile var item models.SSHServer var trimmedServerID string var err error // the ssh access profile must be found(id) // the ssh access profile must be visible to the user(userID) // the access profile must be bound to a server (serverID) profile, err = s.GetSSHAccessProfileForUser(userID, id) if err != nil { return item, err } trimmedServerID = strings.TrimSpace(serverID) if profile.ServerTargetType != SSHAccessProfileServerTargetTypeServer || profile.ServerID != trimmedServerID { return item, sql.ErrNoRows } return profile.Server, nil } func (s *Store) ListSSHServersForAccessProfileServerGroupForUser(userID string, id string, serverGroupID string) ([]models.SSHServer, error) { var profile models.SSHAccessProfile var trimmedGroupID string var err error // the ssh access profile must be found(id) // the ssh access profile must be visible to the user(userID) // the access profile must be bound to a server group (serverGroupID) profile, err = s.GetSSHAccessProfileForUser(userID, id) if err != nil { return nil, err } trimmedGroupID = strings.TrimSpace(serverGroupID) if profile.ServerTargetType != SSHAccessProfileServerTargetTypeGroup || profile.ServerGroupID != trimmedGroupID { return nil, sql.ErrNoRows } return s.ListSSHServersForGroup(trimmedGroupID) } func (s *Store) listSSHAccessProfileTargets(profileID string) ([]models.SSHAccessProfileTarget, error) { var rows *sql.Rows var err error var items []models.SSHAccessProfileTarget var item models.SSHAccessProfileTarget rows, err = s.Query(`SELECT p.public_id, t.target_type, CASE WHEN t.target_type = 'user' THEN COALESCE(u.public_id, '') ELSE COALESCE(g.public_id, '') END, CASE WHEN t.target_type = 'user' THEN COALESCE(CASE WHEN u.display_name != '' THEN u.display_name || ' (' || u.username || ')' ELSE u.username END, '') ELSE COALESCE(g.name, '') END, CASE WHEN t.target_type = 'user' THEN CASE WHEN u.public_id IS NOT NULL AND u.disabled = 0 THEN 1 ELSE 0 END ELSE CASE WHEN g.public_id IS NOT NULL AND g.disabled = 0 THEN 1 ELSE 0 END END, t.created_at FROM ssh_access_profile_targets t JOIN ssh_access_profiles p ON p.id = t.profile_id LEFT JOIN users u ON t.target_type = 'user' AND u.id = t.target_id LEFT JOIN user_groups g ON t.target_type = 'group' AND g.id = t.target_id WHERE p.public_id = ? ORDER BY t.target_type, target_id`, strings.TrimSpace(profileID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ProfileID, &item.TargetType, &item.TargetID, &item.TargetName, &item.TargetActive, &item.CreatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) GetSSHSecret(id string) (models.SSHSecret, error) { var row *sql.Row var item models.SSHSecret var err error row = s.QueryRow(`SELECT sec.public_id, sec.kind, sec.payload, sec.password, sec.metadata_json, sec.created_by_kind, COALESCE(sec_cbu.public_id, sec_cbp.public_id, ''), sec.created_by_subject_name, sec.created_at, sec.updated_at FROM ssh_secrets sec LEFT JOIN users sec_cbu ON sec_cbu.id = sec.created_by_user_id LEFT JOIN service_principals sec_cbp ON sec_cbp.id = sec.created_by_principal_id WHERE sec.public_id = ?`, strings.TrimSpace(id)) err = row.Scan(&item.ID, &item.Kind, &item.Payload, &item.Password, &item.MetadataJSON, &item.CreatedByKind, &item.CreatedBySubjectID, &item.CreatedBySubjectName, &item.CreatedAt, &item.UpdatedAt) if err != nil { return item, err } return item, nil } func (s *Store) ListSSHServerHostKeys(serverID string) ([]models.SSHServerHostKey, error) { var rows *sql.Rows var items []models.SSHServerHostKey var item models.SSHServerHostKey var err error rows, err = s.Query(`SELECT hk.public_id, s.public_id, hk.algorithm, hk.public_key, hk.fingerprint, hk.created_at FROM ssh_server_host_keys hk JOIN ssh_servers s ON s.id = hk.server_id WHERE s.public_id = ? ORDER BY hk.created_at, hk.fingerprint`, strings.TrimSpace(serverID)) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.ServerID, &item.Algorithm, &item.PublicKey, &item.Fingerprint, &item.CreatedAt) if err != nil { return nil, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, err } return items, nil } func (s *Store) CreateSSHServerHostKey(item models.SSHServerHostKey) (models.SSHServerHostKey, error) { var err error var now int64 if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } now = time.Now().UTC().Unix() item.CreatedAt = now _, err = s.Exec(`INSERT INTO ssh_server_host_keys (public_id, server_id, algorithm, public_key, fingerprint, created_at) VALUES (?, (SELECT id FROM ssh_servers WHERE public_id = ?), ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.ServerID), strings.TrimSpace(item.Algorithm), strings.TrimSpace(item.PublicKey), strings.TrimSpace(item.Fingerprint), item.CreatedAt, ) if err != nil { return item, err } return item, nil } func (s *Store) DeleteSSHServerHostKey(serverID string, hostKeyID string) error { var err error _, err = s.Exec(`DELETE FROM ssh_server_host_keys WHERE server_id = (SELECT id FROM ssh_servers WHERE public_id = ?) AND public_id = ?`, strings.TrimSpace(serverID), strings.TrimSpace(hostKeyID)) return err } func (s *Store) CreateSSHSession(item models.SSHSession) (models.SSHSession, error) { var err error var now int64 if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } now = time.Now().UTC().Unix() if item.Status == "" { item.Status = "pending" } if item.RequestedTerm == "" { item.RequestedTerm = "xterm-256color" } if item.RequestedCols <= 0 { item.RequestedCols = 80 } if item.RequestedRows <= 0 { item.RequestedRows = 24 } item.StartedAt = now _, err = s.Exec(`INSERT INTO ssh_sessions ( public_id, profile_id, server_id, user_id, username, remote_username, auth_method, second_factor_mode, host, port, status, host_key_fingerprint, requested_term, requested_cols, requested_rows, started_at, connected_at, ended_at, remote_addr, user_agent, error ) VALUES (?, (SELECT id FROM ssh_access_profiles WHERE public_id = ?), (SELECT id FROM ssh_servers WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, item.ID, strings.TrimSpace(item.ProfileID), strings.TrimSpace(item.ServerID), strings.TrimSpace(item.UserID), strings.TrimSpace(item.Username), strings.TrimSpace(item.RemoteUsername), strings.TrimSpace(item.AuthMethod), strings.TrimSpace(item.SecondFactorMode), strings.TrimSpace(item.Host), item.Port, strings.TrimSpace(item.Status), strings.TrimSpace(item.HostKeyFingerprint), strings.TrimSpace(item.RequestedTerm), item.RequestedCols, item.RequestedRows, item.StartedAt, item.ConnectedAt, item.EndedAt, strings.TrimSpace(item.RemoteAddr), strings.TrimSpace(item.UserAgent), strings.TrimSpace(item.Error), ) if err != nil { return item, err } return item, nil } func (s *Store) ExpireStalePendingSSHSessions(cutoff int64, reason string) error { var err error _, err = s.Exec(`UPDATE ssh_sessions SET status = 'closed', ended_at = ?, error = ? WHERE status = 'pending' AND started_at < ?`, cutoff, strings.TrimSpace(reason), cutoff, ) return err } func (s *Store) CountPendingSSHSessionsForUser(userID string) (int, error) { var row *sql.Row var count int var err error row = s.QueryRow(`SELECT COUNT(*) FROM ssh_sessions WHERE user_id = (SELECT id FROM users WHERE public_id = ?) AND status = 'pending'`, strings.TrimSpace(userID)) err = row.Scan(&count) if err != nil { return 0, err } return count, nil } func (s *Store) GetSSHSession(id string) (models.SSHSession, error) { var row *sql.Row var item models.SSHSession var err error row = s.QueryRow(`SELECT ss.public_id, COALESCE(p.public_id, ''), COALESCE(p.name, ''), COALESCE(s.public_id, ''), COALESCE(s.name, ''), COALESCE(u.public_id, ''), ss.username, ss.remote_username, ss.auth_method, ss.second_factor_mode, ss.host, ss.port, ss.status, ss.host_key_fingerprint, ss.requested_term, ss.requested_cols, ss.requested_rows, ss.started_at, ss.connected_at, ss.ended_at, ss.remote_addr, ss.user_agent, ss.error FROM ssh_sessions ss LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id LEFT JOIN ssh_servers s ON s.id = ss.server_id LEFT JOIN users u ON u.id = ss.user_id WHERE ss.public_id = ?`, strings.TrimSpace(id)) err = row.Scan(&item.ID, &item.ProfileID, &item.ProfileName, &item.ServerID, &item.ServerName, &item.UserID, &item.Username, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.Host, &item.Port, &item.Status, &item.HostKeyFingerprint, &item.RequestedTerm, &item.RequestedCols, &item.RequestedRows, &item.StartedAt, &item.ConnectedAt, &item.EndedAt, &item.RemoteAddr, &item.UserAgent, &item.Error) if err != nil { return item, err } return item, nil } func (s *Store) ListSSHSessionsForUser(userID string, limit int) ([]models.SSHSession, error) { var items []models.SSHSession var hasMore bool var err error items, hasMore, err = s.ListSSHSessionsForUserFiltered(userID, limit, 0, "", "") if err != nil { return nil, err } _ = hasMore return items, nil } func (s *Store) ListSSHSessionsForUserFiltered(userID string, limit int, offset int, query string, status string) ([]models.SSHSession, bool, error) { var rows *sql.Rows var items []models.SSHSession var item models.SSHSession var whereParts []string var args []any var sqlQuery string var like string var err error if limit <= 0 { limit = 20 } if limit > 200 { limit = 200 } if offset < 0 { offset = 0 } query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) whereParts = append(whereParts, "u.public_id = ?") args = append(args, strings.TrimSpace(userID)) if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ss.status = ?") args = append(args, status) } sqlQuery = fmt.Sprintf(`SELECT ss.public_id, COALESCE(p.public_id, ''), COALESCE(p.name, ''), COALESCE(s.public_id, ''), COALESCE(s.name, ''), COALESCE(u.public_id, ''), ss.username, ss.remote_username, ss.auth_method, ss.second_factor_mode, ss.host, ss.port, ss.status, ss.host_key_fingerprint, ss.requested_term, ss.requested_cols, ss.requested_rows, ss.started_at, ss.connected_at, ss.ended_at, ss.remote_addr, ss.user_agent, ss.error FROM ssh_sessions ss LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id LEFT JOIN ssh_servers s ON s.id = ss.server_id LEFT JOIN users u ON u.id = ss.user_id WHERE %s ORDER BY ss.started_at DESC LIMIT ? OFFSET ?`, strings.Join(whereParts, " AND ")) args = append(args, limit + 1, offset) rows, err = s.Query(sqlQuery, args...) if err != nil { return nil, false, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.ProfileID, &item.ProfileName, &item.ServerID, &item.ServerName, &item.UserID, &item.Username, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.Host, &item.Port, &item.Status, &item.HostKeyFingerprint, &item.RequestedTerm, &item.RequestedCols, &item.RequestedRows, &item.StartedAt, &item.ConnectedAt, &item.EndedAt, &item.RemoteAddr, &item.UserAgent, &item.Error) if err != nil { return nil, false, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, false, err } if len(items) > limit { items = items[:limit] return items, true, nil } return items, false, nil } func (s *Store) CountSSHSessionsForUserFiltered(userID string, query string, status string) (int, error) { var whereParts []string var args []any var sqlQuery string var like string var total int var err error query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) whereParts = append(whereParts, "u.public_id = ?") args = append(args, strings.TrimSpace(userID)) if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ss.status = ?") args = append(args, status) } sqlQuery = fmt.Sprintf(`SELECT COUNT(*) FROM ssh_sessions ss LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id LEFT JOIN ssh_servers s ON s.id = ss.server_id LEFT JOIN users u ON u.id = ss.user_id WHERE %s`, strings.Join(whereParts, " AND ")) err = s.QueryRow(sqlQuery, args...).Scan(&total) if err != nil { return 0, err } return total, nil } func (s *Store) ListSSHSessionsFiltered(limit int, offset int, query string, status string) ([]models.SSHSession, bool, error) { var rows *sql.Rows var items []models.SSHSession var item models.SSHSession var whereParts []string var args []any var sqlQuery string var like string var err error if limit <= 0 { limit = 20 } if limit > 200 { limit = 200 } if offset < 0 { offset = 0 } query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ss.status = ?") args = append(args, status) } sqlQuery = `SELECT ss.public_id, COALESCE(p.public_id, ''), COALESCE(p.name, ''), COALESCE(s.public_id, ''), COALESCE(s.name, ''), COALESCE(u.public_id, ''), ss.username, ss.remote_username, ss.auth_method, ss.second_factor_mode, ss.host, ss.port, ss.status, ss.host_key_fingerprint, ss.requested_term, ss.requested_cols, ss.requested_rows, ss.started_at, ss.connected_at, ss.ended_at, ss.remote_addr, ss.user_agent, ss.error FROM ssh_sessions ss LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id LEFT JOIN ssh_servers s ON s.id = ss.server_id LEFT JOIN users u ON u.id = ss.user_id` if len(whereParts) > 0 { sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ") } sqlQuery += "\n\t\tORDER BY ss.started_at DESC\n\t\tLIMIT ? OFFSET ?" args = append(args, limit + 1, offset) rows, err = s.Query(sqlQuery, args...) if err != nil { return nil, false, err } defer rows.Close() for rows.Next() { err = rows.Scan(&item.ID, &item.ProfileID, &item.ProfileName, &item.ServerID, &item.ServerName, &item.UserID, &item.Username, &item.RemoteUsername, &item.AuthMethod, &item.SecondFactorMode, &item.Host, &item.Port, &item.Status, &item.HostKeyFingerprint, &item.RequestedTerm, &item.RequestedCols, &item.RequestedRows, &item.StartedAt, &item.ConnectedAt, &item.EndedAt, &item.RemoteAddr, &item.UserAgent, &item.Error) if err != nil { return nil, false, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, false, err } if len(items) > limit { items = items[:limit] return items, true, nil } return items, false, nil } func (s *Store) CountSSHSessionsFiltered(query string, status string) (int, error) { var whereParts []string var args []any var sqlQuery string var like string var total int var err error query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ss.public_id LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(p.name, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR COALESCE(s.name, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ss.username LIKE ? OR ss.remote_username LIKE ? OR ss.host LIKE ? OR ss.auth_method LIKE ? OR ss.status LIKE ? OR ss.host_key_fingerprint LIKE ? OR ss.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ss.status = ?") args = append(args, status) } sqlQuery = `SELECT COUNT(*) FROM ssh_sessions ss LEFT JOIN ssh_access_profiles p ON p.id = ss.profile_id LEFT JOIN ssh_servers s ON s.id = ss.server_id LEFT JOIN users u ON u.id = ss.user_id` if len(whereParts) > 0 { sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ") } err = s.QueryRow(sqlQuery, args...).Scan(&total) if err != nil { return 0, err } return total, nil } func (s *Store) UpdateSSHSessionStatus(id string, status string, hostKeyFingerprint string, connectedAt int64, endedAt int64, errorText string) error { var err error _, err = s.Exec(`UPDATE ssh_sessions SET status = ?, host_key_fingerprint = ?, connected_at = ?, ended_at = ?, error = ? WHERE public_id = ?`, strings.TrimSpace(status), strings.TrimSpace(hostKeyFingerprint), connectedAt, endedAt, strings.TrimSpace(errorText), strings.TrimSpace(id), ) return err } func (s *Store) CreateSSHFileTransfer(item models.SSHFileTransfer) (models.SSHFileTransfer, error) { var pathsJSON string var now int64 var err error if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } pathsJSON, err = encodeStringList(item.Paths) if err != nil { return item, err } now = time.Now().UTC().Unix() if item.StartedAt <= 0 { item.StartedAt = now } if strings.TrimSpace(item.Status) == "" { item.Status = "running" } _, err = s.Exec(`INSERT INTO ssh_file_transfers ( public_id, session_id, user_id, username, profile_id, server_id, server_name, remote_username, operation, source_session_id, target_session_id, target_server_id, target_server_name, target_dir, paths_json, bytes_transferred, status, error, started_at, finished_at, duration_ms, remote_addr, user_agent ) VALUES (?, (SELECT id FROM ssh_sessions WHERE public_id = ?), (SELECT id FROM users WHERE public_id = ?), ?, (SELECT id FROM ssh_access_profiles WHERE public_id = ?), (SELECT id FROM ssh_servers WHERE public_id = ?), ?, ?, ?, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_sessions WHERE public_id = ?) END, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_sessions WHERE public_id = ?) END, CASE WHEN ? = '' THEN NULL ELSE (SELECT id FROM ssh_servers WHERE public_id = ?) END, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, strings.TrimSpace(item.ID), strings.TrimSpace(item.SessionID), strings.TrimSpace(item.UserID), strings.TrimSpace(item.Username), strings.TrimSpace(item.ProfileID), strings.TrimSpace(item.ServerID), strings.TrimSpace(item.ServerName), strings.TrimSpace(item.RemoteUsername), strings.TrimSpace(item.Operation), strings.TrimSpace(item.SourceSessionID), strings.TrimSpace(item.SourceSessionID), strings.TrimSpace(item.TargetSessionID), strings.TrimSpace(item.TargetSessionID), strings.TrimSpace(item.TargetServerID), strings.TrimSpace(item.TargetServerID), strings.TrimSpace(item.TargetServerName), strings.TrimSpace(item.TargetDir), pathsJSON, item.BytesTransferred, strings.TrimSpace(item.Status), strings.TrimSpace(item.Error), item.StartedAt, item.FinishedAt, item.DurationMS, strings.TrimSpace(item.RemoteAddr), strings.TrimSpace(item.UserAgent), ) return item, err } func (s *Store) FinishSSHFileTransfer(ctx context.Context, id string, status string, paths []string, bytesTransferred int64, errText string) error { var startedAt int64 var finishedAt int64 var durationMS int64 var pathsJSON string var err error var tx txExecutor var owned bool pathsJSON, err = encodeStringList(paths) if err != nil { return err } finishedAt = time.Now().UTC().Unix() tx, owned, err = s.beginImmediateContext(ctx) if err != nil { return err } err = tx.QueryRow(`SELECT started_at FROM ssh_file_transfers WHERE public_id = ?`, strings.TrimSpace(id)).Scan(&startedAt) if err != nil { rollbackIfOwned(tx, owned) return err } durationMS = 0 if startedAt > 0 && finishedAt >= startedAt { durationMS = (finishedAt - startedAt) * 1000 } _, err = tx.Exec(`UPDATE ssh_file_transfers SET status = ?, paths_json = ?, bytes_transferred = ?, error = ?, finished_at = ?, duration_ms = ? WHERE public_id = ?`, strings.TrimSpace(status), pathsJSON, bytesTransferred, strings.TrimSpace(errText), finishedAt, durationMS, strings.TrimSpace(id), ) if err != nil { rollbackIfOwned(tx, owned) return err } return commitIfOwned(tx, owned) } func scanSSHFileTransferRows(rows *sql.Rows, limit int) ([]models.SSHFileTransfer, bool, error) { var items []models.SSHFileTransfer var item models.SSHFileTransfer var pathsJSON string var err error for rows.Next() { err = rows.Scan( &item.ID, &item.SessionID, &item.UserID, &item.Username, &item.ProfileID, &item.ServerID, &item.ServerName, &item.RemoteUsername, &item.Operation, &item.SourceSessionID, &item.TargetSessionID, &item.TargetServerID, &item.TargetServerName, &item.TargetDir, &pathsJSON, &item.BytesTransferred, &item.Status, &item.Error, &item.StartedAt, &item.FinishedAt, &item.DurationMS, &item.RemoteAddr, &item.UserAgent, ) if err != nil { return nil, false, err } item.Paths, err = decodeStringList(pathsJSON) if err != nil { return nil, false, err } items = append(items, item) } err = rows.Err() if err != nil { return nil, false, err } if len(items) > limit { items = items[:limit] return items, true, nil } return items, false, nil } func (s *Store) ListSSHFileTransfersForUserFiltered(userID string, limit int, offset int, query string, status string, sessionID string) ([]models.SSHFileTransfer, bool, error) { var rows *sql.Rows var whereParts []string var args []any var sqlQuery string var like string var err error if limit <= 0 { limit = 20 } if limit > 200 { limit = 200 } if offset < 0 { offset = 0 } query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) sessionID = strings.TrimSpace(sessionID) whereParts = append(whereParts, "u.public_id = ?") args = append(args, strings.TrimSpace(userID)) if sessionID != "" { whereParts = append(whereParts, "sess.public_id = ?") args = append(args, sessionID) } if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.username LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ft.status = ?") args = append(args, status) } sqlQuery = fmt.Sprintf(`SELECT ft.public_id, COALESCE(sess.public_id, ''), COALESCE(u.public_id, ''), ft.username, COALESCE(p.public_id, ''), COALESCE(s.public_id, ''), ft.server_name, ft.remote_username, ft.operation, COALESCE(ssess.public_id, ''), COALESCE(tsess.public_id, ''), COALESCE(ts.public_id, ''), ft.target_server_name, ft.target_dir, ft.paths_json, ft.bytes_transferred, ft.status, ft.error, ft.started_at, ft.finished_at, ft.duration_ms, ft.remote_addr, ft.user_agent FROM ssh_file_transfers ft LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id LEFT JOIN users u ON u.id = ft.user_id LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id LEFT JOIN ssh_servers s ON s.id = ft.server_id LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id WHERE %s ORDER BY ft.started_at DESC, ft.id DESC LIMIT ? OFFSET ?`, strings.Join(whereParts, " AND ")) args = append(args, limit + 1, offset) rows, err = s.Query(sqlQuery, args...) if err != nil { return nil, false, err } defer rows.Close() return scanSSHFileTransferRows(rows, limit) } func (s *Store) CountSSHFileTransfersForUserFiltered(userID string, query string, status string, sessionID string) (int, error) { var whereParts []string var args []any var sqlQuery string var like string var total int var err error query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) sessionID = strings.TrimSpace(sessionID) whereParts = append(whereParts, "u.public_id = ?") args = append(args, strings.TrimSpace(userID)) if sessionID != "" { whereParts = append(whereParts, "sess.public_id = ?") args = append(args, sessionID) } if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.username LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ft.status = ?") args = append(args, status) } sqlQuery = fmt.Sprintf(`SELECT COUNT(*) FROM ssh_file_transfers ft LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id LEFT JOIN users u ON u.id = ft.user_id LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id LEFT JOIN ssh_servers s ON s.id = ft.server_id LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id WHERE %s`, strings.Join(whereParts, " AND ")) err = s.QueryRow(sqlQuery, args...).Scan(&total) if err != nil { return 0, err } return total, nil } func (s *Store) ListSSHFileTransfersFiltered(limit int, offset int, query string, status string, sessionID string, userID string) ([]models.SSHFileTransfer, bool, error) { var rows *sql.Rows var whereParts []string var args []any var sqlQuery string var like string var err error if limit <= 0 { limit = 20 } if limit > 200 { limit = 200 } if offset < 0 { offset = 0 } query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) sessionID = strings.TrimSpace(sessionID) userID = strings.TrimSpace(userID) if sessionID != "" { whereParts = append(whereParts, "sess.public_id = ?") args = append(args, sessionID) } if userID != "" { whereParts = append(whereParts, "u.public_id = ?") args = append(args, userID) } if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ft.username LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ft.status = ?") args = append(args, status) } sqlQuery = `SELECT ft.public_id, COALESCE(sess.public_id, ''), COALESCE(u.public_id, ''), ft.username, COALESCE(p.public_id, ''), COALESCE(s.public_id, ''), ft.server_name, ft.remote_username, ft.operation, COALESCE(ssess.public_id, ''), COALESCE(tsess.public_id, ''), COALESCE(ts.public_id, ''), ft.target_server_name, ft.target_dir, ft.paths_json, ft.bytes_transferred, ft.status, ft.error, ft.started_at, ft.finished_at, ft.duration_ms, ft.remote_addr, ft.user_agent FROM ssh_file_transfers ft LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id LEFT JOIN users u ON u.id = ft.user_id LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id LEFT JOIN ssh_servers s ON s.id = ft.server_id LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id` if len(whereParts) > 0 { sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ") } sqlQuery += "\n\t\tORDER BY ft.started_at DESC, ft.id DESC\n\t\tLIMIT ? OFFSET ?" args = append(args, limit + 1, offset) rows, err = s.Query(sqlQuery, args...) if err != nil { return nil, false, err } defer rows.Close() return scanSSHFileTransferRows(rows, limit) } func (s *Store) CountSSHFileTransfersFiltered(query string, status string, sessionID string, userID string) (int, error) { var whereParts []string var args []any var sqlQuery string var like string var total int var err error query = strings.TrimSpace(query) status = strings.ToLower(strings.TrimSpace(status)) sessionID = strings.TrimSpace(sessionID) userID = strings.TrimSpace(userID) if sessionID != "" { whereParts = append(whereParts, "sess.public_id = ?") args = append(args, sessionID) } if userID != "" { whereParts = append(whereParts, "u.public_id = ?") args = append(args, userID) } if query != "" { like = "%" + query + "%" whereParts = append(whereParts, "(ft.public_id LIKE ? OR COALESCE(sess.public_id, '') LIKE ? OR COALESCE(u.public_id, '') LIKE ? OR ft.username LIKE ? OR COALESCE(p.public_id, '') LIKE ? OR COALESCE(s.public_id, '') LIKE ? OR ft.server_name LIKE ? OR ft.remote_username LIKE ? OR ft.operation LIKE ? OR COALESCE(tsess.public_id, '') LIKE ? OR COALESCE(ts.public_id, '') LIKE ? OR ft.target_server_name LIKE ? OR ft.target_dir LIKE ? OR ft.paths_json LIKE ? OR ft.status LIKE ? OR ft.error LIKE ?)") args = append(args, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like, like) } if status != "" { whereParts = append(whereParts, "ft.status = ?") args = append(args, status) } sqlQuery = `SELECT COUNT(*) FROM ssh_file_transfers ft LEFT JOIN ssh_sessions sess ON sess.id = ft.session_id LEFT JOIN users u ON u.id = ft.user_id LEFT JOIN ssh_access_profiles p ON p.id = ft.profile_id LEFT JOIN ssh_servers s ON s.id = ft.server_id LEFT JOIN ssh_sessions ssess ON ssess.id = ft.source_session_id LEFT JOIN ssh_sessions tsess ON tsess.id = ft.target_session_id LEFT JOIN ssh_servers ts ON ts.id = ft.target_server_id` if len(whereParts) > 0 { sqlQuery += "\n\t\tWHERE " + strings.Join(whereParts, " AND ") } err = s.QueryRow(sqlQuery, args...).Scan(&total) if err != nil { return 0, err } return total, nil } func createSSHSecretTx(tx txExecutor, item models.SSHSecret) (models.SSHSecret, error) { var err error var now int64 if strings.TrimSpace(item.ID) == "" { item.ID, err = util.NewID() if err != nil { return item, err } } now = time.Now().UTC().Unix() item.CreatedAt = now item.UpdatedAt = now _, err = tx.Exec(`INSERT INTO ssh_secrets ( public_id, kind, payload, password, metadata_json, created_by_kind, created_by_user_id, created_by_principal_id, created_by_subject_name, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, (SELECT id FROM users WHERE public_id = ? AND ? = 'user'), (SELECT id FROM service_principals WHERE public_id = ? AND ? = 'service_principal'), ?, ?, ?)`, item.ID, strings.TrimSpace(item.Kind), item.Payload, item.Password, normalizeJSONText(item.MetadataJSON, "{}"), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectID), strings.TrimSpace(item.CreatedByKind), strings.TrimSpace(item.CreatedBySubjectName), item.CreatedAt, item.UpdatedAt, ) return item, err } func updateSSHSecretTx(tx txExecutor, item models.SSHSecret) (models.SSHSecret, error) { var err error item.UpdatedAt = time.Now().UTC().Unix() _, err = tx.Exec(`UPDATE ssh_secrets SET kind = ?, payload = ?, metadata_json = ?, updated_at = ? WHERE public_id = ?`, strings.TrimSpace(item.Kind), item.Payload, normalizeJSONText(item.MetadataJSON, "{}"), item.UpdatedAt, strings.TrimSpace(item.ID), ) if err == nil && item.Password != "" { // TODO: how can i update the password without security breach? // currently, i perform the password update only if the given password is not blank _, err = tx.Exec(`UPDATE ssh_secrets SET password = ?, updated_at = ? WHERE public_id = ?`, item.Password, item.UpdatedAt, strings.TrimSpace(item.ID), ) } return item, err } func insertSSHAccessProfileTargetTx(tx txExecutor, profileID string, targetType string, targetID string, createdAt int64) error { var err error var id string targetType = strings.ToLower(strings.TrimSpace(targetType)) targetID = strings.TrimSpace(targetID) if targetType != SSHAccessProfileTargetTypeUser && targetType != SSHAccessProfileTargetTypeGroup { return errors.New("target_type must be user or group") } if targetID == "" { return errors.New("target_id is required") } id, err = util.NewID() if err != nil { return err } _, err = tx.Exec(`INSERT INTO ssh_access_profile_targets ( public_id, profile_id, target_type, target_id, created_at ) VALUES (?, (SELECT id FROM ssh_access_profiles WHERE public_id = ?), ?, CASE WHEN ? = 'user' THEN (SELECT id FROM users WHERE public_id = ?) WHEN ? = 'group' THEN (SELECT id FROM user_groups WHERE public_id = ?) ELSE NULL END, ?)`, id, strings.TrimSpace(profileID), targetType, targetType, targetID, targetType, targetID, createdAt, ) return err } func normalizeStringList(items []string) []string { var out []string var seen map[string]bool var value string var i int seen = map[string]bool{} for i = 0; i < len(items); i++ { value = strings.TrimSpace(items[i]) if value == "" { continue } if seen[value] { continue } seen[value] = true out = append(out, value) } return out } func encodeStringList(items []string) (string, error) { var raw []byte var err error raw, err = json.Marshal(items) if err != nil { return "", err } return string(raw), nil } func decodeStringList(raw string) ([]string, error) { var out []string var err error raw = strings.TrimSpace(raw) if raw == "" { return []string{}, nil } err = json.Unmarshal([]byte(raw), &out) if err != nil { return nil, err } return normalizeStringList(out), nil } func normalizeJSONText(raw string, fallback string) string { var value string value = strings.TrimSpace(raw) if value == "" { return fallback } return value } func emptyStringToNil(raw string) any { var value string value = strings.TrimSpace(raw) if value == "" { return nil } return value } func normalizeSSHServerHostKeyPolicy(value string) string { value = strings.TrimSpace(value) if value == "trust_on_first_use" { return value } return "strict" } func NormalizeSSHServerTags(items []string) []string { var out []string out = normalizeStringList(items) sort.Strings(out) return out }