stability and settings

This commit is contained in:
2025-09-17 12:36:24 +00:00
parent 4d586a46e9
commit 2115238217
9 changed files with 1271 additions and 265 deletions

571
main.py
View File

@@ -1,64 +1,99 @@
import logging
import shutil
import subprocess
import traceback
import uuid
import shlex
import yaml
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import List, Set
from typing import Dict, List, Any
import ocrmypdf
import pypdf
import pytesseract
from PIL import Image
from faster_whisper import WhisperModel
# MODIFICATION: Added Form for model selection
from fastapi import (Depends, FastAPI, File, Form, HTTPException, Request,
UploadFile, status)
from fastapi.responses import FileResponse
UploadFile, status, Body)
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huey import SqliteHuey
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
from sqlalchemy import (Column, DateTime, Integer, String, Text,
create_engine)
create_engine, delete, event)
from sqlalchemy.pool import NullPool
from string import Formatter
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from werkzeug.utils import secure_filename
# --------------------------------------------------------------------------------
# --- 1. CONFIGURATION
# --------------------------------------------------------------------------------
class Settings(BaseSettings):
class AppPaths(BaseModel):
BASE_DIR: Path = Path(__file__).resolve().parent
UPLOADS_DIR: Path = BASE_DIR / "uploads"
PROCESSED_DIR: Path = BASE_DIR / "processed"
DATABASE_URL: str = f"sqlite:///{BASE_DIR / 'jobs.db'}"
HUEY_DB_PATH: str = str(BASE_DIR / "huey.db")
# MODIFICATION: Removed hardcoded model size, added a set of allowed models
WHISPER_COMPUTE_TYPE: str = "int8"
ALLOWED_WHISPER_MODELS: Set[str] = {"tiny", "base", "small", "medium", "large-v3", "distil-large-v2"}
MAX_FILE_SIZE_BYTES: int = 500 * 1024 * 1024 # 500 MB
ALLOWED_PDF_EXTENSIONS: set = {".pdf"}
ALLOWED_IMAGE_EXTENSIONS: set = {".png", ".jpg", ".jpeg", ".tiff", ".tif"}
ALLOWED_AUDIO_EXTENSIONS: set = {".mp3", "m4a", ".ogg", ".flac", ".opus"}
SETTINGS_FILE: Path = BASE_DIR / "settings.yml"
settings = Settings()
PATHS = AppPaths()
APP_CONFIG: Dict[str, Any] = {}
def load_app_config():
global APP_CONFIG
try:
with open(PATHS.SETTINGS_FILE, 'r') as f:
APP_CONFIG = yaml.safe_load(f)
APP_CONFIG['app_settings']['max_file_size_bytes'] = APP_CONFIG['app_settings']['max_file_size_mb'] * 1024 * 1024
allowed_extensions = {
".pdf", ".ps", ".eps", ".png", ".jpg", ".jpeg", ".tiff", ".tif", ".gif",
".bmp", ".webp", ".svg", ".jxl", ".avif", ".ppm", ".mp3", ".m4a", ".ogg",
".flac", ".opus", ".wav", ".aac", ".mp4", ".mkv", ".mov", ".webm", ".avi",
".flv", ".md", ".txt", ".html", ".docx", ".odt", ".rst", ".epub", ".mobi",
".azw3", ".pptx", ".xlsx"
}
APP_CONFIG['app_settings']['allowed_all_extensions'] = allowed_extensions
logger.info("Successfully loaded settings from settings.yml")
except (FileNotFoundError, yaml.YAMLError) as e:
logger.error(f"Could not load settings.yml: {e}. App may not function correctly.")
APP_CONFIG = {}
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
settings.UPLOADS_DIR.mkdir(exist_ok=True)
settings.PROCESSED_DIR.mkdir(exist_ok=True)
PATHS.UPLOADS_DIR.mkdir(exist_ok=True)
PATHS.PROCESSED_DIR.mkdir(exist_ok=True)
# --------------------------------------------------------------------------------
# --- 2. DATABASE (for Job Tracking) - NO CHANGES
# --- 2. DATABASE & Schemas
# --------------------------------------------------------------------------------
engine = create_engine(settings.DATABASE_URL, connect_args={"check_same_thread": False})
engine = create_engine(
PATHS.DATABASE_URL,
connect_args={"check_same_thread": False, "timeout": 30},
poolclass=NullPool,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# THIS IS THE CRITICAL FIX
Base = declarative_base()
@event.listens_for(engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, connection_record):
"""
Enable WAL mode and set sane synchronous for better concurrency
between the FastAPI process and Huey worker processes.
"""
c = dbapi_connection.cursor()
try:
c.execute("PRAGMA journal_mode=WAL;")
c.execute("PRAGMA synchronous=NORMAL;")
finally:
c.close()
class Job(Base):
__tablename__ = "jobs"
id = Column(String, primary_key=True, index=True)
@@ -80,10 +115,6 @@ def get_db():
finally:
db.close()
# --------------------------------------------------------------------------------
# --- 3. PYDANTIC SCHEMAS (Data Validation) - NO CHANGES
# --------------------------------------------------------------------------------
class JobCreate(BaseModel):
id: str
task_type: str
@@ -104,9 +135,8 @@ class JobSchema(BaseModel):
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
# --------------------------------------------------------------------------------
# --- 4. CRUD OPERATIONS (Database Interactions) - NO CHANGES
# --- 3. CRUD OPERATIONS (No Changes)
# --------------------------------------------------------------------------------
def get_job(db: Session, job_id: str):
return db.query(Job).filter(Job.id == job_id).first()
@@ -143,37 +173,101 @@ def mark_job_as_completed(db: Session, job_id: str, preview: str | None = None):
db.commit()
return db_job
# --------------------------------------------------------------------------------
# --- 5. BACKGROUND TASKS (Huey)
# --- 4. BACKGROUND TASK SETUP
# --------------------------------------------------------------------------------
huey = SqliteHuey(filename=settings.HUEY_DB_PATH)
huey = SqliteHuey(filename=PATHS.HUEY_DB_PATH)
# --- START: NEW WHISPER MODEL CACHING ---
# This dictionary will live in the memory of the Huey worker process,
# allowing us to reuse loaded models across tasks.
WHISPER_MODELS_CACHE: Dict[str, WhisperModel] = {}
def get_whisper_model(model_size: str, whisper_settings: dict) -> WhisperModel:
"""
Loads a Whisper model into the cache if not present, and returns the model.
This ensures a model is only loaded into memory once per worker process.
"""
if model_size not in WHISPER_MODELS_CACHE:
compute_type = whisper_settings.get('compute_type', 'int8')
logger.info(f"Whisper model '{model_size}' not in cache. Loading into memory...")
model = WhisperModel(model_size, device="cpu", compute_type=compute_type)
WHISPER_MODELS_CACHE[model_size] = model
logger.info(f"Model '{model_size}' loaded successfully.")
else:
logger.info(f"Found model '{model_size}' in cache. Reusing.")
return WHISPER_MODELS_CACHE[model_size]
# --- END: NEW WHISPER MODEL CACHING ---
# MODIFICATION: Removed global whisper model and lazy loader.
# The model will now be loaded inside the task itself based on user selection.
@huey.task()
def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str, whisper_settings: dict):
db = SessionLocal()
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled':
logger.info(f"Job {job_id} was cancelled before starting.")
return
if not job or job.status == 'cancelled': return
update_job_status(db, job_id, "processing")
# --- MODIFIED: Use the caching function to get the model ---
model = get_whisper_model(model_size, whisper_settings)
logger.info(f"Starting transcription for job {job_id}")
segments, info = model.transcribe(input_path_str, beam_size=5)
full_transcript = []
for segment in segments:
job_check = get_job(db, job_id) # Check for cancellation during long tasks
if job_check.status == 'cancelled':
logger.info(f"Job {job_id} cancelled during transcription.")
return
if info.duration > 0:
progress = int((segment.end / info.duration) * 100)
update_job_status(db, job_id, "processing", progress=progress)
full_transcript.append(segment.text.strip())
transcript_text = "\n".join(full_transcript)
# write atomically to avoid partial files
out_path = Path(output_path_str)
tmp_out = out_path.with_suffix(out_path.suffix + f".{uuid.uuid4().hex}.tmp")
with tmp_out.open("w", encoding="utf-8") as f:
f.write(transcript_text)
tmp_out.replace(out_path)
mark_job_as_completed(db, job_id, preview=transcript_text)
logger.info(f"Transcription for job {job_id} completed.")
except Exception:
logger.exception(f"ERROR during transcription for job {job_id}")
update_job_status(db, job_id, "failed", error="See server logs for details.")
finally:
Path(input_path_str).unlink(missing_ok=True)
db.close()
# Other tasks remain unchanged
@huey.task()
def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str, ocr_settings: dict):
db = SessionLocal()
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled': return
update_job_status(db, job_id, "processing")
logger.info(f"Starting PDF OCR for job {job_id}")
ocrmypdf.ocr(input_path_str, output_path_str, deskew=True, force_ocr=True, clean=True, optimize=1, progress_bar=False)
ocrmypdf.ocr(input_path_str, output_path_str,
deskew=ocr_settings.get('deskew', True),
force_ocr=ocr_settings.get('force_ocr', True),
clean=ocr_settings.get('clean', True),
optimize=ocr_settings.get('optimize', 1),
progress_bar=False)
with open(output_path_str, "rb") as f:
reader = pypdf.PdfReader(f)
preview = "\n".join(page.extract_text() or "" for page in reader.pages)
mark_job_as_completed(db, job_id, preview=preview)
logger.info(f"PDF OCR for job {job_id} completed.")
except Exception as e:
logger.error(f"ERROR during PDF OCR for job {job_id}: {e}\n{traceback.format_exc()}")
update_job_status(db, job_id, "failed", error=str(e))
except Exception:
logger.exception(f"ERROR during PDF OCR for job {job_id}")
update_job_status(db, job_id, "failed", error="See server logs for details.")
finally:
Path(input_path_str).unlink(missing_ok=True)
db.close()
@@ -183,10 +277,7 @@ def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
db = SessionLocal()
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled':
logger.info(f"Job {job_id} was cancelled before starting.")
return
if not job or job.status == 'cancelled': return
update_job_status(db, job_id, "processing", progress=50)
logger.info(f"Starting Image OCR for job {job_id}")
text = pytesseract.image_to_string(Image.open(input_path_str))
@@ -194,175 +285,324 @@ def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
f.write(text)
mark_job_as_completed(db, job_id, preview=text)
logger.info(f"Image OCR for job {job_id} completed.")
except Exception as e:
logger.error(f"ERROR during Image OCR for job {job_id}: {e}\n{traceback.format_exc()}")
update_job_status(db, job_id, "failed", error=str(e))
except Exception:
logger.exception(f"ERROR during Image OCR for job {job_id}")
update_job_status(db, job_id, "failed", error="See server logs for details.")
finally:
Path(input_path_str).unlink(missing_ok=True)
db.close()
# MODIFICATION: The task now accepts `model_size` and loads the model dynamically.
@huey.task()
def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str):
def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str, tool: str, task_key: str, conversion_tools_config: dict):
db = SessionLocal()
temp_input_file = None
temp_output_file = None
try:
job = get_job(db, job_id)
if not job or job.status == 'cancelled':
logger.info(f"Job {job_id} was cancelled before starting.")
return
if not job or job.status == 'cancelled': return
update_job_status(db, job_id, "processing", progress=25)
logger.info(f"Starting conversion for job {job_id} using {tool} with task {task_key}")
tool_config = conversion_tools_config.get(tool)
if not tool_config: raise ValueError(f"Unknown conversion tool: {tool}")
input_path = Path(input_path_str)
output_path = Path(output_path_str)
current_input_path = input_path
if tool == "mozjpeg":
temp_input_file = input_path.with_suffix('.temp.ppm')
logger.info(f"Pre-converting for MozJPEG: {input_path} -> {temp_input_file}")
pre_conv_cmd = ["vips", "copy", str(input_path), str(temp_input_file)]
pre_conv_result = subprocess.run(pre_conv_cmd, capture_output=True, text=True, check=False, timeout=tool_config.get("timeout", 300))
if pre_conv_result.returncode != 0:
err = (pre_conv_result.stderr or "")[:4000]
raise Exception(f"MozJPEG pre-conversion to PPM failed: {err}")
current_input_path = temp_input_file
update_job_status(db, job_id, "processing", progress=50)
# Build safe mapping for formatting and validate placeholders
ALLOWED_VARS = {"input", "output", "output_dir", "output_ext", "quality", "speed", "preset", "device", "dpi", "samplerate", "bitdepth"}
def validate_and_build_command(template_str: str, mapping: dict):
fmt = Formatter()
used = {fname for _, fname, _, _ in fmt.parse(template_str) if fname}
bad = used - ALLOWED_VARS
if bad:
raise ValueError(f"Command template contains disallowed placeholders: {bad}")
formatted = template_str.format(**mapping)
return shlex.split(formatted)
update_job_status(db, job_id, "processing")
# Load the specified model for this task
logger.info(f"Loading faster-whisper model: {model_size} for job {job_id}...")
model = WhisperModel(
model_size,
device="cpu",
compute_type=settings.WHISPER_COMPUTE_TYPE
)
logger.info(f"Whisper model '{model_size}' loaded successfully.")
logger.info(f"Starting transcription for job {job_id}")
segments, info = model.transcribe(input_path_str, beam_size=5)
full_transcript = []
total_duration = info.duration
for segment in segments:
job_check = get_job(db, job_id)
if job_check.status == 'cancelled':
logger.info(f"Job {job_id} cancelled during transcription.")
return
# Use a temporary output path and atomically move into place after success
temp_output_file = output_path.with_suffix(output_path.suffix + f".{uuid.uuid4().hex}.tmp")
# Update progress based on the segment's end time
if total_duration > 0:
progress = int((segment.end / total_duration) * 100)
update_job_status(db, job_id, "processing", progress=progress)
full_transcript.append(segment.text.strip())
# Prepare mapping
mapping = {
"input": str(current_input_path),
"output": str(temp_output_file),
"output_dir": str(output_path.parent),
"output_ext": output_path.suffix.lstrip('.'),
}
transcript_text = "\n".join(full_transcript)
with open(output_path_str, "w", encoding="utf-8") as f:
f.write(transcript_text)
mark_job_as_completed(db, job_id, preview=transcript_text)
logger.info(f"Transcription for job {job_id} completed.")
# Allow tool-specific adjustments to mapping
if tool.startswith("ghostscript"):
device, setting = task_key.split('_')
mapping.update({"device": device, "dpi": setting, "preset": setting})
elif tool == "pngquant":
_, quality_key = task_key.split('_')
quality_map = {"hq": "80-95", "mq": "65-80", "fast": "65-80"}
speed_map = {"hq": "1", "mq": "3", "fast": "11"}
mapping.update({"quality": quality_map.get(quality_key, "65-80"), "speed": speed_map.get(quality_key, "3")})
elif tool == "sox":
_, rate, depth = task_key.split('_')
rate = rate.replace('k', '000') if 'k' in rate else rate
depth = depth.replace('b', '') if 'b' in depth else '16'
mapping.update({"samplerate": rate, "bitdepth": depth})
elif tool == "mozjpeg":
_, quality = task_key.split('_')
quality = quality.replace('q', '')
mapping.update({"quality": quality})
command_template_str = tool_config["command_template"]
command = validate_and_build_command(command_template_str, mapping)
logger.info(f"Executing command: {' '.join(command)}")
# run with timeout and capture output; run_command helper ensures trimmed logs on failure
def run_command(argv: List[str], timeout: int = 300):
try:
res = subprocess.run(argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=timeout)
except subprocess.TimeoutExpired:
raise Exception(f"Command timed out after {timeout}s")
if res.returncode != 0:
stderr = (res.stderr or "")[:4000]
stdout = (res.stdout or "")[:4000]
raise Exception(f"Command failed exit {res.returncode}. stderr: {stderr}; stdout: {stdout}")
return res
result = run_command(command, timeout=tool_config.get("timeout", 300))
if tool == "libreoffice":
expected_output_filename = input_path.with_suffix(output_path.suffix).name
generated_file = output_path.parent / expected_output_filename
if generated_file.exists():
# move generated file into place
generated_file.replace(output_path)
else:
raise Exception(f"LibreOffice did not create the expected file: {expected_output_filename}")
# move temp output into final location atomically
if temp_output_file and temp_output_file.exists():
temp_output_file.replace(output_path)
mark_job_as_completed(db, job_id, preview=f"Successfully converted file.")
logger.info(f"Conversion for job {job_id} completed.")
except Exception as e:
logger.error(f"ERROR during transcription for job {job_id}: {e}\n{traceback.format_exc()}")
update_job_status(db, job_id, "failed", error=str(e))
logger.exception(f"ERROR during conversion for job {job_id}")
update_job_status(db, job_id, "failed", error="See server logs for details.")
finally:
Path(input_path_str).unlink(missing_ok=True)
if temp_input_file:
temp_input_file.unlink(missing_ok=True)
if temp_output_file:
temp_output_file.unlink(missing_ok=True)
db.close()
# --------------------------------------------------------------------------------
# --- 6. FASTAPI APPLICATION
# --- 5. FASTAPI APPLICATION
# --------------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Application starting up...")
Base.metadata.create_all(bind=engine)
load_app_config()
yield
logger.info("Application shutting down...")
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory=settings.BASE_DIR / "static"), name="static")
templates = Jinja2Templates(directory=settings.BASE_DIR / "templates")
app.mount("/static", StaticFiles(directory=PATHS.BASE_DIR / "static"), name="static")
templates = Jinja2Templates(directory=PATHS.BASE_DIR / "templates")
# --- Helper Functions ---
async def save_upload_file_chunked(upload_file: UploadFile, destination: Path):
"""
Streams the uploaded file in chunks directly to a file on disk.
This is memory-efficient and reliable for large files.
"""
max_size = APP_CONFIG.get("app_settings", {}).get("max_file_size_bytes", 100 * 1024 * 1024)
tmp = destination.with_suffix(destination.suffix + f".{uuid.uuid4().hex}.tmp")
size = 0
with open(destination, "wb") as buffer:
while chunk := await upload_file.read(1024 * 1024): # 1MB chunks
if size + len(chunk) > settings.MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File exceeds limit of {settings.MAX_FILE_SIZE_BYTES // 1024 // 1024} MB"
)
buffer.write(chunk)
size += len(chunk)
try:
with tmp.open("wb") as buffer:
while True:
chunk = await upload_file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
if size > max_size:
raise HTTPException(status_code=413, detail=f"File exceeds {max_size / 1024 / 1024} MB limit")
buffer.write(chunk)
# atomic move into place
tmp.replace(destination)
except Exception:
tmp.unlink(missing_ok=True)
raise
def is_allowed_file(filename: str, allowed_extensions: set) -> bool:
return Path(filename).suffix.lower() in allowed_extensions
# --- API Endpoints ---
@app.get("/")
async def get_index(request: Request):
# MODIFICATION: Pass available models to the template
return templates.TemplateResponse("index.html", {
"request": request,
"whisper_models": sorted(list(settings.ALLOWED_WHISPER_MODELS))
})
# --- Routes (only transcription route is modified) ---
@app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED)
async def submit_pdf_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
if not is_allowed_file(file.filename, settings.ALLOWED_PDF_EXTENSIONS):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}"
upload_path = settings.UPLOADS_DIR / unique_filename
processed_path = settings.PROCESSED_DIR / unique_filename
await save_upload_file_chunked(file, upload_path)
job_data = JobCreate(id=job_id, task_type="ocr", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path))
return {"job_id": new_job.id, "status": new_job.status}
@app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED)
async def submit_image_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
if not is_allowed_file(file.filename, settings.ALLOWED_IMAGE_EXTENSIONS):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PNG, JPG, or TIFF.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
file_ext = Path(safe_basename).suffix
unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}"
upload_path = settings.UPLOADS_DIR / unique_filename
processed_path = settings.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.txt"
await save_upload_file_chunked(file, upload_path)
job_data = JobCreate(id=job_id, task_type="ocr-image", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_image_ocr_task(new_job.id, str(upload_path), str(processed_path))
return {"job_id": new_job.id, "status": new_job.status}
# MODIFICATION: Endpoint now accepts `model_size` as form data.
@app.post("/transcribe-audio", status_code=status.HTTP_202_ACCEPTED)
async def submit_audio_transcription(
file: UploadFile = File(...),
model_size: str = Form("base"),
db: Session = Depends(get_db)
):
if not is_allowed_file(file.filename, settings.ALLOWED_AUDIO_EXTENSIONS):
if not is_allowed_file(file.filename, {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus"}):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio file type.")
# Validate the selected model size
if model_size not in settings.ALLOWED_WHISPER_MODELS:
whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {})
if model_size not in whisper_config.get("allowed_models", []):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model size: {model_size}.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix
audio_filename = f"{stem}_{job_id}{suffix}"
transcript_filename = f"{stem}_{job_id}.txt"
upload_path = settings.UPLOADS_DIR / audio_filename
processed_path = settings.PROCESSED_DIR / transcript_filename
upload_path = PATHS.UPLOADS_DIR / audio_filename
processed_path = PATHS.PROCESSED_DIR / transcript_filename
await save_upload_file_chunked(file, upload_path)
job_data = JobCreate(id=job_id, task_type="transcription", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
# Pass the selected model size to the background task
run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size=model_size)
# --- MODIFIED: Pass whisper_config to the task ---
run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size=model_size, whisper_settings=whisper_config)
return {"job_id": new_job.id, "status": new_job.status}
@app.post("/job/{job_id}/cancel", status_code=status.HTTP_200_OK)
# --- Other routes remain unchanged ---
@app.get("/")
async def get_index(request: Request):
whisper_models = APP_CONFIG.get("transcription_settings", {}).get("whisper", {}).get("allowed_models", [])
conversion_tools = APP_CONFIG.get("conversion_tools", {})
return templates.TemplateResponse("index.html", {
"request": request,
"whisper_models": sorted(list(whisper_models)),
"conversion_tools": conversion_tools
})
@app.get("/settings")
async def get_settings_page(request: Request):
try:
with open(PATHS.SETTINGS_FILE, 'r') as f:
current_config = yaml.safe_load(f)
except Exception as e:
logger.error(f"Could not load settings.yml for settings page: {e}")
current_config = {}
return templates.TemplateResponse("settings.html", {"request": request, "config": current_config})
@app.post("/settings/save")
async def save_settings(new_config: Dict = Body(...)):
try:
with open(PATHS.SETTINGS_FILE, 'w') as f:
yaml.dump(new_config, f, default_flow_style=False, sort_keys=False)
load_app_config()
return JSONResponse({"message": "Settings saved successfully."})
except Exception as e:
logger.error(f"Failed to save settings: {e}")
raise HTTPException(status_code=500, detail="Could not write to settings.yml.")
@app.post("/settings/clear-history")
async def clear_job_history(db: Session = Depends(get_db)):
try:
num_deleted = db.query(Job).delete()
db.commit()
logger.info(f"Cleared {num_deleted} jobs from history.")
return {"deleted_count": num_deleted}
except Exception as e:
db.rollback()
logger.error(f"Failed to clear job history: {e}")
raise HTTPException(status_code=500, detail="Database error while clearing history.")
@app.post("/settings/delete-files")
async def delete_processed_files():
deleted_count = 0
errors = []
for f in PATHS.PROCESSED_DIR.glob('*'):
try:
if f.is_file():
f.unlink()
deleted_count += 1
except Exception as e:
errors.append(f.name)
logger.error(f"Could not delete processed file {f.name}: {e}")
if errors:
raise HTTPException(status_code=500, detail=f"Could not delete some files: {', '.join(errors)}")
logger.info(f"Deleted {deleted_count} files from processed directory.")
return {"deleted_count": deleted_count}
@app.post("/convert-file", status_code=status.HTTP_202_ACCEPTED)
async def submit_file_conversion(file: UploadFile = File(...), output_format: str = Form(...), db: Session = Depends(get_db)):
allowed_exts = APP_CONFIG.get("app_settings", {}).get("allowed_all_extensions", set())
if not is_allowed_file(file.filename, allowed_exts):
raise HTTPException(status_code=400, detail=f"File type '{Path(file.filename).suffix}' not allowed.")
conversion_tools = APP_CONFIG.get("conversion_tools", {})
try:
tool, task_key = output_format.split('_', 1)
if tool not in conversion_tools or task_key not in conversion_tools[tool]["formats"]:
raise ValueError()
except ValueError:
raise HTTPException(status_code=400, detail="Invalid output format selected.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
original_stem = Path(safe_basename).stem
target_ext = task_key.split('_')[0]
if tool == "ghostscript_pdf":
target_ext = "pdf"
upload_filename = f"{original_stem}_{job_id}{Path(safe_basename).suffix}"
processed_filename = f"{original_stem}_{job_id}.{target_ext}"
upload_path = PATHS.UPLOADS_DIR / upload_filename
processed_path = PATHS.PROCESSED_DIR / processed_filename
await save_upload_file_chunked(file, upload_path)
job_data = JobCreate(id=job_id, task_type="conversion", original_filename=file.filename,
input_filepath=str(upload_path), processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_conversion_task(new_job.id, str(upload_path), str(processed_path), tool, task_key, conversion_tools)
return {"job_id": new_job.id, "status": new_job.status}
@app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED)
async def submit_pdf_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
if not is_allowed_file(file.filename, {".pdf"}):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}"
upload_path = PATHS.UPLOADS_DIR / unique_filename
processed_path = PATHS.PROCESSED_DIR / unique_filename
await save_upload_file_chunked(file, upload_path)
job_data = JobCreate(id=job_id, task_type="ocr", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
ocr_settings = APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {})
run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path), ocr_settings)
return {"job_id": new_job.id, "status": new_job.status}
@app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED)
async def submit_image_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
allowed_exts = {".png", ".jpg", ".jpeg", ".tiff", ".tif"}
if not is_allowed_file(file.filename, allowed_exts):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PNG, JPG, or TIFF.")
job_id = uuid.uuid4().hex
safe_basename = secure_filename(file.filename)
file_ext = Path(safe_basename).suffix
unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}"
upload_path = PATHS.UPLOADS_DIR / unique_filename
processed_path = PATHS.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.txt"
await save_upload_file_chunked(file, upload_path)
job_data = JobCreate(id=job_id, task_type="ocr-image", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
new_job = create_job(db=db, job=job_data)
run_image_ocr_task(new_job.id, str(upload_path), str(processed_path))
return {"job_id": new_job.id, "status": new_job.status}
@app.post("/job/{job_id}/cancel", status_code=status.HTTP_202_ACCEPTED)
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
job = get_job(db, job_id)
if not job:
@@ -386,12 +626,13 @@ async def get_job_status(job_id: str, db: Session = Depends(get_db)):
@app.get("/download/{filename}")
async def download_file(filename: str):
safe_filename = secure_filename(filename)
file_path = settings.PROCESSED_DIR / safe_filename
if not file_path.resolve().is_relative_to(settings.PROCESSED_DIR.resolve()):
file_path = PATHS.PROCESSED_DIR / safe_filename
file_path = file_path.resolve()
base = PATHS.PROCESSED_DIR.resolve()
try:
file_path.relative_to(base)
except ValueError:
raise HTTPException(status_code=403, detail="Access denied.")
if not file_path.is_file():
raise HTTPException(status_code=404, detail="File not found.")
return FileResponse(path=file_path, filename=safe_filename, media_type="application/octet-stream")