atlus/backend/sessions.py
roberts 6a0c8757f8 Run terminals and GUI apps as the authenticated user, not root
Atlus runs as root (systemd) but user-facing processes must run under the
authenticated user's identity. Added privilege-dropping via preexec_fn
(os.setgid + os.initgroups + os.setuid) to both terminal PTY spawning
and GUI app launching. System admin operations (services, packages,
network, updates) intentionally remain root.

Autostart apps now support a configurable default_user; without one set,
autostart defers until the first user logs in and runs as that user.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 00:35:52 -05:00

322 lines
11 KiB
Python

"""Atlus — Persistent desktop session manager.
Each user gets one desktop session that survives browser refreshes/closes.
PTY processes are owned by the SessionManager, not by WebSocket connections.
"""
import asyncio
import json
import logging
import os
import pwd
import signal
import time
import uuid
from collections import deque
from backend.privdrop import make_preexec_fn
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from ptyprocess import PtyProcess
from backend.config import DATA_DIR
log = logging.getLogger("atlus.sessions")
SESSIONS_DIR = DATA_DIR / "sessions"
SCROLLBACK_MAXLEN = 5000 # lines of scrollback per terminal
# ---------------------------------------------------------------------------
# Terminal — a long-lived PTY with background reader
# ---------------------------------------------------------------------------
@dataclass
class ManagedTerminal:
"""A PTY process that lives independently of any WebSocket."""
terminal_id: str
pty: PtyProcess
scrollback: deque = field(default_factory=lambda: deque(maxlen=SCROLLBACK_MAXLEN))
title: str = "Shell"
created_at: float = field(default_factory=time.time)
_reader_task: Optional[asyncio.Task] = field(default=None, repr=False)
_websockets: list = field(default_factory=list, repr=False)
@property
def alive(self) -> bool:
try:
return self.pty.isalive()
except Exception:
return False
def attach_ws(self, ws):
"""Register a WebSocket to receive live output."""
if ws not in self._websockets:
self._websockets.append(ws)
def detach_ws(self, ws):
"""Unregister a WebSocket."""
try:
self._websockets.remove(ws)
except ValueError:
pass
async def _read_loop(self):
"""Background task: read PTY output → scrollback + attached WebSockets."""
loop = asyncio.get_running_loop()
while self.alive:
try:
raw = await loop.run_in_executor(None, lambda: self.pty.read(4096))
if not raw:
continue
text = raw.decode("utf-8", errors="replace") if isinstance(raw, bytes) else raw
self.scrollback.append(text)
# Fan out to attached WebSockets
dead = []
for ws in self._websockets:
try:
await ws.send_json({"type": "output", "data": text})
except Exception:
dead.append(ws)
for ws in dead:
self.detach_ws(ws)
except EOFError:
break
except Exception:
break
# PTY died — notify attached clients
for ws in list(self._websockets):
try:
await ws.send_json({"type": "output", "data": "\r\n\x1b[31m[Process exited]\x1b[0m\r\n"})
except Exception:
pass
log.info("Terminal %s reader loop ended (pid %s)", self.terminal_id, self.pty.pid)
def start_reader(self):
"""Start the background reader task."""
if self._reader_task is None or self._reader_task.done():
self._reader_task = asyncio.create_task(self._read_loop())
def kill(self):
"""Terminate the PTY process."""
if self._reader_task and not self._reader_task.done():
self._reader_task.cancel()
if self.alive:
try:
self.pty.kill(signal.SIGHUP)
self.pty.wait()
except Exception:
pass
def to_dict(self) -> dict:
return {
"terminal_id": self.terminal_id,
"title": self.title,
"alive": self.alive,
"pid": self.pty.pid,
"created_at": self.created_at,
}
# ---------------------------------------------------------------------------
# Desktop Session — one per user
# ---------------------------------------------------------------------------
@dataclass
class DesktopSession:
"""Persistent desktop session for a single user."""
username: str
terminals: dict[str, ManagedTerminal] = field(default_factory=dict)
desktop_state: dict = field(default_factory=dict) # open apps, active app, etc.
created_at: float = field(default_factory=time.time)
last_active: float = field(default_factory=time.time)
def touch(self):
self.last_active = time.time()
def to_dict(self) -> dict:
return {
"username": self.username,
"terminals": {tid: t.to_dict() for tid, t in self.terminals.items() if t.alive},
"desktop_state": self.desktop_state,
"created_at": self.created_at,
"last_active": self.last_active,
}
# ---------------------------------------------------------------------------
# Session Manager — singleton
# ---------------------------------------------------------------------------
class SessionManager:
"""Manages all user desktop sessions."""
def __init__(self):
self._sessions: dict[str, DesktopSession] = {}
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
# ---- Session lifecycle ----
def get_or_create(self, username: str) -> DesktopSession:
"""Get existing session or create a new one for the user."""
if username not in self._sessions:
session = DesktopSession(username=username)
# Try to restore desktop state from disk
state_file = SESSIONS_DIR / f"{username}.json"
if state_file.exists():
try:
data = json.loads(state_file.read_text())
session.desktop_state = data.get("desktop_state", {})
log.info("Restored desktop state for %s", username)
except (json.JSONDecodeError, OSError) as e:
log.warning("Failed to restore session for %s: %s", username, e)
self._sessions[username] = session
log.info("Created session for %s", username)
session = self._sessions[username]
session.touch()
# Prune dead terminals
self._prune_dead(session)
return session
def _prune_dead(self, session: DesktopSession):
"""Remove terminals whose PTY has exited."""
dead = [tid for tid, t in session.terminals.items() if not t.alive]
for tid in dead:
t = session.terminals.pop(tid)
t.kill()
log.info("Pruned dead terminal %s for %s", tid, session.username)
def save_state(self, username: str, desktop_state: dict):
"""Persist desktop state to disk."""
session = self._sessions.get(username)
if not session:
return
session.desktop_state = desktop_state
session.touch()
state_file = SESSIONS_DIR / f"{username}.json"
try:
data = {
"desktop_state": desktop_state,
"terminal_ids": [
{"terminal_id": t.terminal_id, "title": t.title}
for t in session.terminals.values() if t.alive
],
"saved_at": time.time(),
}
state_file.write_text(json.dumps(data, indent=2))
except OSError as e:
log.warning("Failed to save session for %s: %s", username, e)
# ---- Terminal lifecycle ----
def create_terminal(self, username: str, cols: int = 120, rows: int = 30) -> ManagedTerminal:
"""Spawn a new PTY terminal for the user."""
session = self.get_or_create(username)
# Spawn PTY
try:
pw = pwd.getpwnam(username)
except KeyError:
pw = None
shell = pw.pw_shell if pw else "/bin/bash"
home = pw.pw_dir if pw else "/"
env = {
"TERM": "xterm-256color",
"HOME": home,
"USER": username,
"SHELL": shell,
"LANG": os.environ.get("LANG", "en_US.UTF-8"),
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
}
log.info("Spawning PTY for %s: shell=%s home=%s", username, shell, home)
try:
pty = PtyProcess.spawn(
[shell, "-l"],
dimensions=(rows, cols),
env=env,
cwd=home,
preexec_fn=make_preexec_fn(username),
)
except Exception:
log.exception("Failed to spawn PTY for %s", username)
raise
terminal_id = str(uuid.uuid4())[:8]
terminal = ManagedTerminal(
terminal_id=terminal_id,
pty=pty,
title=f"Shell {len(session.terminals) + 1}",
)
terminal.start_reader()
session.terminals[terminal_id] = terminal
session.touch()
log.info("Created terminal %s for %s (pid %d)", terminal_id, username, pty.pid)
return terminal
def get_terminal(self, username: str, terminal_id: str) -> Optional[ManagedTerminal]:
"""Get a terminal by ID."""
session = self._sessions.get(username)
if not session:
return None
return session.terminals.get(terminal_id)
def close_terminal(self, username: str, terminal_id: str) -> bool:
"""Kill and remove a terminal."""
session = self._sessions.get(username)
if not session:
return False
terminal = session.terminals.pop(terminal_id, None)
if not terminal:
return False
terminal.kill()
session.touch()
log.info("Closed terminal %s for %s", terminal_id, username)
return True
def list_terminals(self, username: str) -> list[dict]:
"""List all alive terminals for a user."""
session = self._sessions.get(username)
if not session:
return []
self._prune_dead(session)
return [t.to_dict() for t in session.terminals.values() if t.alive]
# ---- Cleanup ----
async def cleanup_stale(self, max_idle_hours: float = 24):
"""Remove sessions that have been idle too long."""
cutoff = time.time() - (max_idle_hours * 3600)
stale = [u for u, s in self._sessions.items() if s.last_active < cutoff]
for username in stale:
session = self._sessions.pop(username)
for terminal in session.terminals.values():
terminal.kill()
log.info("Cleaned up stale session for %s (idle since %s)",
username, time.ctime(session.last_active))
def shutdown_all(self):
"""Kill all PTYs — called on server shutdown."""
for session in self._sessions.values():
for terminal in session.terminals.values():
terminal.kill()
log.info("All sessions shut down")
# ---------------------------------------------------------------------------
# Singleton instance
# ---------------------------------------------------------------------------
manager = SessionManager()