"""PAM authentication and JWT token management for Atlus.""" import logging import platform import uuid from datetime import datetime, timedelta, timezone from typing import Annotated import jwt from fastapi import Depends, HTTPException, WebSocket, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from backend.config import ( JWT_ALGORITHM, JWT_EXPIRY_HOURS, get_jwt_secret, is_token_revoked, revoke_token, ) log = logging.getLogger("atlus.auth") _bearer = HTTPBearer() # PAM auth is only used on Linux — fall back to dev mode on macOS/other _use_pam = False _pam_instance = None if platform.system() == "Linux": try: import pam _pam_instance = pam.pam() _use_pam = True except (ImportError, OSError): log.warning("PAM module not available — running in dev mode") else: log.warning("Non-Linux platform (%s) — running in dev mode (any credentials accepted)", platform.system()) # --------------------------------------------------------------------------- # PAM # --------------------------------------------------------------------------- def authenticate_user(username: str, password: str) -> bool: """Validate credentials against Linux PAM. On non-Linux systems (dev mode), accepts any non-empty credentials. """ if not username or not password: return False if _use_pam and _pam_instance: return _pam_instance.authenticate(username, password, service="login") # Dev mode — accept anything return True # --------------------------------------------------------------------------- # JWT helpers # --------------------------------------------------------------------------- def create_token(username: str) -> tuple[str, str]: """Issue a signed JWT. Returns (token, jti).""" jti = uuid.uuid4().hex now = datetime.now(timezone.utc) payload = { "sub": username, "jti": jti, "iat": now, "exp": now + timedelta(hours=JWT_EXPIRY_HOURS), } token = jwt.encode(payload, get_jwt_secret(), algorithm=JWT_ALGORITHM) return token, jti def decode_token(token: str) -> dict: """Decode and validate a JWT. Raises on any failure.""" try: payload = jwt.decode( token, get_jwt_secret(), algorithms=[JWT_ALGORITHM] ) except jwt.ExpiredSignatureError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token expired") except jwt.InvalidTokenError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token") if is_token_revoked(payload.get("jti", "")): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token revoked") return payload def logout(token: str) -> None: """Revoke a token so it cannot be reused.""" try: payload = jwt.decode( token, get_jwt_secret(), algorithms=[JWT_ALGORITHM], options={"verify_exp": False}, ) revoke_token(payload.get("jti", "")) except jwt.InvalidTokenError: pass # --------------------------------------------------------------------------- # FastAPI dependencies # --------------------------------------------------------------------------- def get_current_user( creds: Annotated[HTTPAuthorizationCredentials, Depends(_bearer)], ) -> str: """Dependency — extracts and validates the bearer token, returns username.""" payload = decode_token(creds.credentials) return payload["sub"] async def ws_authenticate(websocket: WebSocket) -> str: """Authenticate a WebSocket connection via token query param. Usage in route: @router.websocket("/ws/something") async def ws(websocket: WebSocket): username = await ws_authenticate(websocket) ... """ token = websocket.query_params.get("token") if not token: await websocket.close(code=4001, reason="Missing token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token") try: payload = decode_token(token) except HTTPException: await websocket.close(code=4001, reason="Invalid token") raise return payload["sub"]