atlus/backend/sessions.py
roberts 68fe9b4435 Fix terminal session: better error handling, reliable state save
- Fix asyncio.get_event_loop() → get_running_loop() in PTY reader
- Add error logging for PTY spawn failures
- Add POST /api/session/state endpoint for sendBeacon (beforeunload)
- Use navigator.sendBeacon for reliable state save on page close
- Improve frontend error reporting when terminal creation fails

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 21:46:08 -05:00

319 lines
11 KiB
Python

"""Atlus — Persistent desktop session manager.
Each user gets one desktop session that survives browser refreshes/closes.
PTY processes are owned by the SessionManager, not by WebSocket connections.
"""
import asyncio
import json
import logging
import os
import pwd
import signal
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from ptyprocess import PtyProcess
from backend.config import DATA_DIR
log = logging.getLogger("atlus.sessions")
SESSIONS_DIR = DATA_DIR / "sessions"
SCROLLBACK_MAXLEN = 5000 # lines of scrollback per terminal
# ---------------------------------------------------------------------------
# Terminal — a long-lived PTY with background reader
# ---------------------------------------------------------------------------
@dataclass
class ManagedTerminal:
"""A PTY process that lives independently of any WebSocket."""
terminal_id: str
pty: PtyProcess
scrollback: deque = field(default_factory=lambda: deque(maxlen=SCROLLBACK_MAXLEN))
title: str = "Shell"
created_at: float = field(default_factory=time.time)
_reader_task: Optional[asyncio.Task] = field(default=None, repr=False)
_websockets: list = field(default_factory=list, repr=False)
@property
def alive(self) -> bool:
try:
return self.pty.isalive()
except Exception:
return False
def attach_ws(self, ws):
"""Register a WebSocket to receive live output."""
if ws not in self._websockets:
self._websockets.append(ws)
def detach_ws(self, ws):
"""Unregister a WebSocket."""
try:
self._websockets.remove(ws)
except ValueError:
pass
async def _read_loop(self):
"""Background task: read PTY output → scrollback + attached WebSockets."""
loop = asyncio.get_running_loop()
while self.alive:
try:
raw = await loop.run_in_executor(None, lambda: self.pty.read(4096))
if not raw:
continue
text = raw.decode("utf-8", errors="replace") if isinstance(raw, bytes) else raw
self.scrollback.append(text)
# Fan out to attached WebSockets
dead = []
for ws in self._websockets:
try:
await ws.send_json({"type": "output", "data": text})
except Exception:
dead.append(ws)
for ws in dead:
self.detach_ws(ws)
except EOFError:
break
except Exception:
break
# PTY died — notify attached clients
for ws in list(self._websockets):
try:
await ws.send_json({"type": "output", "data": "\r\n\x1b[31m[Process exited]\x1b[0m\r\n"})
except Exception:
pass
log.info("Terminal %s reader loop ended (pid %s)", self.terminal_id, self.pty.pid)
def start_reader(self):
"""Start the background reader task."""
if self._reader_task is None or self._reader_task.done():
self._reader_task = asyncio.create_task(self._read_loop())
def kill(self):
"""Terminate the PTY process."""
if self._reader_task and not self._reader_task.done():
self._reader_task.cancel()
if self.alive:
try:
self.pty.kill(signal.SIGHUP)
self.pty.wait()
except Exception:
pass
def to_dict(self) -> dict:
return {
"terminal_id": self.terminal_id,
"title": self.title,
"alive": self.alive,
"pid": self.pty.pid,
"created_at": self.created_at,
}
# ---------------------------------------------------------------------------
# Desktop Session — one per user
# ---------------------------------------------------------------------------
@dataclass
class DesktopSession:
"""Persistent desktop session for a single user."""
username: str
terminals: dict[str, ManagedTerminal] = field(default_factory=dict)
desktop_state: dict = field(default_factory=dict) # open apps, active app, etc.
created_at: float = field(default_factory=time.time)
last_active: float = field(default_factory=time.time)
def touch(self):
self.last_active = time.time()
def to_dict(self) -> dict:
return {
"username": self.username,
"terminals": {tid: t.to_dict() for tid, t in self.terminals.items() if t.alive},
"desktop_state": self.desktop_state,
"created_at": self.created_at,
"last_active": self.last_active,
}
# ---------------------------------------------------------------------------
# Session Manager — singleton
# ---------------------------------------------------------------------------
class SessionManager:
"""Manages all user desktop sessions."""
def __init__(self):
self._sessions: dict[str, DesktopSession] = {}
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
# ---- Session lifecycle ----
def get_or_create(self, username: str) -> DesktopSession:
"""Get existing session or create a new one for the user."""
if username not in self._sessions:
session = DesktopSession(username=username)
# Try to restore desktop state from disk
state_file = SESSIONS_DIR / f"{username}.json"
if state_file.exists():
try:
data = json.loads(state_file.read_text())
session.desktop_state = data.get("desktop_state", {})
log.info("Restored desktop state for %s", username)
except (json.JSONDecodeError, OSError) as e:
log.warning("Failed to restore session for %s: %s", username, e)
self._sessions[username] = session
log.info("Created session for %s", username)
session = self._sessions[username]
session.touch()
# Prune dead terminals
self._prune_dead(session)
return session
def _prune_dead(self, session: DesktopSession):
"""Remove terminals whose PTY has exited."""
dead = [tid for tid, t in session.terminals.items() if not t.alive]
for tid in dead:
t = session.terminals.pop(tid)
t.kill()
log.info("Pruned dead terminal %s for %s", tid, session.username)
def save_state(self, username: str, desktop_state: dict):
"""Persist desktop state to disk."""
session = self._sessions.get(username)
if not session:
return
session.desktop_state = desktop_state
session.touch()
state_file = SESSIONS_DIR / f"{username}.json"
try:
data = {
"desktop_state": desktop_state,
"terminal_ids": [
{"terminal_id": t.terminal_id, "title": t.title}
for t in session.terminals.values() if t.alive
],
"saved_at": time.time(),
}
state_file.write_text(json.dumps(data, indent=2))
except OSError as e:
log.warning("Failed to save session for %s: %s", username, e)
# ---- Terminal lifecycle ----
def create_terminal(self, username: str, cols: int = 120, rows: int = 30) -> ManagedTerminal:
"""Spawn a new PTY terminal for the user."""
session = self.get_or_create(username)
# Spawn PTY
try:
pw = pwd.getpwnam(username)
except KeyError:
pw = None
shell = pw.pw_shell if pw else "/bin/bash"
home = pw.pw_dir if pw else "/"
env = {
"TERM": "xterm-256color",
"HOME": home,
"USER": username,
"SHELL": shell,
"LANG": os.environ.get("LANG", "en_US.UTF-8"),
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
}
log.info("Spawning PTY for %s: shell=%s home=%s", username, shell, home)
try:
pty = PtyProcess.spawn(
[shell, "-l"],
dimensions=(rows, cols),
env=env,
cwd=home,
)
except Exception:
log.exception("Failed to spawn PTY for %s", username)
raise
terminal_id = str(uuid.uuid4())[:8]
terminal = ManagedTerminal(
terminal_id=terminal_id,
pty=pty,
title=f"Shell {len(session.terminals) + 1}",
)
terminal.start_reader()
session.terminals[terminal_id] = terminal
session.touch()
log.info("Created terminal %s for %s (pid %d)", terminal_id, username, pty.pid)
return terminal
def get_terminal(self, username: str, terminal_id: str) -> Optional[ManagedTerminal]:
"""Get a terminal by ID."""
session = self._sessions.get(username)
if not session:
return None
return session.terminals.get(terminal_id)
def close_terminal(self, username: str, terminal_id: str) -> bool:
"""Kill and remove a terminal."""
session = self._sessions.get(username)
if not session:
return False
terminal = session.terminals.pop(terminal_id, None)
if not terminal:
return False
terminal.kill()
session.touch()
log.info("Closed terminal %s for %s", terminal_id, username)
return True
def list_terminals(self, username: str) -> list[dict]:
"""List all alive terminals for a user."""
session = self._sessions.get(username)
if not session:
return []
self._prune_dead(session)
return [t.to_dict() for t in session.terminals.values() if t.alive]
# ---- Cleanup ----
async def cleanup_stale(self, max_idle_hours: float = 24):
"""Remove sessions that have been idle too long."""
cutoff = time.time() - (max_idle_hours * 3600)
stale = [u for u, s in self._sessions.items() if s.last_active < cutoff]
for username in stale:
session = self._sessions.pop(username)
for terminal in session.terminals.values():
terminal.kill()
log.info("Cleaned up stale session for %s (idle since %s)",
username, time.ctime(session.last_active))
def shutdown_all(self):
"""Kill all PTYs — called on server shutdown."""
for session in self._sessions.values():
for terminal in session.terminals.values():
terminal.kill()
log.info("All sessions shut down")
# ---------------------------------------------------------------------------
# Singleton instance
# ---------------------------------------------------------------------------
manager = SessionManager()