atlus/backend/routers/updates.py
2026-03-14 23:51:53 -05:00

372 lines
13 KiB
Python

"""Self-update — check Gitea repo for new commits and apply updates."""
import asyncio
import logging
import os
import re
import shutil
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from backend.auth import get_current_user
from backend.config import BASE_DIR
router = APIRouter(prefix="/api/updates", tags=["updates"])
log = logging.getLogger("atlus.updates")
# ---------------------------------------------------------------------------
# Guard
# ---------------------------------------------------------------------------
_IS_GIT = (BASE_DIR / ".git").is_dir()
def _find_git() -> str | None:
"""Find the git binary, even with systemd's minimal PATH."""
found = shutil.which("git")
if found:
return found
for p in ("/usr/bin/git", "/usr/local/bin/git", "/bin/git"):
if os.path.isfile(p) and os.access(p, os.X_OK):
return p
return None
_GIT_BIN = _find_git()
_HAS_GIT = _GIT_BIN is not None
def _require_git():
if not _HAS_GIT:
raise HTTPException(503, "git is not installed")
if not _IS_GIT:
raise HTTPException(503, "Atlus was not installed via git — updates unavailable")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _safe_env():
"""Build environment with full PATH for git commands."""
env = {**os.environ, "LC_ALL": "C"}
path = env.get("PATH", "")
for p in ("/usr/bin", "/usr/sbin", "/bin", "/sbin"):
if p not in path:
path = p + ":" + path
env["PATH"] = path
# Ensure HOME is set — systemd may strip it, git needs it for SSH keys
if "HOME" not in env:
import pwd
try:
env["HOME"] = pwd.getpwuid(os.getuid()).pw_dir
except KeyError:
env["HOME"] = "/root"
return env
async def _git(*args: str, timeout: float = 30) -> str:
"""Run a git command in the Atlus install directory."""
_require_git()
cmd = [_GIT_BIN, "-C", str(BASE_DIR)] + list(args)
log.debug("Running: %s", " ".join(cmd))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_safe_env(),
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise HTTPException(504, "git operation timed out")
if proc.returncode != 0:
msg = stderr.decode().strip() or stdout.decode().strip()
log.warning("git %s failed (rc=%d): %s", args[0] if args else "?", proc.returncode, msg)
raise HTTPException(500, f"git error: {msg}")
return stdout.decode().strip()
async def _git_nofail(*args: str, timeout: float = 30) -> tuple[int, str]:
"""Run a git command, return (returncode, stdout) without raising."""
_require_git()
cmd = [_GIT_BIN, "-C", str(BASE_DIR)] + list(args)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_safe_env(),
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
return 1, ""
if proc.returncode != 0:
log.debug("git %s returned %d: %s", args[0] if args else "?",
proc.returncode, stderr.decode().strip())
return proc.returncode, stdout.decode().strip()
# ---------------------------------------------------------------------------
# Package dependency detection
# ---------------------------------------------------------------------------
def _parse_apt_packages(install_sh_content: str) -> set[str]:
"""Extract package names from apt-get install lines in install.sh."""
packages = set()
# Match lines with apt-get install (possibly multi-line with backslash continuations)
# First join backslash-continued lines
joined = install_sh_content.replace("\\\n", " ")
for line in joined.split("\n"):
line = line.strip()
if "apt-get install" not in line:
continue
# Remove everything before the first package name
# apt-get install [-y] [-qq] [--no-install-recommends] package1 package2 ...
# Strip the apt-get install part and flags
after_install = re.split(r"apt-get\s+install\s+", line, maxsplit=1)
if len(after_install) < 2:
continue
tokens = after_install[1].split()
for tok in tokens:
# Skip flags and redirections
if tok.startswith("-") or tok.startswith(">") or tok in ("2>&1",):
continue
# Stop at pipe, semicolon, redirection
if tok in ("|", ";", "&&", "||") or tok.startswith(">"):
break
# Valid package name
if re.match(r"^[a-zA-Z0-9][a-zA-Z0-9.+\-]+$", tok):
packages.add(tok)
return packages
async def _get_installed_packages(packages: set[str]) -> set[str]:
"""Check which packages from the set are already installed via dpkg-query."""
if not packages:
return set()
installed = set()
# Use dpkg-query for batch check — more reliable than dpkg -s
# Note: dpkg-query returns non-zero if ANY package is unknown, but still
# outputs status for known packages on stdout.
try:
proc = await asyncio.create_subprocess_exec(
"dpkg-query", "-W", "-f", "${Package} ${Status}\n", *sorted(packages),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_safe_env(),
)
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=10)
# returncode may be non-zero if some packages are unknown — that's fine
for line in stdout.decode().strip().splitlines():
# Format: "packagename install ok installed"
if "install ok installed" in line:
pkg_name = line.split()[0]
installed.add(pkg_name)
except (asyncio.TimeoutError, Exception) as e:
log.debug("dpkg-query failed, falling back to individual checks: %s", e)
# Fallback: check each package individually
for pkg in packages:
proc = await asyncio.create_subprocess_exec(
"dpkg", "-s", pkg,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_safe_env(),
)
try:
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
except asyncio.TimeoutError:
continue
if proc.returncode == 0:
output = stdout.decode()
if "install ok installed" in output:
installed.add(pkg)
return installed
async def _detect_new_packages(remote_ref: str) -> list[str]:
"""Compare remote install.sh against locally installed packages.
Returns list of package names that the update requires but aren't installed."""
try:
# Get install.sh content from the remote ref
rc, remote_install_sh = await _git_nofail("show", f"{remote_ref}:install.sh", timeout=10)
if rc != 0:
return []
remote_packages = _parse_apt_packages(remote_install_sh)
if not remote_packages:
return []
installed = await _get_installed_packages(remote_packages)
missing = sorted(remote_packages - installed)
return missing
except Exception as e:
log.debug("Package detection failed: %s", e)
return []
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get("/check")
async def check_for_updates(_user: str = Depends(get_current_user)):
"""Check if there are new commits on the remote."""
# Get current local HEAD
local_hash = await _git("rev-parse", "HEAD")
# Fetch latest from remote (may take a few seconds)
try:
await _git("fetch", "origin", timeout=30)
except HTTPException as e:
# Fetch failed (no network, no auth, etc.)
error_msg = str(e.detail) if hasattr(e, 'detail') else str(e)
log.warning("git fetch failed: %s", error_msg)
return {
"available": False,
"local_hash": local_hash[:8],
"remote_hash": local_hash[:8],
"behind_count": 0,
"error": f"Could not reach remote: {error_msg}",
}
# Get remote HEAD
rc, remote_hash = await _git_nofail("rev-parse", "origin/main")
if rc != 0:
# Try origin/master as fallback
rc, remote_hash = await _git_nofail("rev-parse", "origin/master")
if rc != 0:
return {
"available": False,
"local_hash": local_hash[:8],
"remote_hash": "",
"behind_count": 0,
"error": "Could not determine remote branch",
}
# Count commits behind
rc, count_str = await _git_nofail("rev-list", "--count", f"HEAD..{remote_hash}")
behind_count = int(count_str) if rc == 0 and count_str.isdigit() else 0
# Detect new system packages needed by the update
new_packages = []
if behind_count > 0:
new_packages = await _detect_new_packages(remote_hash)
return {
"available": behind_count > 0,
"local_hash": local_hash[:8],
"remote_hash": remote_hash[:8],
"behind_count": behind_count,
"new_packages": new_packages,
}
class InstallDepsRequest(BaseModel):
packages: list[str]
@router.post("/install-deps")
async def install_deps(req: InstallDepsRequest, _user: str = Depends(get_current_user)):
"""Install system packages via apt-get. Used before applying updates."""
if not req.packages:
return {"success": True, "message": "No packages to install", "output": ""}
# Validate package names — only allow safe characters
for pkg in req.packages:
if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9.+\-]+$", pkg):
raise HTTPException(400, f"Invalid package name: {pkg}")
apt_bin = shutil.which("apt-get") or "/usr/bin/apt-get"
# Update package lists first
log.info("Running apt-get update before install")
update_proc = await asyncio.create_subprocess_exec(
apt_bin, "update", "-qq",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_safe_env(),
)
try:
await asyncio.wait_for(update_proc.communicate(), timeout=120)
except asyncio.TimeoutError:
pass # Non-fatal, proceed with install anyway
cmd = [apt_bin, "install", "-y"] + req.packages
log.info("Installing system packages: %s", " ".join(req.packages))
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_safe_env(),
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise HTTPException(504, "Package installation timed out")
output = stdout.decode().strip()
err_output = stderr.decode().strip()
if proc.returncode != 0:
log.warning("apt-get install failed (rc=%d): %s", proc.returncode, err_output)
raise HTTPException(500, f"Package installation failed: {err_output[-500:]}")
log.info("System packages installed successfully")
return {
"success": True,
"message": f"Installed {len(req.packages)} package(s)",
"output": output[-500:],
}
@router.post("/apply")
async def apply_update(_user: str = Depends(get_current_user)):
"""Pull latest changes and schedule a service restart."""
# Pull latest
pull_output = await _git("pull", "--ff-only", timeout=60)
# Reinstall Python dependencies
pip_bin = str(BASE_DIR / "venv" / "bin" / "pip")
if not os.path.exists(pip_bin):
# Fallback: try system pip
pip_bin = shutil.which("pip3") or shutil.which("pip") or "pip3"
req_file = str(BASE_DIR / "backend" / "requirements.txt")
if os.path.exists(req_file):
proc = await asyncio.create_subprocess_exec(
pip_bin, "install", "-r", req_file, "-q",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
await asyncio.wait_for(proc.communicate(), timeout=120)
except asyncio.TimeoutError:
pass # Non-fatal — deps might already be satisfied
# Schedule restart after response is sent
if shutil.which("systemctl"):
asyncio.create_task(_delayed_restart())
return {
"success": True,
"message": "Update applied. Restarting service...",
"pull_output": pull_output[-500:],
}
async def _delayed_restart(delay: float = 2.0):
"""Wait briefly then restart the atlus systemd service."""
await asyncio.sleep(delay)
proc = await asyncio.create_subprocess_exec(
"systemctl", "restart", "atlus",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.wait()