1144 lines
52 KiB
Python
1144 lines
52 KiB
Python
# main.py (merged)
|
|
|
|
import logging
|
|
import shutil
|
|
import subprocess
|
|
import traceback
|
|
import uuid
|
|
import shlex
|
|
import yaml
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Dict, List, Any
|
|
import resource
|
|
from threading import Semaphore
|
|
from logging.handlers import RotatingFileHandler
|
|
from urllib.parse import urlencode
|
|
|
|
import ocrmypdf
|
|
import pypdf
|
|
import pytesseract
|
|
from PIL import Image
|
|
from faster_whisper import WhisperModel
|
|
from fastapi import (Depends, FastAPI, File, Form, HTTPException, Request,
|
|
UploadFile, status, Body)
|
|
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from huey import SqliteHuey
|
|
from pydantic import BaseModel, ConfigDict, field_serializer
|
|
from sqlalchemy import (Column, DateTime, Integer, String, Text,
|
|
create_engine, delete, event)
|
|
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
from string import Formatter
|
|
from werkzeug.utils import secure_filename
|
|
from typing import List as TypingList
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from authlib.integrations.starlette_client import OAuth
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
# Instantiate OAuth object (was referenced in code)
|
|
oauth = OAuth()
|
|
|
|
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# --- 1. CONFIGURATION & SECURITY HELPERS
|
|
# --------------------------------------------------------------------------------
|
|
# --- Path Safety ---
|
|
UPLOADS_BASE = Path(os.environ.get("UPLOADS_DIR", "/app/uploads")).resolve()
|
|
PROCESSED_BASE = Path(os.environ.get("PROCESSED_DIR", "/app/processed")).resolve()
|
|
CHUNK_TMP_BASE = Path(os.environ.get("CHUNK_TMP_DIR", str(UPLOADS_BASE / "tmp"))).resolve()
|
|
|
|
def ensure_path_is_safe(p: Path, allowed_bases: List[Path]):
|
|
"""Ensure a path resolves to a location within one of the allowed base directories."""
|
|
try:
|
|
resolved_p = p.resolve()
|
|
if not any(resolved_p.is_relative_to(base) for base in allowed_bases):
|
|
raise ValueError(f"Path {resolved_p} is outside of allowed directories.")
|
|
return resolved_p
|
|
except Exception as e:
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"Path safety check failed for {p}: {e}")
|
|
raise ValueError("Invalid or unsafe path specified.")
|
|
|
|
# --- Resource Limiting ---
|
|
def _limit_resources_preexec():
|
|
"""Set resource limits for child processes to prevent DoS attacks."""
|
|
try:
|
|
# 6000s CPU, 2GB address space, i dont know if thats too much tbh
|
|
resource.setrlimit(resource.RLIMIT_CPU, (6000, 6000))
|
|
resource.setrlimit(resource.RLIMIT_AS, (4 * 1024 * 1024 * 1024, 2 * 1024 * 1024 * 1024))
|
|
except Exception as e:
|
|
# This may fail in some environments (e.g. Windows, some containers)
|
|
logging.getLogger(__name__).warning(f"Could not set resource limits: {e}")
|
|
pass
|
|
|
|
# --- Model concurrency semaphore ---
|
|
MODEL_CONCURRENCY = int(os.environ.get("MODEL_CONCURRENCY", "1"))
|
|
_model_semaphore = Semaphore(MODEL_CONCURRENCY)
|
|
|
|
# --- Logging Setup ---
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
_log_handler = RotatingFileHandler("app.log", maxBytes=10*1024*1024, backupCount=5)
|
|
_log_formatter = logging.Formatter('%(asctime)s %(levelname)s %(name)s %(message)s')
|
|
_log_handler.setFormatter(_log_formatter)
|
|
logging.getLogger().addHandler(_log_handler)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Environment Mode ---
|
|
LOCAL_ONLY_MODE = os.getenv('LOCAL_ONLY', 'True').lower() in ('true', '1', 't')
|
|
if LOCAL_ONLY_MODE:
|
|
logger.warning("Authentication is DISABLED. Running in LOCAL_ONLY mode.")
|
|
|
|
class AppPaths(BaseModel):
|
|
BASE_DIR: Path = Path(__file__).resolve().parent
|
|
UPLOADS_DIR: Path = UPLOADS_BASE
|
|
PROCESSED_DIR: Path = PROCESSED_BASE
|
|
CHUNK_TMP_DIR: Path = CHUNK_TMP_BASE
|
|
DATABASE_URL: str = f"sqlite:///{BASE_DIR / 'jobs.db'}"
|
|
HUEY_DB_PATH: str = str(BASE_DIR / "huey.db")
|
|
CONFIG_DIR: Path = BASE_DIR / "config"
|
|
SETTINGS_FILE: Path = CONFIG_DIR / "settings.yml"
|
|
DEFAULT_SETTINGS_FILE: Path = BASE_DIR / "settings.default.yml"
|
|
|
|
PATHS = AppPaths()
|
|
APP_CONFIG: Dict[str, Any] = {}
|
|
PATHS.UPLOADS_DIR.mkdir(exist_ok=True, parents=True)
|
|
PATHS.PROCESSED_DIR.mkdir(exist_ok=True, parents=True)
|
|
PATHS.CHUNK_TMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
PATHS.CONFIG_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
|
def load_app_config():
|
|
"""
|
|
Loads configuration from settings.yml, with a fallback to settings.default.yml,
|
|
and finally to hardcoded defaults if both files are missing.
|
|
"""
|
|
global APP_CONFIG
|
|
try:
|
|
# --- Primary Method: Attempt to load settings.yml ---
|
|
with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f:
|
|
cfg_raw = yaml.safe_load(f) or {}
|
|
|
|
# This logic block is intentionally duplicated to maintain compatibility
|
|
defaults = {
|
|
"app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": []},
|
|
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}},
|
|
"conversion_tools": {},
|
|
"ocr_settings": {"ocrmypdf": {}},
|
|
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []}
|
|
}
|
|
cfg = defaults.copy()
|
|
cfg.update(cfg_raw) # Merge loaded settings into defaults
|
|
app_settings = cfg.get("app_settings", {})
|
|
max_mb = app_settings.get("max_file_size_mb", 100)
|
|
app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024
|
|
allowed = app_settings.get("allowed_all_extensions", [])
|
|
if not isinstance(allowed, (list, set)):
|
|
allowed = list(allowed)
|
|
app_settings["allowed_all_extensions"] = set(allowed)
|
|
cfg["app_settings"] = app_settings
|
|
APP_CONFIG = cfg
|
|
logger.info("Successfully loaded settings from settings.yml")
|
|
|
|
except (FileNotFoundError, yaml.YAMLError) as e:
|
|
logger.warning(f"Could not load settings.yml: {e}. Falling back to settings.default.yml...")
|
|
try:
|
|
# --- Fallback Method: Attempt to load settings.default.yml ---
|
|
with open(PATHS.DEFAULT_SETTINGS_FILE, 'r', encoding='utf8') as f:
|
|
cfg_raw = yaml.safe_load(f) or {}
|
|
|
|
# The same processing logic is applied to the fallback file
|
|
defaults = {
|
|
"app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": []},
|
|
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}},
|
|
"conversion_tools": {},
|
|
"ocr_settings": {"ocrmypdf": {}},
|
|
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []}
|
|
}
|
|
cfg = defaults.copy()
|
|
cfg.update(cfg_raw) # Merge loaded settings into defaults
|
|
app_settings = cfg.get("app_settings", {})
|
|
max_mb = app_settings.get("max_file_size_mb", 100)
|
|
app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024
|
|
allowed = app_settings.get("allowed_all_extensions", [])
|
|
if not isinstance(allowed, (list, set)):
|
|
allowed = list(allowed)
|
|
app_settings["allowed_all_extensions"] = set(allowed)
|
|
cfg["app_settings"] = app_settings
|
|
APP_CONFIG = cfg
|
|
logger.info("Successfully loaded settings from settings.default.yml")
|
|
|
|
except (FileNotFoundError, yaml.YAMLError) as e_fallback:
|
|
# --- Final Failsafe: Use hardcoded defaults ---
|
|
logger.error(f"CRITICAL: Fallback file settings.default.yml also failed: {e_fallback}. Using hardcoded defaults.")
|
|
APP_CONFIG = {
|
|
"app_settings": {"max_file_size_mb": 100, "max_file_size_bytes": 100 * 1024 * 1024, "allowed_all_extensions": set()},
|
|
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8", "device": "cpu"}},
|
|
"conversion_tools": {},
|
|
"ocr_settings": {"ocrmypdf": {}},
|
|
"auth_settings": {"oidc_client_id": "", "oidc_client_secret": "", "oidc_server_metadata_url": "", "admin_users": []}
|
|
}
|
|
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# --- 2. DATABASE & Schemas
|
|
# --------------------------------------------------------------------------------
|
|
engine = create_engine(
|
|
PATHS.DATABASE_URL,
|
|
connect_args={"check_same_thread": False, "timeout": 30},
|
|
poolclass=NullPool,
|
|
)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
@event.listens_for(engine, "connect")
|
|
def _set_sqlite_pragmas(dbapi_connection, connection_record):
|
|
c = dbapi_connection.cursor()
|
|
try:
|
|
c.execute("PRAGMA journal_mode=WAL;")
|
|
c.execute("PRAGMA synchronous=NORMAL;")
|
|
finally:
|
|
c.close()
|
|
|
|
class Job(Base):
|
|
__tablename__ = "jobs"
|
|
id = Column(String, primary_key=True, index=True)
|
|
user_id = Column(String, index=True, nullable=True)
|
|
task_type = Column(String, index=True)
|
|
status = Column(String, default="pending")
|
|
progress = Column(Integer, default=0)
|
|
original_filename = Column(String)
|
|
input_filepath = Column(String)
|
|
input_filesize = Column(Integer, nullable=True)
|
|
processed_filepath = Column(String, nullable=True)
|
|
output_filesize = Column(Integer, nullable=True)
|
|
result_preview = Column(Text, nullable=True)
|
|
error_message = Column(Text, nullable=True)
|
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
|
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
class JobCreate(BaseModel):
|
|
id: str
|
|
user_id: str | None = None
|
|
task_type: str
|
|
original_filename: str
|
|
input_filepath: str
|
|
input_filesize: int | None = None
|
|
processed_filepath: str | None = None
|
|
|
|
class JobSchema(BaseModel):
|
|
id: str
|
|
task_type: str
|
|
status: str
|
|
progress: int
|
|
original_filename: str
|
|
input_filesize: int | None = None
|
|
output_filesize: int | None = None
|
|
processed_filepath: str | None = None
|
|
result_preview: str | None = None
|
|
error_message: str | None = None
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
model_config = ConfigDict(from_attributes=True)
|
|
@field_serializer('created_at', 'updated_at')
|
|
def serialize_dt(self, dt: datetime, _info):
|
|
return dt.isoformat() + "Z"
|
|
|
|
class FinalizeUploadPayload(BaseModel):
|
|
upload_id: str
|
|
original_filename: str
|
|
total_chunks: int
|
|
task_type: str
|
|
model_size: str = ""
|
|
output_format: str = ""
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# --- 3. CRUD OPERATIONS
|
|
# --------------------------------------------------------------------------------
|
|
def get_job(db: Session, job_id: str):
|
|
return db.query(Job).filter(Job.id == job_id).first()
|
|
|
|
def get_jobs(db: Session, user_id: str | None = None, skip: int = 0, limit: int = 100):
|
|
query = db.query(Job)
|
|
if user_id:
|
|
query = query.filter(Job.user_id == user_id)
|
|
return query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
|
|
|
|
def create_job(db: Session, job: JobCreate):
|
|
db_job = Job(**job.model_dump())
|
|
db.add(db_job)
|
|
db.commit()
|
|
db.refresh(db_job)
|
|
return db_job
|
|
|
|
def update_job_status(db: Session, job_id: str, status: str, progress: int = None, error: str = None):
|
|
db_job = get_job(db, job_id)
|
|
if db_job:
|
|
db_job.status = status
|
|
if progress is not None:
|
|
db_job.progress = progress
|
|
if error:
|
|
db_job.error_message = error
|
|
db.commit()
|
|
db.refresh(db_job)
|
|
return db_job
|
|
|
|
def mark_job_as_completed(db: Session, job_id: str, output_filepath_str: str | None = None, preview: str | None = None):
|
|
db_job = get_job(db, job_id)
|
|
if db_job and db_job.status != 'cancelled':
|
|
db_job.status = "completed"
|
|
db_job.progress = 100
|
|
if preview:
|
|
db_job.result_preview = preview.strip()[:2000]
|
|
if output_filepath_str:
|
|
try:
|
|
output_path = Path(output_filepath_str)
|
|
if output_path.exists():
|
|
db_job.output_filesize = output_path.stat().st_size
|
|
except Exception:
|
|
logger.exception(f"Could not stat output file {output_filepath_str} for job {job_id}")
|
|
db.commit()
|
|
return db_job
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# --- 4. BACKGROUND TASK SETUP
|
|
# --------------------------------------------------------------------------------
|
|
huey = SqliteHuey(filename=PATHS.HUEY_DB_PATH)
|
|
WHISPER_MODELS_CACHE: Dict[str, WhisperModel] = {}
|
|
|
|
def get_whisper_model(model_size: str, whisper_settings: dict) -> WhisperModel:
|
|
if model_size in WHISPER_MODELS_CACHE:
|
|
logger.info(f"Reusing cached model '{model_size}'.")
|
|
return WHISPER_MODELS_CACHE[model_size]
|
|
with _model_semaphore:
|
|
if model_size in WHISPER_MODELS_CACHE:
|
|
return WHISPER_MODELS_CACHE[model_size]
|
|
logger.info(f"Loading Whisper model '{model_size}'...")
|
|
model = WhisperModel(model_size, device=whisper_settings.get("device", "cpu"), compute_type=whisper_settings.get('compute_type', 'int8'))
|
|
WHISPER_MODELS_CACHE[model_size] = model
|
|
logger.info(f"Model '{model_size}' loaded.")
|
|
return model
|
|
|
|
def run_command(argv: TypingList[str], timeout: int = 300):
|
|
try:
|
|
res = subprocess.run(argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=timeout, preexec_fn=_limit_resources_preexec)
|
|
if res.returncode != 0:
|
|
raise Exception(f"Command failed with exit code {res.returncode}. Stderr: {res.stderr[:1000]}")
|
|
return res
|
|
except subprocess.TimeoutExpired:
|
|
raise Exception(f"Command timed out after {timeout}s")
|
|
|
|
def validate_and_build_command(template_str: str, mapping: Dict[str, str]) -> TypingList[str]:
|
|
fmt = Formatter()
|
|
used = {fname for _, fname, _, _ in fmt.parse(template_str) if fname}
|
|
ALLOWED_VARS = {"input", "output", "output_dir", "output_ext", "quality", "speed", "preset", "device", "dpi", "samplerate", "bitdepth", "filter"}
|
|
bad = used - ALLOWED_VARS
|
|
if bad:
|
|
raise ValueError(f"Command template contains disallowed placeholders: {bad}")
|
|
safe_mapping = dict(mapping)
|
|
for name in used:
|
|
if name not in safe_mapping:
|
|
safe_mapping[name] = safe_mapping.get("output_ext", "") if name == "filter" else ""
|
|
formatted = template_str.format(**safe_mapping)
|
|
return shlex.split(formatted)
|
|
|
|
@huey.task()
|
|
def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str, whisper_settings: dict):
|
|
db = SessionLocal()
|
|
input_path = Path(input_path_str)
|
|
try:
|
|
job = get_job(db, job_id)
|
|
if not job or job.status == 'cancelled': return
|
|
update_job_status(db, job_id, "processing")
|
|
model = get_whisper_model(model_size, whisper_settings)
|
|
logger.info(f"Starting transcription for job {job_id}")
|
|
segments, info = model.transcribe(str(input_path), beam_size=5)
|
|
full_transcript = []
|
|
for segment in segments:
|
|
job_check = get_job(db, job_id)
|
|
if job_check.status == 'cancelled':
|
|
logger.info(f"Job {job_id} cancelled during transcription.")
|
|
return
|
|
if info.duration > 0:
|
|
progress = int((segment.end / info.duration) * 100)
|
|
update_job_status(db, job_id, "processing", progress=progress)
|
|
full_transcript.append(segment.text.strip())
|
|
transcript_text = "\n".join(full_transcript)
|
|
out_path = Path(output_path_str)
|
|
tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix}")
|
|
with tmp_out.open("w", encoding="utf-8") as f:
|
|
f.write(transcript_text)
|
|
tmp_out.replace(out_path)
|
|
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=transcript_text)
|
|
logger.info(f"Transcription for job {job_id} completed.")
|
|
except Exception as e:
|
|
logger.exception(f"ERROR during transcription for job {job_id}")
|
|
update_job_status(db, job_id, "failed", error=f"Transcription failed: {e}")
|
|
finally:
|
|
try:
|
|
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
|
|
input_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
# swallow cleanup errors but log
|
|
logger.exception("Failed to cleanup input file after transcription.")
|
|
db.close()
|
|
|
|
@huey.task()
|
|
def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str, ocr_settings: dict):
|
|
db = SessionLocal()
|
|
input_path = Path(input_path_str)
|
|
try:
|
|
job = get_job(db, job_id)
|
|
if not job or job.status == 'cancelled':
|
|
return
|
|
update_job_status(db, job_id, "processing")
|
|
logger.info(f"Starting PDF OCR for job {job_id}")
|
|
ocrmypdf.ocr(str(input_path), str(output_path_str),
|
|
deskew=ocr_settings.get('deskew', True),
|
|
force_ocr=ocr_settings.get('force_ocr', True),
|
|
clean=ocr_settings.get('clean', True),
|
|
optimize=ocr_settings.get('optimize', 1),
|
|
progress_bar=False)
|
|
with open(output_path_str, "rb") as f:
|
|
reader = pypdf.PdfReader(f)
|
|
preview = "\n".join(page.extract_text() or "" for page in reader.pages)
|
|
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=preview)
|
|
logger.info(f"PDF OCR for job {job_id} completed.")
|
|
except Exception as e:
|
|
logger.exception(f"ERROR during PDF OCR for job {job_id}")
|
|
update_job_status(db, job_id, "failed", error=f"PDF OCR failed: {e}")
|
|
finally:
|
|
try:
|
|
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR])
|
|
input_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
logger.exception("Failed to cleanup input file after PDF OCR.")
|
|
db.close()
|
|
|
|
@huey.task()
|
|
def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
|
|
db = SessionLocal()
|
|
input_path = Path(input_path_str)
|
|
try:
|
|
job = get_job(db, job_id)
|
|
if not job or job.status == 'cancelled':
|
|
return
|
|
update_job_status(db, job_id, "processing", progress=50)
|
|
logger.info(f"Starting Image OCR for job {job_id}")
|
|
text = pytesseract.image_to_string(Image.open(str(input_path)))
|
|
out_path = Path(output_path_str)
|
|
tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix}")
|
|
with tmp_out.open("w", encoding="utf-8") as f:
|
|
f.write(text)
|
|
tmp_out.replace(out_path)
|
|
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=text)
|
|
logger.info(f"Image OCR for job {job_id} completed.")
|
|
except Exception as e:
|
|
logger.exception(f"ERROR during Image OCR for job {job_id}")
|
|
update_job_status(db, job_id, "failed", error=f"Image OCR failed: {e}")
|
|
finally:
|
|
try:
|
|
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR])
|
|
input_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
logger.exception("Failed to cleanup input file after Image OCR.")
|
|
db.close()
|
|
|
|
|
|
@huey.task()
|
|
def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str, tool: str, task_key: str, conversion_tools_config: dict):
|
|
db = SessionLocal()
|
|
input_path = Path(input_path_str)
|
|
output_path = Path(output_path_str)
|
|
temp_input_file = None
|
|
temp_output_file = None
|
|
try:
|
|
job = get_job(db, job_id)
|
|
if not job or job.status == 'cancelled':
|
|
return
|
|
update_job_status(db, job_id, "processing", progress=25)
|
|
logger.info(f"Starting conversion for job {job_id} using {tool} with task {task_key}")
|
|
tool_config = conversion_tools_config.get(tool)
|
|
if not tool_config:
|
|
raise ValueError(f"Unknown conversion tool: {tool}")
|
|
|
|
current_input_path = input_path
|
|
|
|
# Pre-processing for specific tools
|
|
if tool == "mozjpeg":
|
|
temp_input_file = input_path.with_suffix('.temp.ppm')
|
|
logger.info(f"Pre-converting for MozJPEG: {input_path} -> {temp_input_file}")
|
|
pre_conv_cmd = ["vips", "copy", str(input_path), str(temp_input_file)]
|
|
pre_conv_result = subprocess.run(pre_conv_cmd, capture_output=True, text=True, check=False, timeout=tool_config.get("timeout", 300))
|
|
if pre_conv_result.returncode != 0:
|
|
err = (pre_conv_result.stderr or "")[:4000]
|
|
raise Exception(f"MozJPEG pre-conversion to PPM failed: {err}")
|
|
current_input_path = temp_input_file
|
|
|
|
update_job_status(db, job_id, "processing", progress=50)
|
|
|
|
# prepare temporary output and mapping
|
|
temp_output_file = output_path.with_name(f"{output_path.stem}.tmp-{uuid.uuid4().hex}{output_path.suffix}")
|
|
mapping = {
|
|
"input": str(current_input_path),
|
|
"output": str(temp_output_file),
|
|
"output_dir": str(output_path.parent),
|
|
"output_ext": output_path.suffix.lstrip('.'),
|
|
}
|
|
|
|
# tool specific mapping adjustments
|
|
if tool.startswith("ghostscript"):
|
|
# task_key form: "device_setting"
|
|
parts = task_key.split('_', 1)
|
|
device = parts[0] if parts else ""
|
|
setting = parts[1] if len(parts) > 1 else ""
|
|
mapping.update({"device": device, "dpi": setting, "preset": setting})
|
|
elif tool == "pngquant":
|
|
_, quality_key = task_key.split('_')
|
|
quality_map = {"hq": "80-95", "mq": "65-80", "fast": "65-80"}
|
|
speed_map = {"hq": "1", "mq": "3", "fast": "11"}
|
|
mapping.update({"quality": quality_map.get(quality_key, "65-80"), "speed": speed_map.get(quality_key, "3")})
|
|
elif tool == "sox":
|
|
rate, depth = '', ''
|
|
try:
|
|
_, rate, depth = task_key.split('_')
|
|
depth = ('-b' + depth.replace('b', '')) if 'b' in depth else '16b'
|
|
except:
|
|
_, rate = task_key.split('_')
|
|
depth = ''
|
|
|
|
rate = rate.replace('k', '000') if 'k' in rate else rate
|
|
mapping.update({"samplerate": rate, "bitdepth": depth})
|
|
elif tool == "mozjpeg":
|
|
_, quality = task_key.split('_')
|
|
quality = quality.replace('q', '')
|
|
mapping.update({"quality": quality})
|
|
elif tool == "libreoffice":
|
|
target_ext = output_path.suffix.lstrip('.')
|
|
filter_val = tool_config.get("filters", {}).get(target_ext, target_ext)
|
|
mapping["filter"] = filter_val
|
|
|
|
command_template_str = tool_config["command_template"]
|
|
command = validate_and_build_command(command_template_str, mapping)
|
|
logger.info(f"Executing command: {' '.join(command)}")
|
|
|
|
result = run_command(command, timeout=tool_config.get("timeout", 300))
|
|
|
|
if temp_output_file and temp_output_file.exists():
|
|
temp_output_file.replace(output_path)
|
|
|
|
mark_job_as_completed(db, job_id, output_filepath_str=str(output_path), preview=f"Successfully converted file.")
|
|
logger.info(f"Conversion for job {job_id} completed.")
|
|
except Exception as e:
|
|
logger.exception(f"ERROR during conversion for job {job_id}")
|
|
update_job_status(db, job_id, "failed", error=f"Conversion failed: {e}")
|
|
finally:
|
|
try:
|
|
ensure_path_is_safe(input_path, [PATHS.UPLOADS_DIR])
|
|
input_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
logger.exception("Failed to cleanup main input file after conversion.")
|
|
if temp_input_file:
|
|
try:
|
|
temp_input_file_path = Path(temp_input_file)
|
|
ensure_path_is_safe(temp_input_file_path, [PATHS.UPLOADS_DIR, PATHS.PROCESSED_DIR])
|
|
temp_input_file_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
logger.exception("Failed to cleanup temp input file after conversion.")
|
|
if temp_output_file:
|
|
try:
|
|
temp_output_file_path = Path(temp_output_file)
|
|
ensure_path_is_safe(temp_output_file_path, [PATHS.UPLOADS_DIR, PATHS.PROCESSED_DIR])
|
|
temp_output_file_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
logger.exception("Failed to cleanup temp output file after conversion.")
|
|
db.close()
|
|
|
|
# --------------------------------------------------------------------------------
|
|
# --- 5. FASTAPI APPLICATION
|
|
# --------------------------------------------------------------------------------
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Application starting up...")
|
|
Base.metadata.create_all(bind=engine)
|
|
load_app_config()
|
|
ENV = os.environ.get('ENV', 'dev').lower() # probably reduntant because I load the .env at the start but whatever
|
|
ALLOW_LOCAL_ONLY = os.environ.get('ALLOW_LOCAL_ONLY', 'false').lower() == 'true'
|
|
if LOCAL_ONLY_MODE and ENV != 'dev' and not ALLOW_LOCAL_ONLY:
|
|
raise RuntimeError('LOCAL_ONLY_MODE may only be enabled in dev or when ALLOW_LOCAL_ONLY=true is set.')
|
|
if not LOCAL_ONLY_MODE:
|
|
oidc_cfg = APP_CONFIG.get('auth_settings', {})
|
|
if not all(oidc_cfg.get(k) for k in ['oidc_client_id', 'oidc_client_secret', 'oidc_server_metadata_url']):
|
|
logger.error("OIDC auth settings are incomplete. Auth will be disabled if not in LOCAL_ONLY_MODE.")
|
|
else:
|
|
oauth.register(
|
|
name='oidc',
|
|
client_id=oidc_cfg.get('oidc_client_id'),
|
|
client_secret=oidc_cfg.get('oidc_client_secret'),
|
|
server_metadata_url=oidc_cfg.get('oidc_server_metadata_url'),
|
|
client_kwargs={'scope': 'openid email profile'},
|
|
userinfo_endpoint=oidc_cfg.get('oidc_userinfo_endpoint'),
|
|
end_session_endpoint=oidc_cfg.get('oidc_end_session_endpoint')
|
|
)
|
|
logger.info('OAuth registered.')
|
|
yield
|
|
logger.info('Application shutting down...')
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
ENV = os.environ.get('ENV', 'dev').lower()
|
|
SECRET_KEY = os.environ.get('SECRET_KEY')
|
|
if not SECRET_KEY and not LOCAL_ONLY_MODE and ENV != 'dev':
|
|
raise RuntimeError('SECRET_KEY must be set in production when authentication is enabled.')
|
|
if not SECRET_KEY:
|
|
logger.warning('SECRET_KEY is not set. Generating a temporary key. Sessions will not persist across restarts.')
|
|
SECRET_KEY = os.urandom(24).hex()
|
|
|
|
# Should probably set https_only=True in production behind HTTPS i guess
|
|
app.add_middleware(
|
|
SessionMiddleware,
|
|
secret_key=SECRET_KEY,
|
|
https_only=False,
|
|
same_site='lax',
|
|
max_age=14 * 24 * 60 * 60 # 14 days in seconds
|
|
)
|
|
|
|
|
|
# Static / templates
|
|
app.mount("/static", StaticFiles(directory=str(PATHS.BASE_DIR / "static")), name="static")
|
|
templates = Jinja2Templates(directory=str(PATHS.BASE_DIR / "templates"))
|
|
|
|
# --- AUTH & USER HELPERS ---
|
|
def get_current_user(request: Request):
|
|
if LOCAL_ONLY_MODE:
|
|
return {'sub': 'local_user', 'email': 'local@user.com', 'name': 'Local User'}
|
|
return request.session.get('user')
|
|
|
|
def is_admin(request: Request) -> bool:
|
|
if LOCAL_ONLY_MODE: return True
|
|
user = get_current_user(request)
|
|
if not user: return False
|
|
admin_users = APP_CONFIG.get("auth_settings", {}).get("admin_users", [])
|
|
return user.get('email') in admin_users
|
|
|
|
def require_user(request: Request):
|
|
user = get_current_user(request)
|
|
if not user: raise HTTPException(status_code=401, detail="Not authenticated")
|
|
return user
|
|
|
|
def require_admin(request: Request):
|
|
if not is_admin(request): raise HTTPException(status_code=403, detail="Administrator privileges required.")
|
|
return True
|
|
|
|
# --- CHUNKED UPLOADs ---
|
|
@app.post("/upload/chunk")
|
|
async def upload_chunk(
|
|
chunk: UploadFile = File(...),
|
|
upload_id: str = Form(...),
|
|
chunk_number: int = Form(...),
|
|
user: dict = Depends(require_user) # AUTHENTICATION
|
|
):
|
|
safe_upload_id = secure_filename(upload_id)
|
|
temp_dir = PATHS.CHUNK_TMP_DIR / safe_upload_id
|
|
temp_dir = ensure_path_is_safe(temp_dir, [PATHS.CHUNK_TMP_DIR])
|
|
temp_dir.mkdir(exist_ok=True)
|
|
chunk_path = temp_dir / f"{chunk_number}.chunk"
|
|
|
|
try:
|
|
with open(chunk_path, "wb") as buffer:
|
|
shutil.copyfileobj(chunk.file, buffer)
|
|
finally:
|
|
chunk.file.close()
|
|
return JSONResponse({"message": f"Chunk {chunk_number} for {safe_upload_id} uploaded."})
|
|
|
|
async def _stitch_chunks(temp_dir: Path, final_path: Path, total_chunks: int):
|
|
"""Stitches chunks together and cleans up."""
|
|
ensure_path_is_safe(temp_dir, [PATHS.CHUNK_TMP_DIR])
|
|
ensure_path_is_safe(final_path, [PATHS.UPLOADS_DIR])
|
|
with open(final_path, "wb") as final_file:
|
|
for i in range(total_chunks):
|
|
chunk_path = temp_dir / f"{i}.chunk"
|
|
if not chunk_path.exists():
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
raise HTTPException(status_code=400, detail=f"Upload failed: missing chunk {i}")
|
|
with open(chunk_path, "rb") as chunk_file:
|
|
final_file.write(chunk_file.read())
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
@app.post("/upload/finalize", status_code=status.HTTP_202_ACCEPTED)
|
|
async def finalize_upload(payload: FinalizeUploadPayload, user: dict = Depends(require_user), db: Session = Depends(get_db)):
|
|
safe_upload_id = secure_filename(payload.upload_id)
|
|
temp_dir = PATHS.CHUNK_TMP_DIR / safe_upload_id
|
|
temp_dir = ensure_path_is_safe(temp_dir, [PATHS.CHUNK_TMP_DIR])
|
|
if not temp_dir.is_dir():
|
|
raise HTTPException(status_code=404, detail="Upload session not found or already finalized.")
|
|
|
|
job_id = uuid.uuid4().hex
|
|
safe_filename = secure_filename(payload.original_filename)
|
|
final_path = PATHS.UPLOADS_DIR / f"{Path(safe_filename).stem}_{job_id}{Path(safe_filename).suffix}"
|
|
await _stitch_chunks(temp_dir, final_path, payload.total_chunks)
|
|
|
|
job_data = JobCreate(
|
|
id=job_id, user_id=user['sub'], task_type=payload.task_type,
|
|
original_filename=payload.original_filename, input_filepath=str(final_path),
|
|
input_filesize=final_path.stat().st_size
|
|
)
|
|
|
|
if payload.task_type == "transcription":
|
|
stem = Path(safe_filename).stem
|
|
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}.txt"
|
|
job_data.processed_filepath = str(processed_path)
|
|
create_job(db=db, job=job_data)
|
|
run_transcription_task(job_id, str(final_path), str(processed_path), payload.model_size, APP_CONFIG.get("transcription_settings", {}).get("whisper", {}))
|
|
elif payload.task_type == "ocr":
|
|
stem, suffix = Path(safe_filename).stem, Path(safe_filename).suffix
|
|
processed_path = PATHS.PROCESSED_DIR / f"{stem}_{job_id}{suffix}"
|
|
job_data.processed_filepath = str(processed_path)
|
|
create_job(db=db, job=job_data)
|
|
run_pdf_ocr_task(job_id, str(final_path), str(processed_path), APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {}))
|
|
elif payload.task_type == "conversion":
|
|
try:
|
|
tool, task_key = payload.output_format.split('_', 1)
|
|
except Exception:
|
|
final_path.unlink(missing_ok=True)
|
|
raise HTTPException(status_code=400, detail="Invalid output_format for conversion.")
|
|
original_stem = Path(safe_filename).stem
|
|
target_ext = task_key.split('_')[0]
|
|
if tool == "ghostscript_pdf": target_ext = "pdf"
|
|
processed_path = PATHS.PROCESSED_DIR / f"{original_stem}_{job_id}.{target_ext}"
|
|
job_data.processed_filepath = str(processed_path)
|
|
create_job(db=db, job=job_data)
|
|
run_conversion_task(job_id, str(final_path), str(processed_path), tool, task_key, APP_CONFIG.get("conversion_tools", {}))
|
|
else:
|
|
final_path.unlink(missing_ok=True)
|
|
raise HTTPException(status_code=400, detail="Invalid task type.")
|
|
|
|
return {"job_id": job_id, "status": "pending"}
|
|
|
|
# --- OLD DIRECT-UPLOAD ROUTES (kept for compatibility) ---
|
|
# These use the same task functions but accept direct file uploads (no chunking).
|
|
async def save_upload_file_chunked(upload_file: UploadFile, destination: Path) -> int:
|
|
"""
|
|
Write upload to a tmp file in chunks, then atomically move to final destination.
|
|
Returns the final size of the file in bytes.
|
|
"""
|
|
max_size = APP_CONFIG.get("app_settings", {}).get("max_file_size_bytes", 100 * 1024 * 1024)
|
|
tmp = destination.with_name(f"{destination.stem}.tmp-{uuid.uuid4().hex}{destination.suffix}")
|
|
size = 0
|
|
try:
|
|
with tmp.open("wb") as buffer:
|
|
while True:
|
|
chunk = await upload_file.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
size += len(chunk)
|
|
if size > max_size:
|
|
raise HTTPException(status_code=413, detail=f"File exceeds {max_size / 1024 / 1024} MB limit")
|
|
buffer.write(chunk)
|
|
tmp.replace(destination)
|
|
return size
|
|
except Exception:
|
|
try:
|
|
ensure_path_is_safe(tmp, [PATHS.UPLOADS_DIR, PATHS.CHUNK_TMP_DIR])
|
|
tmp.unlink(missing_ok=True)
|
|
except Exception:
|
|
logger.exception("Failed to remove temp upload file after error.")
|
|
raise
|
|
|
|
def is_allowed_file(filename: str, allowed_extensions: set) -> bool:
|
|
return Path(filename).suffix.lower() in allowed_extensions
|
|
|
|
@app.post("/transcribe-audio", status_code=status.HTTP_202_ACCEPTED)
|
|
async def submit_audio_transcription(
|
|
file: UploadFile = File(...),
|
|
model_size: str = Form("base"),
|
|
db: Session = Depends(get_db),
|
|
user: dict = Depends(require_user)
|
|
):
|
|
allowed_audio_exts = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus"}
|
|
if not is_allowed_file(file.filename, allowed_audio_exts):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio file type.")
|
|
|
|
whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {})
|
|
if model_size not in whisper_config.get("allowed_models", []):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model size: {model_size}.")
|
|
|
|
job_id = uuid.uuid4().hex
|
|
safe_basename = secure_filename(file.filename)
|
|
stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix
|
|
|
|
audio_filename = f"{stem}_{job_id}{suffix}"
|
|
transcript_filename = f"{stem}_{job_id}.txt"
|
|
upload_path = PATHS.UPLOADS_DIR / audio_filename
|
|
processed_path = PATHS.PROCESSED_DIR / transcript_filename
|
|
|
|
input_size = await save_upload_file_chunked(file, upload_path)
|
|
|
|
job_data = JobCreate(
|
|
id=job_id,
|
|
user_id=user['sub'],
|
|
task_type="transcription",
|
|
original_filename=file.filename,
|
|
input_filepath=str(upload_path),
|
|
input_filesize=input_size,
|
|
processed_filepath=str(processed_path)
|
|
)
|
|
new_job = create_job(db=db, job=job_data)
|
|
|
|
run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size, whisper_settings=whisper_config)
|
|
|
|
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
|
|
|
@app.post("/convert-file", status_code=status.HTTP_202_ACCEPTED)
|
|
async def submit_file_conversion(file: UploadFile = File(...), output_format: str = Form(...), db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
allowed_exts = APP_CONFIG.get("app_settings", {}).get("allowed_all_extensions", set())
|
|
if not is_allowed_file(file.filename, allowed_exts):
|
|
raise HTTPException(status_code=400, detail=f"File type '{Path(file.filename).suffix}' not allowed.")
|
|
conversion_tools = APP_CONFIG.get("conversion_tools", {})
|
|
try:
|
|
tool, task_key = output_format.split('_', 1)
|
|
if tool not in conversion_tools or task_key not in conversion_tools[tool].get("formats", {}):
|
|
# fallback: allow tasks that exist but may not be in formats map (some configs only have commands)
|
|
if tool not in conversion_tools:
|
|
raise ValueError()
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid output format selected.")
|
|
job_id = uuid.uuid4().hex
|
|
safe_basename = secure_filename(file.filename)
|
|
original_stem = Path(safe_basename).stem
|
|
target_ext = task_key.split('_')[0]
|
|
if tool == "ghostscript_pdf":
|
|
target_ext = "pdf"
|
|
upload_filename = f"{original_stem}_{job_id}{Path(safe_basename).suffix}"
|
|
processed_filename = f"{original_stem}_{job_id}.{target_ext}"
|
|
upload_path = PATHS.UPLOADS_DIR / upload_filename
|
|
processed_path = PATHS.PROCESSED_DIR / processed_filename
|
|
input_size = await save_upload_file_chunked(file, upload_path)
|
|
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="conversion", original_filename=file.filename,
|
|
input_filepath=str(upload_path),
|
|
input_filesize=input_size,
|
|
processed_filepath=str(processed_path))
|
|
new_job = create_job(db=db, job=job_data)
|
|
run_conversion_task(new_job.id, str(upload_path), str(processed_path), tool, task_key, conversion_tools)
|
|
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
|
|
|
@app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED)
|
|
async def submit_pdf_ocr(file: UploadFile = File(...), db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
if not is_allowed_file(file.filename, {".pdf"}):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.")
|
|
job_id = uuid.uuid4().hex
|
|
safe_basename = secure_filename(file.filename)
|
|
unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}"
|
|
upload_path = PATHS.UPLOADS_DIR / unique_filename
|
|
processed_path = PATHS.PROCESSED_DIR / unique_filename
|
|
input_size = await save_upload_file_chunked(file, upload_path)
|
|
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="ocr", original_filename=file.filename,
|
|
input_filepath=str(upload_path),
|
|
input_filesize=input_size,
|
|
processed_filepath=str(processed_path))
|
|
new_job = create_job(db=db, job=job_data)
|
|
ocr_settings = APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {})
|
|
run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path), ocr_settings)
|
|
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
|
|
|
@app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED)
|
|
async def submit_image_ocr(file: UploadFile = File(...), db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
allowed_exts = {".png", ".jpg", ".jpeg", ".tiff", ".tif"}
|
|
if not is_allowed_file(file.filename, allowed_exts):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PNG, JPG, or TIFF.")
|
|
job_id = uuid.uuid4().hex
|
|
safe_basename = secure_filename(file.filename)
|
|
file_ext = Path(safe_basename).suffix
|
|
unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}"
|
|
upload_path = PATHS.UPLOADS_DIR / unique_filename
|
|
processed_path = PATHS.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.txt"
|
|
input_size = await save_upload_file_chunked(file, upload_path)
|
|
job_data = JobCreate(id=job_id, user_id=user['sub'], task_type="ocr-image", original_filename=file.filename,
|
|
input_filepath=str(upload_path),
|
|
input_filesize=input_size,
|
|
processed_filepath=str(processed_path))
|
|
new_job = create_job(db=db, job=job_data)
|
|
run_image_ocr_task(new_job.id, str(upload_path), str(processed_path))
|
|
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
|
|
|
# --- Routes for auth and pages ---
|
|
if not LOCAL_ONLY_MODE:
|
|
@app.get('/login')
|
|
async def login(request: Request):
|
|
redirect_uri = request.url_for('auth')
|
|
return await oauth.oidc.authorize_redirect(request, redirect_uri)
|
|
|
|
@app.get('/auth')
|
|
async def auth(request: Request):
|
|
try:
|
|
token = await oauth.oidc.authorize_access_token(request)
|
|
user = await oauth.oidc.userinfo(token=token)
|
|
request.session['user'] = dict(user)
|
|
# Store id_token in session for logout
|
|
request.session['id_token'] = token.get('id_token')
|
|
except Exception as e:
|
|
logger.error(f"Authentication failed: {e}")
|
|
raise HTTPException(status_code=401, detail="Authentication failed")
|
|
return RedirectResponse(url='/')
|
|
|
|
@app.get("/logout")
|
|
async def logout(request: Request):
|
|
logout_endpoint = oauth.oidc.server_metadata.get("end_session_endpoint")
|
|
logger.info(f"OIDC end_session_endpoint: {logout_endpoint}")
|
|
|
|
# local-only logout if provider doesn't expose end_session_endpoint
|
|
if not logout_endpoint:
|
|
request.session.clear()
|
|
logger.warning("OIDC 'end_session_endpoint' not found. Performing local-only logout.")
|
|
return RedirectResponse(url="/", status_code=302)
|
|
|
|
|
|
# Prefer a single canonical / registered post-logout redirect URI from config
|
|
post_logout_redirect_uri = str(request.url_for("get_index"))
|
|
logger.info(f"Post logout redirect URI: {post_logout_redirect_uri}")
|
|
|
|
logout_url = f"{logout_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}"
|
|
|
|
logger.info(f"Redirecting to provider logout URL: {logout_url}")
|
|
|
|
request.session.clear()
|
|
return RedirectResponse(url=logout_url, status_code=302)
|
|
|
|
|
|
#### TODO: Remove this weird forward authz endpoint, its needed if reverse proxy does foward auth
|
|
|
|
@app.get("/api/authz/forward-auth")
|
|
async def forward_auth(request: Request):
|
|
redirect_uri = request.url_for('auth')
|
|
return await oauth.oidc.authorize_redirect(request, redirect_uri)
|
|
|
|
@app.get("/")
|
|
async def get_index(request: Request):
|
|
user = get_current_user(request)
|
|
admin_status = is_admin(request)
|
|
whisper_models = APP_CONFIG.get("transcription_settings", {}).get("whisper", {}).get("allowed_models", [])
|
|
conversion_tools = APP_CONFIG.get("conversion_tools", {})
|
|
return templates.TemplateResponse("index.html", {
|
|
"request": request,
|
|
"user": user,
|
|
"is_admin": admin_status,
|
|
"whisper_models": sorted(list(whisper_models)),
|
|
"conversion_tools": conversion_tools,
|
|
"local_only_mode": LOCAL_ONLY_MODE
|
|
})
|
|
|
|
@app.get("/settings")
|
|
async def get_settings_page(request: Request):
|
|
"""
|
|
Displays the contents of the currently active configuration file.
|
|
It prioritizes settings.yml and falls back to settings.default.yml.
|
|
"""
|
|
user = get_current_user(request)
|
|
admin_status = is_admin(request)
|
|
current_config = {}
|
|
config_source = "none" # A helper variable to track which file was loaded
|
|
|
|
try:
|
|
# 1. Attempt to load the primary, user-provided settings.yml
|
|
with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f:
|
|
current_config = yaml.safe_load(f) or {}
|
|
config_source = str(PATHS.SETTINGS_FILE.name)
|
|
logger.info(f"Displaying configuration from '{config_source}' on settings page.")
|
|
|
|
except FileNotFoundError:
|
|
logger.warning(f"'{PATHS.SETTINGS_FILE.name}' not found. Attempting to display fallback configuration.")
|
|
try:
|
|
# 2. If it's not found, fall back to the default settings file
|
|
with open(PATHS.DEFAULT_SETTINGS_FILE, 'r', encoding='utf8') as f:
|
|
current_config = yaml.safe_load(f) or {}
|
|
config_source = str(PATHS.DEFAULT_SETTINGS_FILE.name)
|
|
logger.info(f"Displaying configuration from fallback '{config_source}' on settings page.")
|
|
|
|
except Exception as e_fallback:
|
|
# 3. If even the default file fails, log the error and use an empty config
|
|
logger.exception(f"CRITICAL: Could not load fallback '{PATHS.DEFAULT_SETTINGS_FILE.name}' for settings page: {e_fallback}")
|
|
current_config = {} # Failsafe
|
|
config_source = "error"
|
|
|
|
except Exception as e_primary:
|
|
# Handles other errors with the primary settings.yml (e.g., parsing errors, permissions)
|
|
logger.exception(f"Could not load '{PATHS.SETTINGS_FILE.name}' for settings page: {e_primary}")
|
|
current_config = {} # Failsafe
|
|
config_source = "error"
|
|
|
|
return templates.TemplateResponse(
|
|
"settings.html",
|
|
{
|
|
"request": request,
|
|
"config": current_config,
|
|
"config_source": config_source, # You can use this in the template!
|
|
"user": user,
|
|
"is_admin": admin_status,
|
|
"local_only_mode": LOCAL_ONLY_MODE,
|
|
}
|
|
)
|
|
|
|
|
|
import collections.abc
|
|
|
|
def deep_merge(source: dict, destination: dict) -> dict:
|
|
"""
|
|
Recursively merges the `source` dictionary into the `destination` dictionary.
|
|
|
|
Values from `source` will overwrite values in `destination`.
|
|
"""
|
|
for key, value in source.items():
|
|
if isinstance(value, collections.abc.Mapping):
|
|
# If the value is a dictionary, recurse
|
|
node = destination.setdefault(key, {})
|
|
deep_merge(value, node)
|
|
else:
|
|
# Otherwise, overwrite the value
|
|
destination[key] = value
|
|
return destination
|
|
|
|
@app.post("/settings/save")
|
|
async def save_settings(
|
|
request: Request,
|
|
new_config_from_ui: Dict = Body(...),
|
|
admin: bool = Depends(require_admin)
|
|
):
|
|
"""
|
|
Safely updates settings.yml by merging UI changes with the existing file,
|
|
preserving any settings not managed by the UI.
|
|
"""
|
|
tmp_path = PATHS.SETTINGS_FILE.with_suffix(".tmp")
|
|
user = get_current_user(request)
|
|
|
|
try:
|
|
# Handle the special case where the user wants to revert to defaults
|
|
if not new_config_from_ui:
|
|
if PATHS.SETTINGS_FILE.exists():
|
|
PATHS.SETTINGS_FILE.unlink()
|
|
logger.info(f"Admin '{user.get('email')}' reverted to default settings by deleting settings.yml.")
|
|
load_app_config()
|
|
return JSONResponse({"message": "Settings reverted to default."})
|
|
|
|
# --- Read-Modify-Write Cycle ---
|
|
|
|
# 1. READ: Load the current configuration from settings.yml on disk.
|
|
# If the file doesn't exist, start with an empty dictionary.
|
|
try:
|
|
with PATHS.SETTINGS_FILE.open("r", encoding="utf8") as f:
|
|
current_config_on_disk = yaml.safe_load(f) or {}
|
|
except FileNotFoundError:
|
|
current_config_on_disk = {}
|
|
|
|
# 2. MODIFY: Deep merge the changes from the UI into the config from the disk.
|
|
# The UI config (`source`) overwrites keys in the disk config (`destination`).
|
|
merged_config = deep_merge(source=new_config_from_ui, destination=current_config_on_disk)
|
|
|
|
# 3. WRITE: Save the fully merged configuration back to the file.
|
|
with tmp_path.open("w", encoding="utf8") as f:
|
|
yaml.safe_dump(merged_config, f, default_flow_style=False, sort_keys=False)
|
|
|
|
tmp_path.replace(PATHS.SETTINGS_FILE)
|
|
logger.info(f"Admin '{user.get('email')}' successfully updated settings.yml.")
|
|
|
|
# Reload the app config to apply changes immediately
|
|
load_app_config()
|
|
|
|
return JSONResponse({"message": "Settings saved successfully. The new configuration is now active."})
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Failed to update settings for admin '{user.get('email')}'")
|
|
if tmp_path.exists():
|
|
tmp_path.unlink()
|
|
raise HTTPException(status_code=500, detail=f"Could not save settings.yml: {e}")
|
|
|
|
# job management endpoints
|
|
|
|
@app.post("/settings/clear-history")
|
|
async def clear_job_history(db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
try:
|
|
num_deleted = db.query(Job).filter(Job.user_id == user['sub']).delete()
|
|
db.commit()
|
|
logger.info(f"Cleared {num_deleted} jobs from history for user {user['sub']}.")
|
|
return {"deleted_count": num_deleted}
|
|
except Exception:
|
|
db.rollback()
|
|
logger.exception("Failed to clear job history")
|
|
raise HTTPException(status_code=500, detail="Database error while clearing history.")
|
|
|
|
@app.post("/settings/delete-files")
|
|
async def delete_processed_files(db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
deleted_count = 0
|
|
errors = []
|
|
user_jobs = get_jobs(db, user_id=user['sub'])
|
|
for job in user_jobs:
|
|
if job.processed_filepath:
|
|
try:
|
|
p = ensure_path_is_safe(Path(job.processed_filepath), [PATHS.PROCESSED_DIR])
|
|
if p.is_file():
|
|
p.unlink()
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
errors.append(Path(job.processed_filepath).name)
|
|
logger.exception(f"Could not delete processed file {Path(job.processed_filepath).name}")
|
|
if errors:
|
|
raise HTTPException(status_code=500, detail=f"Could not delete some files: {', '.join(errors)}")
|
|
logger.info(f"Deleted {deleted_count} files from processed directory for user {user['sub']}.")
|
|
return {"deleted_count": deleted_count}
|
|
|
|
@app.post("/job/{job_id}/cancel", status_code=status.HTTP_202_ACCEPTED)
|
|
async def cancel_job(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
job = get_job(db, job_id)
|
|
if not job or job.user_id != user['sub']:
|
|
raise HTTPException(status_code=404, detail="Job not found.")
|
|
if job.status in ["pending", "processing"]:
|
|
update_job_status(db, job_id, status="cancelled")
|
|
return {"message": "Job cancellation requested."}
|
|
raise HTTPException(status_code=400, detail=f"Job is already in a final state ({job.status}).")
|
|
|
|
@app.get("/jobs", response_model=List[JobSchema])
|
|
async def get_all_jobs(db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
return get_jobs(db, user_id=user['sub'])
|
|
|
|
@app.get("/job/{job_id}", response_model=JobSchema)
|
|
async def get_job_status(job_id: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
job = get_job(db, job_id)
|
|
if not job or job.user_id != user['sub']:
|
|
raise HTTPException(status_code=404, detail="Job not found.")
|
|
return job
|
|
|
|
@app.get("/download/{filename}")
|
|
async def download_file(filename: str, db: Session = Depends(get_db), user: dict = Depends(require_user)):
|
|
safe_filename = secure_filename(filename)
|
|
file_path = ensure_path_is_safe(PATHS.PROCESSED_DIR / safe_filename, [PATHS.PROCESSED_DIR])
|
|
if not file_path.is_file():
|
|
raise HTTPException(status_code=404, detail="File not found.")
|
|
job = db.query(Job).filter(Job.processed_filepath == str(file_path), Job.user_id == user['sub']).first()
|
|
if not job:
|
|
raise HTTPException(status_code=403, detail="You do not have permission to download this file.")
|
|
download_filename = Path(job.original_filename).stem + Path(job.processed_filepath).suffix
|
|
return FileResponse(path=file_path, filename=download_filename, media_type="application/octet-stream")
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
try:
|
|
with engine.connect() as conn:
|
|
conn.execute("SELECT 1")
|
|
except Exception:
|
|
logger.exception("Health check failed")
|
|
return JSONResponse({"ok": False}, status_code=500)
|
|
return {"ok": True}
|
|
|
|
@app.get('/favicon.ico', include_in_schema=False)
|
|
async def favicon():
|
|
return FileResponse(str(PATHS.BASE_DIR / 'static' / 'favicon.png')) |