Files
filewizard/main.py
2025-10-03 10:30:53 +00:00

3035 lines
136 KiB
Python
Executable File

# main.py (merged)
import threading
import logging
import shutil
import subprocess
import traceback
import uuid
import shlex
import yaml
import os
import httpx
import glob
import cv2
import numpy as np
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Any, Optional
import resource
from threading import Semaphore
from logging.handlers import RotatingFileHandler
from urllib.parse import urljoin, urlparse
from io import BytesIO
import zipfile
import sys
import re
import importlib
import collections.abc
import time
import ocrmypdf
import pypdf
import pytesseract
from fastapi.middleware.cors import CORSMiddleware
from pytesseract import TesseractNotFoundError
from PIL import Image, UnidentifiedImageError
from faster_whisper import WhisperModel
from fastapi import (Depends, FastAPI, File, Form, HTTPException, Request,
UploadFile, status, Body)
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from huey import SqliteHuey, crontab
from pydantic import BaseModel, ConfigDict, field_serializer
from sqlalchemy import (Column, DateTime, Integer, String, Text,
create_engine, delete, event)
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.exc import OperationalError
from string import Formatter
from werkzeug.utils import secure_filename
from typing import List as TypingList
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client import OAuth
from dotenv import load_dotenv
from piper import PiperVoice
import wave
import io
load_dotenv()
# --- Optional Dependency Handling for Piper TTS ---
try:
from piper.synthesis import SynthesisConfig
# download helpers: some piper versions export download_voice, others expose ensure_voice_exists/find_voice
try:
# prefer the more explicit helpers if present
from piper.download import get_voices, ensure_voice_exists, find_voice, VoiceNotFoundError
except Exception:
# fall back to older API if available
try:
from piper.download import get_voices, download_voice, VoiceNotFoundError
ensure_voice_exists = None
find_voice = None
except Exception:
# partial import failed -> treat as piper-not-installed for download helpers
get_voices = None
download_voice = None
ensure_voice_exists = None
find_voice = None
VoiceNotFoundError = None
except ImportError:
SynthesisConfig = None
get_voices = None
download_voice = None
ensure_voice_exists = None
find_voice = None
VoiceNotFoundError = None
try:
from PyPDF2 import PdfMerger
_HAS_PYPDF2 = True
except Exception:
_HAS_PYPDF2 = False
# Instantiate OAuth object (was referenced in code)
oauth = OAuth()
# --------------------------------------------------------------------------------
# --- 1. CONFIGURATION & SECURITY HELPERS
# --------------------------------------------------------------------------------
# --- Path Safety ---
UPLOADS_BASE = Path(os.environ.get("UPLOADS_DIR", "/app/uploads")).resolve()
PROCESSED_BASE = Path(os.environ.get("PROCESSED_DIR", "/app/processed")).resolve()
CHUNK_TMP_BASE = Path(os.environ.get("CHUNK_TMP_DIR", str(UPLOADS_BASE / "tmp"))).resolve()
def ensure_path_is_safe(p: Path, allowed_bases: List[Path]):
"""Ensure a path resolves to a location within one of the allowed base directories."""
try:
resolved_p = p.resolve()
if not any(resolved_p.is_relative_to(base) for base in allowed_bases):
raise ValueError(f"Path {resolved_p} is outside of allowed directories.")
return resolved_p
except Exception as e:
logger = logging.getLogger(__name__)
logger.error(f"Path safety check failed for {p}: {e}")
raise ValueError("Invalid or unsafe path specified.")
# --- Resource Limiting ---
def _limit_resources_preexec():
"""Set resource limits for child processes to prevent DoS attacks."""
try:
# 6000s CPU, 4GB address space
resource.setrlimit(resource.RLIMIT_CPU, (6000, 6000))
resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, 4 * 1024 * 1024 * 1024))
except Exception as e:
# This may fail in some environments (e.g. Windows, some containers)
logging.getLogger(__name__).warning(f"Could not set resource limits: {e}")
pass
# --- Model concurrency semaphore (lazily initialized) ---
_model_semaphore: Optional[Semaphore] = None
def get_model_semaphore() -> Semaphore:
"""Lazily initializes and returns the global model semaphore."""
global _model_semaphore
if _model_semaphore is None:
# Read from app config, fall back to env var, then to a hardcoded default of 1
model_concurrency_from_env = int(os.environ.get("MODEL_CONCURRENCY", "1"))
model_concurrency = APP_CONFIG.get("app_settings", {}).get("model_concurrency", model_concurrency_from_env)
_model_semaphore = Semaphore(model_concurrency)
logger.info(f"Model concurrency semaphore initialized with limit: {model_concurrency}")
return _model_semaphore
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
_log_handler = RotatingFileHandler("app.log", maxBytes=10*1024*1024, backupCount=1)
_log_formatter = logging.Formatter('%(asctime)s %(levelname)s %(name)s %(message)s')
_log_handler.setFormatter(_log_formatter)
logging.getLogger().addHandler(_log_handler)
logger = logging.getLogger(__name__)
# --- Environment Mode ---
LOCAL_ONLY_MODE = os.getenv('LOCAL_ONLY', 'True').lower() in ('true', '1', 't')
if LOCAL_ONLY_MODE:
logger.warning("Authentication is DISABLED. Running in LOCAL_ONLY mode.")
class AppPaths(BaseModel):
BASE_DIR: Path = Path(__file__).resolve().parent
UPLOADS_DIR: Path = UPLOADS_BASE
PROCESSED_DIR: Path = PROCESSED_BASE
CHUNK_TMP_DIR: Path = CHUNK_TMP_BASE
TTS_MODELS_DIR: Path = BASE_DIR / "models" / "tts"
KOKORO_TTS_MODELS_DIR: Path = BASE_DIR / "models" / "tts" / "kokoro"
KOKORO_MODEL_FILE: Path = KOKORO_TTS_MODELS_DIR / "kokoro-v1.0.onnx"
KOKORO_VOICES_FILE: Path = KOKORO_TTS_MODELS_DIR / "voices-v1.0.bin"
DATABASE_URL: str = f"sqlite:///{BASE_DIR / 'jobs.db'}"
HUEY_DB_PATH: str = str(BASE_DIR / "huey.db")
CONFIG_DIR: Path = BASE_DIR / "config"
SETTINGS_FILE: Path = CONFIG_DIR / "settings.yml"
DEFAULT_SETTINGS_FILE: Path = BASE_DIR / "settings.default.yml"
PATHS = AppPaths()
APP_CONFIG: Dict[str, Any] = {}
PATHS.UPLOADS_DIR.mkdir(exist_ok=True, parents=True)
PATHS.PROCESSED_DIR.mkdir(exist_ok=True, parents=True)
PATHS.CHUNK_TMP_DIR.mkdir(exist_ok=True, parents=True)
PATHS.CONFIG_DIR.mkdir(exist_ok=True, parents=True)
PATHS.TTS_MODELS_DIR.mkdir(exist_ok=True, parents=True)
PATHS.KOKORO_TTS_MODELS_DIR.mkdir(exist_ok=True, parents=True)
def load_app_config():
"""
Loads configuration from settings.yml, with a fallback to settings.default.yml,
and finally to hardcoded defaults if both files are missing.
Reads environment variables for transcription device settings.
"""
global APP_CONFIG
try:
# --- Primary Method: Attempt to load settings.yml ---
with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f:
cfg_raw = yaml.safe_load(f) or {}
# Read transcription settings from environment variables, providing smart defaults.
transcription_device = os.environ.get("TRANSCRIPTION_DEVICE", "cpu")
# Default to float16 for CUDA for better performance, otherwise int8 for CPU.
default_compute_type = "float16" if transcription_device == "cuda" else "int8"
transcription_compute_type = os.environ.get("TRANSCRIPTION_COMPUTE_TYPE", default_compute_type)
transcription_device_index_str = os.environ.get("TRANSCRIPTION_DEVICE_INDEX", "0")
# Handle multiple device indexes (e.g., "0,1")
try:
if ',' in transcription_device_index_str:
transcription_device_index = [int(i.strip()) for i in transcription_device_index_str.split(',')]
else:
transcription_device_index = int(transcription_device_index_str)
except ValueError:
logger.warning(f"Invalid TRANSCRIPTION_DEVICE_INDEX value: '{transcription_device_index_str}'. Defaulting to 0.")
transcription_device_index = 0
defaults = {
"app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": [], "app_public_url": ""},
"transcription_settings": {
"whisper": {
"allowed_models": ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"],
"compute_type": transcription_compute_type,
"device": transcription_device,
"device_index": transcription_device_index
}
},
"tts_settings": {
"piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}},
"kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"}
},
"conversion_tools": {},
"ocr_settings": {"ocrmypdf": {}},
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []},
"webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""}
}
cfg = defaults.copy()
cfg.update(cfg_raw) # Merge loaded settings into defaults
app_settings = cfg.get("app_settings", {})
max_mb = app_settings.get("max_file_size_mb", 100)
app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024
allowed = app_settings.get("allowed_all_extensions", [])
if not isinstance(allowed, (list, set)):
allowed = list(allowed)
app_settings["allowed_all_extensions"] = set(allowed)
cfg["app_settings"] = app_settings
APP_CONFIG = cfg
logger.info("Successfully loaded settings from settings.yml")
except (FileNotFoundError, yaml.YAMLError) as e:
logger.warning(f"Could not load settings.yml: {e}. Falling back to settings.default.yml...")
try:
# --- Fallback Method: Attempt to load settings.default.yml ---
with open(PATHS.DEFAULT_SETTINGS_FILE, 'r', encoding='utf8') as f:
cfg_raw = yaml.safe_load(f) or {}
defaults = {
"app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": [], "app_public_url": ""},
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}},
"tts_settings": {
"piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}},
"kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"}
},
"conversion_tools": {},
"ocr_settings": {"ocrmypdf": {}},
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []},
"webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""}
}
cfg = defaults.copy()
cfg.update(cfg_raw) # Merge loaded settings into defaults
app_settings = cfg.get("app_settings", {})
max_mb = app_settings.get("max_file_size_mb", 100)
app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024
allowed = app_settings.get("allowed_all_extensions", [])
if not isinstance(allowed, (list, set)):
allowed = list(allowed)
app_settings["allowed_all_extensions"] = set(allowed)
cfg["app_settings"] = app_settings
APP_CONFIG = cfg
logger.info("Successfully loaded settings from settings.default.yml")
except (FileNotFoundError, yaml.YAMLError) as e_fallback:
# --- Final Failsafe: Use hardcoded defaults ---
logger.error(f"CRITICAL: Fallback file settings.default.yml also failed: {e_fallback}. Using hardcoded defaults.")
APP_CONFIG = {
"app_settings": {"max_file_size_mb": 100, "max_file_size_bytes": 100 * 1024 * 1024, "allowed_all_extensions": set(), "app_public_url": ""},
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}},
"tts_settings": {
"piper": {"model_dir": str(PATHS.TTS_MODELS_DIR), "use_cuda": False, "synthesis_config": {"length_scale": 1.0, "noise_scale": 0.667, "noise_w": 0.8}},
"kokoro": {"model_dir": str(PATHS.KOKORO_TTS_MODELS_DIR), "command_template": "kokoro-tts {input} {output} --model {model_path} --voices {voices_path} --lang {lang} --voice {model_name}"}
},
"conversion_tools": {},
"ocr_settings": {"ocrmypdf": {}},
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []},
"webhook_settings": {"enabled": False, "allow_chunked_api_uploads": False, "allowed_callback_urls": [], "callback_bearer_token": ""}
}
# --------------------------------------------------------------------------------
# --- 2. DATABASE & Schemas
# --------------------------------------------------------------------------------
engine = create_engine(
PATHS.DATABASE_URL,
connect_args={"check_same_thread": False, "timeout": 30},
poolclass=NullPool,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
@event.listens_for(engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, connection_record):
c = dbapi_connection.cursor()
try:
c.execute("PRAGMA journal_mode=WAL;")
c.execute("PRAGMA synchronous=NORMAL;")
finally:
c.close()
class Job(Base):
__tablename__ = "jobs"
id = Column(String, primary_key=True, index=True)
user_id = Column(String, index=True, nullable=True)
parent_job_id = Column(String, index=True, nullable=True)
task_type = Column(String, index=True)
status = Column(String, default="pending")
progress = Column(Integer, default=0)
original_filename = Column(String)
input_filepath = Column(String)
input_filesize = Column(Integer, nullable=True)
processed_filepath = Column(String, nullable=True)
output_filesize = Column(Integer, nullable=True)
result_preview = Column(Text, nullable=True)
error_message = Column(Text, nullable=True)
callback_url = Column(String, nullable=True)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
class JobCreate(BaseModel):
id: str
user_id: str | None = None
parent_job_id: str | None = None
task_type: str
original_filename: str
input_filepath: str
input_filesize: int | None = None
callback_url: str | None = None
processed_filepath: str | None = None
class JobSchema(BaseModel):
id: str
parent_job_id: str | None = None
task_type: str
status: str
progress: int
original_filename: str
input_filesize: int | None = None
output_filesize: int | None = None
processed_filepath: str | None = None
result_preview: str | None = None
error_message: str | None = None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
@field_serializer('created_at', 'updated_at')
def serialize_dt(self, dt: datetime, _info):
return dt.isoformat() + "Z"
class FinalizeUploadPayload(BaseModel):
upload_id: str
original_filename: str
total_chunks: int
task_type: str
model_size: str = ""
model_name: str = ""
output_format: str = ""
callback_url: Optional[str] = None # For API chunked uploads
class JobSelection(BaseModel):
job_ids: List[str]
# --------------------------------------------------------------------------------
# --- 3. CRUD OPERATIONS & WEBHOOKS
# --------------------------------------------------------------------------------
def get_job(db: Session, job_id: str):
# return db.query(Job).filter(Job.id == job_id).first()
return db.query(Job).filter(Job.id == job_id).first()
def get_jobs(db: Session, user_id: str | None = None, skip: int = 0, limit: int = 100):
query = db.query(Job)
if user_id:
query = query.filter(Job.user_id == user_id)
return query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
def create_job(db: Session, job: JobCreate):
db_job = Job(**job.model_dump())
db.add(db_job)
db.commit()
db.refresh(db_job)
# Broadcast the new job to UI clients via Huey task
job_schema = JobSchema.model_validate(db_job)
return db_job
def update_job_status(db: Session, job_id: str, status: str, progress: int = None, error: str = None):
db_job = get_job(db, job_id)
if db_job:
db_job.status = status
if progress is not None:
db_job.progress = progress
if error:
db_job.error_message = error
db.commit()
db.refresh(db_job)
# Broadcast the updated job to UI clients via Huey task
job_schema = JobSchema.model_validate(db_job)
return db_job
def mark_job_as_completed(db: Session, job_id: str, output_filepath_str: str | None = None, preview: str | None = None):
db_job = get_job(db, job_id)
if db_job and db_job.status != 'cancelled':
db_job.status = "completed"
db_job.progress = 100
if preview:
db_job.result_preview = preview.strip()[:2000]
if output_filepath_str:
try:
output_path = Path(output_filepath_str)
if output_path.exists():
db_job.output_filesize = output_path.stat().st_size
except Exception:
logger.exception(f"Could not stat output file {output_filepath_str} for job {job_id}")
db.commit()
db.refresh(db_job)
# Broadcast the final job state to UI clients via Huey task
job_schema = JobSchema.model_validate(db_job)
return db_job
def send_webhook_notification(job_id: str, app_config: Dict[str, Any], base_url: str):
"""Sends a notification to the callback URL if one is configured for the job."""
webhook_config = app_config.get("webhook_settings", {})
if not webhook_config.get("enabled", False):
return
db = SessionLocal()
try:
job = get_job(db, job_id)
if not job or not job.callback_url:
return
download_url = None
if job.status == "completed" and job.processed_filepath:
filename = Path(job.processed_filepath).name
public_url = app_config.get("app_settings", {}).get("app_public_url", base_url)
if not public_url:
logger.warning(f"app_public_url is not set. Cannot generate a full download URL for job {job_id}.")
download_url = f"/download/{filename}" # Relative URL as fallback
else:
download_url = urljoin(public_url, f"/download/{filename}")
payload = {
"job_id": job.id,
"status": job.status,
"original_filename": job.original_filename,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + "Z",
"updated_at": job.updated_at.isoformat() + "Z",
}
headers = {"Content-Type": "application/json", "User-Agent": "FileProcessor-Webhook/1.0"}
token = webhook_config.get("callback_bearer_token")
if token:
headers["Authorization"] = f"Bearer {token}"
try:
with httpx.Client() as client:
response = client.post(job.callback_url, json=payload, headers=headers, timeout=15)
response.raise_for_status()
logger.info(f"Sent webhook notification for job {job_id} to {job.callback_url} (Status: {response.status_code})")
except httpx.RequestError as e:
logger.error(f"Failed to send webhook for job {job_id} to {job.callback_url}: {e}")
except httpx.HTTPStatusError as e:
logger.error(f"Webhook for job {job_id} received non-2xx response {e.response.status_code} from {job.callback_url}")
except Exception as e:
logger.exception(f"An unexpected error occurred in send_webhook_notification for job {job_id}: {e}")
finally:
db.close()
# --------------------------------------------------------------------------------
# --- 4. BACKGROUND TASK SETUP
# --------------------------------------------------------------------------------
huey = SqliteHuey(filename=PATHS.HUEY_DB_PATH)
WHISPER_MODELS_CACHE: Dict[str, WhisperModel] = {}
PIPER_VOICES_CACHE: Dict[str, "PiperVoice"] = {}
AVAILABLE_TTS_VOICES_CACHE: Dict[str, Any] | None = None
WHISPER_MODELS_LAST_USED: Dict[str, float] = {}
# --- Cache Eviction Settings ---
_cache_cleanup_thread: Optional[threading.Thread] = None
_cache_lock = threading.Lock() # Global lock for modifying cache dictionaries
_model_locks: Dict[str, threading.Lock] = {}
_global_lock = threading.Lock() # Lock for initializing model-specific locks
def _whisper_cache_cleanup_worker():
"""
Periodically checks for and unloads Whisper models that have been inactive.
The timeout and check interval are configured in the application settings.
"""
while True:
# Read settings within the loop to allow for live changes
app_settings = APP_CONFIG.get("app_settings", {})
check_interval = app_settings.get("cache_check_interval", 300)
inactivity_timeout = app_settings.get("model_inactivity_timeout", 1800)
time.sleep(check_interval)
with _cache_lock:
# Create a copy of items to avoid issues with modifying dict while iterating
expired_models = []
for model_size, last_used in WHISPER_MODELS_LAST_USED.items():
if time.time() - last_used > inactivity_timeout:
expired_models.append(model_size)
if not expired_models:
continue
logger.info(f"Found {len(expired_models)} inactive Whisper models to unload: {expired_models}")
for model_size in expired_models:
# Acquire the specific model lock before removing to prevent race conditions
model_lock = _get_or_create_model_lock(model_size)
with model_lock:
# Check if the model is still in the cache (it should be)
if model_size in WHISPER_MODELS_CACHE:
logger.info(f"Unloading inactive Whisper model: {model_size}")
# Remove from caches
model_to_unload = WHISPER_MODELS_CACHE.pop(model_size, None)
WHISPER_MODELS_LAST_USED.pop(model_size, None)
# Explicitly delete the object to encourage garbage collection
if model_to_unload:
del model_to_unload
# Explicitly run garbage collection outside the main lock
import gc
gc.collect()
def get_whisper_model(model_size: str, whisper_settings: dict) -> Any:
# Fast path: check cache. If hit, update timestamp and return.
with _cache_lock:
if model_size in WHISPER_MODELS_CACHE:
logger.debug(f"Cache hit for model '{model_size}'")
WHISPER_MODELS_LAST_USED[model_size] = time.time()
return WHISPER_MODELS_CACHE[model_size]
# Model not in cache, prepare for loading.
model_lock = _get_or_create_model_lock(model_size)
with model_lock:
# Re-check cache inside lock in case another thread loaded it
with _cache_lock:
if model_size in WHISPER_MODELS_CACHE:
WHISPER_MODELS_LAST_USED[model_size] = time.time()
return WHISPER_MODELS_CACHE[model_size]
logger.info(f"Loading Whisper model '{model_size}'...")
try:
device = whisper_settings.get("device", "cpu")
compute_type = whisper_settings.get("compute_type", "int8")
device_index = whisper_settings.get("device_index", 0)
model = WhisperModel(
model_size,
device=device,
device_index=device_index,
compute_type=compute_type,
cpu_threads=max(1, os.cpu_count() // 2),
num_workers=1
)
# Add the new model to the cache under lock
with _cache_lock:
WHISPER_MODELS_CACHE[model_size] = model
WHISPER_MODELS_LAST_USED[model_size] = time.time()
logger.info(f"Model '{model_size}' loaded (device={device}, compute={compute_type})")
return model
except Exception as e:
logger.error(f"Model '{model_size}' failed to load: {str(e)}", exc_info=True)
raise RuntimeError(f"Whisper model initialization failed: {e}") from e
def _get_or_create_model_lock(model_size: str) -> threading.Lock:
"""Thread-safe lock acquisition with minimal global contention"""
# Fast path: lock already exists
if model_size in _model_locks:
return _model_locks[model_size]
# Slow path: create lock under global lock
with _global_lock:
return _model_locks.setdefault(model_size, threading.Lock())
def get_piper_voice(model_name: str, tts_settings: dict | None) -> "PiperVoice":
"""
Load (or download + load) a Piper voice in a robust way:
- Try Python API helpers (get_voices, ensure_voice_exists/find_voice, download_voice)
- On any failure, try CLI fallback (download_voice_cli)
- Attempt to locate model files after download (search subdirs)
- Try re-importing piper if bindings were previously unavailable
"""
# ----- Defensive normalization -----
if tts_settings is None or not isinstance(tts_settings, dict):
logger.debug("get_piper_voice: normalizing tts_settings (was %r)", tts_settings)
tts_settings = {}
model_dir_val = tts_settings.get("model_dir", None)
if model_dir_val is None:
model_dir = Path(str(PATHS.TTS_MODELS_DIR))
else:
try:
model_dir = Path(model_dir_val)
except Exception:
logger.warning("Could not coerce tts_settings['model_dir']=%r to Path; using default.", model_dir_val)
model_dir = Path(str(PATHS.TTS_MODELS_DIR))
model_dir.mkdir(parents=True, exist_ok=True)
# If PiperVoice already cached, reuse
if model_name in PIPER_VOICES_CACHE:
logger.info("Reusing cached Piper voice '%s'.", model_name)
return PIPER_VOICES_CACHE[model_name]
with get_model_semaphore():
if model_name in PIPER_VOICES_CACHE:
return PIPER_VOICES_CACHE[model_name]
# If Python bindings are missing, attempt CLI download first (and try re-import)
if PiperVoice is None:
logger.info("Piper Python bindings missing; attempting CLI download fallback for '%s' before failing import.", model_name)
cli_ok = False
try:
cli_ok = download_voice_cli(model_name, model_dir)
except Exception as e:
logger.warning("CLI download attempt raised: %s", e)
cli_ok = False
if cli_ok:
# attempt to re-import piper package (maybe import issue was transient)
try:
importlib.invalidate_caches()
piper_mod = importlib.import_module("piper")
from piper import PiperVoice as _PiperVoice # noqa: F401
from piper.synthesis import SynthesisConfig as _SynthesisConfig # noqa: F401
globals().update({"PiperVoice": _PiperVoice, "SynthesisConfig": _SynthesisConfig})
logger.info("Successfully re-imported piper after CLI download.")
except Exception:
logger.warning("Could not import piper after CLI download; bindings still unavailable.")
# If bindings still absent, we cannot load models; raise helpful error
if PiperVoice is None:
raise RuntimeError(
"Piper Python bindings are not installed or failed to import. "
"Tried CLI download fallback but python bindings are still unavailable. "
"Please install 'piper-tts' in the runtime used by this process."
)
# Now we have Piper bindings (or they were present to begin with). Attempt Python helpers.
onnx_path = None
config_path = None
# Prefer using get_voices to update the index if available
voices_info = None
try:
if get_voices:
try:
voices_info = get_voices(str(model_dir), update_voices=True)
except TypeError:
# some versions may not support update_voices kwarg
voices_info = get_voices(str(model_dir))
except Exception as e:
logger.debug("get_voices failed or unavailable: %s", e)
voices_info = None
try:
# Preferred modern helpers
if ensure_voice_exists and find_voice:
try:
ensure_voice_exists(model_name, [model_dir], model_dir, voices_info)
onnx_path, config_path = find_voice(model_name, [model_dir])
except Exception as e:
# Could be VoiceNotFoundError or other download error
logger.warning("ensure/find voice failed for %s: %s", model_name, e)
raise
elif download_voice:
# older API: call download helper directly
try:
download_voice(model_name, model_dir)
# attempt to locate files
onnx_path = model_dir / f"{model_name}.onnx"
config_path = model_dir / f"{model_name}.onnx.json"
except Exception:
logger.warning("download_voice failed for %s", model_name)
raise
else:
# No python download helper available
raise RuntimeError("No Python download helper available in installed piper package.")
except Exception as py_exc:
# Python helper route failed; try CLI fallback BEFORE giving up
logger.info("Python download route failed for '%s' (%s). Trying CLI fallback...", model_name, py_exc)
try:
cli_ok = download_voice_cli(model_name, model_dir)
except Exception as e:
logger.warning("CLI fallback attempt raised: %s", e)
cli_ok = False
if not cli_ok:
# If CLI also failed, re-raise the original python exception to preserve context
logger.error("Both Python download helpers and CLI fallback failed for '%s'.", model_name)
raise
# CLI succeeded (or at least returned success) — try to find files on disk
onnx_path, config_path = _find_model_files(model_name, model_dir)
if not (onnx_path and config_path):
# maybe CLI wrote into a nested dir or different name; try to search broadly
logger.info("Could not find model files after CLI download in %s; attempting broader search...", model_dir)
onnx_path, config_path = _find_model_files(model_name, model_dir)
if not (onnx_path and config_path):
logger.error("Model files still missing after CLI fallback for '%s'.", model_name)
raise RuntimeError(f"Piper voice files for '{model_name}' missing after CLI fallback.")
# continue to loading below
# Final safety check and last-resort search
if not (onnx_path and config_path):
onnx_path, config_path = _find_model_files(model_name, model_dir)
if not (onnx_path and config_path):
raise RuntimeError(f"Piper voice files for '{model_name}' are missing after attempts to download.")
# Load the PiperVoice
try:
use_cuda = bool(tts_settings.get("use_cuda", False))
voice = PiperVoice.load(str(onnx_path), config_path=str(config_path), use_cuda=use_cuda)
PIPER_VOICES_CACHE[model_name] = voice
logger.info("Loaded Piper voice '%s' from %s", model_name, onnx_path)
return voice
except Exception as e:
logger.exception("Failed to load Piper voice '%s' from files (%s, %s): %s", model_name, onnx_path, config_path, e)
raise
def _find_model_files(model_name: str, model_dir: Path):
"""
Try multiple strategies to find onnx and config files for a given model_name under model_dir.
Returns (onnx_path, config_path) or (None, None).
"""
# direct files in model_dir
onnx = model_dir / f"{model_name}.onnx"
cfg = model_dir / f"{model_name}.onnx.json"
if onnx.exists() and cfg.exists():
return onnx, cfg
# possible alternative names or nested directories: search recursively
matches_onnx = list(model_dir.rglob(f"{model_name}*.onnx"))
matches_cfg = list(model_dir.rglob(f"{model_name}*.onnx.json"))
if matches_onnx and matches_cfg:
# prefer same directory match
for o in matches_onnx:
for c in matches_cfg:
if o.parent == c.parent:
return o, c
# otherwise return first matches
return matches_onnx[0], matches_cfg[0]
# last-resort: any onnx + any json in same subdir that contain model name token
for o in model_dir.rglob("*.onnx"):
if model_name in o.name:
# try find any matching json in same dir
cands = list(o.parent.glob("*.onnx.json"))
if cands:
return o, cands[0]
return None, None
# ---------------------------
# CLI: list available voices
# ---------------------------
def list_voices_cli(timeout: int = 30, python_executables: Optional[List[str]] = None) -> List[str]:
"""
Run `python -m piper.download_voices` (no args) and parse output into a list of voice IDs.
Returns [] on failure.
"""
if python_executables is None:
python_executables = [sys.executable, "python3", "python"]
# Regex: voice ids look like en_US-lessac-medium (letters/digits/._-)
voice_regex = re.compile(r'^([A-Za-z0-9_\-\.]+)')
for py in python_executables:
cmd = [py, "-m", "piper.download_voices"]
try:
logger.debug("Trying Piper CLI list: %s", shlex.join(cmd))
cp = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
timeout=timeout,
)
out = cp.stdout.strip()
# If stdout empty, sometimes the script writes to stderr
if not out:
out = cp.stderr.strip()
if not out:
logger.debug("Piper CLI listed nothing (empty output) for %s", py)
continue
voices = []
for line in out.splitlines():
line = line.strip()
if not line:
continue
# Try to extract first token that matches voice id pattern
m = voice_regex.match(line)
if m:
v = m.group(1)
# basic sanity: avoid capturing words like 'Available' or headings
if re.search(r'\d', v) or '-' in v or '_' in v or '.' in v:
voices.append(v)
else:
# allow alphabetic tokens too (defensive)
voices.append(v)
else:
# Also handle lines like " - en_US-lessac-medium: description"
parts = re.split(r'[:\s]+', line)
if parts:
candidate = parts[0].lstrip('-').strip()
if candidate:
voices.append(candidate)
# Dedupe while preserving order
seen = set()
dedup = []
for v in voices:
if v not in seen:
seen.add(v)
dedup.append(v)
logger.info("Piper CLI list returned %d voices via %s", len(dedup), py)
return dedup
except subprocess.CalledProcessError as e:
logger.debug("Piper CLI list (%s) non-zero exit. stdout=%s stderr=%s", py, e.stdout, e.stderr)
except FileNotFoundError:
logger.debug("Python executable not found: %s", py)
except subprocess.TimeoutExpired:
logger.warning("Piper CLI list timed out for %s", py)
except Exception as e:
logger.exception("Unexpected error running Piper CLI list with %s: %s", py, e)
logger.error("All Piper CLI list attempts failed.")
return []
# ---------------------------
# CLI: download a voice
# ---------------------------
def download_voice_cli(model_name: str, model_dir: Path, python_executables: Optional[List[str]] = None, timeout: int = 300) -> bool:
"""
Try to download a Piper voice using CLI:
python -m piper.download_voices <model_name> --data-dir <model_dir>
Returns True if the CLI ran and expected files exist afterwards (best effort).
"""
if python_executables is None:
python_executables = [sys.executable, "python3", "python"]
model_dir = Path(model_dir)
model_dir.mkdir(parents=True, exist_ok=True)
for py in python_executables:
cmd = [py, "-m", "piper.download_voices", model_name, "--data-dir", str(model_dir)]
try:
logger.info("Trying Piper CLI download: %s", shlex.join(cmd))
cp = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
timeout=timeout,
)
logger.debug("Piper CLI download stdout: %s", cp.stdout)
logger.debug("Piper CLI download stderr: %s", cp.stderr)
# Heuristic success check
onnx = model_dir / f"{model_name}.onnx"
cfg = model_dir / f"{model_name}.onnx.json"
if onnx.exists() and cfg.exists():
logger.info("Piper CLI created expected files for %s", model_name)
return True
# Some versions might create nested dirs; treat non-error CLI execution as success (caller will re-check)
return True
except subprocess.CalledProcessError as e:
logger.warning("Piper CLI (%s) returned non-zero exit. stdout: %s; stderr: %s", py, e.stdout, e.stderr)
except FileNotFoundError:
logger.debug("Python executable %s not found.", py)
except subprocess.TimeoutExpired:
logger.warning("Piper CLI call timed out for python %s", py)
except Exception as e:
logger.exception("Unexpected error running Piper CLI download with %s: %s", py, e)
logger.error("All Piper CLI attempts failed for model %s", model_name)
return False
# ---------------------------
# Safe get_voices wrapper
# ---------------------------
def safe_get_voices(model_dir: Path) -> List[Dict]:
"""
Try to call the in-Python get_voices(..., update_voices=True) and return a list of dicts.
If that fails, fall back to list_voices_cli() and return a list of simple dicts:
[{"id": "en_US-lessac-medium", "name": "...", "local": False}, ...]
Keeps the shape flexible so your existing endpoint can use it with minimal changes.
"""
# Prefer Python API if available
try:
if get_voices: # get_voices imported earlier in your file
# Ensure up-to-date index (like CLI)
raw = get_voices(str(model_dir), update_voices=True)
# get_voices may already return the desired structure; normalise to a list of dicts
if isinstance(raw, dict):
# some versions return mapping id->meta
items = []
for vid, meta in raw.items():
d = {"id": vid}
if isinstance(meta, dict):
d.update(meta)
items.append(d)
return items
elif isinstance(raw, list):
return raw
else:
# unknown format -> fall back to CLI
logger.debug("get_voices returned unexpected type; falling back to CLI list.")
except Exception as e:
logger.warning("In-Python get_voices failed: %s. Falling back to CLI listing.", e)
# CLI fallback: parse voice ids and create simple dicts
cli_list = list_voices_cli()
results = [{"id": vid, "name": vid, "local": False} for vid in cli_list]
return results
def list_kokoro_voices_cli(timeout: int = 60) -> List[str]:
"""
Run `kokoro-tts --help-voices` and parse the output for available models.
Returns [] on failure.
"""
model_path = PATHS.KOKORO_MODEL_FILE
voices_path = PATHS.KOKORO_VOICES_FILE
if not (model_path.exists() and voices_path.exists()):
logger.warning("Cannot list Kokoro TTS voices because model/voices files are missing.")
return []
cmd = ["kokoro-tts", "--help-voices", "--model", str(model_path), "--voices", str(voices_path)]
try:
logger.debug("Trying Kokoro TTS CLI list: %s", shlex.join(cmd))
cp = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
timeout=timeout,
)
out = cp.stdout.strip()
if not out:
out = cp.stderr.strip()
if not out:
logger.warning("Kokoro TTS CLI list returned no output.")
return []
voices = []
voice_pattern = re.compile(r'^\s*\d+\.\s+([a-z]{2,3}_[a-zA-Z0-9]+)$')
for line in out.splitlines():
line = line.strip()
match = voice_pattern.match(line)
if match:
voices.append(match.group(1))
logger.info("Kokoro TTS CLI list returned %d voices.", len(voices))
return sorted(list(set(voices)))
except FileNotFoundError:
logger.info("Kokoro TTS ('kokoro-tts' command) not found in PATH. Kokoro TTS support disabled.")
return []
except subprocess.CalledProcessError as e:
logger.error("Kokoro TTS CLI list command failed. stderr: %s", e.stderr[:1000])
return []
except subprocess.TimeoutExpired:
logger.warning("Kokoro TTS CLI list command timed out.")
return []
except Exception as e:
logger.exception("Unexpected error running Kokoro TTS CLI list: %s", e)
return []
def list_kokoro_languages_cli(timeout: int = 60) -> List[str]:
"""
Run `kokoro-tts --help-languages` and parse the output for available languages.
Returns [] on failure.
"""
model_path = PATHS.KOKORO_MODEL_FILE
voices_path = PATHS.KOKORO_VOICES_FILE
if not (model_path.exists() and voices_path.exists()):
logger.warning("Cannot list Kokoro TTS languages because model/voices files are missing.")
return []
cmd = ["kokoro-tts", "--help-languages", "--model", str(model_path), "--voices", str(voices_path)]
try:
logger.debug("Trying Kokoro TTS language list: %s", shlex.join(cmd))
cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout)
out = cp.stdout.strip()
if not out:
out = cp.stderr.strip()
if not out:
logger.warning("Kokoro TTS language list returned no output.")
return []
languages = []
lang_pattern = re.compile(r'^\s*([a-z]{2,3}(?:-[a-z]{2,3})?)$')
for line in out.splitlines():
line = line.strip()
if line.lower().startswith("supported languages"):
continue
match = lang_pattern.match(line)
if match:
languages.append(match.group(1))
logger.info("Kokoro TTS language list returned %d languages.", len(languages))
return sorted(list(set(languages)))
except FileNotFoundError:
logger.info("Kokoro TTS ('kokoro-tts' command) not found in PATH. Kokoro TTS support disabled.")
return []
except subprocess.CalledProcessError as e:
logger.error("Kokoro TTS language list command failed. stderr: %s", e.stderr[:1000])
return []
except subprocess.TimeoutExpired:
logger.warning("Kokoro TTS language list command timed out.")
return []
except Exception as e:
logger.exception("Unexpected error running Kokoro TTS language list: %s", e)
return []
def run_command(
argv: List[str],
timeout: int = 300
) -> subprocess.CompletedProcess:
"""
Executes a command, captures its output, and handles timeouts and errors.
Uses resource limits for child processes. This is a simplified, more robust
implementation using subprocess.run.
"""
logger.debug("Executing command: %s with timeout=%ss", " ".join(shlex.quote(s) for s in argv), timeout)
preexec = globals().get("_limit_resources_preexec", None)
try:
# subprocess.run handles timeout, output capturing, and error checking.
result = subprocess.run(
argv,
capture_output=True,
text=True,
timeout=timeout,
check=True, # Raises CalledProcessError on non-zero exit
preexec_fn=preexec
)
logger.debug("Command completed successfully: %s", " ".join(shlex.quote(s) for s in argv))
return result
except FileNotFoundError:
msg = f"Command not found: {argv[0]}"
logger.error(msg)
raise Exception(msg) from None
except subprocess.TimeoutExpired as e:
msg = f"Command timed out after {timeout}s: {' '.join(shlex.quote(s) for s in argv)}"
logger.error(msg)
raise Exception(msg) from e
except subprocess.CalledProcessError as e:
snippet = (e.stderr or "")[:1000]
msg = f"Command failed with exit code {e.returncode}. Stderr: {snippet}"
logger.error(msg)
raise Exception(msg) from e
except Exception as e:
msg = f"Unexpected error launching command: {e}"
logger.exception(msg)
raise Exception(msg) from e
def validate_and_build_command(template_str: str, mapping: Dict[str, str]) -> TypingList[str]:
fmt = Formatter()
used = {fname for _, fname, _, _ in fmt.parse(template_str) if fname}
ALLOWED_VARS = {
"input", "output", "output_dir", "output_ext", "quality", "speed", "preset",
"device", "dpi", "samplerate", "bitdepth", "filter", "model_name",
"model_path", "voices_path", "lang"
}
bad = used - ALLOWED_VARS
if bad:
raise ValueError(f"Command template contains disallowed placeholders: {bad}")
safe_mapping = dict(mapping)
for name in used:
if name not in safe_mapping:
safe_mapping[name] = safe_mapping.get("output_ext", "") if name == "filter" else ""
# Securely build the command by splitting the template BEFORE formatting.
# This prevents argument injection if a value in the mapping (e.g. a filename)
# contains spaces or other shell-special characters.
command_parts = shlex.split(template_str)
formatted_command = [part.format(**safe_mapping) for part in command_parts]
# Filter out any empty strings that result from empty optional placeholders
return [part for part in formatted_command if part]
# --- TASK RUNNERS ---
@huey.task()
def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str, whisper_settings: dict, app_config: dict, base_url: str):
db = SessionLocal()
input_path = Path(input_path_str)
output_path = Path(output_path_str)
# --- Constants ---
# Time in seconds between database checks for progress updates and cancellations.
# This avoids hammering the database on every single segment.
DB_POLL_INTERVAL_SECONDS = 5
try:
job = get_job(db, job_id)
if not job:
logger.warning(f"Job {job_id} not found. Aborting task.")
return
if job.status == 'cancelled':
logger.info(f"Job {job_id} was already cancelled before starting. Aborting.")
return
update_job_status(db, job_id, "processing", progress=0)
model = get_whisper_model(model_size, whisper_settings)
logger.info(f"Starting transcription for job {job_id} with model '{model_size}'")
segments_generator, info = model.transcribe(str(input_path), beam_size=5)
logger.info(f"Detected language: {info.language} with probability {info.language_probability:.2f} for a duration of {info.duration:.2f}s")
last_update_time = time.time()
# Use a temporary file to ensure atomic writes. The final file will only appear
# once the transcription is fully and successfully written.
tmp_output_path = output_path.with_name(f"{output_path.stem}.tmp-{uuid.uuid4().hex}{output_path.suffix}")
# Store a small preview in memory for the final update without holding the whole transcript.
preview_segments = []
PREVIEW_MAX_LENGTH = 1000 # characters
current_preview_length = 0
with tmp_output_path.open("w", encoding="utf-8") as f:
for segment in segments_generator:
segment_text = segment.text.strip()
f.write(segment_text + "\n")
# Build a small preview
if current_preview_length < PREVIEW_MAX_LENGTH:
preview_segments.append(segment_text)
current_preview_length += len(segment_text)
current_time = time.time()
if current_time - last_update_time > DB_POLL_INTERVAL_SECONDS:
last_update_time = current_time
# Check for cancellation without overwhelming the DB
job_check = get_job(db, job_id)
if job_check and job_check.status == 'cancelled':
logger.info(f"Job {job_id} cancelled during transcription. Stopping.")
# The temporary file will be cleaned up in the finally block
return
# Update progress
if info.duration > 0:
progress = int((segment.end / info.duration) * 100)
update_job_status(db, job_id, "processing", progress=progress)
tmp_output_path.replace(output_path)
transcript_preview = " ".join(preview_segments)
if len(transcript_preview) > PREVIEW_MAX_LENGTH:
transcript_preview = transcript_preview[:PREVIEW_MAX_LENGTH] + "..."
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=transcript_preview)
logger.info(f"Transcription for job {job_id} completed successfully.")
except Exception as e:
logger.exception(f"An unexpected error occurred during transcription for job {job_id}")
update_job_status(db, job_id, "failed", error=str(e))
finally:
# This block executes whether the task succeeded, failed, or was cancelled and returned.
logger.debug(f"Performing cleanup for job {job_id}")
# Clean up the temporary file if it still exists (e.g., due to cancellation)
if 'tmp_output_path' in locals() and tmp_output_path.exists():
try:
tmp_output_path.unlink()
logger.debug(f"Removed temporary file: {tmp_output_path}")
except OSError as e:
logger.error(f"Error removing temporary file {tmp_output_path}: {e}")
# Clean up the original input file
try:
# First, ensure we are not deleting from an unexpected directory
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
input_path.unlink(missing_ok=True)
logger.debug(f"Removed input file: {input_path}")
except Exception as e:
logger.exception(f"Failed to cleanup input file {input_path} for job {job_id}: {e}")
if db:
db.close()
# If this was a sub-job, trigger a progress update for its parent.
db_for_check = SessionLocal()
try:
job = get_job(db_for_check, job_id)
if job and job.parent_job_id:
_update_parent_zip_job_progress(job.parent_job_id)
finally:
db_for_check.close()
# Send notification last, after all state has been finalized.
send_webhook_notification(job_id, app_config, base_url)
@huey.task()
def run_tts_task(job_id: str, input_path_str: str, output_path_str: str, model_name: str, tts_settings: dict, app_config: dict, base_url: str):
db = SessionLocal()
input_path = Path(input_path_str)
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled':
return
update_job_status(db, job_id, "processing")
engine, actual_model_name = "piper", model_name
if '/' in model_name:
parts = model_name.split('/', 1)
engine = parts[0]
actual_model_name = parts[1]
logger.info(f"Starting TTS for job {job_id} using engine '{engine}' with model '{actual_model_name}'")
out_path = Path(output_path_str)
tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix}")
if engine == "piper":
piper_settings = tts_settings.get("piper", {})
voice = get_piper_voice(actual_model_name, piper_settings)
with open(input_path, 'r', encoding='utf-8') as f:
text_to_speak = f.read()
synthesis_params = piper_settings.get("synthesis_config", {})
synthesis_config = SynthesisConfig(**synthesis_params) if SynthesisConfig else None
with wave.open(str(tmp_out), "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(voice.config.sample_rate)
voice.synthesize_wav(text_to_speak, wav_file, synthesis_config)
elif engine == "kokoro":
kokoro_settings = tts_settings.get("kokoro", {})
command_template_str = kokoro_settings.get("command_template")
if not command_template_str:
raise ValueError("Kokoro TTS command_template is not defined in settings.")
try:
lang, voice_name = actual_model_name.split('/', 1)
except ValueError:
raise ValueError(f"Invalid Kokoro model format. Expected 'lang/voice', but got '{actual_model_name}'.")
mapping = {
"input": str(input_path),
"output": str(tmp_out),
"lang": lang,
"model_name": voice_name,
"model_path": str(PATHS.KOKORO_MODEL_FILE),
"voices_path": str(PATHS.KOKORO_VOICES_FILE),
}
command = validate_and_build_command(command_template_str, mapping)
logger.info(f"Executing Kokoro TTS command: {' '.join(command)}")
run_command(command, timeout=kokoro_settings.get("timeout", 300))
if not tmp_out.exists():
raise FileNotFoundError("Kokoro TTS command did not produce an output file.")
else:
raise ValueError(f"Unsupported TTS engine: {engine}")
tmp_out.replace(out_path)
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview="Successfully generated audio.")
logger.info(f"TTS for job {job_id} completed.")
except Exception as e:
logger.exception(f"ERROR during TTS for job {job_id}")
update_job_status(db, job_id, "failed", error=f"TTS failed: {e}")
finally:
try:
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
input_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup input file after TTS.")
db.close()
# If this was a sub-job, trigger a progress update for its parent.
db_for_check = SessionLocal()
try:
job = get_job(db_for_check, job_id)
if job and job.parent_job_id:
_update_parent_zip_job_progress(job.parent_job_id)
finally:
db_for_check.close()
send_webhook_notification(job_id, app_config, base_url)
@huey.task()
def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str, ocr_settings: dict, app_config: dict, base_url: str):
db = SessionLocal()
input_path = Path(input_path_str)
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled':
return
# Re-check for cancellation right before starting the heavy work
if get_job(db, job_id).status == 'cancelled':
logger.info(f"PDF OCR job {job_id} cancelled before starting. Aborting.")
return
update_job_status(db, job_id, "processing")
logger.info(f"Starting PDF OCR for job {job_id}")
ocrmypdf.ocr(str(input_path), str(output_path_str),
deskew=ocr_settings.get('deskew', True),
force_ocr=ocr_settings.get('force_ocr', True),
clean=ocr_settings.get('clean', True),
optimize=ocr_settings.get('optimize', 1),
progress_bar=False)
with open(output_path_str, "rb") as f:
reader = pypdf.PdfReader(f)
preview = "\n".join(page.extract_text() or "" for page in reader.pages)
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=preview)
logger.info(f"PDF OCR for job {job_id} completed.")
except Exception as e:
logger.exception(f"ERROR during PDF OCR for job {job_id}")
update_job_status(db, job_id, "failed", error=f"PDF OCR failed: {e}")
finally:
try:
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR])
input_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup input file after PDF OCR.")
db.close()
# If this was a sub-job, trigger a progress update for its parent.
db_for_check = SessionLocal()
try:
job = get_job(db_for_check, job_id)
if job and job.parent_job_id:
_update_parent_zip_job_progress(job.parent_job_id)
finally:
db_for_check.close()
send_webhook_notification(job_id, app_config, base_url)
@huey.task()
def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str, app_config: dict, base_url: str):
db = SessionLocal()
input_path = Path(input_path_str)
out_path = Path(output_path_str)
try:
job = get_job(db, job_id)
if not job or job.status == "cancelled":
return
# Re-check for cancellation right before starting the heavy work
if get_job(db, job_id).status == 'cancelled':
logger.info(f"Image OCR job {job_id} cancelled before starting. Aborting.")
return
update_job_status(db, job_id, "processing", progress=10)
logger.info(f"Starting Image OCR for job {job_id} - {input_path}")
# open image and gather frames (support multi-frame TIFF)
try:
pil_img = Image.open(str(input_path))
except UnidentifiedImageError as e:
raise RuntimeError(f"Cannot identify/open input image: {e}")
frames = []
try:
# some images support n_frames (multi-page TIFF); iterate safely
n_frames = getattr(pil_img, "n_frames", 1)
for i in range(n_frames):
pil_img.seek(i)
# copy the frame to avoid problems when the original image object is closed
frames.append(pil_img.convert("RGB").copy())
except Exception:
# fallback: single frame
frames = [pil_img.convert("RGB")]
update_job_status(db, job_id, "processing", progress=30)
pdf_bytes_list = []
text_parts = []
for idx, frame in enumerate(frames):
# produce searchable PDF bytes for the frame and plain text as well
try:
pdf_bytes = pytesseract.image_to_pdf_or_hocr(frame, extension="pdf")
except TesseractNotFoundError as e:
raise RuntimeError("Tesseract not found. Ensure Tesseract OCR is installed and in PATH.") from e
except Exception as e:
raise RuntimeError(f"Failed to run Tesseract on frame {idx}: {e}") from e
pdf_bytes_list.append(pdf_bytes)
# also extract plain text for preview and possible fallback
try:
page_text = pytesseract.image_to_string(frame)
except Exception:
page_text = ""
text_parts.append(page_text)
# update progress incrementally
prog = 30 + int((idx + 1) / max(1, len(frames)) * 50)
update_job_status(db, job_id, "processing", progress=min(prog, 80))
# merge per-page pdfs if multiple frames
final_pdf_bytes = None
if len(pdf_bytes_list) == 1:
final_pdf_bytes = pdf_bytes_list[0]
else:
if _HAS_PYPDF2:
merger = PdfMerger()
for b in pdf_bytes_list:
merger.append(io.BytesIO(b))
out_buffer = io.BytesIO()
merger.write(out_buffer)
merger.close()
final_pdf_bytes = out_buffer.getvalue()
else:
# PyPDF2 not installed — try a simple concatenation (not valid PDF merge),
# better to fail loudly so user can install PyPDF2; but as a fallback
# write the first page only and include a warning in job preview.
logger.warning("PyPDF2 not available; only the first frame will be written to output PDF.")
final_pdf_bytes = pdf_bytes_list[0]
text_parts.insert(0, "[WARNING] Multiple frames detected but PyPDF2 not available; only first page saved.\n")
# write out atomically
tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix or '.pdf'}")
try:
tmp_out.parent.mkdir(parents=True, exist_ok=True)
with tmp_out.open("wb") as f:
f.write(final_pdf_bytes)
tmp_out.replace(out_path)
except Exception as e:
raise RuntimeError(f"Failed writing output PDF to {out_path}: {e}") from e
# create a preview from the recognized text (limit length)
full_text = "\n\n".join(text_parts).strip()
preview = full_text[:1000] + ("" if len(full_text) > 1000 else "")
mark_job_as_completed(db, job_id, output_filepath_str=str(out_path), preview=preview)
update_job_status(db, job_id, "completed", progress=100)
logger.info(f"Image OCR for job {job_id} completed. Output: {out_path}")
except TesseractNotFoundError:
logger.exception(f"Tesseract not found for job {job_id}")
update_job_status(db, job_id, "failed", error="Image OCR failed: Tesseract not found on server.")
except Exception as e:
logger.exception(f"ERROR during Image OCR for job {job_id}: {e}")
update_job_status(db, job_id, "failed", error=f"Image OCR failed: {e}")
finally:
# cleanup input file (but only if it lives in allowed uploads dir)
try:
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR])
input_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup input file after Image OCR.")
try:
db.close()
except Exception:
logger.exception("Failed to close DB session after Image OCR.")
# If this was a sub-job, trigger a progress update for its parent.
db_for_check = SessionLocal()
try:
job = get_job(db_for_check, job_id)
if job and job.parent_job_id:
_update_parent_zip_job_progress(job.parent_job_id)
finally:
db_for_check.close()
# send webhook regardless of success/failure (keeps original behavior)
try:
send_webhook_notification(job_id, app_config, base_url)
except Exception:
logger.exception("Failed to send webhook notification after Image OCR.")
@huey.task()
def run_conversion_task(job_id: str,
input_path_str: str,
output_path_str: str,
tool: str,
task_key: str,
conversion_tools_config: dict,
app_config: dict,
base_url: str):
"""
Drop-in replacement for conversion task.
- Uses improved run_command for short operations (resource-limited).
- Uses cancellable Popen runner for long-running conversion to respond to DB cancellations.
"""
db = SessionLocal()
input_path = Path(input_path_str)
output_path = Path(output_path_str)
# localize helpers for speed
_get_job = get_job
_update_job_status = update_job_status
_validate_build = validate_and_build_command
_mark_completed = mark_job_as_completed
_ensure_safe = ensure_path_is_safe
_send_webhook = send_webhook_notification
temp_input_file: Optional[Path] = None
temp_output_file: Optional[Path] = None
POLL_INTERVAL = 1.0
STDERR_SNIPPET = 4000
def _parse_task_key(tool_name: str, tk: str, tool_cfg: dict, mapping: dict):
try:
if tool_name.startswith("ghostscript"):
parts = tk.split("_", 1)
device = parts[0] if parts and parts[0] else ""
setting = parts[1] if len(parts) > 1 else ""
mapping.update({"device": device, "dpi": setting, "preset": setting})
elif tool_name == "pngquant":
parts = tk.split("_", 1)
quality_key = parts[1] if len(parts) > 1 else (parts[0] if parts else "mq")
quality_map = {"hq": "80-95", "mq": "65-80", "fast": "65-80"}
speed_map = {"hq": "1", "mq": "3", "fast": "11"}
mapping.update({"quality": quality_map.get(quality_key, "65-80"),
"speed": speed_map.get(quality_key, "3")})
elif tool_name == "sox":
parts = tk.split("_")
if len(parts) >= 3:
rate_token = parts[-2]
depth_token = parts[-1]
elif len(parts) == 2:
rate_token = parts[-1]
depth_token = ""
else:
rate_token = ""
depth_token = ""
rate_val = rate_token.replace("k", "000") if rate_token else ""
if depth_token:
depth_val = ('-b' + depth_token.replace('b', '')) if 'b' in depth_token else depth_token
else:
depth_val = ''
mapping.update({"samplerate": rate_val, "bitdepth": depth_val})
elif tool_name == "mozjpeg":
parts = tk.split("_", 1)
quality_token = parts[1] if len(parts) > 1 else (parts[0] if parts else "")
quality = quality_token.replace("q", "") if quality_token else ""
mapping.update({"quality": quality})
elif tool_name == "libreoffice":
target_ext = output_path.suffix.lstrip('.')
filter_val = tool_cfg.get("filters", {}).get(target_ext, target_ext)
mapping["filter"] = filter_val
except Exception:
logger.exception("Failed to parse task_key for tool %s; continuing with defaults.", tool_name)
def _run_cancellable_command(command: List[str], timeout: int):
"""
Run command with Popen and poll the DB for cancellation. Enforce timeout.
Returns CompletedProcess-like on success. Raises Exception on failure/timeout/cancel.
"""
preexec = globals().get("_limit_resources_preexec", None)
logger.debug("Launching conversion subprocess: %s", " ".join(shlex.quote(c) for c in command))
proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, preexec_fn=preexec)
start = time.monotonic()
stderr_accum = []
stderr_len = 0
STDERR_LIMIT = STDERR_SNIPPET
try:
while True:
ret = proc.poll()
# Check job status
job_check = _get_job(db, job_id)
if job_check is None:
logger.warning("Job %s disappeared; killing conversion process.", job_id)
try:
proc.kill()
except Exception:
pass
raise Exception("Job disappeared during conversion")
if job_check.status == "cancelled":
logger.info("Job %s cancelled; terminating conversion process.", job_id)
try:
proc.kill()
except Exception:
pass
raise Exception("Conversion cancelled")
if ret is not None:
# process done - read remaining stderr/stdout safely
try:
out, err = proc.communicate(timeout=2)
except Exception:
out, err = "", ""
try:
if proc.stderr:
err = proc.stderr.read(STDERR_LIMIT)
except Exception:
pass
if err and len(err) > STDERR_LIMIT:
err = err[-STDERR_LIMIT:]
if ret != 0:
msg = (err or "")[:STDERR_LIMIT]
raise Exception(f"Conversion command failed (rc={ret}): {msg}")
return subprocess.CompletedProcess(args=command, returncode=ret, stdout=out, stderr=err)
# timeout check
elapsed = time.monotonic() - start
if timeout and elapsed > timeout:
logger.warning("Conversion command timed out after %ss; terminating.", timeout)
try:
proc.kill()
except Exception:
pass
raise Exception("Conversion command timed out")
time.sleep(POLL_INTERVAL)
finally:
try:
if proc.stdout:
proc.stdout.close()
if proc.stderr:
proc.stderr.close()
except Exception:
pass
try:
job = _get_job(db, job_id)
if not job:
logger.warning("Job %s not found; aborting conversion.", job_id)
return
if job.status == "cancelled":
logger.info("Job %s already cancelled; aborting conversion.", job_id)
return
_update_job_status(db, job_id, "processing", progress=25)
logger.info("Starting conversion for job %s using %s with task %s", job_id, tool, task_key)
tool_config = conversion_tools_config.get(tool)
if not tool_config:
raise ValueError(f"Unknown conversion tool: {tool}")
current_input_path = input_path
# Pre-conversion step for mozjpeg uses improved run_command (resource-limited)
if tool == "mozjpeg":
temp_input_file = input_path.with_suffix('.temp.ppm')
logger.info("Pre-converting for MozJPEG: %s -> %s", input_path, temp_input_file)
vips_bin = shutil.which("vips") or "vips"
pre_conv_cmd = [vips_bin, "copy", str(input_path), str(temp_input_file)]
try:
run_command(pre_conv_cmd, timeout=int(tool_config.get("timeout", 300)))
except Exception as ex:
err_msg = str(ex)
short_err = (err_msg or "")[:STDERR_SNIPPET]
logger.exception("MozJPEG pre-conversion failed: %s", short_err)
raise Exception(f"MozJPEG pre-conversion to PPM failed: {short_err}")
current_input_path = temp_input_file
_update_job_status(db, job_id, "processing", progress=50)
# Prepare atomic temp output on same FS
output_path.parent.mkdir(parents=True, exist_ok=True)
temp_output_file = output_path.with_name(f"{output_path.stem}.tmp-{uuid.uuid4().hex}{output_path.suffix}")
mapping = {
"input": str(current_input_path),
"output": str(temp_output_file),
"output_dir": str(output_path.parent),
"output_ext": output_path.suffix.lstrip('.'),
}
_parse_task_key(tool, task_key, tool_config, mapping)
command_template_str = tool_config.get("command_template")
if not command_template_str:
raise ValueError(f"Tool '{tool}' missing 'command_template' in configuration.")
command = _validate_build(command_template_str, mapping)
if not isinstance(command, (list, tuple)) or not command:
raise ValueError("validate_and_build_command must return a non-empty list/tuple command.")
command = [str(x) for x in command]
logger.info("Executing command: %s", " ".join(shlex.quote(c) for c in command))
# Run main conversion in cancellable manner
timeout_val = int(tool_config.get("timeout", 300))
# call the cancellable runner above
result = _run_cancellable_command(command, timeout=timeout_val) if False else None
# the above is replaced with a direct call to the actual function:
result = _run_cancellable_command(command, timeout=timeout_val)
# If successful and temp output exists, move it into place atomically
if temp_output_file and temp_output_file.exists():
temp_output_file.replace(output_path)
_mark_completed(db, job_id, output_filepath_str=str(output_path), preview="Successfully converted file.")
logger.info("Conversion for job %s completed.", job_id)
except Exception as e:
logger.exception("ERROR during conversion for job %s: %s", job_id, e)
try:
_update_job_status(db, job_id, "failed", error=f"Conversion failed: {e}")
except Exception:
logger.exception("Failed to update job status to failed after conversion error.")
finally:
# clean main input
try:
_ensure_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
input_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup main input file after conversion.")
# cleanup temp input
if temp_input_file:
try:
temp_input_file_path = Path(temp_input_file)
_ensure_safe(temp_input_file_path, [PATHS.UPLOADS_DIR, PATHS.PROCESSED_DIR])
temp_input_file_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup temp input file after conversion.")
if temp_output_file:
try:
temp_output_file_path = Path(temp_output_file)
_ensure_safe(temp_output_file_path, [PATHS.UPLOADS_DIR, PATHS.PROCESSED_DIR])
temp_output_file_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup temp output file after conversion.")
try:
db.close()
except Exception:
logger.exception("Failed to close DB session after conversion.")
# If this was a sub-job, trigger a progress update for its parent.
db_for_check = SessionLocal()
try:
job = get_job(db_for_check, job_id)
if job and job.parent_job_id:
_update_parent_zip_job_progress(job.parent_job_id)
finally:
db_for_check.close()
try:
gc.collect()
except Exception:
pass
try:
_send_webhook(job_id, app_config, base_url)
except Exception:
logger.exception("Failed to send webhook notification after conversion.")
def dispatch_single_file_job(original_filename: str, input_filepath: str, task_type: str, user: dict, db: Session, app_config: Dict, base_url: str, job_id: str | None = None, options: Dict = None, parent_job_id: str | None = None):
"""Helper to create and dispatch a job for a single file."""
if options is None:
options = {}
# If no job_id is passed, generate one. This is for sub-tasks from zips.
if job_id is None:
job_id = uuid.uuid4().hex
safe_filename = secure_filename(original_filename)
final_path = Path(input_filepath)
# Ensure the input file exists before creating a job
if not final_path.exists():
logger.error(f"Input file does not exist, cannot dispatch job: {input_filepath}")
return
job_data = JobCreate(
id=job_id, user_id=user['sub'], task_type=task_type,
original_filename=original_filename, input_filepath=str(final_path),
input_filesize=final_path.stat().st_size,
parent_job_id=parent_job_id
)
if task_type == "transcription":
stem = Path(safe_filename).stem
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt"
job_data.processed_filepath = str(processed_path)
create_job(db=db, job=job_data)
run_transcription_task(job_data.id, str(final_path), str(processed_path), options.get("model_size", "base"), app_config.get("transcription_settings", {}).get("whisper", {}), app_config, base_url)
elif task_type == "tts":
tts_config = app_config.get("tts_settings", {})
stem = Path(safe_filename).stem
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.wav"
job_data.processed_filepath = str(processed_path)
create_job(db=db, job=job_data)
run_tts_task(job_data.id, str(final_path), str(processed_path), options.get("model_name"), tts_config, app_config, base_url)
elif task_type == "ocr":
stem, suffix = Path(safe_filename).stem, Path(safe_filename).suffix.lower()
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp', '.webp'}
if suffix not in IMAGE_EXTENSIONS and suffix != '.pdf':
logger.warning(f"Skipping unsupported file type for OCR: {original_filename}")
# Clean up the orphaned file from the zip extraction
final_path.unlink(missing_ok=True)
return
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.pdf"
job_data.processed_filepath = str(processed_path)
create_job(db=db, job=job_data)
if suffix in IMAGE_EXTENSIONS:
run_image_ocr_task(job_data.id, str(final_path), str(processed_path), app_config, base_url)
else:
run_pdf_ocr_task(job_data.id, str(final_path), str(processed_path), app_config.get("ocr_settings", {}).get("ocrmypdf", {}), app_config, base_url)
elif task_type == "conversion":
try:
logger.info(f"Preparing to dispatch conversion job for file '{original_filename}' with requested format '{options.get('output_format')}'")
all_tools = app_config.get("conversion_tools", {}).keys()
logger.info(f"Available conversion tools: {', '.join(all_tools)}")
tool, task_key = _parse_tool_and_task_key(options.get("output_format"), all_tools)
logger.info(f"Dispatching conversion job using tool '{tool}' with task key '{task_key}' for file '{original_filename}'")
except (AttributeError, ValueError):
if parent_job_id:
logger.warning(f"Skipping file '{original_filename}' from batch job '{parent_job_id}' as it is not applicable for the selected conversion format '{options.get('output_format')}'.")
final_path.unlink(missing_ok=True)
return
else:
logger.error(f"Invalid or missing output_format for conversion of {original_filename}")
final_path.unlink(missing_ok=True)
return
original_stem = Path(safe_filename).stem
if tool == 'pandoc_academic':
processed_path = PATHS.PROCESSED_DIR / f"{original_stem}_{job_id}.pdf"
job_data.processed_filepath = str(processed_path)
job_data.task_type = 'academic_pandoc' # Use a more specific task type for the DB
create_job(db=db, job=job_data)
run_academic_pandoc_task(job_data.id, str(final_path), str(processed_path), task_key, APP_CONFIG, base_url)
else:
target_ext = task_key.split('_')[0]
if tool == "ghostscript_pdf": target_ext = "pdf"
processed_path = PATHS.PROCESSED_DIR / f"{original_stem}_{job_id}.{target_ext}"
job_data.processed_filepath = str(processed_path)
create_job(db=db, job=job_data)
run_conversion_task(job_data.id, str(final_path), str(processed_path), tool, task_key, app_config.get("conversion_tools", {}), app_config, base_url)
else:
logger.error(f"Invalid task type '{task_type}' for file {original_filename}")
final_path.unlink(missing_ok=True)
@huey.task()
def run_academic_pandoc_task(job_id: str, input_path_str: str, output_path_str: str, task_key: str, app_config: dict, base_url: str):
"""
Runs a Pandoc conversion for a zipped academic project (e.g., markdown + bibliography).
"""
db = SessionLocal()
input_path = Path(input_path_str)
output_path = Path(output_path_str)
unzip_dir = PATHS.UPLOADS_DIR / f"unzipped_{job_id}"
def find_first_file_with_ext(directory: Path, extensions: List[str]) -> Optional[Path]:
for ext in extensions:
try:
return next(directory.rglob(f"*{ext}"))
except StopIteration:
continue
return None
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled':
return
update_job_status(db, job_id, "processing", progress=10)
logger.info(f"Starting academic Pandoc task for job {job_id}")
# 1. Unzip the project
if not zipfile.is_zipfile(input_path):
raise ValueError("Input is not a valid ZIP archive.")
unzip_dir.mkdir()
with zipfile.ZipFile(input_path, 'r') as zip_ref:
zip_ref.extractall(unzip_dir)
update_job_status(db, job_id, "processing", progress=25)
# 2. Find required files
main_doc = find_first_file_with_ext(unzip_dir, ['.md', '.tex', '.txt'])
bib_file = find_first_file_with_ext(unzip_dir, ['.bib'])
csl_file = find_first_file_with_ext(unzip_dir, ['.csl'])
if not main_doc:
raise FileNotFoundError("No main document (.md, .tex, .txt) found in the ZIP archive.")
if not bib_file:
raise FileNotFoundError("No bibliography file (.bib) found in the ZIP archive.")
update_job_status(db, job_id, "processing", progress=40)
# 3. Build Pandoc command
command = ['pandoc', str(main_doc), '-o', str(output_path)]
command.extend(['--bibliography', str(bib_file)])
command.append('--citeproc') # Use the citation processor
# Handle CSL style
style_key = task_key.split('_')[-1] # e.g., 'apa' from 'pdf_apa'
csl_path_or_url = None
if csl_file:
logger.info(f"Using CSL file found in ZIP: {csl_file.name}")
csl_path_or_url = str(csl_file)
else:
# Look up CSL from config
try:
csl_path_or_url = app_config['academic_settings']['pandoc']['csl_files'][style_key]
logger.info(f"Using CSL style '{style_key}' from configuration.")
except KeyError:
logger.warning(f"No CSL style found for key '{style_key}'. Pandoc will use its default.")
if csl_path_or_url:
command.extend(['--csl', csl_path_or_url])
command.extend(['--pdf-engine', 'xelatex'])
update_job_status(db, job_id, "processing", progress=50)
logger.info(f"Executing Pandoc command for job {job_id}: {' '.join(command)}")
# 4. Execute command directly to control working directory and error capture
try:
process = subprocess.run(
command,
capture_output=True,
text=True,
timeout=300,
check=True, # Raise CalledProcessError on non-zero exit
cwd=unzip_dir, # Run pandoc in the unzipped directory
preexec_fn=globals().get("_limit_resources_preexec", None)
)
except subprocess.CalledProcessError as e:
# Capture the full, detailed error log from pandoc/latex
error_log = e.stderr or "No stderr output."
logger.error(f"Pandoc compilation failed. Full log:\n{error_log}")
# Raise a more informative exception for the user
raise Exception(f"Pandoc compilation failed. Please check your document for errors. Log: {error_log[:2000]}") from e
# 5. Verify output
if not output_path.exists() or output_path.stat().st_size == 0:
raise Exception("Pandoc conversion failed: The tool produced an empty or missing output file.")
mark_job_as_completed(db, job_id, output_filepath_str=str(output_path), preview="Successfully created academic PDF.")
logger.info(f"Academic Pandoc task for job {job_id} completed.")
except Exception as e:
logger.exception(f"ERROR during academic Pandoc task for job {job_id}")
update_job_status(db, job_id, "failed", error=f"Pandoc task failed: {e}")
finally:
# 6. Cleanup
if unzip_dir.exists():
shutil.rmtree(unzip_dir, ignore_errors=True)
try:
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
input_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup input ZIP file after Pandoc task.")
if db:
db.close()
# If this was a sub-job, trigger a progress update for its parent.
db_for_check = SessionLocal()
try:
job = get_job(db_for_check, job_id)
if job and job.parent_job_id:
_update_parent_zip_job_progress(job.parent_job_id)
finally:
db_for_check.close()
send_webhook_notification(job_id, app_config, base_url)
@huey.task()
def _update_parent_zip_job_progress(parent_job_id: str):
"""Checks and updates the progress of a parent 'unzip' job."""
db = SessionLocal()
try:
parent_job = get_job(db, parent_job_id)
if not parent_job or parent_job.status not in ['processing', 'pending']:
return # Job is already finalized or doesn't exist
child_jobs = db.query(Job).filter(Job.parent_job_id == parent_job.id).all()
total_children = len(child_jobs)
if total_children == 0:
return # Should not happen if dispatched correctly, but safeguard.
finished_children = 0
for child in child_jobs:
if child.status in ['completed', 'failed', 'cancelled']:
finished_children += 1
progress = int((finished_children / total_children) * 100) if total_children > 0 else 100
if finished_children == total_children:
failed_count = sum(1 for child in child_jobs if child.status == 'failed')
preview = f"Batch processing complete. {total_children - failed_count}/{total_children} tasks succeeded."
if failed_count > 0:
preview += f" ({failed_count} failed)."
mark_job_as_completed(db, parent_job.id, preview=preview)
logger.info(f"Batch job {parent_job.id} marked as completed.")
else:
if parent_job.progress != progress:
update_job_status(db, parent_job.id, 'processing', progress=progress)
except Exception as e:
logger.exception(f"Error in _update_parent_zip_job_progress for parent {parent_job_id}: {e}")
finally:
db.close()
@huey.task()
def unzip_and_dispatch_task(job_id: str, input_path_str: str, sub_task_type: str, sub_task_options: dict, user: dict, app_config: dict, base_url: str):
db = SessionLocal()
input_path = Path(input_path_str)
unzip_dir = PATHS.UPLOADS_DIR / f"unzipped_{job_id}"
logger.info(f"Starting unzip and dispatch task for job {job_id} into {sub_task_type} jobs. ")
try:
if not zipfile.is_zipfile(input_path):
raise ValueError("Uploaded file is not a valid ZIP archive.")
unzip_dir.mkdir()
with zipfile.ZipFile(input_path, 'r') as zip_ref:
zip_ref.extractall(unzip_dir)
file_count = 0
for extracted_file_path in unzip_dir.rglob('*'):
if extracted_file_path.is_file():
file_count += 1
dispatch_single_file_job(
original_filename=extracted_file_path.name,
input_filepath=str(extracted_file_path),
task_type=sub_task_type,
options=sub_task_options,
user=user,
db=db,
app_config=app_config,
base_url=base_url,
parent_job_id=job_id
)
if file_count > 0:
# Mark parent job as processing, to be completed by the periodic task
update_job_status(db, job_id, "processing", progress=0)
else:
# No files found, mark as completed with a note
mark_job_as_completed(db, job_id, preview="ZIP archive was empty. No sub-jobs created.")
except Exception as e:
logger.exception(f"ERROR during ZIP processing for job {job_id}")
update_job_status(db, job_id, "failed", error=f"Failed to process ZIP file: {e}")
# If unzipping fails, clean up the directory
if unzip_dir.exists():
shutil.rmtree(unzip_dir)
finally:
try:
# CRITICAL FIX: Only delete the original ZIP file.
# Do NOT delete the unzip_dir here, as the sub-tasks need the files.
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
input_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to cleanup original ZIP file.")
db.close()
# --------------------------------------------------------------------------------
# --- 5. FASTAPI APPLICATION
# --------------------------------------------------------------------------------
# --- SSE Broadcaster for real-time UI updates ---
import asyncio
import json
# --------------------------------------------------------------------------------
# --- 2. DATABASE & Schemas
async def download_kokoro_models_if_missing():
"""Checks for Kokoro TTS model files and downloads them if they don't exist."""
files_to_download = {
"model": {"path": PATHS.KOKORO_MODEL_FILE, "url": "https://github.com/nazdridoy/kokoro-tts/releases/download/v1.0.0/kokoro-v1.0.onnx"},
"voices": {"path": PATHS.KOKORO_VOICES_FILE, "url": "https://github.com/nazdridoy/kokoro-tts/releases/download/v1.0.0/voices-v1.0.bin"}
}
async with httpx.AsyncClient() as client:
for name, details in files_to_download.items():
path, url = details["path"], details["url"]
if not path.exists():
logger.info(f"Kokoro TTS {name} file missing. Downloading from {url}...")
try:
with path.open("wb") as f:
async with client.stream("GET", url, follow_redirects=True, timeout=300) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
f.write(chunk)
logger.info(f"Successfully downloaded Kokoro TTS {name} file to {path}.")
except Exception as e:
logger.error(f"Failed to download Kokoro TTS {name} file: {e}")
if path.exists(): path.unlink(missing_ok=True)
else:
logger.info(f"Found existing Kokoro TTS {name} file at {path}.")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Application starting up...")
# Base.metadata.create_all(bind=engine)
create_attempts = 3
for attempt in range(1, create_attempts + 1):
try:
# use engine.begin() to ensure the DDL runs in a connection/transaction context
with engine.begin() as conn:
Base.metadata.create_all(bind=conn)
logger.info("Database tables ensured (create_all succeeded).")
break
except OperationalError as oe:
# Some SQLite drivers raise an OperationalError when two processes try to create the same table at once.
msg = str(oe).lower()
# If we see "already exists" we treat this as a race and retry briefly.
if "already exists" in msg or ("table" in msg and "already exists" in msg):
logger.warning(
"Database table creation race detected (attempt %d/%d): %s. Retrying...",
attempt,
create_attempts,
oe,
)
time.sleep(0.5)
continue
else:
logger.exception("Database initialization failed with OperationalError.")
raise
except Exception:
logger.exception("Unexpected error during DB initialization.")
raise
load_app_config()
# Start the cache cleanup thread
global _cache_cleanup_thread
if _cache_cleanup_thread is None:
_cache_cleanup_thread = threading.Thread(target=_whisper_cache_cleanup_worker, daemon=True)
_cache_cleanup_thread.start()
logger.info("Whisper model cache cleanup thread started.")
# Download required models on startup
if shutil.which("kokoro-tts"):
await download_kokoro_models_if_missing()
if PiperVoice is None:
logger.warning("piper-tts is not installed. Piper TTS features will be disabled. Install with: pip install piper-tts")
if not shutil.which("kokoro-tts"):
logger.warning("kokoro-tts command not found in PATH. Kokoro TTS features will be disabled.")
ENV = os.environ.get('ENV', 'dev').lower()
ALLOW_LOCAL_ONLY = os.environ.get('ALLOW_LOCAL_ONLY', 'false').lower() == 'true'
if LOCAL_ONLY_MODE and ENV != 'dev' and not ALLOW_LOCAL_ONLY:
raise RuntimeError('LOCAL_ONLY_MODE may only be enabled in dev or when ALLOW_LOCAL_ONLY=true is set.')
if not LOCAL_ONLY_MODE:
oidc_cfg = APP_CONFIG.get('auth_settings', {})
if not all(oidc_cfg.get(k) for k in ['oidc_client_id', 'oidc_client_secret', 'oidc_server_metadata_url']):
logger.error("OIDC auth settings are incomplete. Auth will be disabled if not in LOCAL_ONLY_MODE.")
else:
oauth.register(
name='oidc',
client_id=oidc_cfg.get('oidc_client_id'),
client_secret=oidc_cfg.get('oidc_client_secret'),
server_metadata_url=oidc_cfg.get('oidc_server_metadata_url'),
client_kwargs={'scope': 'openid email profile'},
userinfo_endpoint=oidc_cfg.get('oidc_userinfo_endpoint'),
end_session_endpoint=oidc_cfg.get('oidc_end_session_endpoint')
)
logger.info('OAuth registered.')
yield
logger.info('Application shutting down...')
app = FastAPI(lifespan=lifespan)
ENV = os.environ.get('ENV', 'dev').lower()
SECRET_KEY = os.environ.get('SECRET_KEY')
if not SECRET_KEY and not LOCAL_ONLY_MODE and ENV != 'dev':
raise RuntimeError('SECRET_KEY must be set in production when authentication is enabled.')
if not SECRET_KEY:
logger.warning('SECRET_KEY is not set. Generating a temporary key. Sessions will not persist across restarts.')
SECRET_KEY = os.urandom(24).hex()
app.add_middleware(
SessionMiddleware,
secret_key=SECRET_KEY,
https_only=False, # Set to True if behind HTTPS proxy
same_site='lax',
max_age=14 * 24 * 60 * 60 # 14 days
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Static / templates
app.mount("/static", StaticFiles(directory=str(PATHS.BASE_DIR / "static")), name="static")
templates = Jinja2Templates(directory=str(PATHS.BASE_DIR / "templates"))
# --- AUTH & USER HELPERS ---
http_bearer = HTTPBearer()
def get_current_user(request: Request):
if LOCAL_ONLY_MODE:
return {'sub': 'local_user', 'email': 'local@user.com', 'name': 'Local User'}
return request.session.get('user')
async def require_api_user(request: Request, creds: HTTPAuthorizationCredentials = Depends(http_bearer)):
"""Dependency for API routes requiring OIDC bearer token authentication."""
if LOCAL_ONLY_MODE:
return {'sub': 'local_api_user', 'email': 'local@api.user.com', 'name': 'Local API User'}
if not creds:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
token = creds.credentials
try:
user = await oauth.oidc.userinfo(token={'access_token': token})
return dict(user)
except Exception as e:
logger.error(f"API token validation failed: {e}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token")
def is_admin(request: Request) -> bool:
if LOCAL_ONLY_MODE: return True
user = get_current_user(request)
if not user: return False
admin_users = APP_CONFIG.get("auth_settings", {}).get("admin_users", [])
return user.get('email') in admin_users
def require_user(request: Request):
user = get_current_user(request)
if not user: raise HTTPException(status_code=401, detail="Not authenticated")
return user
def require_admin(request: Request):
if not is_admin(request): raise HTTPException(status_code=403, detail="Administrator privileges required.")
return True
# --- FILE SAVING UTILITY ---
async def save_upload_file(upload_file: UploadFile, destination: Path) -> int:
"""
Saves an uploaded file to a destination, handling size limits.
This function is used by both the simple API and the legacy direct-upload routes.
"""
max_size = APP_CONFIG.get("app_settings", {}).get("max_file_size_bytes", 100 * 1024 * 1024)
tmp_path = destination.with_name(f"{destination.stem}.tmp-{uuid.uuid4().hex}{destination.suffix}")
size = 0
try:
with tmp_path.open("wb") as buffer:
while True:
chunk = await upload_file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
if size > max_size:
raise HTTPException(status_code=413, detail=f"File exceeds {max_size / 1024 / 1024} MB limit")
buffer.write(chunk)
tmp_path.replace(destination)
return size
except Exception as e:
try:
# Ensure temp file is cleaned up on error
ensure_path_is_safe(tmp_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
tmp_path.unlink(missing_ok=True)
except Exception:
logger.exception("Failed to remove temp upload file after error.")
# Re-raise the original exception
raise e
finally:
await upload_file.close()
def is_allowed_file(filename: str, allowed_extensions: set) -> bool:
if not allowed_extensions: # If set is empty, allow all
return True
return Path(filename).suffix.lower() in allowed_extensions
# --- CHUNKED UPLOADS (for UI) ---
@app.post("/upload/chunk")
async def upload_chunk(
chunk: UploadFile = File(...),
upload_id: str = Form(...),
chunk_number: int = Form(...),
user: dict = Depends(require_user)
):
safe_upload_id = secure_filename(upload_id)
temp_dir = ensure_path_is_safe(PATHS.CHUNK_TMP_DIR / safe_upload_id, [PATHS.CHUNK_TMP_DIR])
temp_dir.mkdir(exist_ok=True)
chunk_path = temp_dir / f"{chunk_number}.chunk"
def save_chunk_sync():
try:
with open(chunk_path, "wb") as buffer:
shutil.copyfileobj(chunk.file, buffer)
finally:
chunk.file.close()
await run_in_threadpool(save_chunk_sync)
return JSONResponse({"message": f"Chunk {chunk_number} for {safe_upload_id} uploaded."})
async def _stitch_chunks(temp_dir: Path, final_path: Path, total_chunks: int):
"""Stitches chunks together memory-efficiently and cleans up."""
ensure_path_is_safe(temp_dir, [PATHS.CHUNK_TMP_DIR])
ensure_path_is_safe(final_path, [PATHS.UPLOADS_DIR])
# This is a blocking function that will be run in a threadpool
def do_stitch():
with open(final_path, "wb") as final_file:
for i in range(total_chunks):
chunk_path = temp_dir / f"{i}.chunk"
if not chunk_path.exists():
# Raise an exception that can be caught and handled
raise FileNotFoundError(f"Upload failed: missing chunk {i}")
with open(chunk_path, "rb") as chunk_file:
# Use copyfileobj for memory efficiency
shutil.copyfileobj(chunk_file, final_file)
try:
await run_in_threadpool(do_stitch)
except FileNotFoundError as e:
# If a chunk was missing, clean up and re-raise as HTTPException
shutil.rmtree(temp_dir, ignore_errors=True)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# For any other error during stitching, clean up and re-raise
shutil.rmtree(temp_dir, ignore_errors=True)
raise e # Re-raise the original exception
else:
# If successful, clean up the temp directory
shutil.rmtree(temp_dir, ignore_errors=True)
def _parse_tool_and_task_key(output_format: str, all_tool_keys: list) -> (str, str):
"""Robustly parses an output_format string to find the matching tool and task key."""
# Sort keys by length descending to match longest prefix first (e.g., 'ghostscript_image' before 'ghostscript')
for tool_key in sorted(all_tool_keys, key=len, reverse=True):
if output_format.startswith(tool_key + '_'):
task_key = output_format[len(tool_key) + 1:]
return tool_key, task_key
raise ValueError(f"Could not determine tool from output_format: {output_format}")
@app.post("/upload/finalize", response_model=JobSchema, status_code=status.HTTP_202_ACCEPTED)
async def finalize_upload(request: Request, payload: FinalizeUploadPayload, user: dict = Depends(require_user), db: Session = Depends(get_db)):
safe_upload_id = secure_filename(payload.upload_id)
temp_dir = ensure_path_is_safe(PATHS.CHUNK_TMP_DIR / safe_upload_id, [PATHS.CHUNK_TMP_DIR])
if not temp_dir.is_dir():
raise HTTPException(status_code=404, detail="Upload session not found or already finalized.")
webhook_config = APP_CONFIG.get("webhook_settings", {})
if payload.callback_url and not is_allowed_callback_url(payload.callback_url, webhook_config.get("allowed_callback_urls", [])):
raise HTTPException(status_code=400, detail="Provided callback_url is not allowed.")
job_id = uuid.uuid4().hex
safe_filename = secure_filename(payload.original_filename)
final_path = PATHS.UPLOADS_DIR / f"{Path(safe_filename).stem}_{job_id}{Path(safe_filename).suffix}"
await _stitch_chunks(temp_dir, final_path, payload.total_chunks)
base_url = str(request.base_url)
# Check if the selected conversion is the new academic pandoc task
tool, task_key = None, None
if payload.task_type == 'conversion':
try:
all_tools = APP_CONFIG.get("conversion_tools", {}).keys()
tool, task_key = _parse_tool_and_task_key(payload.output_format, all_tools)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid or missing output_format for conversion.")
if tool == 'pandoc_academic':
# This is a single job that processes a ZIP file as a project.
options = {"output_format": payload.output_format}
dispatch_single_file_job(payload.original_filename, str(final_path), "conversion", user, db, APP_CONFIG, base_url, job_id=job_id, options=options)
elif Path(safe_filename).suffix.lower() == '.zip':
# This is the original batch processing logic for ZIP files.
job_data = JobCreate(
id=job_id, user_id=user['sub'], task_type="unzip",
original_filename=payload.original_filename, input_filepath=str(final_path),
input_filesize=final_path.stat().st_size
)
create_job(db=db, job=job_data)
sub_task_options = {
"model_size": payload.model_size,
"model_name": payload.model_name,
"output_format": payload.output_format
}
unzip_and_dispatch_task(job_id, str(final_path), payload.task_type, sub_task_options, user, APP_CONFIG, base_url)
else:
# This is the logic for all other single-file uploads.
options = {"model_size": payload.model_size, "model_name": payload.model_name, "output_format": payload.output_format}
dispatch_single_file_job(payload.original_filename, str(final_path), payload.task_type, user, db, APP_CONFIG, base_url, job_id=job_id, options=options)
# --- FIX STARTS HERE ---
# Instead of returning a minimal object, fetch the newly created job
# from the database and return the full serialized object. This ensures
# the frontend has all the data it needs to correctly update the UI row.
db.flush() # Ensure the job is available to be queried
db_job = get_job(db, job_id)
if not db_job:
# This is an unlikely race condition but we handle it just in case.
# The SSE event will still create the row correctly.
raise HTTPException(status_code=500, detail="Job was created but could not be retrieved for an immediate response.")
# Also, update the function signature to use the response_model
# from: @app.post("/upload/finalize", status_code=status.HTTP_202_ACCEPTED)
# to: @app.post("/upload/finalize", response_model=JobSchema, status_code=status.HTTP_202_ACCEPTED)
return db_job
# --- FIX ENDS HERE ---
# --- LEGACY DIRECT-UPLOAD ROUTES (kept for compatibility) ---
@app.post("/transcribe-audio", status_code=status.HTTP_202_ACCEPTED)
async def submit_audio_transcription(
request: Request, file: UploadFile = File(...), model_size: str = Form("base"),
db: Session = Depends(get_db), user: dict = Depends(require_user)
):
allowed_audio_exts = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus"}
if not is_allowed_file(file.filename, allowed_audio_exts):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio file type.")
whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {})
if model_size not in whisper_config.get("allowed_models", []):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model size: {model_size}.")
job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename)
stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix
upload_path = PATHS.UPLOADS_DIR / f"{stem}_{job_id}{suffix}"
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt"
input_size = await save_upload_file(file, upload_path)
base_url = str(request.base_url)
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="transcription", original_filename=file.filename,
input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size, whisper_settings=whisper_config, app_config=APP_CONFIG, base_url=base_url)
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
@app.post("/convert-file", status_code=status.HTTP_202_ACCEPTED)
async def submit_file_conversion(request: Request, file: UploadFile = File(...), output_format: str = Form(...), db: Session = Depends(get_db), user: dict = Depends(require_user)):
allowed_exts = APP_CONFIG.get("app_settings", {}).get("allowed_all_extensions", set())
if not is_allowed_file(file.filename, allowed_exts):
raise HTTPException(status_code=400, detail=f"File type '{Path(file.filename).suffix}' not allowed.")
conversion_tools = APP_CONFIG.get("conversion_tools", {})
try:
tool, task_key = output_format.split('_', 1)
if tool not in conversion_tools: raise ValueError()
except ValueError:
raise HTTPException(status_code=400, detail="Invalid output format selected.")
job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename)
original_stem = Path(safe_basename).stem
target_ext = task_key.split('_')[0]
if tool == "ghostscript_pdf": target_ext = "pdf"
upload_path = PATHS.UPLOADS_DIR / f"{original_stem}_{job_id}{Path(safe_basename).suffix}"
processed_path = PATHS.PROCESSED_DIR / f"{original_stem}_{job_id}.{target_ext}"
input_size = await save_upload_file(file, upload_path)
base_url = str(request.base_url)
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="conversion", original_filename=file.filename,
input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_conversion_task(new_job.id, str(upload_path), str(processed_path), tool, task_key, conversion_tools, APP_CONFIG, base_url)
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
@app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED)
async def submit_pdf_ocr(request: Request, file: UploadFile = File(...), db: Session = Depends(get_db), user: dict = Depends(require_user)):
if not is_allowed_file(file.filename, {".pdf"}):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.")
job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename)
unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}"
upload_path = PATHS.UPLOADS_DIR / unique_filename
processed_path = PATHS.PROCESSED_DIR / unique_filename
input_size = await save_upload_file(file, upload_path)
base_url = str(request.base_url)
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="ocr", original_filename=file.filename,
input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
ocr_settings = APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {})
run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path), ocr_settings, APP_CONFIG, base_url)
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
@app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED)
async def submit_image_ocr(request: Request, file: UploadFile = File(...), db: Session = Depends(get_db), user: dict = Depends(require_user)):
allowed_exts = {".png", ".jpg", ".jpeg", ".tiff", ".tif"}
if not is_allowed_file(file.filename, allowed_exts):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PNG, JPG, or TIFF.")
job_id, safe_basename = uuid.uuid4().hex, secure_filename(file.filename)
file_ext = Path(safe_basename).suffix
unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}"
upload_path = PATHS.UPLOADS_DIR / unique_filename
processed_path = PATHS.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.pdf"
input_size = await save_upload_file(file, upload_path)
base_url = str(request.base_url)
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="ocr-image", original_filename=file.filename,
input_filepath=str(upload_path), input_filesize=input_size, processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_image_ocr_task(new_job.id, str(upload_path), str(processed_path), APP_CONFIG, base_url)
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
# --------------------------------------------------------------------------------
# --- API V1 ROUTES (for programmatic access)
# --------------------------------------------------------------------------------
def is_allowed_callback_url(url: str, allowed: List[str]) -> bool:
if not allowed:
return False
try:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
for a in allowed:
ap = urlparse(a)
if ap.scheme and ap.netloc:
if parsed.scheme == ap.scheme and parsed.netloc == ap.netloc:
return True
else:
# support legacy prefix entries - keep fallback
if url.startswith(a):
return True
return False
except Exception:
return False
@app.get("/api/v1/tts-voices")
async def get_tts_voices_list(user: dict = Depends(require_user)):
global AVAILABLE_TTS_VOICES_CACHE
kokoro_available = shutil.which("kokoro-tts") is not None
piper_available = PiperVoice is not None
if not piper_available and not kokoro_available:
return JSONResponse(content={"error": "TTS feature not configured on server (no TTS engines found)."}, status_code=501)
if AVAILABLE_TTS_VOICES_CACHE:
return AVAILABLE_TTS_VOICES_CACHE
all_voices = []
try:
if piper_available:
logger.info("Fetching available Piper voices list...")
piper_voices = safe_get_voices(PATHS.TTS_MODELS_DIR)
for voice in piper_voices:
voice['id'] = f"piper/{voice.get('id')}"
voice['name'] = f"Piper: {voice.get('name', voice.get('id'))}"
all_voices.extend(piper_voices)
if kokoro_available:
logger.info("Fetching available Kokoro TTS voices and languages...")
kokoro_voices = list_kokoro_voices_cli()
kokoro_langs = list_kokoro_languages_cli()
for lang in kokoro_langs:
for voice in kokoro_voices:
all_voices.append({
"id": f"kokoro/{lang}/{voice}",
"name": f"Kokoro ({lang}): {voice}",
"local": False
})
AVAILABLE_TTS_VOICES_CACHE = sorted(all_voices, key=lambda x: x['name'])
return AVAILABLE_TTS_VOICES_CACHE
except Exception as e:
logger.exception("Could not fetch list of TTS voices.")
raise HTTPException(status_code=500, detail=f"Could not retrieve voices list: {e}")
# --- Standard API endpoint (non-chunked) ---
@app.post("/api/v1/process", status_code=status.HTTP_202_ACCEPTED, tags=["Webhook API"])
async def api_process_file(
request: Request, file: UploadFile = File(...), task_type: str = Form(...), callback_url: str = Form(...),
model_size: Optional[str] = Form("base"), model_name: Optional[str] = Form(None),
output_format: Optional[str] = Form(None),
db: Session = Depends(get_db), user: dict = Depends(require_api_user)
):
"""
Programmatically submit a file for processing via a single HTTP request.
This is the recommended endpoint for services like n8n.
Requires bearer token authentication unless in LOCAL_ONLY_MODE.
"""
webhook_config = APP_CONFIG.get("webhook_settings", {})
if not webhook_config.get("enabled", False):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Webhook processing is disabled on the server.")
if not is_allowed_callback_url(callback_url, webhook_config.get("allowed_callback_urls", [])):
logger.warning(f"Rejected webhook from user '{user.get('email')}' with disallowed callback URL: {callback_url}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Provided callback_url is not in the list of allowed URLs.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix
upload_filename = f"{stem}_{job_id}{suffix}"
upload_path = PATHS.UPLOADS_DIR / upload_filename
try:
input_size = await save_upload_file(file, upload_path)
except HTTPException as e:
raise e # Re-raise exceptions from save_upload_file (e.g., file too large)
except Exception as e:
logger.exception("Failed to save uploaded file for webhook processing.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to save file: {e}")
base_url = str(request.base_url)
job_data_args = {
"id": job_id, "user_id": user['sub'], "original_filename": file.filename,
"input_filepath": str(upload_path), "input_filesize": input_size,
"callback_url": callback_url, "task_type": task_type,
}
# --- API Task Dispatching Logic ---
if task_type == "transcription":
whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {})
if model_size not in whisper_config.get("allowed_models", []):
raise HTTPException(status_code=400, detail=f"Invalid model_size '{model_size}'")
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt"
job_data_args["processed_filepath"] = str(processed_path)
create_job(db=db, job=JobCreate(**job_data_args))
run_transcription_task(job_id, str(upload_path), str(processed_path), model_size, whisper_config, APP_CONFIG, base_url)
elif task_type == "tts":
if not is_allowed_file(file.filename, {".txt"}):
raise HTTPException(status_code=400, detail="Invalid file type for TTS, requires .txt")
if not model_name:
raise HTTPException(status_code=400, detail="model_name is required for TTS task.")
tts_config = APP_CONFIG.get("tts_settings", {})
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.wav"
job_data_args["processed_filepath"] = str(processed_path)
create_job(db=db, job=JobCreate(**job_data_args))
run_tts_task(job_id, str(upload_path), str(processed_path), model_name, tts_config, APP_CONFIG, base_url)
elif task_type == "conversion":
if not output_format:
raise HTTPException(status_code=400, detail="output_format is required for conversion task.")
conversion_tools = APP_CONFIG.get("conversion_tools", {})
try:
tool, task_key = output_format.split('_', 1)
if tool not in conversion_tools: raise ValueError("Invalid tool")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid output_format selected.")
target_ext = task_key.split('_')[0]
if tool == "ghostscript_pdf": target_ext = "pdf"
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.{target_ext}"
job_data_args["processed_filepath"] = str(processed_path)
create_job(db=db, job=JobCreate(**job_data_args))
run_conversion_task(job_id, str(upload_path), str(processed_path), tool, task_key, conversion_tools, APP_CONFIG, base_url)
elif task_type == "ocr":
if not is_allowed_file(file.filename, {".pdf"}):
raise HTTPException(status_code=400, detail="Invalid file type for ocr, requires .pdf")
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}{suffix}"
job_data_args["processed_filepath"] = str(processed_path)
create_job(db=db, job=JobCreate(**job_data_args))
run_pdf_ocr_task(job_id, str(upload_path), str(processed_path), APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {}), APP_CONFIG, base_url)
elif task_type == "ocr-image":
if not is_allowed_file(file.filename, {".png", ".jpg", ".jpeg", ".tiff", ".tif"}):
raise HTTPException(status_code=400, detail="Invalid file type for ocr-image.")
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt"
job_data_args["processed_filepath"] = str(processed_path)
create_job(db=db, job=JobCreate(**job_data_args))
run_image_ocr_task(job_id, str(upload_path), str(processed_path), APP_CONFIG, base_url)
else:
upload_path.unlink(missing_ok=True) # Cleanup orphaned file
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid task_type: '{task_type}'")
return {"job_id": job_id, "status": "pending"}
# --- Chunked API endpoints (optional) ---
@app.post("/api/v1/upload/chunk", tags=["Webhook API"])
async def api_upload_chunk(
chunk: UploadFile = File(...), upload_id: str = Form(...), chunk_number: int = Form(...),
user: dict = Depends(require_api_user)
):
"""API endpoint for uploading a single file chunk."""
webhook_config = APP_CONFIG.get("webhook_settings", {})
if not webhook_config.get("enabled", False) or not webhook_config.get("allow_chunked_api_uploads", False):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Chunked API uploads are disabled.")
return await upload_chunk(chunk, upload_id, chunk_number, user)
@app.post("/api/v1/upload/finalize", status_code=status.HTTP_202_ACCEPTED, tags=["Webhook API"])
async def api_finalize_upload(
request: Request, payload: FinalizeUploadPayload, user: dict = Depends(require_api_user), db: Session = Depends(get_db)
):
"""API endpoint to finalize a chunked upload and start a processing job."""
webhook_config = APP_CONFIG.get("webhook_settings", {})
if not webhook_config.get("enabled", False) or not webhook_config.get("allow_chunked_api_uploads", False):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Chunked API uploads are disabled.")
# Validate callback URL if provided for a webhook job
if payload.callback_url and not is_allowed_callback_url(payload.callback_url, webhook_config.get("allowed_callback_urls", [])):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Provided callback_url is not allowed.")
# Re-use the main finalization logic, but with API user context
return await finalize_upload(request, payload, user, db)
# --------------------------------------------------------------------------------
# --- AUTH & PAGE ROUTES
# --------------------------------------------------------------------------------
if not LOCAL_ONLY_MODE:
@app.get('/login')
async def login(request: Request):
redirect_uri = request.url_for('auth')
return await oauth.oidc.authorize_redirect(request, redirect_uri)
@app.get('/auth')
async def auth(request: Request):
try:
token = await oauth.oidc.authorize_access_token(request)
user = await oauth.oidc.userinfo(token=token)
request.session['user'] = dict(user)
request.session['id_token'] = token.get('id_token')
except Exception as e:
logger.error(f"Authentication failed: {e}")
raise HTTPException(status_code=401, detail="Authentication failed")
return RedirectResponse(url='/')
@app.get("/logout")
async def logout(request: Request):
logout_endpoint = oauth.oidc.server_metadata.get("end_session_endpoint")
if not logout_endpoint:
request.session.clear()
logger.warning("OIDC 'end_session_endpoint' not found. Performing local-only logout.")
return RedirectResponse(url="/", status_code=302)
post_logout_redirect_uri = str(request.url_for("get_index"))
logout_url = f"{logout_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}"
request.session.clear()
return RedirectResponse(url=logout_url, status_code=302)
# This is for reverse proxies that use forward auth
@app.get("/api/authz/forward-auth")
async def forward_auth(request: Request):
redirect_uri = request.url_for('auth')
return await oauth.oidc.authorize_redirect(request, redirect_uri)
@app.get("/")
async def get_index(request: Request):
user = get_current_user(request)
admin_status = is_admin(request)
whisper_models = APP_CONFIG.get("transcription_settings", {}).get("whisper", {}).get("allowed_models", [])
conversion_tools = APP_CONFIG.get("conversion_tools", {})
return templates.TemplateResponse("index.html", {
"request": request, "user": user, "is_admin": admin_status,
"whisper_models": sorted(list(whisper_models)),
"conversion_tools": conversion_tools, "local_only_mode": LOCAL_ONLY_MODE
})
@app.get("/settings")
async def get_settings_page(request: Request):
"""Displays the contents of the currently active configuration."""
user = get_current_user(request)
admin_status = is_admin(request)
# Use the globally loaded and merged APP_CONFIG for consistency
# This ensures all default keys are present before rendering.
current_config = APP_CONFIG
# Determine the source file for display purposes
config_source = "none"
if PATHS.SETTINGS_FILE.exists():
config_source = str(PATHS.SETTINGS_FILE.name)
elif PATHS.DEFAULT_SETTINGS_FILE.exists():
config_source = str(PATHS.DEFAULT_SETTINGS_FILE.name)
return templates.TemplateResponse(
"settings.html",
{"request": request, "config": current_config, "config_source": config_source,
"user": user, "is_admin": admin_status, "local_only_mode": LOCAL_ONLY_MODE}
)
def deep_merge(source: dict, destination: dict) -> dict:
"""Recursively merges dicts."""
for key, value in source.items():
if isinstance(value, collections.abc.Mapping):
node = destination.setdefault(key, {})
deep_merge(value, node)
else:
destination[key] = value
return destination
@app.post("/settings/save")
async def save_settings(
request: Request, new_config_from_ui: Dict = Body(...), admin: bool = Depends(require_admin)
):
"""Safely updates settings.yml by merging UI changes with the existing file."""
tmp_path = PATHS.SETTINGS_FILE.with_suffix(".tmp")
user = get_current_user(request)
try:
if not new_config_from_ui:
if PATHS.SETTINGS_FILE.exists():
PATHS.SETTINGS_FILE.unlink()
logger.info(f"Admin '{user.get('email')}' reverted to default settings.")
load_app_config()
return JSONResponse({"message": "Settings reverted to default."})
try:
with PATHS.SETTINGS_FILE.open("r", encoding="utf8") as f:
current_config_on_disk = yaml.safe_load(f) or {}
except FileNotFoundError:
current_config_on_disk = {}
merged_config = deep_merge(source=new_config_from_ui, destination=current_config_on_disk)
with tmp_path.open("w", encoding="utf8") as f:
yaml.safe_dump(merged_config, f, default_flow_style=False, sort_keys=False)
tmp_path.replace(PATHS.SETTINGS_FILE)
logger.info(f"Admin '{user.get('email')}' updated settings.yml.")
load_app_config()
return JSONResponse({"message": "Settings saved successfully."})
except Exception as e:
logger.exception(f"Failed to update settings for admin '{user.get('email')}'")
if tmp_path.exists(): tmp_path.unlink()
raise HTTPException(status_code=500, detail=f"Could not save settings.yml: {e}")
# --------------------------------------------------------------------------------
# --- JOB MANAGEMENT & UTILITY ROUTES
# --------------------------------------------------------------------------------
@app.post("/settings/clear-history")
async def clear_job_history(db: Session = Depends(get_db), user: dict = Depends(require_user)):
try:
num_deleted = db.query(Job).filter(Job.user_id == user['sub']).delete()
db.commit()
logger.info(f"Cleared {num_deleted} jobs for user {user['sub']}.")
return {"deleted_count": num_deleted}
except Exception:
db.rollback()
logger.exception("Failed to clear job history")
raise HTTPException(status_code=500, detail="Database error while clearing history.")
@app.post("/settings/delete-files")
async def delete_processed_files(db: Session = Depends(get_db), user: dict = Depends(require_user)):
deleted_count, errors = 0, []
for job in get_jobs(db, user_id=user['sub']):
if job.processed_filepath:
try:
p = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR])
if p.is_file():
p.unlink()
deleted_count += 1
except Exception:
errors.append(Path(job.processed_filepath).name)
logger.exception(f"Could not delete file {Path(job.processed_filepath).name}")
if errors:
raise HTTPException(status_code=500, detail=f"Could not delete some files: {', '.join(errors)}")
logger.info(f"Deleted {deleted_count} files for user {user['sub']}.")
return {"deleted_count": deleted_count}
@app.post("/job/{job_id}/cancel", status_code=status.HTTP_202_ACCEPTED)
async def cancel_job(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
job = get_job(db, job_id)
if not job or job.user_id != user['sub']:
raise HTTPException(status_code=404, detail="Job not found.")
if job.status in ["pending", "processing"]:
update_job_status(db, job_id, status="cancelled")
return {"message": "Job cancellation requested."}
raise HTTPException(status_code=400, detail=f"Job is already in a final state ({job.status}).")
@app.get("/jobs", response_model=List[JobSchema])
async def get_all_jobs(db: Session = Depends(get_db), user: dict = Depends(require_user)):
return get_jobs(db, user_id=user['sub'])
@app.get("/job/{job_id}", response_model=JobSchema)
async def get_job_status(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
job = get_job(db, job_id)
if not job or job.user_id != user['sub']:
raise HTTPException(status_code=404, detail="Job not found.")
return job
class JobStatusRequest(BaseModel):
job_ids: TypingList[str]
@app.post("/api/v1/jobs/status", response_model=TypingList[JobSchema])
async def get_jobs_status(payload: JobStatusRequest, db: Session = Depends(get_db), user: dict = Depends(require_user)):
"""
Accepts a list of job IDs and returns their current status.
This is used by the frontend for polling active jobs.
"""
if not payload.job_ids:
return []
# Fetch all requested jobs from the database in a single query
jobs = db.query(Job).filter(Job.id.in_(payload.job_ids), Job.user_id == user['sub']).all()
return jobs
@app.get("/download/{filename}")
async def download_file(filename: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
safe_filename = secure_filename(filename)
file_path = ensure_path_is_safe(PATHS.PROCESSED_DIR / safe_filename, [PATHS.PROCESSED_DIR])
if not file_path.is_file():
raise HTTPException(status_code=404, detail="File not found.")
# API users can download files they own via webhook URL. UI users need session.
job_owner_id = user.get('sub') if user else None
job = db.query(Job).filter(Job.processed_filepath == str(file_path), Job.user_id == job_owner_id).first()
if not job:
raise HTTPException(status_code=403, detail="You do not have permission to download this file.")
download_filename = Path(job.original_filename).stem + Path(job.processed_filepath).suffix
return FileResponse(path=file_path, filename=download_filename, media_type="application/octet-stream")
@app.post("/download/batch", response_class=StreamingResponse)
async def download_batch(payload: JobSelection, db: Session = Depends(get_db), user: dict = Depends(require_user)):
job_ids = payload.job_ids
if not job_ids:
raise HTTPException(status_code=400, detail="No job IDs provided.")
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
for job_id in job_ids:
job = get_job(db, job_id)
if job and job.user_id == user['sub'] and job.status == 'completed' and job.processed_filepath:
file_path = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR])
if file_path.exists():
download_filename = f"{Path(job.original_filename).stem}_{job_id}{file_path.suffix}"
zip_file.write(file_path, arcname=download_filename)
zip_buffer.seek(0)
return StreamingResponse(zip_buffer, media_type="application/x-zip-compressed", headers={
'Content-Disposition': f'attachment; filename="file-wizard-batch-{uuid.uuid4().hex[:8]}.zip"'
})
@app.get("/download/zip-batch/{job_id}", response_class=StreamingResponse)
async def download_zip_batch(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
"""Downloads all processed files from a ZIP upload batch as a new ZIP file."""
parent_job = get_job(db, job_id)
if not parent_job or parent_job.user_id != user['sub']:
raise HTTPException(status_code=404, detail="Parent job not found.")
if parent_job.task_type != 'unzip':
raise HTTPException(status_code=400, detail="This job is not a batch upload.")
child_jobs = db.query(Job).filter(Job.parent_job_id == job_id, Job.status == 'completed').all()
if not child_jobs:
raise HTTPException(status_code=404, detail="No completed sub-jobs found for this batch.")
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
files_added = 0
for job in child_jobs:
if job.processed_filepath:
file_path = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR])
if file_path.exists():
# Create a more user-friendly name inside the zip
download_filename = f"{Path(job.original_filename).stem}{file_path.suffix}"
zip_file.write(file_path, arcname=download_filename)
files_added += 1
if files_added == 0:
raise HTTPException(status_code=404, detail="No processed files found for the completed sub-jobs.")
zip_buffer.seek(0)
# Generate a filename for the download
batch_filename = f"{Path(parent_job.original_filename).stem}_processed.zip"
return StreamingResponse(zip_buffer, media_type="application/x-zip-compressed", headers={
'Content-Disposition': f'attachment; filename="{batch_filename}"'
})
@app.get("/health")
async def health():
try:
with engine.connect() as conn:
conn.execution_options(isolation_level="AUTOCOMMIT").execute("SELECT 1")
except Exception:
logger.exception("Health check failed")
return JSONResponse({"ok": False}, status_code=500)
return {"ok": True}
@app.get('/favicon.ico', include_in_schema=False)
async def favicon():
return FileResponse(str(PATHS.BASE_DIR / 'static' / 'favicon.png'))