"""Atlus — Persistent desktop session manager. Each user gets one desktop session that survives browser refreshes/closes. PTY processes are owned by the SessionManager, not by WebSocket connections. """ import asyncio import json import logging import os import pwd import signal import time import uuid from collections import deque from dataclasses import dataclass, field from pathlib import Path from typing import Optional from ptyprocess import PtyProcess from backend.config import DATA_DIR log = logging.getLogger("atlus.sessions") SESSIONS_DIR = DATA_DIR / "sessions" SCROLLBACK_MAXLEN = 5000 # lines of scrollback per terminal # --------------------------------------------------------------------------- # Terminal — a long-lived PTY with background reader # --------------------------------------------------------------------------- @dataclass class ManagedTerminal: """A PTY process that lives independently of any WebSocket.""" terminal_id: str pty: PtyProcess scrollback: deque = field(default_factory=lambda: deque(maxlen=SCROLLBACK_MAXLEN)) title: str = "Shell" created_at: float = field(default_factory=time.time) _reader_task: Optional[asyncio.Task] = field(default=None, repr=False) _websockets: list = field(default_factory=list, repr=False) @property def alive(self) -> bool: try: return self.pty.isalive() except Exception: return False def attach_ws(self, ws): """Register a WebSocket to receive live output.""" if ws not in self._websockets: self._websockets.append(ws) def detach_ws(self, ws): """Unregister a WebSocket.""" try: self._websockets.remove(ws) except ValueError: pass async def _read_loop(self): """Background task: read PTY output → scrollback + attached WebSockets.""" loop = asyncio.get_event_loop() while self.alive: try: raw = await loop.run_in_executor(None, lambda: self.pty.read(4096)) if not raw: continue text = raw.decode("utf-8", errors="replace") if isinstance(raw, bytes) else raw self.scrollback.append(text) # Fan out to attached WebSockets dead = [] for ws in self._websockets: try: await ws.send_json({"type": "output", "data": text}) except Exception: dead.append(ws) for ws in dead: self.detach_ws(ws) except EOFError: break except Exception: break # PTY died — notify attached clients for ws in list(self._websockets): try: await ws.send_json({"type": "output", "data": "\r\n\x1b[31m[Process exited]\x1b[0m\r\n"}) except Exception: pass log.info("Terminal %s reader loop ended (pid %s)", self.terminal_id, self.pty.pid) def start_reader(self): """Start the background reader task.""" if self._reader_task is None or self._reader_task.done(): self._reader_task = asyncio.create_task(self._read_loop()) def kill(self): """Terminate the PTY process.""" if self._reader_task and not self._reader_task.done(): self._reader_task.cancel() if self.alive: try: self.pty.kill(signal.SIGHUP) self.pty.wait() except Exception: pass def to_dict(self) -> dict: return { "terminal_id": self.terminal_id, "title": self.title, "alive": self.alive, "pid": self.pty.pid, "created_at": self.created_at, } # --------------------------------------------------------------------------- # Desktop Session — one per user # --------------------------------------------------------------------------- @dataclass class DesktopSession: """Persistent desktop session for a single user.""" username: str terminals: dict[str, ManagedTerminal] = field(default_factory=dict) desktop_state: dict = field(default_factory=dict) # open apps, active app, etc. created_at: float = field(default_factory=time.time) last_active: float = field(default_factory=time.time) def touch(self): self.last_active = time.time() def to_dict(self) -> dict: return { "username": self.username, "terminals": {tid: t.to_dict() for tid, t in self.terminals.items() if t.alive}, "desktop_state": self.desktop_state, "created_at": self.created_at, "last_active": self.last_active, } # --------------------------------------------------------------------------- # Session Manager — singleton # --------------------------------------------------------------------------- class SessionManager: """Manages all user desktop sessions.""" def __init__(self): self._sessions: dict[str, DesktopSession] = {} SESSIONS_DIR.mkdir(parents=True, exist_ok=True) # ---- Session lifecycle ---- def get_or_create(self, username: str) -> DesktopSession: """Get existing session or create a new one for the user.""" if username not in self._sessions: session = DesktopSession(username=username) # Try to restore desktop state from disk state_file = SESSIONS_DIR / f"{username}.json" if state_file.exists(): try: data = json.loads(state_file.read_text()) session.desktop_state = data.get("desktop_state", {}) log.info("Restored desktop state for %s", username) except (json.JSONDecodeError, OSError) as e: log.warning("Failed to restore session for %s: %s", username, e) self._sessions[username] = session log.info("Created session for %s", username) session = self._sessions[username] session.touch() # Prune dead terminals self._prune_dead(session) return session def _prune_dead(self, session: DesktopSession): """Remove terminals whose PTY has exited.""" dead = [tid for tid, t in session.terminals.items() if not t.alive] for tid in dead: t = session.terminals.pop(tid) t.kill() log.info("Pruned dead terminal %s for %s", tid, session.username) def save_state(self, username: str, desktop_state: dict): """Persist desktop state to disk.""" session = self._sessions.get(username) if not session: return session.desktop_state = desktop_state session.touch() state_file = SESSIONS_DIR / f"{username}.json" try: data = { "desktop_state": desktop_state, "terminal_ids": [ {"terminal_id": t.terminal_id, "title": t.title} for t in session.terminals.values() if t.alive ], "saved_at": time.time(), } state_file.write_text(json.dumps(data, indent=2)) except OSError as e: log.warning("Failed to save session for %s: %s", username, e) # ---- Terminal lifecycle ---- def create_terminal(self, username: str, cols: int = 120, rows: int = 30) -> ManagedTerminal: """Spawn a new PTY terminal for the user.""" session = self.get_or_create(username) # Spawn PTY try: pw = pwd.getpwnam(username) except KeyError: pw = None shell = pw.pw_shell if pw else "/bin/bash" home = pw.pw_dir if pw else "/" env = { "TERM": "xterm-256color", "HOME": home, "USER": username, "SHELL": shell, "LANG": os.environ.get("LANG", "en_US.UTF-8"), "PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", } pty = PtyProcess.spawn( [shell, "-l"], dimensions=(rows, cols), env=env, cwd=home, ) terminal_id = str(uuid.uuid4())[:8] terminal = ManagedTerminal( terminal_id=terminal_id, pty=pty, title=f"Shell {len(session.terminals) + 1}", ) terminal.start_reader() session.terminals[terminal_id] = terminal session.touch() log.info("Created terminal %s for %s (pid %d)", terminal_id, username, pty.pid) return terminal def get_terminal(self, username: str, terminal_id: str) -> Optional[ManagedTerminal]: """Get a terminal by ID.""" session = self._sessions.get(username) if not session: return None return session.terminals.get(terminal_id) def close_terminal(self, username: str, terminal_id: str) -> bool: """Kill and remove a terminal.""" session = self._sessions.get(username) if not session: return False terminal = session.terminals.pop(terminal_id, None) if not terminal: return False terminal.kill() session.touch() log.info("Closed terminal %s for %s", terminal_id, username) return True def list_terminals(self, username: str) -> list[dict]: """List all alive terminals for a user.""" session = self._sessions.get(username) if not session: return [] self._prune_dead(session) return [t.to_dict() for t in session.terminals.values() if t.alive] # ---- Cleanup ---- async def cleanup_stale(self, max_idle_hours: float = 24): """Remove sessions that have been idle too long.""" cutoff = time.time() - (max_idle_hours * 3600) stale = [u for u, s in self._sessions.items() if s.last_active < cutoff] for username in stale: session = self._sessions.pop(username) for terminal in session.terminals.values(): terminal.kill() log.info("Cleaned up stale session for %s (idle since %s)", username, time.ctime(session.last_active)) def shutdown_all(self): """Kill all PTYs — called on server shutdown.""" for session in self._sessions.values(): for terminal in session.terminals.values(): terminal.kill() log.info("All sessions shut down") # --------------------------------------------------------------------------- # Singleton instance # --------------------------------------------------------------------------- manager = SessionManager()