# main.py (merged) import threading import logging import shutil import subprocess import traceback import uuid import shlex import yaml import os import httpx import glob import cv2 import numpy as np from contextlib import asynccontextmanager from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Any, Optional import resource from threading import Semaphore from logging.handlers import RotatingFileHandler from urllib.parse import urljoin, urlparse from io import BytesIO import zipfile import sys import re import importlib import collections.abc import time import ocrmypdf import pypdf import pytesseract from pytesseract import TesseractNotFoundError from PIL import Image, UnidentifiedImageError from faster_whisper import WhisperModel from fastapi import (Depends, FastAPI, File, Form, HTTPException, Request, UploadFile, status, Body) from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from huey import SqliteHuey, crontab from pydantic import BaseModel, ConfigDict, field_serializer from sqlalchemy import (Column, DateTime, Integer, String, Text, create_engine, delete, event) from sqlalchemy.orm import Session, declarative_base, sessionmaker from sqlalchemy.pool import NullPool from sqlalchemy.exc import OperationalError from string import Formatter from werkzeug.utils import secure_filename from typing import List as TypingList from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client import OAuth from dotenv import load_dotenv from piper import PiperVoice import wave import io load_dotenv() # --- Optional Dependency Handling for Piper TTS --- try: from piper.synthesis import SynthesisConfig # download helpers: some piper versions export download_voice, others expose ensure_voice_exists/find_voice try: # prefer the more explicit helpers if present from piper.download import get_voices, ensure_voice_exists, find_voice, VoiceNotFoundError except Exception: # fall back to older API if available try: from piper.download import get_voices, download_voice, VoiceNotFoundError ensure_voice_exists = None find_voice = None except Exception: # partial import failed -> treat as piper-not-installed for download helpers get_voices = None download_voice = None ensure_voice_exists = None find_voice = None VoiceNotFoundError = None except ImportError: SynthesisConfig = None get_voices = None download_voice = None ensure_voice_exists = None find_voice = None VoiceNotFoundError = None try: from PyPDF2 import PdfMerger _HAS_PYPDF2 = True except Exception: _HAS_PYPDF2 = False # Instantiate OAuth object (was referenced in code) oauth = OAuth() # -------------------------------------------------------------------------------- # --- 1. CONFIGURATION & SECURITY HELPERS # -------------------------------------------------------------------------------- # --- Path Safety --- UPLOADS_BASE = Path(os.environ.get("UPLOADS_DIR", "/app/uploads")).resolve() PROCESSED_BASE = Path(os.environ.get("PROCESSED_DIR", "/app/processed")).resolve() CHUNK_TMP_BASE = Path(os.environ.get("CHUNK_TMP_DIR", str(UPLOADS_BASE / "tmp"))).resolve() def ensure_path_is_safe(p: Path, allowed_bases: List[Path]): """Ensure a path resolves to a location within one of the allowed base directories.""" try: resolved_p = p.resolve() if not any(resolved_p.is_relative_to(base) for base in allowed_bases): raise ValueError(f"Path {resolved_p} is outside of allowed directories.") return resolved_p except Exception as e: logger = logging.getLogger(__name__) logger.error(f"Path safety check failed for {p}: {e}") raise ValueError("Invalid or unsafe path specified.") # --- Resource Limiting --- def _limit_resources_preexec(): """Set resource limits for child processes to prevent DoS attacks.""" try: # 6000s CPU, 4GB address space resource.setrlimit(resource.RLIMIT_CPU, (6000, 6000)) resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, 4 * 1024 * 1024 * 1024)) except Exception as e: # This may fail in some environments (e.g. Windows, some containers) logging.getLogger(__name__).warning(f"Could not set resource limits: {e}") pass # --- Model concurrency semaphore --- MODEL_CONCURRENCY = int(os.environ.get("MODEL_CONCURRENCY", "1")) _model_semaphore = Semaphore(MODEL_CONCURRENCY) # --- Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') _log_handler = RotatingFileHandler("app.log", maxBytes=10*1024*1024, backupCount=1) _log_formatter = logging.Formatter('%(asctime)s %(levelname)s %(name)s %(message)s') _log_handler.setFormatter(_log_formatter) logging.getLogger().addHandler(_log_handler) logger = logging.getLogger(__name__) # --- Environment Mode --- LOCAL_ONLY_MODE = os.getenv('LOCAL_ONLY', 'True').lower() in ('true', '1', 't') if LOCAL_ONLY_MODE: logger.warning("Authentication is DISABLED. Running in LOCAL_ONLY mode.") class AppPaths(BaseModel): BASE_DIR: Path = Path(__file__).resolve().parent UPLOADS_DIR: Path = UPLOADS_BASE PROCESSED_DIR: Path = PROCESSED_BASE CHUNK_TMP_DIR: Path = CHUNK_TMP_BASE TTS_MODELS_DIR: Path = BASE_DIR / "models" / "tts" KOKORO_TTS_MODELS_DIR: Path = BASE_DIR / "models" / "tts" / "kokoro" KOKORO_MODEL_FILE: Path = KOKORO_TTS_MODELS_DIR / "kokoro-v1.0.onnx" KOKORO_VOICES_FILE: Path = KOKORO_TTS_MODELS_DIR / "voices-v1.0.bin" DATABASE_URL: str = f"sqlite:///{BASE_DIR / 'jobs.db'}" HUEY_DB_PATH: str = str(BASE_DIR / "huey.db") CONFIG_DIR: Path = BASE_DIR / "config" SETTINGS_FILE: Path = CONFIG_DIR / "settings.yml" DEFAULT_SETTINGS_FILE: Path = BASE_DIR / "settings.default.yml" PATHS = AppPaths() APP_CONFIG: Dict[str, Any] = {} PATHS.UPLOADS_DIR.mkdir(exist_ok=True, parents=True) PATHS.PROCESSED_DIR.mkdir(exist_ok=True, parents=True) PATHS.CHUNK_TMP_DIR.mkdir(exist_ok=True, parents=True) PATHS.CONFIG_DIR.mkdir(exist_ok=True, parents=True) PATHS.TTS_MODELS_DIR.mkdir(exist_ok=True, parents=True) PATHS.KOKORO_TTS_MODELS_DIR.mkdir(exist_ok=True, parents=True) def load_app_config(): """ Loads configuration from settings.yml, with a fallback to settings.default.yml, and finally to hardcoded defaults if both files are missing. """ global APP_CONFIG try: # --- Primary Method: Attempt to load settings.yml --- with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f: cfg_raw = yaml.safe_load(f) or {} defaults = { "app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": [], "app_public_url": ""}, "transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}}, "tts_settings": { "piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}}, "kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"} }, "conversion_tools": {}, "ocr_settings": {"ocrmypdf": {}}, "auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []}, "webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""} } cfg = defaults.copy() cfg.update(cfg_raw) # Merge loaded settings into defaults app_settings = cfg.get("app_settings", {}) max_mb = app_settings.get("max_file_size_mb", 100) app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024 allowed = app_settings.get("allowed_all_extensions", []) if not isinstance(allowed, (list, set)): allowed = list(allowed) app_settings["allowed_all_extensions"] = set(allowed) cfg["app_settings"] = app_settings APP_CONFIG = cfg logger.info("Successfully loaded settings from settings.yml") except (FileNotFoundError, yaml.YAMLError) as e: logger.warning(f"Could not load settings.yml: {e}. Falling back to settings.default.yml...") try: # --- Fallback Method: Attempt to load settings.default.yml --- with open(PATHS.DEFAULT_SETTINGS_FILE, 'r', encoding='utf8') as f: cfg_raw = yaml.safe_load(f) or {} defaults = { "app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": [], "app_public_url": ""}, "transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}}, "tts_settings": { "piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}}, "kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"} }, "conversion_tools": {}, "ocr_settings": {"ocrmypdf": {}}, "auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []}, "webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""} } cfg = defaults.copy() cfg.update(cfg_raw) # Merge loaded settings into defaults app_settings = cfg.get("app_settings", {}) max_mb = app_settings.get("max_file_size_mb", 100) app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024 allowed = app_settings.get("allowed_all_extensions", []) if not isinstance(allowed, (list, set)): allowed = list(allowed) app_settings["allowed_all_extensions"] = set(allowed) cfg["app_settings"] = app_settings APP_CONFIG = cfg logger.info("Successfully loaded settings from settings.default.yml") except (FileNotFoundError, yaml.YAMLError) as e_fallback: # --- Final Failsafe: Use hardcoded defaults --- logger.error(f"CRITICAL: Fallback file settings.default.yml also failed: {e_fallback}. Using hardcoded defaults.") APP_CONFIG = { "app_settings": {"max_file_size_mb": 100, "max_file_size_bytes": 100 * 1024 * 1024, "allowed_all_extensions": set(), "app_public_url": ""}, "transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}}, "tts_settings": { "piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}}, "kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"} }, "conversion_tools": {}, "ocr_settings": {"ocrmypdf": {}}, "auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []}, "webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""} } # -------------------------------------------------------------------------------- # --- 2. DATABASE & Schemas # -------------------------------------------------------------------------------- engine = create_engine( PATHS.DATABASE_URL, connect_args={"check_same_thread": False, "timeout": 30}, poolclass=NullPool, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() @event.listens_for(engine, "connect") def _set_sqlite_pragmas(dbapi_connection, connection_record): c = dbapi_connection.cursor() try: c.execute("PRAGMA journal_mode=WAL;") c.execute("PRAGMA synchronous=NORMAL;") finally: c.close() class Job(Base): __tablename__ = "jobs" id = Column(String, primary_key=True, index=True) user_id = Column(String, index=True, nullable=True) parent_job_id = Column(String, index=True, nullable=True) task_type = Column(String, index=True) status = Column(String, default="pending") progress = Column(Integer, default=0) original_filename = Column(String) input_filepath = Column(String) input_filesize = Column(Integer, nullable=True) processed_filepath = Column(String, nullable=True) output_filesize = Column(Integer, nullable=True) result_preview = Column(Text, nullable=True) error_message = Column(Text, nullable=True) callback_url = Column(String, nullable=True) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) def get_db(): db = SessionLocal() try: yield db finally: db.close() class JobCreate(BaseModel): id: str user_id: str | None = None parent_job_id: str | None = None task_type: str original_filename: str input_filepath: str input_filesize: int | None = None callback_url: str | None = None processed_filepath: str | None = None class JobSchema(BaseModel): id: str parent_job_id: str | None = None task_type: str status: str progress: int original_filename: str input_filesize: int | None = None output_filesize: int | None = None processed_filepath: str | None = None result_preview: str | None = None error_message: str | None = None created_at: datetime updated_at: datetime model_config = ConfigDict(from_attributes=True) @field_serializer('created_at', 'updated_at') def serialize_dt(self, dt: datetime, _info): return dt.isoformat() + "Z" class FinalizeUploadPayload(BaseModel): upload_id: str original_filename: str total_chunks: int task_type: str model_size: str = "" model_name: str = "" output_format: str = "" callback_url: Optional[str] = None # For API chunked uploads class JobSelection(BaseModel): job_ids: List[str] # -------------------------------------------------------------------------------- # --- 3. CRUD OPERATIONS & WEBHOOKS # -------------------------------------------------------------------------------- def get_job(db: Session, job_id: str): # return db.query(Job).filter(Job.id == job_id).first() return db.query(Job).filter(Job.id == job_id).first() def get_jobs(db: Session, user_id: str | None = None, skip: int = 0, limit: int = 100): query = db.query(Job) if user_id: query = query.filter(Job.user_id == user_id) return query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all() def create_job(db: Session, job: JobCreate): db_job = Job(**job.model_dump()) db.add(db_job) db.commit() db.refresh(db_job) return db_job def update_job_status(db: Session, job_id: str, status: str, progress: int = None, error: str = None): db_job = get_job(db, job_id) if db_job: db_job.status = status if progress is not None: db_job.progress = progress if error: db_job.error_message = error db.commit() db.refresh(db_job) return db_job def mark_job_as_completed(db: Session, job_id: str, output_filepath_str: str | None = None, preview: str | None = None): db_job = get_job(db, job_id) if db_job and db_job.status != 'cancelled': db_job.status = "completed" db_job.progress = 100 if preview: db_job.result_preview = preview.strip()[:2000] if output_filepath_str: try: output_path = Path(output_filepath_str) if output_path.exists(): db_job.output_filesize = output_path.stat().st_size except Exception: logger.exception(f"Could not stat output file {output_filepath_str} for job {job_id}") db.commit() return db_job def send_webhook_notification(job_id: str, app_config: Dict[str, Any], base_url: str): """Sends a notification to the callback URL if one is configured for the job.""" webhook_config = app_config.get("webhook_settings", {}) if not webhook_config.get("enabled", False): return db = SessionLocal() try: job = get_job(db, job_id) if not job or not job.callback_url: return download_url = None if job.status == "completed" and job.processed_filepath: filename = Path(job.processed_filepath).name public_url = app_config.get("app_settings", {}).get("app_public_url", base_url) if not public_url: logger.warning(f"app_public_url is not set. Cannot generate a full download URL for job {job_id}.") download_url = f"/download/{filename}" # Relative URL as fallback else: download_url = urljoin(public_url, f"/download/{filename}") payload = { "job_id": job.id, "status": job.status, "original_filename": job.original_filename, "download_url": download_url, "error_message": job.error_message, "created_at": job.created_at.isoformat() + "Z", "updated_at": job.updated_at.isoformat() + "Z", } headers = {"Content-Type": "application/json", "User-Agent": "FileProcessor-Webhook/1.0"} token = webhook_config.get("callback_bearer_token") if token: headers["Authorization"] = f"Bearer {token}" try: with httpx.Client() as client: response = client.post(job.callback_url, json=payload, headers=headers, timeout=15) response.raise_for_status() logger.info(f"Sent webhook notification for job {job_id} to {job.callback_url} (Status: {response.status_code})") except httpx.RequestError as e: logger.error(f"Failed to send webhook for job {job_id} to {job.callback_url}: {e}") except httpx.HTTPStatusError as e: logger.error(f"Webhook for job {job_id} received non-2xx response {e.response.status_code} from {job.callback_url}") except Exception as e: logger.exception(f"An unexpected error occurred in send_webhook_notification for job {job_id}: {e}") finally: db.close() # -------------------------------------------------------------------------------- # --- 4. BACKGROUND TASK SETUP # -------------------------------------------------------------------------------- huey = SqliteHuey(filename=PATHS.HUEY_DB_PATH) WHISPER_MODELS_CACHE: Dict[str, WhisperModel] = {} PIPER_VOICES_CACHE: Dict[str, "PiperVoice"] = {} AVAILABLE_TTS_VOICES_CACHE: Dict[str, Any] | None = None _model_locks: Dict[str, threading.Lock] = {} _global_lock = threading.Lock() def get_whisper_model(model_size: str, whisper_settings: dict) -> Any: # Fast path: cache hit without any locking if model_size in WHISPER_MODELS_CACHE: logger.debug(f"Cache hit for model '{model_size}'") return WHISPER_MODELS_CACHE[model_size] # Prepare for potential load with minimal contention model_lock = _get_or_create_model_lock(model_size) # Critical section: check cache again under model-specific lock with model_lock: if model_size in WHISPER_MODELS_CACHE: return WHISPER_MODELS_CACHE[model_size] logger.info(f"Loading Whisper model '{model_size}'...") try: # Optimized initialization with validated settings device = whisper_settings.get("device", "cpu") compute_type = whisper_settings.get("compute_type", "int8") # fast_whisper-specific optimizations model = WhisperModel( model_size, device=device, compute_type=compute_type, cpu_threads=max(1, os.cpu_count() // 2), # Prevent CPU oversubscription num_workers=1 # Optimal for most transcription workloads ) # Atomic cache update WHISPER_MODELS_CACHE[model_size] = model logger.info(f"Model '{model_size}' loaded (device={device}, compute={compute_type})") return model except Exception as e: logger.error(f"Model '{model_size}' failed to load: {str(e)}", exc_info=True) raise RuntimeError(f"Whisper model initialization failed: {e}") from e def _get_or_create_model_lock(model_size: str) -> threading.Lock: """Thread-safe lock acquisition with minimal global contention""" # Fast path: lock already exists if model_size in _model_locks: return _model_locks[model_size] # Slow path: create lock under global lock with _global_lock: return _model_locks.setdefault(model_size, threading.Lock()) def get_piper_voice(model_name: str, tts_settings: dict | None) -> "PiperVoice": """ Load (or download + load) a Piper voice in a robust way: - Try Python API helpers (get_voices, ensure_voice_exists/find_voice, download_voice) - On any failure, try CLI fallback (download_voice_cli) - Attempt to locate model files after download (search subdirs) - Try re-importing piper if bindings were previously unavailable """ # ----- Defensive normalization ----- if tts_settings is None or not isinstance(tts_settings, dict): logger.debug("get_piper_voice: normalizing tts_settings (was %r)", tts_settings) tts_settings = {} model_dir_val = tts_settings.get("model_dir", None) if model_dir_val is None: model_dir = Path(str(PATHS.TTS_MODELS_DIR)) else: try: model_dir = Path(model_dir_val) except Exception: logger.warning("Could not coerce tts_settings['model_dir']=%r to Path; using default.", model_dir_val) model_dir = Path(str(PATHS.TTS_MODELS_DIR)) model_dir.mkdir(parents=True, exist_ok=True) # If PiperVoice already cached, reuse if model_name in PIPER_VOICES_CACHE: logger.info("Reusing cached Piper voice '%s'.", model_name) return PIPER_VOICES_CACHE[model_name] with _model_semaphore: if model_name in PIPER_VOICES_CACHE: return PIPER_VOICES_CACHE[model_name] # If Python bindings are missing, attempt CLI download first (and try re-import) if PiperVoice is None: logger.info("Piper Python bindings missing; attempting CLI download fallback for '%s' before failing import.", model_name) cli_ok = False try: cli_ok = download_voice_cli(model_name, model_dir) except Exception as e: logger.warning("CLI download attempt raised: %s", e) cli_ok = False if cli_ok: # attempt to re-import piper package (maybe import issue was transient) try: importlib.invalidate_caches() piper_mod = importlib.import_module("piper") from piper import PiperVoice as _PiperVoice # noqa: F401 from piper.synthesis import SynthesisConfig as _SynthesisConfig # noqa: F401 globals().update({"PiperVoice": _PiperVoice, "SynthesisConfig": _SynthesisConfig}) logger.info("Successfully re-imported piper after CLI download.") except Exception: logger.warning("Could not import piper after CLI download; bindings still unavailable.") # If bindings still absent, we cannot load models; raise helpful error if PiperVoice is None: raise RuntimeError( "Piper Python bindings are not installed or failed to import. " "Tried CLI download fallback but python bindings are still unavailable. " "Please install 'piper-tts' in the runtime used by this process." ) # Now we have Piper bindings (or they were present to begin with). Attempt Python helpers. onnx_path = None config_path = None # Prefer using get_voices to update the index if available voices_info = None try: if get_voices: try: voices_info = get_voices(str(model_dir), update_voices=True) except TypeError: # some versions may not support update_voices kwarg voices_info = get_voices(str(model_dir)) except Exception as e: logger.debug("get_voices failed or unavailable: %s", e) voices_info = None try: # Preferred modern helpers if ensure_voice_exists and find_voice: try: ensure_voice_exists(model_name, [model_dir], model_dir, voices_info) onnx_path, config_path = find_voice(model_name, [model_dir]) except Exception as e: # Could be VoiceNotFoundError or other download error logger.warning("ensure/find voice failed for %s: %s", model_name, e) raise elif download_voice: # older API: call download helper directly try: download_voice(model_name, model_dir) # attempt to locate files onnx_path = model_dir / f"{model_name}.onnx" config_path = model_dir / f"{model_name}.onnx.json" except Exception: logger.warning("download_voice failed for %s", model_name) raise else: # No python download helper available raise RuntimeError("No Python download helper available in installed piper package.") except Exception as py_exc: # Python helper route failed; try CLI fallback BEFORE giving up logger.info("Python download route failed for '%s' (%s). Trying CLI fallback...", model_name, py_exc) try: cli_ok = download_voice_cli(model_name, model_dir) except Exception as e: logger.warning("CLI fallback attempt raised: %s", e) cli_ok = False if not cli_ok: # If CLI also failed, re-raise the original python exception to preserve context logger.error("Both Python download helpers and CLI fallback failed for '%s'.", model_name) raise # CLI succeeded (or at least returned success) — try to find files on disk onnx_path, config_path = _find_model_files(model_name, model_dir) if not (onnx_path and config_path): # maybe CLI wrote into a nested dir or different name; try to search broadly logger.info("Could not find model files after CLI download in %s; attempting broader search...", model_dir) onnx_path, config_path = _find_model_files(model_name, model_dir) if not (onnx_path and config_path): logger.error("Model files still missing after CLI fallback for '%s'.", model_name) raise RuntimeError(f"Piper voice files for '{model_name}' missing after CLI fallback.") # continue to loading below # Final safety check and last-resort search if not (onnx_path and config_path): onnx_path, config_path = _find_model_files(model_name, model_dir) if not (onnx_path and config_path): raise RuntimeError(f"Piper voice files for '{model_name}' are missing after attempts to download.") # Load the PiperVoice try: use_cuda = bool(tts_settings.get("use_cuda", False)) voice = PiperVoice.load(str(onnx_path), config_path=str(config_path), use_cuda=use_cuda) PIPER_VOICES_CACHE[model_name] = voice logger.info("Loaded Piper voice '%s' from %s", model_name, onnx_path) return voice except Exception as e: logger.exception("Failed to load Piper voice '%s' from files (%s, %s): %s", model_name, onnx_path, config_path, e) raise def _find_model_files(model_name: str, model_dir: Path): """ Try multiple strategies to find onnx and config files for a given model_name under model_dir. Returns (onnx_path, config_path) or (None, None). """ # direct files in model_dir onnx = model_dir / f"{model_name}.onnx" cfg = model_dir / f"{model_name}.onnx.json" if onnx.exists() and cfg.exists(): return onnx, cfg # possible alternative names or nested directories: search recursively matches_onnx = list(model_dir.rglob(f"{model_name}*.onnx")) matches_cfg = list(model_dir.rglob(f"{model_name}*.onnx.json")) if matches_onnx and matches_cfg: # prefer same directory match for o in matches_onnx: for c in matches_cfg: if o.parent == c.parent: return o, c # otherwise return first matches return matches_onnx[0], matches_cfg[0] # last-resort: any onnx + any json in same subdir that contain model name token for o in model_dir.rglob("*.onnx"): if model_name in o.name: # try find any matching json in same dir cands = list(o.parent.glob("*.onnx.json")) if cands: return o, cands[0] return None, None # --------------------------- # CLI: list available voices # --------------------------- def list_voices_cli(timeout: int = 30, python_executables: Optional[List[str]] = None) -> List[str]: """ Run `python -m piper.download_voices` (no args) and parse output into a list of voice IDs. Returns [] on failure. """ if python_executables is None: python_executables = [sys.executable, "python3", "python"] # Regex: voice ids look like en_US-lessac-medium (letters/digits/._-) voice_regex = re.compile(r'^([A-Za-z0-9_\-\.]+)') for py in python_executables: cmd = [py, "-m", "piper.download_voices"] try: logger.debug("Trying Piper CLI list: %s", shlex.join(cmd)) cp = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout, ) out = cp.stdout.strip() # If stdout empty, sometimes the script writes to stderr if not out: out = cp.stderr.strip() if not out: logger.debug("Piper CLI listed nothing (empty output) for %s", py) continue voices = [] for line in out.splitlines(): line = line.strip() if not line: continue # Try to extract first token that matches voice id pattern m = voice_regex.match(line) if m: v = m.group(1) # basic sanity: avoid capturing words like 'Available' or headings if re.search(r'\d', v) or '-' in v or '_' in v or '.' in v: voices.append(v) else: # allow alphabetic tokens too (defensive) voices.append(v) else: # Also handle lines like " - en_US-lessac-medium: description" parts = re.split(r'[:\s]+', line) if parts: candidate = parts[0].lstrip('-').strip() if candidate: voices.append(candidate) # Dedupe while preserving order seen = set() dedup = [] for v in voices: if v not in seen: seen.add(v) dedup.append(v) logger.info("Piper CLI list returned %d voices via %s", len(dedup), py) return dedup except subprocess.CalledProcessError as e: logger.debug("Piper CLI list (%s) non-zero exit. stdout=%s stderr=%s", py, e.stdout, e.stderr) except FileNotFoundError: logger.debug("Python executable not found: %s", py) except subprocess.TimeoutExpired: logger.warning("Piper CLI list timed out for %s", py) except Exception as e: logger.exception("Unexpected error running Piper CLI list with %s: %s", py, e) logger.error("All Piper CLI list attempts failed.") return [] # --------------------------- # CLI: download a voice # --------------------------- def download_voice_cli(model_name: str, model_dir: Path, python_executables: Optional[List[str]] = None, timeout: int = 300) -> bool: """ Try to download a Piper voice using CLI: python -m piper.download_voices --data-dir Returns True if the CLI ran and expected files exist afterwards (best effort). """ if python_executables is None: python_executables = [sys.executable, "python3", "python"] model_dir = Path(model_dir) model_dir.mkdir(parents=True, exist_ok=True) for py in python_executables: cmd = [py, "-m", "piper.download_voices", model_name, "--data-dir", str(model_dir)] try: logger.info("Trying Piper CLI download: %s", shlex.join(cmd)) cp = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout, ) logger.debug("Piper CLI download stdout: %s", cp.stdout) logger.debug("Piper CLI download stderr: %s", cp.stderr) # Heuristic success check onnx = model_dir / f"{model_name}.onnx" cfg = model_dir / f"{model_name}.onnx.json" if onnx.exists() and cfg.exists(): logger.info("Piper CLI created expected files for %s", model_name) return True # Some versions might create nested dirs; treat non-error CLI execution as success (caller will re-check) return True except subprocess.CalledProcessError as e: logger.warning("Piper CLI (%s) returned non-zero exit. stdout: %s; stderr: %s", py, e.stdout, e.stderr) except FileNotFoundError: logger.debug("Python executable %s not found.", py) except subprocess.TimeoutExpired: logger.warning("Piper CLI call timed out for python %s", py) except Exception as e: logger.exception("Unexpected error running Piper CLI download with %s: %s", py, e) logger.error("All Piper CLI attempts failed for model %s", model_name) return False # --------------------------- # Safe get_voices wrapper # --------------------------- def safe_get_voices(model_dir: Path) -> List[Dict]: """ Try to call the in-Python get_voices(..., update_voices=True) and return a list of dicts. If that fails, fall back to list_voices_cli() and return a list of simple dicts: [{"id": "en_US-lessac-medium", "name": "...", "local": False}, ...] Keeps the shape flexible so your existing endpoint can use it with minimal changes. """ # Prefer Python API if available try: if get_voices: # get_voices imported earlier in your file # Ensure up-to-date index (like CLI) raw = get_voices(str(model_dir), update_voices=True) # get_voices may already return the desired structure; normalise to a list of dicts if isinstance(raw, dict): # some versions return mapping id->meta items = [] for vid, meta in raw.items(): d = {"id": vid} if isinstance(meta, dict): d.update(meta) items.append(d) return items elif isinstance(raw, list): return raw else: # unknown format -> fall back to CLI logger.debug("get_voices returned unexpected type; falling back to CLI list.") except Exception as e: logger.warning("In-Python get_voices failed: %s. Falling back to CLI listing.", e) # CLI fallback: parse voice ids and create simple dicts cli_list = list_voices_cli() results = [{"id": vid, "name": vid, "local": False} for vid in cli_list] return results def list_kokoro_voices_cli(timeout: int = 60) -> List[str]: """ Run `kokoro-tts --help-voices` and parse the output for available models. Returns [] on failure. """ model_path = PATHS.KOKORO_MODEL_FILE voices_path = PATHS.KOKORO_VOICES_FILE if not (model_path.exists() and voices_path.exists()): logger.warning("Cannot list Kokoro TTS voices because model/voices files are missing.") return [] cmd = ["kokoro-tts", "--help-voices", "--model", str(model_path), "--voices", str(voices_path)] try: logger.debug("Trying Kokoro TTS CLI list: %s", shlex.join(cmd)) cp = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout, ) out = cp.stdout.strip() if not out: out = cp.stderr.strip() if not out: logger.warning("Kokoro TTS CLI list returned no output.") return [] voices = [] voice_pattern = re.compile(r'^\s*\d+\.\s+([a-z]{2,3}_[a-zA-Z0-9]+)$') for line in out.splitlines(): line = line.strip() match = voice_pattern.match(line) if match: voices.append(match.group(1)) logger.info("Kokoro TTS CLI list returned %d voices.", len(voices)) return sorted(list(set(voices))) except FileNotFoundError: logger.info("Kokoro TTS ('kokoro-tts' command) not found in PATH. Kokoro TTS support disabled.") return [] except subprocess.CalledProcessError as e: logger.error("Kokoro TTS CLI list command failed. stderr: %s", e.stderr[:1000]) return [] except subprocess.TimeoutExpired: logger.warning("Kokoro TTS CLI list command timed out.") return [] except Exception as e: logger.exception("Unexpected error running Kokoro TTS CLI list: %s", e) return [] def list_kokoro_languages_cli(timeout: int = 60) -> List[str]: """ Run `kokoro-tts --help-languages` and parse the output for available languages. Returns [] on failure. """ model_path = PATHS.KOKORO_MODEL_FILE voices_path = PATHS.KOKORO_VOICES_FILE if not (model_path.exists() and voices_path.exists()): logger.warning("Cannot list Kokoro TTS languages because model/voices files are missing.") return [] cmd = ["kokoro-tts", "--help-languages", "--model", str(model_path), "--voices", str(voices_path)] try: logger.debug("Trying Kokoro TTS language list: %s", shlex.join(cmd)) cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout) out = cp.stdout.strip() if not out: out = cp.stderr.strip() if not out: logger.warning("Kokoro TTS language list returned no output.") return [] languages = [] lang_pattern = re.compile(r'^\s*([a-z]{2,3}(?:-[a-z]{2,3})?)$') for line in out.splitlines(): line = line.strip() if line.lower().startswith("supported languages"): continue match = lang_pattern.match(line) if match: languages.append(match.group(1)) logger.info("Kokoro TTS language list returned %d languages.", len(languages)) return sorted(list(set(languages))) except FileNotFoundError: logger.info("Kokoro TTS ('kokoro-tts' command) not found in PATH. Kokoro TTS support disabled.") return [] except subprocess.CalledProcessError as e: logger.error("Kokoro TTS language list command failed. stderr: %s", e.stderr[:1000]) return [] except subprocess.TimeoutExpired: logger.warning("Kokoro TTS language list command timed out.") return [] except Exception as e: logger.exception("Unexpected error running Kokoro TTS language list: %s", e) return [] def run_command( argv: List[str], timeout: int = 300, max_output_size: int = 5 * 1024 * 1024 # 5MB ) -> subprocess.CompletedProcess: """ Drop-in replacement for your run_command. - Incrementally reads stdout/stderr in separate threads to avoid unbounded memory growth. - Keeps at most `max_output_size` characters per stream (first N chars). - Enforces a timeout (graceful terminate then kill). - Uses optional preexec function `_limit_resources_preexec` if present in globals. - Raises Exception on non-zero exit or timeout; returns CompletedProcess on success. """ logger.debug("Executing command: %s with timeout=%ss", " ".join(argv), timeout) # quick sanity: ensure there's a program to execute (improves error clarity) try: exe = argv[0] except Exception: raise Exception("Invalid argv passed to run_command") preexec = globals().get("_limit_resources_preexec", None) # Buffers and state for threads stdout_chunks = [] stderr_chunks = [] stdout_len = 0 stderr_len = 0 stdout_lock = threading.Lock() stderr_lock = threading.Lock() stdout_truncated = False stderr_truncated = False def _reader(stream, chunks, lock, name): nonlocal stdout_len, stderr_len, stdout_truncated, stderr_truncated try: while True: data = stream.read(4096) if not data: break with lock: # choose which counters to use by stream identity if name == "stdout": if stdout_len < max_output_size: # append as much as fits remaining = max_output_size - stdout_len to_append = data[:remaining] chunks.append(to_append) stdout_len += len(to_append) if len(data) > remaining: stdout_truncated = True else: stdout_truncated = True else: if stderr_len < max_output_size: remaining = max_output_size - stderr_len to_append = data[:remaining] chunks.append(to_append) stderr_len += len(to_append) if len(data) > remaining: stderr_truncated = True else: stderr_truncated = True except Exception: logger.exception("Reader thread for %s failed", name) finally: try: stream.close() except Exception: pass # Start process try: proc = subprocess.Popen( argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, preexec_fn=preexec ) except FileNotFoundError: msg = f"Command not found: {argv[0]}" logger.error(msg) raise Exception(msg) except Exception as e: msg = f"Unexpected error launching command: {e}" logger.exception(msg) raise Exception(msg) # Start reader threads t_stdout = threading.Thread(target=_reader, args=(proc.stdout, stdout_chunks, stdout_lock, "stdout"), daemon=True) t_stderr = threading.Thread(target=_reader, args=(proc.stderr, stderr_chunks, stderr_lock, "stderr"), daemon=True) t_stdout.start() t_stderr.start() # Wait loop with timeout start = time.monotonic() try: while True: ret = proc.poll() if ret is not None: break elapsed = time.monotonic() - start if timeout and elapsed > timeout: # Timeout -> try terminate -> kill if needed logger.error("Command timed out after %ss: %s", timeout, " ".join(argv)) try: proc.terminate() # give it a short while to exit waited = 0.0 while proc.poll() is None and waited < 2.0: time.sleep(0.1) waited += 0.1 except Exception: logger.exception("Failed to terminate process on timeout; attempting kill.") if proc.poll() is None: try: proc.kill() except Exception: logger.exception("Failed to kill process after timeout.") # ensure threads finish reading leftover data try: t_stdout.join(timeout=1.0) t_stderr.join(timeout=1.0) except Exception: pass raise Exception(f"Command timed out after {timeout}s: {' '.join(argv)}") # sleep a little to avoid busy loop time.sleep(0.1) # process finished normally; allow readers to finish t_stdout.join(timeout=2.0) t_stderr.join(timeout=2.0) # build strings from chunks with stdout_lock: stdout_str = "".join(stdout_chunks) if stdout_chunks else "" if stdout_truncated: truncated_amount = " (truncated to max_output_size)" stdout_str += f"\n[TRUNCATED - output larger than {max_output_size} bytes]{truncated_amount}" with stderr_lock: stderr_str = "".join(stderr_chunks) if stderr_chunks else "" if stderr_truncated: truncated_amount = " (truncated to max_output_size)" stderr_str += f"\n[TRUNCATED - output larger than {max_output_size} bytes]{truncated_amount}" # Check return code rc = proc.returncode if rc != 0: # include limited stderr snippet for diagnostics (like your original) snippet = (stderr_str or "")[:1000] msg = f"Command failed with exit code {rc}. Stderr: {snippet}" logger.error(msg) raise Exception(msg) logger.debug("Command completed successfully: %s", " ".join(argv)) return subprocess.CompletedProcess(args=argv, returncode=rc, stdout=stdout_str, stderr=stderr_str) finally: # ensure no resource leaks try: if proc.stdout: try: proc.stdout.close() except Exception: pass if proc.stderr: try: proc.stderr.close() except Exception: pass except Exception: pass def validate_and_build_command(template_str: str, mapping: Dict[str, str]) -> TypingList[str]: fmt = Formatter() used = {fname for _, fname, _, _ in fmt.parse(template_str) if fname} ALLOWED_VARS = { "input", "output", "output_dir", "output_ext", "quality", "speed", "preset", "device", "dpi", "samplerate", "bitdepth", "filter", "model_name", "model_path", "voices_path", "lang" } bad = used - ALLOWED_VARS if bad: raise ValueError(f"Command template contains disallowed placeholders: {bad}") safe_mapping = dict(mapping) for name in used: if name not in safe_mapping: safe_mapping[name] = safe_mapping.get("output_ext", "") if name == "filter" else "" formatted = template_str.format(**safe_mapping) return shlex.split(formatted) # --- TASK RUNNERS --- @huey.task() def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str, whisper_settings: dict, app_config: dict, base_url: str): db = SessionLocal() input_path = Path(input_path_str) output_path = Path(output_path_str) # --- Constants --- # Time in seconds between database checks for progress updates and cancellations. # This avoids hammering the database on every single segment. DB_POLL_INTERVAL_SECONDS = 5 try: job = get_job(db, job_id) if not job: logger.warning(f"Job {job_id} not found. Aborting task.") return if job.status == 'cancelled': logger.info(f"Job {job_id} was already cancelled before starting. Aborting.") return update_job_status(db, job_id, "processing", progress=0) model = get_whisper_model(model_size, whisper_settings) logger.info(f"Starting transcription for job {job_id} with model '{model_size}'") segments_generator, info = model.transcribe(str(input_path), beam_size=5) logger.info(f"Detected language: {info.language} with probability {info.language_probability:.2f} for a duration of {info.duration:.2f}s") last_update_time = time.time() # Use a temporary file to ensure atomic writes. The final file will only appear # once the transcription is fully and successfully written. tmp_output_path = output_path.with_name(f"{output_path.stem}.tmp-{uuid.uuid4().hex}{output_path.suffix}") # Store a small preview in memory for the final update without holding the whole transcript. preview_segments = [] PREVIEW_MAX_LENGTH = 1000 # characters current_preview_length = 0 with tmp_output_path.open("w", encoding="utf-8") as f: for segment in segments_generator: segment_text = segment.text.strip() f.write(segment_text + "\n") # Build a small preview if current_preview_length < PREVIEW_MAX_LENGTH: preview_segments.append(segment_text) current_preview_length += len(segment_text) current_time = time.time() if current_time - last_update_time > DB_POLL_INTERVAL_SECONDS: last_update_time = current_time # Check for cancellation without overwhelming the DB job_check = get_job(db, job_id) if job_check and job_check.status == 'cancelled': logger.info(f"Job {job_id} cancelled during transcription. Stopping.") # The temporary file will be cleaned up in the finally block return # Update progress if info.duration > 0: progress = int((segment.end / info.duration) * 100) update_job_status(db, job_id, "processing", progress=progress) tmp_output_path.replace(output_path) transcript_preview = " ".join(preview_segments) if len(transcript_preview) > PREVIEW_MAX_LENGTH: transcript_preview = transcript_preview[:PREVIEW_MAX_LENGTH] + "..." mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=transcript_preview) logger.info(f"Transcription for job {job_id} completed successfully.") except Exception as e: logger.exception(f"An unexpected error occurred during transcription for job {job_id}") update_job_status(db, job_id, "failed", error=str(e)) finally: # This block executes whether the task succeeded, failed, or was cancelled and returned. logger.debug(f"Performing cleanup for job {job_id}") # Clean up the temporary file if it still exists (e.g., due to cancellation) if 'tmp_output_path' in locals() and tmp_output_path.exists(): try: tmp_output_path.unlink() logger.debug(f"Removed temporary file: {tmp_output_path}") except OSError as e: logger.error(f"Error removing temporary file {tmp_output_path}: {e}") # Clean up the original input file try: # First, ensure we are not deleting from an unexpected directory ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR]) input_path.unlink(missing_ok=True) logger.debug(f"Removed input file: {input_path}") except Exception as e: logger.exception(f"Failed to cleanup input file {input_path} for job {job_id}: {e}") if db: db.close() # Send notification last, after all state has been finalized. send_webhook_notification(job_id, app_config, base_url) @huey.task() def run_tts_task(job_id: str, input_path_str: str, output_path_str: str, model_name: str, tts_settings: dict, app_config: dict, base_url: str): db = SessionLocal() input_path = Path(input_path_str) try: job = get_job(db, job_id) if not job or job.status == 'cancelled': return update_job_status(db, job_id, "processing") engine, actual_model_name = "piper", model_name if '/' in model_name: parts = model_name.split('/', 1) engine = parts[0] actual_model_name = parts[1] logger.info(f"Starting TTS for job {job_id} using engine '{engine}' with model '{actual_model_name}'") out_path = Path(output_path_str) tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix}") if engine == "piper": piper_settings = tts_settings.get("piper", {}) voice = get_piper_voice(actual_model_name, piper_settings) with open(input_path, 'r', encoding='utf-8') as f: text_to_speak = f.read() synthesis_params = piper_settings.get("synthesis_config", {}) synthesis_config = SynthesisConfig(**synthesis_params) if SynthesisConfig else None with wave.open(str(tmp_out), "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(voice.config.sample_rate) voice.synthesize_wav(text_to_speak, wav_file, synthesis_config) elif engine == "kokoro": kokoro_settings = tts_settings.get("kokoro", {}) command_template_str = kokoro_settings.get("command_template") if not command_template_str: raise ValueError("Kokoro TTS command_template is not defined in settings.") try: lang, voice_name = actual_model_name.split('/', 1) except ValueError: raise ValueError(f"Invalid Kokoro model format. Expected 'lang/voice', but got '{actual_model_name}'.") mapping = { "input": str(input_path), "output": str(tmp_out), "lang": lang, "model_name": voice_name, "model_path": str(PATHS.KOKORO_MODEL_FILE), "voices_path": str(PATHS.KOKORO_VOICES_FILE), } command = validate_and_build_command(command_template_str, mapping) logger.info(f"Executing Kokoro TTS command: {' '.join(command)}") run_command(command, timeout=kokoro_settings.get("timeout", 300)) if not tmp_out.exists(): raise FileNotFoundError("Kokoro TTS command did not produce an output file.") else: raise ValueError(f"Unsupported TTS engine: {engine}") tmp_out.replace(out_path) mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview="Successfully generated audio.") logger.info(f"TTS for job {job_id} completed.") except Exception as e: logger.exception(f"ERROR during TTS for job {job_id}") update_job_status(db, job_id, "failed", error=f"TTS failed: {e}") finally: try: ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR]) input_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup input file after TTS.") db.close() send_webhook_notification(job_id, app_config, base_url) @huey.task() def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str, ocr_settings: dict, app_config: dict, base_url: str): db = SessionLocal() input_path = Path(input_path_str) try: job = get_job(db, job_id) if not job or job.status == 'cancelled': return update_job_status(db, job_id, "processing") logger.info(f"Starting PDF OCR for job {job_id}") ocrmypdf.ocr(str(input_path), str(output_path_str), deskew=ocr_settings.get('deskew', True), force_ocr=ocr_settings.get('force_ocr', True), clean=ocr_settings.get('clean', True), optimize=ocr_settings.get('optimize', 1), progress_bar=False) with open(output_path_str, "rb") as f: reader = pypdf.PdfReader(f) preview = "\n".join(page.extract_text() or "" for page in reader.pages) mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=preview) logger.info(f"PDF OCR for job {job_id} completed.") except Exception as e: logger.exception(f"ERROR during PDF OCR for job {job_id}") update_job_status(db, job_id, "failed", error=f"PDF OCR failed: {e}") finally: try: ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR]) input_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup input file after PDF OCR.") db.close() send_webhook_notification(job_id, app_config, base_url) @huey.task() def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str, app_config: dict, base_url: str): db = SessionLocal() input_path = Path(input_path_str) out_path = Path(output_path_str) try: job = get_job(db, job_id) if not job or job.status == "cancelled": return update_job_status(db, job_id, "processing", progress=10) logger.info(f"Starting Image OCR for job {job_id} - {input_path}") # open image and gather frames (support multi-frame TIFF) try: pil_img = Image.open(str(input_path)) except UnidentifiedImageError as e: raise RuntimeError(f"Cannot identify/open input image: {e}") frames = [] try: # some images support n_frames (multi-page TIFF); iterate safely n_frames = getattr(pil_img, "n_frames", 1) for i in range(n_frames): pil_img.seek(i) # copy the frame to avoid problems when the original image object is closed frames.append(pil_img.convert("RGB").copy()) except Exception: # fallback: single frame frames = [pil_img.convert("RGB")] update_job_status(db, job_id, "processing", progress=30) pdf_bytes_list = [] text_parts = [] for idx, frame in enumerate(frames): # produce searchable PDF bytes for the frame and plain text as well try: pdf_bytes = pytesseract.image_to_pdf_or_hocr(frame, extension="pdf") except TesseractNotFoundError as e: raise RuntimeError("Tesseract not found. Ensure Tesseract OCR is installed and in PATH.") from e except Exception as e: raise RuntimeError(f"Failed to run Tesseract on frame {idx}: {e}") from e pdf_bytes_list.append(pdf_bytes) # also extract plain text for preview and possible fallback try: page_text = pytesseract.image_to_string(frame) except Exception: page_text = "" text_parts.append(page_text) # update progress incrementally prog = 30 + int((idx + 1) / max(1, len(frames)) * 50) update_job_status(db, job_id, "processing", progress=min(prog, 80)) # merge per-page pdfs if multiple frames final_pdf_bytes = None if len(pdf_bytes_list) == 1: final_pdf_bytes = pdf_bytes_list[0] else: if _HAS_PYPDF2: merger = PdfMerger() for b in pdf_bytes_list: merger.append(io.BytesIO(b)) out_buffer = io.BytesIO() merger.write(out_buffer) merger.close() final_pdf_bytes = out_buffer.getvalue() else: # PyPDF2 not installed — try a simple concatenation (not valid PDF merge), # better to fail loudly so user can install PyPDF2; but as a fallback # write the first page only and include a warning in job preview. logger.warning("PyPDF2 not available; only the first frame will be written to output PDF.") final_pdf_bytes = pdf_bytes_list[0] text_parts.insert(0, "[WARNING] Multiple frames detected but PyPDF2 not available; only first page saved.\n") # write out atomically tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix or '.pdf'}") try: tmp_out.parent.mkdir(parents=True, exist_ok=True) with tmp_out.open("wb") as f: f.write(final_pdf_bytes) tmp_out.replace(out_path) except Exception as e: raise RuntimeError(f"Failed writing output PDF to {out_path}: {e}") from e # create a preview from the recognized text (limit length) full_text = "\n\n".join(text_parts).strip() preview = full_text[:1000] + ("…" if len(full_text) > 1000 else "") mark_job_as_completed(db, job_id, output_filepath_str=str(out_path), preview=preview) update_job_status(db, job_id, "completed", progress=100) logger.info(f"Image OCR for job {job_id} completed. Output: {out_path}") except TesseractNotFoundError: logger.exception(f"Tesseract not found for job {job_id}") update_job_status(db, job_id, "failed", error="Image OCR failed: Tesseract not found on server.") except Exception as e: logger.exception(f"ERROR during Image OCR for job {job_id}: {e}") update_job_status(db, job_id, "failed", error=f"Image OCR failed: {e}") finally: # cleanup input file (but only if it lives in allowed uploads dir) try: ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR]) input_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup input file after Image OCR.") try: db.close() except Exception: logger.exception("Failed to close DB session after Image OCR.") # send webhook regardless of success/failure (keeps original behavior) try: send_webhook_notification(job_id, app_config, base_url) except Exception: logger.exception("Failed to send webhook notification after Image OCR.") @huey.task() def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str, tool: str, task_key: str, conversion_tools_config: dict, app_config: dict, base_url: str): """ Drop-in replacement for conversion task. - Uses improved run_command for short operations (resource-limited). - Uses cancellable Popen runner for long-running conversion to respond to DB cancellations. """ db = SessionLocal() input_path = Path(input_path_str) output_path = Path(output_path_str) # localize helpers for speed _get_job = get_job _update_job_status = update_job_status _validate_build = validate_and_build_command _mark_completed = mark_job_as_completed _ensure_safe = ensure_path_is_safe _send_webhook = send_webhook_notification temp_input_file: Optional[Path] = None temp_output_file: Optional[Path] = None POLL_INTERVAL = 1.0 STDERR_SNIPPET = 4000 def _parse_task_key(tool_name: str, tk: str, tool_cfg: dict, mapping: dict): try: if tool_name.startswith("ghostscript"): parts = tk.split("_", 1) device = parts[0] if parts and parts[0] else "" setting = parts[1] if len(parts) > 1 else "" mapping.update({"device": device, "dpi": setting, "preset": setting}) elif tool_name == "pngquant": parts = tk.split("_", 1) quality_key = parts[1] if len(parts) > 1 else (parts[0] if parts else "mq") quality_map = {"hq": "80-95", "mq": "65-80", "fast": "65-80"} speed_map = {"hq": "1", "mq": "3", "fast": "11"} mapping.update({"quality": quality_map.get(quality_key, "65-80"), "speed": speed_map.get(quality_key, "3")}) elif tool_name == "sox": parts = tk.split("_") if len(parts) >= 3: rate_token = parts[-2] depth_token = parts[-1] elif len(parts) == 2: rate_token = parts[-1] depth_token = "" else: rate_token = "" depth_token = "" rate_val = rate_token.replace("k", "000") if rate_token else "" if depth_token: depth_val = ('-b' + depth_token.replace('b', '')) if 'b' in depth_token else depth_token else: depth_val = '' mapping.update({"samplerate": rate_val, "bitdepth": depth_val}) elif tool_name == "mozjpeg": parts = tk.split("_", 1) quality_token = parts[1] if len(parts) > 1 else (parts[0] if parts else "") quality = quality_token.replace("q", "") if quality_token else "" mapping.update({"quality": quality}) elif tool_name == "libreoffice": target_ext = output_path.suffix.lstrip('.') filter_val = tool_cfg.get("filters", {}).get(target_ext, target_ext) mapping["filter"] = filter_val except Exception: logger.exception("Failed to parse task_key for tool %s; continuing with defaults.", tool_name) def _run_cancellable_command(command: List[str], timeout: int): """ Run command with Popen and poll the DB for cancellation. Enforce timeout. Returns CompletedProcess-like on success. Raises Exception on failure/timeout/cancel. """ preexec = globals().get("_limit_resources_preexec", None) logger.debug("Launching conversion subprocess: %s", " ".join(shlex.quote(c) for c in command)) proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, preexec_fn=preexec) start = time.monotonic() stderr_accum = [] stderr_len = 0 STDERR_LIMIT = STDERR_SNIPPET try: while True: ret = proc.poll() # Check job status job_check = _get_job(db, job_id) if job_check is None: logger.warning("Job %s disappeared; killing conversion process.", job_id) try: proc.kill() except Exception: pass raise Exception("Job disappeared during conversion") if job_check.status == "cancelled": logger.info("Job %s cancelled; terminating conversion process.", job_id) try: proc.kill() except Exception: pass raise Exception("Conversion cancelled") if ret is not None: # process done - read remaining stderr/stdout safely try: out, err = proc.communicate(timeout=2) except Exception: out, err = "", "" try: if proc.stderr: err = proc.stderr.read(STDERR_LIMIT) except Exception: pass if err and len(err) > STDERR_LIMIT: err = err[-STDERR_LIMIT:] if ret != 0: msg = (err or "")[:STDERR_LIMIT] raise Exception(f"Conversion command failed (rc={ret}): {msg}") return subprocess.CompletedProcess(args=command, returncode=ret, stdout=out, stderr=err) # timeout check elapsed = time.monotonic() - start if timeout and elapsed > timeout: logger.warning("Conversion command timed out after %ss; terminating.", timeout) try: proc.kill() except Exception: pass raise Exception("Conversion command timed out") time.sleep(POLL_INTERVAL) finally: try: if proc.stdout: proc.stdout.close() if proc.stderr: proc.stderr.close() except Exception: pass try: job = _get_job(db, job_id) if not job: logger.warning("Job %s not found; aborting conversion.", job_id) return if job.status == "cancelled": logger.info("Job %s already cancelled; aborting conversion.", job_id) return _update_job_status(db, job_id, "processing", progress=25) logger.info("Starting conversion for job %s using %s with task %s", job_id, tool, task_key) tool_config = conversion_tools_config.get(tool) if not tool_config: raise ValueError(f"Unknown conversion tool: {tool}") current_input_path = input_path # Pre-conversion step for mozjpeg uses improved run_command (resource-limited) if tool == "mozjpeg": temp_input_file = input_path.with_suffix('.temp.ppm') logger.info("Pre-converting for MozJPEG: %s -> %s", input_path, temp_input_file) vips_bin = shutil.which("vips") or "vips" pre_conv_cmd = [vips_bin, "copy", str(input_path), str(temp_input_file)] try: run_command(pre_conv_cmd, timeout=int(tool_config.get("timeout", 300))) except Exception as ex: err_msg = str(ex) short_err = (err_msg or "")[:STDERR_SNIPPET] logger.exception("MozJPEG pre-conversion failed: %s", short_err) raise Exception(f"MozJPEG pre-conversion to PPM failed: {short_err}") current_input_path = temp_input_file _update_job_status(db, job_id, "processing", progress=50) # Prepare atomic temp output on same FS output_path.parent.mkdir(parents=True, exist_ok=True) temp_output_file = output_path.with_name(f"{output_path.stem}.tmp-{uuid.uuid4().hex}{output_path.suffix}") mapping = { "input": str(current_input_path), "output": str(temp_output_file), "output_dir": str(output_path.parent), "output_ext": output_path.suffix.lstrip('.'), } _parse_task_key(tool, task_key, tool_config, mapping) command_template_str = tool_config.get("command_template") if not command_template_str: raise ValueError(f"Tool '{tool}' missing 'command_template' in configuration.") command = _validate_build(command_template_str, mapping) if not isinstance(command, (list, tuple)) or not command: raise ValueError("validate_and_build_command must return a non-empty list/tuple command.") command = [str(x) for x in command] logger.info("Executing command: %s", " ".join(shlex.quote(c) for c in command)) # Run main conversion in cancellable manner timeout_val = int(tool_config.get("timeout", 300)) # call the cancellable runner above result = _run_cancellable_command(command, timeout=timeout_val) if False else None # the above is replaced with a direct call to the actual function: result = _run_cancellable_command(command, timeout=timeout_val) # If successful and temp output exists, move it into place atomically if temp_output_file and temp_output_file.exists(): temp_output_file.replace(output_path) _mark_completed(db, job_id, output_filepath_str=str(output_path), preview="Successfully converted file.") logger.info("Conversion for job %s completed.", job_id) except Exception as e: logger.exception("ERROR during conversion for job %s: %s", job_id, e) try: _update_job_status(db, job_id, "failed", error=f"Conversion failed: {e}") except Exception: logger.exception("Failed to update job status to failed after conversion error.") finally: # clean main input try: _ensure_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR]) input_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup main input file after conversion.") # cleanup temp input if temp_input_file: try: temp_input_file_path = Path(temp_input_file) _ensure_safe(temp_input_file_path, [PATHS.UPLOADS_DIR, PATHS.PROCESSED_DIR]) temp_input_file_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup temp input file after conversion.") if temp_output_file: try: temp_output_file_path = Path(temp_output_file) _ensure_safe(temp_output_file_path, [PATHS.UPLOADS_DIR, PATHS.PROCESSED_DIR]) temp_output_file_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup temp output file after conversion.") try: db.close() except Exception: logger.exception("Failed to close DB session after conversion.") try: gc.collect() except Exception: pass try: _send_webhook(job_id, app_config, base_url) except Exception: logger.exception("Failed to send webhook notification after conversion.") def dispatch_single_file_job(original_filename: str, input_filepath: str, task_type: str, user: dict, db: Session, app_config: Dict, base_url: str, job_id: str | None = None, options: Dict = None, parent_job_id: str | None = None): """Helper to create and dispatch a job for a single file.""" if options is None: options = {} # If no job_id is passed, generate one. This is for sub-tasks from zips. if job_id is None: job_id = uuid.uuid4().hex safe_filename = secure_filename(original_filename) final_path = Path(input_filepath) # Ensure the input file exists before creating a job if not final_path.exists(): logger.error(f"Input file does not exist, cannot dispatch job: {input_filepath}") return job_data = JobCreate( id=job_id, user_id=user['sub'], task_type=task_type, original_filename=original_filename, input_filepath=str(final_path), input_filesize=final_path.stat().st_size, parent_job_id=parent_job_id ) if task_type == "transcription": stem = Path(safe_filename).stem processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt" job_data.processed_filepath = str(processed_path) create_job(db=db, job=job_data) run_transcription_task(job_data.id, str(final_path), str(processed_path), options.get("model_size", "base"), app_config.get("transcription_settings", {}).get("whisper", {}), app_config, base_url) elif task_type == "tts": tts_config = app_config.get("tts_settings", {}) stem = Path(safe_filename).stem processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.wav" job_data.processed_filepath = str(processed_path) create_job(db=db, job=job_data) run_tts_task(job_data.id, str(final_path), str(processed_path), options.get("model_name"), tts_config, app_config, base_url) elif task_type == "ocr": stem, suffix = Path(safe_filename).stem, Path(safe_filename).suffix.lower() IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp', '.webp'} if suffix not in IMAGE_EXTENSIONS and suffix != '.pdf': logger.warning(f"Skipping unsupported file type for OCR: {original_filename}") # Clean up the orphaned file from the zip extraction final_path.unlink(missing_ok=True) return processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.pdf" job_data.processed_filepath = str(processed_path) create_job(db=db, job=job_data) if suffix in IMAGE_EXTENSIONS: run_image_ocr_task(job_data.id, str(final_path), str(processed_path), app_config, base_url) else: run_pdf_ocr_task(job_data.id, str(final_path), str(processed_path), app_config.get("ocr_settings", {}).get("ocrmypdf", {}), app_config, base_url) elif task_type == "conversion": try: tool, task_key = options.get("output_format").split('_', 1) except (AttributeError, ValueError): logger.error(f"Invalid or missing output_format for conversion of {original_filename}") final_path.unlink(missing_ok=True) return original_stem = Path(safe_filename).stem target_ext = task_key.split('_')[0] if tool == "ghostscript_pdf": target_ext = "pdf" processed_path = PATHS.PROCESSED_DIR / f"{original_stem}_{job_id}.{target_ext}" job_data.processed_filepath = str(processed_path) create_job(db=db, job=job_data) run_conversion_task(job_data.id, str(final_path), str(processed_path), tool, task_key, app_config.get("conversion_tools", {}), app_config, base_url) else: logger.error(f"Invalid task type '{task_type}' for file {original_filename}") final_path.unlink(missing_ok=True) @huey.task() def unzip_and_dispatch_task(job_id: str, input_path_str: str, sub_task_type: str, sub_task_options: dict, user: dict, app_config: dict, base_url: str): db = SessionLocal() input_path = Path(input_path_str) unzip_dir = PATHS.UPLOADS_DIR / f"unzipped_{job_id}" try: if not zipfile.is_zipfile(input_path): raise ValueError("Uploaded file is not a valid ZIP archive.") unzip_dir.mkdir() with zipfile.ZipFile(input_path, 'r') as zip_ref: zip_ref.extractall(unzip_dir) file_count = 0 for extracted_file_path in unzip_dir.rglob('*'): if extracted_file_path.is_file(): file_count += 1 dispatch_single_file_job( original_filename=extracted_file_path.name, input_filepath=str(extracted_file_path), task_type=sub_task_type, options=sub_task_options, user=user, db=db, app_config=app_config, base_url=base_url, parent_job_id=job_id ) if file_count > 0: # Mark parent job as processing, to be completed by the periodic task update_job_status(db, job_id, "processing", progress=0) else: # No files found, mark as completed with a note mark_job_as_completed(db, job_id, preview="ZIP archive was empty. No sub-jobs created.") except Exception as e: logger.exception(f"ERROR during ZIP processing for job {job_id}") update_job_status(db, job_id, "failed", error=f"Failed to process ZIP file: {e}") # If unzipping fails, clean up the directory if unzip_dir.exists(): shutil.rmtree(unzip_dir) finally: try: # CRITICAL FIX: Only delete the original ZIP file. # Do NOT delete the unzip_dir here, as the sub-tasks need the files. ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR]) input_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to cleanup original ZIP file.") db.close() @huey.periodic_task(crontab(minute='*/1')) # Runs every 1 minutes def update_unzip_job_progress(): """Periodically checks and updates the progress of parent 'unzip' jobs.""" db = SessionLocal() try: # Find all 'unzip' jobs that are still marked as 'processing' parent_jobs_to_check = db.query(Job).filter( Job.task_type == 'unzip', Job.status == 'processing' ).all() if not parent_jobs_to_check: return # Nothing to do logger.info(f"Checking progress for {len(parent_jobs_to_check)} active batch jobs.") for parent_job in parent_jobs_to_check: # Find all children of this parent job child_jobs = db.query(Job).filter(Job.parent_job_id == parent_job.id).all() total_children = len(child_jobs) if total_children == 0: # This case shouldn't happen if unzip_and_dispatch_task works, but as a safeguard: mark_job_as_completed(db, parent_job.id, preview="Batch job completed with no sub-tasks.") continue finished_children = 0 for child in child_jobs: if child.status in ['completed', 'failed', 'cancelled']: finished_children += 1 # Calculate and update progress progress = int((finished_children / total_children) * 100) if total_children > 0 else 100 if finished_children == total_children: # All children are done, mark the parent as completed failed_count = sum(1 for child in child_jobs if child.status == 'failed') preview = f"Batch processing complete. {total_children - failed_count}/{total_children} tasks succeeded." if failed_count > 0: preview += f" ({failed_count} failed)." mark_job_as_completed(db, parent_job.id, preview=preview) logger.info(f"Batch job {parent_job.id} marked as completed.") else: # Update the progress if it has changed if parent_job.progress != progress: update_job_status(db, parent_job.id, 'processing', progress=progress) except Exception as e: logger.exception(f"Error in periodic task update_unzip_job_progress: {e}") finally: db.close() # -------------------------------------------------------------------------------- # --- 5. FASTAPI APPLICATION # -------------------------------------------------------------------------------- async def download_kokoro_models_if_missing(): """Checks for Kokoro TTS model files and downloads them if they don't exist.""" files_to_download = { "model": {"path": PATHS.KOKORO_MODEL_FILE, "url": "https://github.com/nazdridoy/kokoro-tts/releases/download/v1.0.0/kokoro-v1.0.onnx"}, "voices": {"path": PATHS.KOKORO_VOICES_FILE, "url": "https://github.com/nazdridoy/kokoro-tts/releases/download/v1.0.0/voices-v1.0.bin"} } async with httpx.AsyncClient() as client: for name, details in files_to_download.items(): path, url = details["path"], details["url"] if not path.exists(): logger.info(f"Kokoro TTS {name} file missing. Downloading from {url}...") try: with path.open("wb") as f: async with client.stream("GET", url, follow_redirects=True, timeout=300) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): f.write(chunk) logger.info(f"Successfully downloaded Kokoro TTS {name} file to {path}.") except Exception as e: logger.error(f"Failed to download Kokoro TTS {name} file: {e}") if path.exists(): path.unlink(missing_ok=True) else: logger.info(f"Found existing Kokoro TTS {name} file at {path}.") @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Application starting up...") # Base.metadata.create_all(bind=engine) create_attempts = 3 for attempt in range(1, create_attempts + 1): try: # use engine.begin() to ensure the DDL runs in a connection/transaction context with engine.begin() as conn: Base.metadata.create_all(bind=conn) logger.info("Database tables ensured (create_all succeeded).") break except OperationalError as oe: # Some SQLite drivers raise an OperationalError when two processes try to create the same table at once. msg = str(oe).lower() # If we see "already exists" we treat this as a race and retry briefly. if "already exists" in msg or ("table" in msg and "already exists" in msg): logger.warning( "Database table creation race detected (attempt %d/%d): %s. Retrying...", attempt, create_attempts, oe, ) time.sleep(0.5) continue else: logger.exception("Database initialization failed with OperationalError.") raise except Exception: logger.exception("Unexpected error during DB initialization.") raise load_app_config() # Download required models on startup if shutil.which("kokoro-tts"): await download_kokoro_models_if_missing() if PiperVoice is None: logger.warning("piper-tts is not installed. Piper TTS features will be disabled. Install with: pip install piper-tts") if not shutil.which("kokoro-tts"): logger.warning("kokoro-tts command not found in PATH. Kokoro TTS features will be disabled.") ENV = os.environ.get('ENV', 'dev').lower() ALLOW_LOCAL_ONLY = os.environ.get('ALLOW_LOCAL_ONLY', 'false').lower() == 'true' if LOCAL_ONLY_MODE and ENV != 'dev' and not ALLOW_LOCAL_ONLY: raise RuntimeError('LOCAL_ONLY_MODE may only be enabled in dev or when ALLOW_LOCAL_ONLY=true is set.') if not LOCAL_ONLY_MODE: oidc_cfg = APP_CONFIG.get('auth_settings', {}) if not all(oidc_cfg.get(k) for k in ['oidc_client_id', 'oidc_client_secret', 'oidc_server_metadata_url']): logger.error("OIDC auth settings are incomplete. Auth will be disabled if not in LOCAL_ONLY_MODE.") else: oauth.register( name='oidc', client_id=oidc_cfg.get('oidc_client_id'), client_secret=oidc_cfg.get('oidc_client_secret'), server_metadata_url=oidc_cfg.get('oidc_server_metadata_url'), client_kwargs={'scope': 'openid email profile'}, userinfo_endpoint=oidc_cfg.get('oidc_userinfo_endpoint'), end_session_endpoint=oidc_cfg.get('oidc_end_session_endpoint') ) logger.info('OAuth registered.') yield logger.info('Application shutting down...') app = FastAPI(lifespan=lifespan) ENV = os.environ.get('ENV', 'dev').lower() SECRET_KEY = os.environ.get('SECRET_KEY') if not SECRET_KEY and not LOCAL_ONLY_MODE and ENV != 'dev': raise RuntimeError('SECRET_KEY must be set in production when authentication is enabled.') if not SECRET_KEY: logger.warning('SECRET_KEY is not set. Generating a temporary key. Sessions will not persist across restarts.') SECRET_KEY = os.urandom(24).hex() app.add_middleware( SessionMiddleware, secret_key=SECRET_KEY, https_only=False, # Set to True if behind HTTPS proxy same_site='lax', max_age=14 * 24 * 60 * 60 # 14 days ) # Static / templates app.mount("/static", StaticFiles(directory=str(PATHS.BASE_DIR / "static")), name="static") templates = Jinja2Templates(directory=str(PATHS.BASE_DIR / "templates")) # --- AUTH & USER HELPERS --- http_bearer = HTTPBearer() def get_current_user(request: Request): if LOCAL_ONLY_MODE: return {'sub': 'local_user', 'email': 'local@user.com', 'name': 'Local User'} return request.session.get('user') async def require_api_user(request: Request, creds: HTTPAuthorizationCredentials = Depends(http_bearer)): """Dependency for API routes requiring OIDC bearer token authentication.""" if LOCAL_ONLY_MODE: return {'sub': 'local_api_user', 'email': 'local@api.user.com', 'name': 'Local API User'} if not creds: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") token = creds.credentials try: user = await oauth.oidc.userinfo(token={'access_token': token}) return dict(user) except Exception as e: logger.error(f"API token validation failed: {e}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token") def is_admin(request: Request) -> bool: if LOCAL_ONLY_MODE: return True user = get_current_user(request) if not user: return False admin_users = APP_CONFIG.get("auth_settings", {}).get("admin_users", []) return user.get('email') in admin_users def require_user(request: Request): user = get_current_user(request) if not user: raise HTTPException(status_code=401, detail="Not authenticated") return user def require_admin(request: Request): if not is_admin(request): raise HTTPException(status_code=403, detail="Administrator privileges required.") return True # --- FILE SAVING UTILITY --- async def save_upload_file(upload_file: UploadFile, destination: Path) -> int: """ Saves an uploaded file to a destination, handling size limits. This function is used by both the simple API and the legacy direct-upload routes. """ max_size = APP_CONFIG.get("app_settings", {}).get("max_file_size_bytes", 100 * 1024 * 1024) tmp_path = destination.with_name(f"{destination.stem}.tmp-{uuid.uuid4().hex}{destination.suffix}") size = 0 try: with tmp_path.open("wb") as buffer: while True: chunk = await upload_file.read(1024 * 1024) if not chunk: break size += len(chunk) if size > max_size: raise HTTPException(status_code=413, detail=f"File exceeds {max_size / 1024 / 1024} MB limit") buffer.write(chunk) tmp_path.replace(destination) return size except Exception as e: try: # Ensure temp file is cleaned up on error ensure_path_is_safe(tmp_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR]) tmp_path.unlink(missing_ok=True) except Exception: logger.exception("Failed to remove temp upload file after error.") # Re-raise the original exception raise e finally: await upload_file.close() def is_allowed_file(filename: str, allowed_extensions: set) -> bool: if not allowed_extensions: # If set is empty, allow all return True return Path(filename).suffix.lower() in allowed_extensions # --- CHUNKED UPLOADS (for UI) --- @app.post("/upload/chunk") async def upload_chunk( chunk: UploadFile = File(...), upload_id: str = Form(...), chunk_number: int = Form(...), user: dict = Depends(require_user) ): safe_upload_id = secure_filename(upload_id) temp_dir = ensure_path_is_safe(PATHS.CHUNK_TMP_DIR / safe_upload_id, [PATHS.CHUNK_TMP_DIR]) temp_dir.mkdir(exist_ok=True) chunk_path = temp_dir / f"{chunk_number}.chunk" try: with open(chunk_path, "wb") as buffer: shutil.copyfileobj(chunk.file, buffer) finally: chunk.file.close() return JSONResponse({"message": f"Chunk {chunk_number} for {safe_upload_id} uploaded."}) async def _stitch_chunks(temp_dir: Path, final_path: Path, total_chunks: int): """Stitches chunks together and cleans up.""" ensure_path_is_safe(temp_dir, [PATHS.CHUNK_TMP_DIR]) ensure_path_is_safe(final_path, [PATHS.UPLOADS_DIR]) with open(final_path, "wb") as final_file: for i in range(total_chunks): chunk_path = temp_dir / f"{i}.chunk" if not chunk_path.exists(): shutil.rmtree(temp_dir, ignore_errors=True) raise HTTPException(status_code=400, detail=f"Upload failed: missing chunk {i}") with open(chunk_path, "rb") as chunk_file: final_file.write(chunk_file.read()) shutil.rmtree(temp_dir, ignore_errors=True) @app.post("/upload/finalize", status_code=status.HTTP_202_ACCEPTED) async def finalize_upload(request: Request, payload: FinalizeUploadPayload, user: dict = Depends(require_user), db: Session = Depends(get_db)): safe_upload_id = secure_filename(payload.upload_id) temp_dir = ensure_path_is_safe(PATHS.CHUNK_TMP_DIR / safe_upload_id, [PATHS.CHUNK_TMP_DIR]) if not temp_dir.is_dir(): raise HTTPException(status_code=404, detail="Upload session not found or already finalized.") webhook_config = APP_CONFIG.get("webhook_settings", {}) if payload.callback_url and not is_allowed_callback_url(payload.callback_url, webhook_config.get("allowed_callback_urls", [])): raise HTTPException(status_code=400, detail="Provided callback_url is not allowed.") job_id = uuid.uuid4().hex safe_filename = secure_filename(payload.original_filename) final_path = PATHS.UPLOADS_DIR / f"{Path(safe_filename).stem}_{job_id}{Path(safe_filename).suffix}" await _stitch_chunks(temp_dir, final_path, payload.total_chunks) base_url = str(request.base_url) if Path(safe_filename).suffix.lower() == '.zip': job_data = JobCreate( id=job_id, user_id=user['sub'], task_type="unzip", original_filename=payload.original_filename, input_filepath=str(final_path), input_filesize=final_path.stat().st_size ) create_job(db=db, job=job_data) sub_task_options = { "model_size": payload.model_size, "model_name": payload.model_name, "output_format": payload.output_format } unzip_and_dispatch_task(job_id, str(final_path), payload.task_type, sub_task_options, user, APP_CONFIG, base_url) else: options = {"model_size": payload.model_size, "model_name": payload.model_name, "output_format": payload.output_format} dispatch_single_file_job(payload.original_filename, str(final_path), payload.task_type, user, db, APP_CONFIG, base_url, job_id=job_id, options=options) return {"job_id": job_id, "status": "pending"} # --- LEGACY DIRECT-UPLOAD ROUTES (kept for compatibility) --- @app.post("/transcribe-audio", status_code=status.HTTP_202_ACCEPTED) async def submit_audio_transcription( request: Request, file: UploadFile = File(...), model_size: str = Form("base"), db: Session = Depends(get_db), user: dict = Depends(require_user) ): allowed_audio_exts = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus"} if not is_allowed_file(file.filename, allowed_audio_exts): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio file type.") whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {}) if model_size not in whisper_config.get("allowed_models", []): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model size: {model_size}.") job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename) stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix upload_path = PATHS.UPLOADS_DIR / f"{stem}_{job_id}{suffix}" processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt" input_size = await save_upload_file(file, upload_path) base_url = str(request.base_url) job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="transcription", original_filename=file.filename, input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path)) new_job = create_job(db=db, job=job_data) run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size, whisper_settings=whisper_config, app_config=APP_CONFIG, base_url=base_url) return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"} @app.post("/convert-file", status_code=status.HTTP_202_ACCEPTED) async def submit_file_conversion(request: Request, file: UploadFile = File(...), output_format: str = Form(...), db: Session = Depends(get_db), user: dict = Depends(require_user)): allowed_exts = APP_CONFIG.get("app_settings", {}).get("allowed_all_extensions", set()) if not is_allowed_file(file.filename, allowed_exts): raise HTTPException(status_code=400, detail=f"File type '{Path(file.filename).suffix}' not allowed.") conversion_tools = APP_CONFIG.get("conversion_tools", {}) try: tool, task_key = output_format.split('_', 1) if tool not in conversion_tools: raise ValueError() except ValueError: raise HTTPException(status_code=400, detail="Invalid output format selected.") job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename) original_stem = Path(safe_basename).stem target_ext = task_key.split('_')[0] if tool == "ghostscript_pdf": target_ext = "pdf" upload_path = PATHS.UPLOADS_DIR / f"{original_stem}_{job_id}{Path(safe_basename).suffix}" processed_path = PATHS.PROCESSED_DIR / f"{original_stem}_{job_id}.{target_ext}" input_size = await save_upload_file(file, upload_path) base_url = str(request.base_url) job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="conversion", original_filename=file.filename, input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path)) new_job = create_job(db=db, job=job_data) run_conversion_task(new_job.id, str(upload_path), str(processed_path), tool, task_key, conversion_tools, APP_CONFIG, base_url) return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"} @app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED) async def submit_pdf_ocr(request: Request, file: UploadFile = File(...), db: Session = Depends(get_db), user: dict = Depends(require_user)): if not is_allowed_file(file.filename, {".pdf"}): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.") job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename) unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}" upload_path = PATHS.UPLOADS_DIR / unique_filename processed_path = PATHS.PROCESSED_DIR / unique_filename input_size = await save_upload_file(file, upload_path) base_url = str(request.base_url) job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="ocr", original_filename=file.filename, input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path)) new_job = create_job(db=db, job=job_data) ocr_settings = APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {}) run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path), ocr_settings, APP_CONFIG, base_url) return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"} @app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED) async def submit_image_ocr(request: Request, file: UploadFile = File(...), db: Session = Depends(get_db), user: dict = Depends(require_user)): allowed_exts = {".png", ".jpg", ".jpeg", ".tiff", ".tif"} if not is_allowed_file(file.filename, allowed_exts): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PNG, JPG, or TIFF.") job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename) file_ext = Path(safe_basename).suffix unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}" upload_path = PATHS.UPLOADS_DIR / unique_filename processed_path = PATHS.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.pdf" input_size = await save_upload_file(file, upload_path) base_url = str(request.base_url) job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="ocr-image", original_filename=file.filename, input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path)) new_job = create_job(db=db, job=job_data) run_image_ocr_task(new_job.id, str(upload_path), str(processed_path), APP_CONFIG, base_url) return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"} # -------------------------------------------------------------------------------- # --- API V1 ROUTES (for programmatic access) # -------------------------------------------------------------------------------- def is_allowed_callback_url(url: str, allowed: List[str]) -> bool: if not allowed: return False try: parsed = urlparse(url) if not parsed.scheme or not parsed.netloc: return False for a in allowed: ap = urlparse(a) if ap.scheme and ap.netloc: if parsed.scheme == ap.scheme and parsed.netloc == ap.netloc: return True else: # support legacy prefix entries - keep fallback if url.startswith(a): return True return False except Exception: return False @app.get("/api/v1/tts-voices") async def get_tts_voices_list(user: dict = Depends(require_user)): global AVAILABLE_TTS_VOICES_CACHE kokoro_available = shutil.which("kokoro-tts") is not None piper_available = PiperVoice is not None if not piper_available and not kokoro_available: return JSONResponse(content={"error": "TTS feature not configured on server (no TTS engines found)."}, status_code=501) if AVAILABLE_TTS_VOICES_CACHE: return AVAILABLE_TTS_VOICES_CACHE all_voices = [] try: if piper_available: logger.info("Fetching available Piper voices list...") piper_voices = safe_get_voices(PATHS.TTS_MODELS_DIR) for voice in piper_voices: voice['id'] = f"piper/{voice.get('id')}" voice['name'] = f"Piper: {voice.get('name', voice.get('id'))}" all_voices.extend(piper_voices) if kokoro_available: logger.info("Fetching available Kokoro TTS voices and languages...") kokoro_voices = list_kokoro_voices_cli() kokoro_langs = list_kokoro_languages_cli() for lang in kokoro_langs: for voice in kokoro_voices: all_voices.append({ "id": f"kokoro/{lang}/{voice}", "name": f"Kokoro ({lang}): {voice}", "local": False }) AVAILABLE_TTS_VOICES_CACHE = sorted(all_voices, key=lambda x: x['name']) return AVAILABLE_TTS_VOICES_CACHE except Exception as e: logger.exception("Could not fetch list of TTS voices.") raise HTTPException(status_code=500, detail=f"Could not retrieve voices list: {e}") # --- Standard API endpoint (non-chunked) --- @app.post("/api/v1/process", status_code=status.HTTP_202_ACCEPTED, tags=["Webhook API"]) async def api_process_file( request: Request, file: UploadFile = File(...), task_type: str = Form(...), callback_url: str = Form(...), model_size: Optional[str] = Form("base"), model_name: Optional[str] = Form(None), output_format: Optional[str] = Form(None), db: Session = Depends(get_db), user: dict = Depends(require_api_user) ): """ Programmatically submit a file for processing via a single HTTP request. This is the recommended endpoint for services like n8n. Requires bearer token authentication unless in LOCAL_ONLY_MODE. """ webhook_config = APP_CONFIG.get("webhook_settings", {}) if not webhook_config.get("enabled", False): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Webhook processing is disabled on the server.") if not is_allowed_callback_url(callback_url, webhook_config.get("allowed_callback_urls", [])): logger.warning(f"Rejected webhook from user '{user.get('email')}' with disallowed callback URL: {callback_url}") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Provided callback_url is not in the list of allowed URLs.") job_id = uuid.uuid4().hex safe_basename = secure_filename(file.filename) stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix upload_filename = f"{stem}_{job_id}{suffix}" upload_path = PATHS.UPLOADS_DIR / upload_filename try: input_size = await save_upload_file(file, upload_path) except HTTPException as e: raise e # Re-raise exceptions from save_upload_file (e.g., file too large) except Exception as e: logger.exception("Failed to save uploaded file for webhook processing.") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to save file: {e}") base_url = str(request.base_url) job_data_args = { "id": job_id, "user_id": user['sub'], "original_filename": file.filename, "input_filepath": str(upload_path), "input_filesize": input_size, "callback_url": callback_url, "task_type": task_type, } # --- API Task Dispatching Logic --- if task_type == "transcription": whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {}) if model_size not in whisper_config.get("allowed_models", []): raise HTTPException(status_code=400, detail=f"Invalid model_size '{model_size}'") processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt" job_data_args["processed_filepath"] = str(processed_path) create_job(db=db, job=JobCreate(**job_data_args)) run_transcription_task(job_id, str(upload_path), str(processed_path), model_size, whisper_config, APP_CONFIG, base_url) elif task_type == "tts": if not is_allowed_file(file.filename, {".txt"}): raise HTTPException(status_code=400, detail="Invalid file type for TTS, requires .txt") if not model_name: raise HTTPException(status_code=400, detail="model_name is required for TTS task.") tts_config = APP_CONFIG.get("tts_settings", {}) processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.wav" job_data_args["processed_filepath"] = str(processed_path) create_job(db=db, job=JobCreate(**job_data_args)) run_tts_task(job_id, str(upload_path), str(processed_path), model_name, tts_config, APP_CONFIG, base_url) elif task_type == "conversion": if not output_format: raise HTTPException(status_code=400, detail="output_format is required for conversion task.") conversion_tools = APP_CONFIG.get("conversion_tools", {}) try: tool, task_key = output_format.split('_', 1) if tool not in conversion_tools: raise ValueError("Invalid tool") except ValueError: raise HTTPException(status_code=400, detail="Invalid output_format selected.") target_ext = task_key.split('_')[0] if tool == "ghostscript_pdf": target_ext = "pdf" processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.{target_ext}" job_data_args["processed_filepath"] = str(processed_path) create_job(db=db, job=JobCreate(**job_data_args)) run_conversion_task(job_id, str(upload_path), str(processed_path), tool, task_key, conversion_tools, APP_CONFIG, base_url) elif task_type == "ocr": if not is_allowed_file(file.filename, {".pdf"}): raise HTTPException(status_code=400, detail="Invalid file type for ocr, requires .pdf") processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}{suffix}" job_data_args["processed_filepath"] = str(processed_path) create_job(db=db, job=JobCreate(**job_data_args)) run_pdf_ocr_task(job_id, str(upload_path), str(processed_path), APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {}), APP_CONFIG, base_url) elif task_type == "ocr-image": if not is_allowed_file(file.filename, {".png", ".jpg", ".jpeg", ".tiff", ".tif"}): raise HTTPException(status_code=400, detail="Invalid file type for ocr-image.") processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt" job_data_args["processed_filepath"] = str(processed_path) create_job(db=db, job=JobCreate(**job_data_args)) run_image_ocr_task(job_id, str(upload_path), str(processed_path), APP_CONFIG, base_url) else: upload_path.unlink(missing_ok=True) # Cleanup orphaned file raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid task_type: '{task_type}'") return {"job_id": job_id, "status": "pending"} # --- Chunked API endpoints (optional) --- @app.post("/api/v1/upload/chunk", tags=["Webhook API"]) async def api_upload_chunk( chunk: UploadFile = File(...), upload_id: str = Form(...), chunk_number: int = Form(...), user: dict = Depends(require_api_user) ): """API endpoint for uploading a single file chunk.""" webhook_config = APP_CONFIG.get("webhook_settings", {}) if not webhook_config.get("enabled", False) or not webhook_config.get("allow_chunked_api_uploads", False): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Chunked API uploads are disabled.") return await upload_chunk(chunk, upload_id, chunk_number, user) @app.post("/api/v1/upload/finalize", status_code=status.HTTP_202_ACCEPTED, tags=["Webhook API"]) async def api_finalize_upload( request: Request, payload: FinalizeUploadPayload, user: dict = Depends(require_api_user), db: Session = Depends(get_db) ): """API endpoint to finalize a chunked upload and start a processing job.""" webhook_config = APP_CONFIG.get("webhook_settings", {}) if not webhook_config.get("enabled", False) or not webhook_config.get("allow_chunked_api_uploads", False): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Chunked API uploads are disabled.") # Validate callback URL if provided for a webhook job if payload.callback_url and not is_allowed_callback_url(payload.callback_url, webhook_config.get("allowed_callback_urls", [])): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Provided callback_url is not allowed.") # Re-use the main finalization logic, but with API user context return await finalize_upload(request, payload, user, db) # -------------------------------------------------------------------------------- # --- AUTH & PAGE ROUTES # -------------------------------------------------------------------------------- if not LOCAL_ONLY_MODE: @app.get('/login') async def login(request: Request): redirect_uri = request.url_for('auth') return await oauth.oidc.authorize_redirect(request, redirect_uri) @app.get('/auth') async def auth(request: Request): try: token = await oauth.oidc.authorize_access_token(request) user = await oauth.oidc.userinfo(token=token) request.session['user'] = dict(user) request.session['id_token'] = token.get('id_token') except Exception as e: logger.error(f"Authentication failed: {e}") raise HTTPException(status_code=401, detail="Authentication failed") return RedirectResponse(url='/') @app.get("/logout") async def logout(request: Request): logout_endpoint = oauth.oidc.server_metadata.get("end_session_endpoint") if not logout_endpoint: request.session.clear() logger.warning("OIDC 'end_session_endpoint' not found. Performing local-only logout.") return RedirectResponse(url="/", status_code=302) post_logout_redirect_uri = str(request.url_for("get_index")) logout_url = f"{logout_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}" request.session.clear() return RedirectResponse(url=logout_url, status_code=302) # This is for reverse proxies that use forward auth @app.get("/api/authz/forward-auth") async def forward_auth(request: Request): redirect_uri = request.url_for('auth') return await oauth.oidc.authorize_redirect(request, redirect_uri) @app.get("/") async def get_index(request: Request): user = get_current_user(request) admin_status = is_admin(request) whisper_models = APP_CONFIG.get("transcription_settings", {}).get("whisper", {}).get("allowed_models", []) conversion_tools = APP_CONFIG.get("conversion_tools", {}) return templates.TemplateResponse("index.html", { "request": request, "user": user, "is_admin": admin_status, "whisper_models": sorted(list(whisper_models)), "conversion_tools": conversion_tools, "local_only_mode": LOCAL_ONLY_MODE }) @app.get("/settings") async def get_settings_page(request: Request): """Displays the contents of the currently active configuration file.""" user = get_current_user(request) admin_status = is_admin(request) current_config, config_source = {}, "none" try: with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f: current_config = yaml.safe_load(f) or {} config_source = str(PATHS.SETTINGS_FILE.name) except FileNotFoundError: try: with open(PATHS.DEFAULT_SETTINGS_FILE, 'r', encoding='utf8') as f: current_config = yaml.safe_load(f) or {} config_source = str(PATHS.DEFAULT_SETTINGS_FILE.name) except Exception as e: logger.exception(f"CRITICAL: Could not load fallback config: {e}") config_source = "error" except Exception as e: logger.exception(f"Could not load primary config: {e}") config_source = "error" return templates.TemplateResponse( "settings.html", {"request": request, "config": current_config, "config_source": config_source, "user": user, "is_admin": admin_status, "local_only_mode": LOCAL_ONLY_MODE} ) def deep_merge(source: dict, destination: dict) -> dict: """Recursively merges dicts.""" for key, value in source.items(): if isinstance(value, collections.abc.Mapping): node = destination.setdefault(key, {}) deep_merge(value, node) else: destination[key] = value return destination @app.post("/settings/save") async def save_settings( request: Request, new_config_from_ui: Dict = Body(...), admin: bool = Depends(require_admin) ): """Safely updates settings.yml by merging UI changes with the existing file.""" tmp_path = PATHS.SETTINGS_FILE.with_suffix(".tmp") user = get_current_user(request) try: if not new_config_from_ui: if PATHS.SETTINGS_FILE.exists(): PATHS.SETTINGS_FILE.unlink() logger.info(f"Admin '{user.get('email')}' reverted to default settings.") load_app_config() return JSONResponse({"message": "Settings reverted to default."}) try: with PATHS.SETTINGS_FILE.open("r", encoding="utf8") as f: current_config_on_disk = yaml.safe_load(f) or {} except FileNotFoundError: current_config_on_disk = {} merged_config = deep_merge(source=new_config_from_ui, destination=current_config_on_disk) with tmp_path.open("w", encoding="utf8") as f: yaml.safe_dump(merged_config, f, default_flow_style=False, sort_keys=False) tmp_path.replace(PATHS.SETTINGS_FILE) logger.info(f"Admin '{user.get('email')}' updated settings.yml.") load_app_config() return JSONResponse({"message": "Settings saved successfully."}) except Exception as e: logger.exception(f"Failed to update settings for admin '{user.get('email')}'") if tmp_path.exists(): tmp_path.unlink() raise HTTPException(status_code=500, detail=f"Could not save settings.yml: {e}") # -------------------------------------------------------------------------------- # --- JOB MANAGEMENT & UTILITY ROUTES # -------------------------------------------------------------------------------- @app.post("/settings/clear-history") async def clear_job_history(db: Session = Depends(get_db), user: dict = Depends(require_user)): try: num_deleted = db.query(Job).filter(Job.user_id == user['sub']).delete() db.commit() logger.info(f"Cleared {num_deleted} jobs for user {user['sub']}.") return {"deleted_count": num_deleted} except Exception: db.rollback() logger.exception("Failed to clear job history") raise HTTPException(status_code=500, detail="Database error while clearing history.") @app.post("/settings/delete-files") async def delete_processed_files(db: Session = Depends(get_db), user: dict = Depends(require_user)): deleted_count, errors = 0, [] for job in get_jobs(db, user_id=user['sub']): if job.processed_filepath: try: p = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR]) if p.is_file(): p.unlink() deleted_count += 1 except Exception: errors.append(Path(job.processed_filepath).name) logger.exception(f"Could not delete file {Path(job.processed_filepath).name}") if errors: raise HTTPException(status_code=500, detail=f"Could not delete some files: {', '.join(errors)}") logger.info(f"Deleted {deleted_count} files for user {user['sub']}.") return {"deleted_count": deleted_count} @app.post("/job/{job_id}/cancel", status_code=status.HTTP_202_ACCEPTED) async def cancel_job(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)): job = get_job(db, job_id) if not job or job.user_id != user['sub']: raise HTTPException(status_code=404, detail="Job not found.") if job.status in ["pending", "processing"]: update_job_status(db, job_id, status="cancelled") return {"message": "Job cancellation requested."} raise HTTPException(status_code=400, detail=f"Job is already in a final state ({job.status}).") @app.get("/jobs", response_model=List[JobSchema]) async def get_all_jobs(db: Session = Depends(get_db), user: dict = Depends(require_user)): return get_jobs(db, user_id=user['sub']) @app.get("/job/{job_id}", response_model=JobSchema) async def get_job_status(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)): job = get_job(db, job_id) if not job or job.user_id != user['sub']: raise HTTPException(status_code=404, detail="Job not found.") return job @app.get("/download/{filename}") async def download_file(filename: str, db: Session = Depends(get_db), user: dict = Depends(require_user)): safe_filename = secure_filename(filename) file_path = ensure_path_is_safe(PATHS.PROCESSED_DIR / safe_filename, [PATHS.PROCESSED_DIR]) if not file_path.is_file(): raise HTTPException(status_code=404, detail="File not found.") # API users can download files they own via webhook URL. UI users need session. job_owner_id = user.get('sub') if user else None job = db.query(Job).filter(Job.processed_filepath == str(file_path), Job.user_id == job_owner_id).first() if not job: raise HTTPException(status_code=403, detail="You do not have permission to download this file.") download_filename = Path(job.original_filename).stem + Path(job.processed_filepath).suffix return FileResponse(path=file_path, filename=download_filename, media_type="application/octet-stream") @app.post("/download/batch", response_class=StreamingResponse) async def download_batch(payload: JobSelection, db: Session = Depends(get_db), user: dict = Depends(require_user)): job_ids = payload.job_ids if not job_ids: raise HTTPException(status_code=400, detail="No job IDs provided.") zip_buffer = BytesIO() with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: for job_id in job_ids: job = get_job(db, job_id) if job and job.user_id == user['sub'] and job.status == 'completed' and job.processed_filepath: file_path = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR]) if file_path.exists(): download_filename = f"{Path(job.original_filename).stem}_{job_id}{file_path.suffix}" zip_file.write(file_path, arcname=download_filename) zip_buffer.seek(0) return StreamingResponse(zip_buffer, media_type="application/x-zip-compressed", headers={ 'Content-Disposition': f'attachment; filename="file-wizard-batch-{uuid.uuid4().hex[:8]}.zip"' }) @app.get("/download/zip-batch/{job_id}", response_class=StreamingResponse) async def download_zip_batch(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)): """Downloads all processed files from a ZIP upload batch as a new ZIP file.""" parent_job = get_job(db, job_id) if not parent_job or parent_job.user_id != user['sub']: raise HTTPException(status_code=404, detail="Parent job not found.") if parent_job.task_type != 'unzip': raise HTTPException(status_code=400, detail="This job is not a batch upload.") child_jobs = db.query(Job).filter(Job.parent_job_id == job_id, Job.status == 'completed').all() if not child_jobs: raise HTTPException(status_code=404, detail="No completed sub-jobs found for this batch.") zip_buffer = BytesIO() with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: files_added = 0 for job in child_jobs: if job.processed_filepath: file_path = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR]) if file_path.exists(): # Create a more user-friendly name inside the zip download_filename = f"{Path(job.original_filename).stem}{file_path.suffix}" zip_file.write(file_path, arcname=download_filename) files_added += 1 if files_added == 0: raise HTTPException(status_code=404, detail="No processed files found for the completed sub-jobs.") zip_buffer.seek(0) # Generate a filename for the download batch_filename = f"{Path(parent_job.original_filename).stem}_processed.zip" return StreamingResponse(zip_buffer, media_type="application/x-zip-compressed", headers={ 'Content-Disposition': f'attachment; filename="{batch_filename}"' }) @app.get("/health") async def health(): try: with engine.connect() as conn: conn.execution_options(isolation_level="AUTOCOMMIT").execute("SELECT 1") except Exception: logger.exception("Health check failed") return JSONResponse({"ok": False}, status_code=500) return {"ok": True} @app.get('/favicon.ico', include_in_schema=False) async def favicon(): return FileResponse(str(PATHS.BASE_DIR / 'static' / 'favicon.png'))