Drag and Drop
This commit is contained in:
401
main.py
401
main.py
@@ -6,7 +6,7 @@ import uuid
|
||||
import shlex
|
||||
import yaml
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@@ -21,17 +21,21 @@ from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from huey import SqliteHuey
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer # MODIFIED: Import field_serializer
|
||||
from sqlalchemy import (Column, DateTime, Integer, String, Text,
|
||||
create_engine, delete, event)
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from string import Formatter
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
from werkzeug.utils import secure_filename
|
||||
from typing import List as TypingList
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 1. CONFIGURATION
|
||||
# --------------------------------------------------------------------------------
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppPaths(BaseModel):
|
||||
BASE_DIR: Path = Path(__file__).resolve().parent
|
||||
@@ -43,30 +47,46 @@ class AppPaths(BaseModel):
|
||||
|
||||
PATHS = AppPaths()
|
||||
APP_CONFIG: Dict[str, Any] = {}
|
||||
PATHS.UPLOADS_DIR.mkdir(exist_ok=True)
|
||||
PATHS.PROCESSED_DIR.mkdir(exist_ok=True)
|
||||
|
||||
def load_app_config():
|
||||
global APP_CONFIG
|
||||
try:
|
||||
with open(PATHS.SETTINGS_FILE, 'r') as f:
|
||||
APP_CONFIG = yaml.safe_load(f)
|
||||
APP_CONFIG['app_settings']['max_file_size_bytes'] = APP_CONFIG['app_settings']['max_file_size_mb'] * 1024 * 1024
|
||||
allowed_extensions = {
|
||||
".pdf", ".ps", ".eps", ".png", ".jpg", ".jpeg", ".tiff", ".tif", ".gif",
|
||||
".bmp", ".webp", ".svg", ".jxl", ".avif", ".ppm", ".mp3", ".m4a", ".ogg",
|
||||
".flac", ".opus", ".wav", ".aac", ".mp4", ".mkv", ".mov", ".webm", ".avi",
|
||||
".flv", ".md", ".txt", ".html", ".docx", ".odt", ".rst", ".epub", ".mobi",
|
||||
".azw3", ".pptx", ".xlsx"
|
||||
with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f:
|
||||
cfg_raw = yaml.safe_load(f) or {}
|
||||
# basic defaults
|
||||
defaults = {
|
||||
"app_settings": {"max_file_size_mb": 100, "allowed_all_extensions": []},
|
||||
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8"}},
|
||||
"conversion_tools": {},
|
||||
"ocr_settings": {"ocrmypdf": {}}
|
||||
}
|
||||
APP_CONFIG['app_settings']['allowed_all_extensions'] = allowed_extensions
|
||||
# shallow merge (safe for top-level keys)
|
||||
cfg = defaults.copy()
|
||||
cfg.update(cfg_raw)
|
||||
# normalize app settings
|
||||
app_settings = cfg.get("app_settings", {})
|
||||
max_mb = app_settings.get("max_file_size_mb", 100)
|
||||
app_settings["max_file_size_bytes"] = int(max_mb) * 1024 * 1024
|
||||
allowed = app_settings.get("allowed_all_extensions", [])
|
||||
if not isinstance(allowed, (list, set)):
|
||||
allowed = list(allowed)
|
||||
app_settings["allowed_all_extensions"] = set(allowed)
|
||||
cfg["app_settings"] = app_settings
|
||||
APP_CONFIG = cfg
|
||||
logger.info("Successfully loaded settings from settings.yml")
|
||||
except (FileNotFoundError, yaml.YAMLError) as e:
|
||||
logger.error(f"Could not load settings.yml: {e}. App may not function correctly.")
|
||||
APP_CONFIG = {}
|
||||
logging.getLogger(__name__).exception(f"Could not load settings.yml: {e}. Using defaults.")
|
||||
|
||||
APP_CONFIG = {
|
||||
"app_settings": {"max_file_size_mb": 100, "max_file_size_bytes": 100 * 1024 * 1024, "allowed_all_extensions": set()},
|
||||
"transcription_settings": {"whisper": {"allowed_models": ["tiny", "base", "small"], "compute_type": "int8"}},
|
||||
"conversion_tools": {},
|
||||
"ocr_settings": {"ocrmypdf": {}}
|
||||
}
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
PATHS.UPLOADS_DIR.mkdir(exist_ok=True)
|
||||
PATHS.PROCESSED_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 2. DATABASE & Schemas
|
||||
@@ -77,8 +97,6 @@ engine = create_engine(
|
||||
poolclass=NullPool,
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# THIS IS THE CRITICAL FIX
|
||||
Base = declarative_base()
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
@@ -102,11 +120,13 @@ class Job(Base):
|
||||
progress = Column(Integer, default=0)
|
||||
original_filename = Column(String)
|
||||
input_filepath = Column(String)
|
||||
input_filesize = Column(Integer, nullable=True)
|
||||
processed_filepath = Column(String, nullable=True)
|
||||
output_filesize = Column(Integer, nullable=True)
|
||||
result_preview = Column(Text, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
@@ -120,6 +140,7 @@ class JobCreate(BaseModel):
|
||||
task_type: str
|
||||
original_filename: str
|
||||
input_filepath: str
|
||||
input_filesize: int | None = None
|
||||
processed_filepath: str | None = None
|
||||
|
||||
class JobSchema(BaseModel):
|
||||
@@ -128,6 +149,8 @@ class JobSchema(BaseModel):
|
||||
status: str
|
||||
progress: int
|
||||
original_filename: str
|
||||
input_filesize: int | None = None
|
||||
output_filesize: int | None = None
|
||||
processed_filepath: str | None = None
|
||||
result_preview: str | None = None
|
||||
error_message: str | None = None
|
||||
@@ -135,8 +158,14 @@ class JobSchema(BaseModel):
|
||||
updated_at: datetime
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# NEW: This serializer ensures the datetime string sent to the frontend ALWAYS
|
||||
# includes the 'Z' UTC indicator, fixing the timezone bug.
|
||||
@field_serializer('created_at', 'updated_at')
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
return dt.isoformat() + "Z"
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 3. CRUD OPERATIONS (No Changes)
|
||||
# --- 3. CRUD OPERATIONS
|
||||
# --------------------------------------------------------------------------------
|
||||
def get_job(db: Session, job_id: str):
|
||||
return db.query(Job).filter(Job.id == job_id).first()
|
||||
@@ -163,80 +192,120 @@ def update_job_status(db: Session, job_id: str, status: str, progress: int = Non
|
||||
db.refresh(db_job)
|
||||
return db_job
|
||||
|
||||
def mark_job_as_completed(db: Session, job_id: str, preview: str | None = None):
|
||||
def mark_job_as_completed(db: Session, job_id: str, output_filepath_str: str | None = None, preview: str | None = None):
|
||||
db_job = get_job(db, job_id)
|
||||
if db_job and db_job.status != 'cancelled':
|
||||
db_job.status = "completed"
|
||||
db_job.progress = 100
|
||||
if preview:
|
||||
db_job.result_preview = preview.strip()[:2000]
|
||||
if output_filepath_str:
|
||||
try:
|
||||
output_path = Path(output_filepath_str)
|
||||
if output_path.exists():
|
||||
db_job.output_filesize = output_path.stat().st_size
|
||||
except Exception:
|
||||
logger.exception(f"Could not stat output file {output_filepath_str} for job {job_id}")
|
||||
db.commit()
|
||||
return db_job
|
||||
|
||||
# ... (The rest of the file is unchanged and remains the same) ...
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 4. BACKGROUND TASK SETUP
|
||||
# --------------------------------------------------------------------------------
|
||||
huey = SqliteHuey(filename=PATHS.HUEY_DB_PATH)
|
||||
|
||||
# --- START: NEW WHISPER MODEL CACHING ---
|
||||
# This dictionary will live in the memory of the Huey worker process,
|
||||
# allowing us to reuse loaded models across tasks.
|
||||
# Whisper model cache per worker process
|
||||
WHISPER_MODELS_CACHE: Dict[str, WhisperModel] = {}
|
||||
|
||||
def get_whisper_model(model_size: str, whisper_settings: dict) -> WhisperModel:
|
||||
"""
|
||||
Loads a Whisper model into the cache if not present, and returns the model.
|
||||
This ensures a model is only loaded into memory once per worker process.
|
||||
"""
|
||||
if model_size not in WHISPER_MODELS_CACHE:
|
||||
compute_type = whisper_settings.get('compute_type', 'int8')
|
||||
logger.info(f"Whisper model '{model_size}' not in cache. Loading into memory...")
|
||||
model = WhisperModel(model_size, device="cpu", compute_type=compute_type)
|
||||
WHISPER_MODELS_CACHE[model_size] = model
|
||||
logger.info(f"Model '{model_size}' loaded successfully.")
|
||||
else:
|
||||
if model_size in WHISPER_MODELS_CACHE:
|
||||
logger.info(f"Found model '{model_size}' in cache. Reusing.")
|
||||
return WHISPER_MODELS_CACHE[model_size]
|
||||
# --- END: NEW WHISPER MODEL CACHING ---
|
||||
return WHISPER_MODELS_CACHE[model_size]
|
||||
device = whisper_settings.get("device", "cpu")
|
||||
compute_type = whisper_settings.get('compute_type', 'int8')
|
||||
logger.info(f"Whisper model '{model_size}' not in cache. Loading into memory on device={device}...")
|
||||
try:
|
||||
model = WhisperModel(model_size, device=device, compute_type=compute_type)
|
||||
except Exception:
|
||||
logger.exception("Failed to load whisper model")
|
||||
raise
|
||||
WHISPER_MODELS_CACHE[model_size] = model
|
||||
logger.info(f"Model '{model_size}' loaded successfully.")
|
||||
return model
|
||||
|
||||
# Helper: safe run_command (trimmed logs + timeout)
|
||||
def run_command(argv: TypingList[str], timeout: int = 300):
|
||||
try:
|
||||
res = subprocess.run(argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise Exception(f"Command timed out after {timeout}s")
|
||||
if res.returncode != 0:
|
||||
stderr = (res.stderr or "")[:4000]
|
||||
stdout = (res.stdout or "")[:4000]
|
||||
raise Exception(f"Command failed exit {res.returncode}. stderr: {stderr}; stdout: {stdout}")
|
||||
return res
|
||||
|
||||
# Helper: validate and build command from template with allowlist
|
||||
ALLOWED_VARS = {"input", "output", "output_dir", "output_ext", "quality", "speed", "preset", "device", "dpi", "samplerate", "bitdepth", "filter"}
|
||||
|
||||
def validate_and_build_command(template_str: str, mapping: Dict[str, str]) -> TypingList[str]:
|
||||
"""
|
||||
Validate placeholders against ALLOWED_VARS and build a safe argv list.
|
||||
If a template uses allowed placeholders that are missing from `mapping`,
|
||||
auto-fill sensible defaults:
|
||||
- 'filter' -> mapping.get('output_ext', '')
|
||||
- others -> empty string
|
||||
This prevents KeyError while preserving the allowlist security check.
|
||||
"""
|
||||
fmt = Formatter()
|
||||
used = {fname for _, fname, _, _ in fmt.parse(template_str) if fname}
|
||||
bad = used - ALLOWED_VARS
|
||||
if bad:
|
||||
raise ValueError(f"Command template contains disallowed placeholders: {bad}")
|
||||
|
||||
# auto-fill missing allowed placeholders with safe defaults
|
||||
safe_mapping = dict(mapping) # shallow copy to avoid mutating caller mapping
|
||||
for name in used:
|
||||
if name not in safe_mapping:
|
||||
if name == "filter":
|
||||
safe_mapping[name] = safe_mapping.get("output_ext", "")
|
||||
else:
|
||||
safe_mapping[name] = ""
|
||||
|
||||
formatted = template_str.format(**safe_mapping)
|
||||
return shlex.split(formatted)
|
||||
|
||||
@huey.task()
|
||||
def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str, whisper_settings: dict):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled': return
|
||||
|
||||
if not job or job.status == 'cancelled':
|
||||
return
|
||||
update_job_status(db, job_id, "processing")
|
||||
|
||||
# --- MODIFIED: Use the caching function to get the model ---
|
||||
model = get_whisper_model(model_size, whisper_settings)
|
||||
|
||||
logger.info(f"Starting transcription for job {job_id}")
|
||||
segments, info = model.transcribe(input_path_str, beam_size=5)
|
||||
|
||||
full_transcript = []
|
||||
for segment in segments:
|
||||
job_check = get_job(db, job_id) # Check for cancellation during long tasks
|
||||
job_check = get_job(db, job_id) # Check for cancellation during long tasks
|
||||
if job_check.status == 'cancelled':
|
||||
logger.info(f"Job {job_id} cancelled during transcription.")
|
||||
return
|
||||
|
||||
if info.duration > 0:
|
||||
progress = int((segment.end / info.duration) * 100)
|
||||
update_job_status(db, job_id, "processing", progress=progress)
|
||||
|
||||
full_transcript.append(segment.text.strip())
|
||||
|
||||
transcript_text = "\n".join(full_transcript)
|
||||
# write atomically to avoid partial files
|
||||
# atomic write of transcript — keep the real extension and mark tmp in the name
|
||||
out_path = Path(output_path_str)
|
||||
tmp_out = out_path.with_suffix(out_path.suffix + f".{uuid.uuid4().hex}.tmp")
|
||||
tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix}")
|
||||
with tmp_out.open("w", encoding="utf-8") as f:
|
||||
f.write(transcript_text)
|
||||
tmp_out.replace(out_path)
|
||||
|
||||
mark_job_as_completed(db, job_id, preview=transcript_text)
|
||||
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=transcript_text)
|
||||
logger.info(f"Transcription for job {job_id} completed.")
|
||||
except Exception:
|
||||
logger.exception(f"ERROR during transcription for job {job_id}")
|
||||
@@ -245,13 +314,13 @@ def run_transcription_task(job_id: str, input_path_str: str, output_path_str: st
|
||||
Path(input_path_str).unlink(missing_ok=True)
|
||||
db.close()
|
||||
|
||||
# Other tasks remain unchanged
|
||||
@huey.task()
|
||||
def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str, ocr_settings: dict):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled': return
|
||||
if not job or job.status == 'cancelled':
|
||||
return
|
||||
update_job_status(db, job_id, "processing")
|
||||
logger.info(f"Starting PDF OCR for job {job_id}")
|
||||
ocrmypdf.ocr(input_path_str, output_path_str,
|
||||
@@ -263,7 +332,7 @@ def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str, ocr
|
||||
with open(output_path_str, "rb") as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
preview = "\n".join(page.extract_text() or "" for page in reader.pages)
|
||||
mark_job_as_completed(db, job_id, preview=preview)
|
||||
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=preview)
|
||||
logger.info(f"PDF OCR for job {job_id} completed.")
|
||||
except Exception:
|
||||
logger.exception(f"ERROR during PDF OCR for job {job_id}")
|
||||
@@ -277,13 +346,18 @@ def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled': return
|
||||
if not job or job.status == 'cancelled':
|
||||
return
|
||||
update_job_status(db, job_id, "processing", progress=50)
|
||||
logger.info(f"Starting Image OCR for job {job_id}")
|
||||
text = pytesseract.image_to_string(Image.open(input_path_str))
|
||||
with open(output_path_str, "w", encoding="utf-8") as f:
|
||||
# atomic write of OCR text
|
||||
out_path = Path(output_path_str)
|
||||
tmp_out = out_path.with_name(f"{out_path.stem}.tmp-{uuid.uuid4().hex}{out_path.suffix}")
|
||||
with tmp_out.open("w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
mark_job_as_completed(db, job_id, preview=text)
|
||||
tmp_out.replace(out_path)
|
||||
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=text)
|
||||
logger.info(f"Image OCR for job {job_id} completed.")
|
||||
except Exception:
|
||||
logger.exception(f"ERROR during Image OCR for job {job_id}")
|
||||
@@ -300,14 +374,18 @@ def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str,
|
||||
temp_output_file = None
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled': return
|
||||
if not job or job.status == 'cancelled':
|
||||
return
|
||||
update_job_status(db, job_id, "processing", progress=25)
|
||||
logger.info(f"Starting conversion for job {job_id} using {tool} with task {task_key}")
|
||||
tool_config = conversion_tools_config.get(tool)
|
||||
if not tool_config: raise ValueError(f"Unknown conversion tool: {tool}")
|
||||
if not tool_config:
|
||||
raise ValueError(f"Unknown conversion tool: {tool}")
|
||||
input_path = Path(input_path_str)
|
||||
output_path = Path(output_path_str)
|
||||
current_input_path = input_path
|
||||
|
||||
# Pre-processing for specific tools
|
||||
if tool == "mozjpeg":
|
||||
temp_input_file = input_path.with_suffix('.temp.ppm')
|
||||
logger.info(f"Pre-converting for MozJPEG: {input_path} -> {temp_input_file}")
|
||||
@@ -317,22 +395,12 @@ def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str,
|
||||
err = (pre_conv_result.stderr or "")[:4000]
|
||||
raise Exception(f"MozJPEG pre-conversion to PPM failed: {err}")
|
||||
current_input_path = temp_input_file
|
||||
|
||||
update_job_status(db, job_id, "processing", progress=50)
|
||||
# Build safe mapping for formatting and validate placeholders
|
||||
ALLOWED_VARS = {"input", "output", "output_dir", "output_ext", "quality", "speed", "preset", "device", "dpi", "samplerate", "bitdepth"}
|
||||
def validate_and_build_command(template_str: str, mapping: dict):
|
||||
fmt = Formatter()
|
||||
used = {fname for _, fname, _, _ in fmt.parse(template_str) if fname}
|
||||
bad = used - ALLOWED_VARS
|
||||
if bad:
|
||||
raise ValueError(f"Command template contains disallowed placeholders: {bad}")
|
||||
formatted = template_str.format(**mapping)
|
||||
return shlex.split(formatted)
|
||||
|
||||
# Use a temporary output path and atomically move into place after success
|
||||
temp_output_file = output_path.with_suffix(output_path.suffix + f".{uuid.uuid4().hex}.tmp")
|
||||
|
||||
# Prepare mapping
|
||||
# prepare temporary output and mapping
|
||||
# use a temp filename that preserves the real extension, e.g. file.tmp-<uuid>.pdf
|
||||
temp_output_file = output_path.with_name(f"{output_path.stem}.tmp-{uuid.uuid4().hex}{output_path.suffix}")
|
||||
mapping = {
|
||||
"input": str(current_input_path),
|
||||
"output": str(temp_output_file),
|
||||
@@ -340,7 +408,7 @@ def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str,
|
||||
"output_ext": output_path.suffix.lstrip('.'),
|
||||
}
|
||||
|
||||
# Allow tool-specific adjustments to mapping
|
||||
# tool specific mapping adjustments
|
||||
if tool.startswith("ghostscript"):
|
||||
device, setting = task_key.split('_')
|
||||
mapping.update({"device": device, "dpi": setting, "preset": setting})
|
||||
@@ -358,38 +426,30 @@ def run_conversion_task(job_id: str, input_path_str: str, output_path_str: str,
|
||||
_, quality = task_key.split('_')
|
||||
quality = quality.replace('q', '')
|
||||
mapping.update({"quality": quality})
|
||||
elif tool == "libreoffice":
|
||||
target_ext = output_path.suffix.lstrip('.')
|
||||
# tool_config may include a 'filters' mapping (see settings.yml example)
|
||||
filter_val = tool_config.get("filters", {}).get(target_ext, target_ext)
|
||||
mapping["filter"] = filter_val
|
||||
|
||||
command_template_str = tool_config["command_template"]
|
||||
command = validate_and_build_command(command_template_str, mapping)
|
||||
logger.info(f"Executing command: {' '.join(command)}")
|
||||
# run with timeout and capture output; run_command helper ensures trimmed logs on failure
|
||||
def run_command(argv: List[str], timeout: int = 300):
|
||||
try:
|
||||
res = subprocess.run(argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise Exception(f"Command timed out after {timeout}s")
|
||||
if res.returncode != 0:
|
||||
stderr = (res.stderr or "")[:4000]
|
||||
stdout = (res.stdout or "")[:4000]
|
||||
raise Exception(f"Command failed exit {res.returncode}. stderr: {stderr}; stdout: {stdout}")
|
||||
return res
|
||||
|
||||
# execute command with timeout and trimmed logs on error
|
||||
result = run_command(command, timeout=tool_config.get("timeout", 300))
|
||||
if tool == "libreoffice":
|
||||
expected_output_filename = input_path.with_suffix(output_path.suffix).name
|
||||
generated_file = output_path.parent / expected_output_filename
|
||||
if generated_file.exists():
|
||||
# move generated file into place
|
||||
generated_file.replace(output_path)
|
||||
else:
|
||||
raise Exception(f"LibreOffice did not create the expected file: {expected_output_filename}")
|
||||
|
||||
# handle LibreOffice special case: sometimes it writes differently
|
||||
# Special-case LibreOffice: support per-format export filters via settings.yml
|
||||
|
||||
|
||||
# move temp output into final location atomically
|
||||
if temp_output_file and temp_output_file.exists():
|
||||
temp_output_file.replace(output_path)
|
||||
|
||||
mark_job_as_completed(db, job_id, preview=f"Successfully converted file.")
|
||||
mark_job_as_completed(db, job_id, output_filepath_str=output_path_str, preview=f"Successfully converted file.")
|
||||
logger.info(f"Conversion for job {job_id} completed.")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"ERROR during conversion for job {job_id}")
|
||||
update_job_status(db, job_id, "failed", error="See server logs for details.")
|
||||
finally:
|
||||
@@ -415,13 +475,14 @@ app = FastAPI(lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=PATHS.BASE_DIR / "static"), name="static")
|
||||
templates = Jinja2Templates(directory=PATHS.BASE_DIR / "templates")
|
||||
|
||||
async def save_upload_file_chunked(upload_file: UploadFile, destination: Path):
|
||||
async def save_upload_file_chunked(upload_file: UploadFile, destination: Path) -> int:
|
||||
"""
|
||||
Streams the uploaded file in chunks directly to a file on disk.
|
||||
This is memory-efficient and reliable for large files.
|
||||
Write upload to a tmp file in chunks, then atomically move to final destination.
|
||||
Returns the final size of the file in bytes.
|
||||
"""
|
||||
max_size = APP_CONFIG.get("app_settings", {}).get("max_file_size_bytes", 100 * 1024 * 1024)
|
||||
tmp = destination.with_suffix(destination.suffix + f".{uuid.uuid4().hex}.tmp")
|
||||
# make a temp filename that keeps the real extension, e.g. file.tmp-<uuid>.pdf
|
||||
tmp = destination.with_name(f"{destination.stem}.tmp-{uuid.uuid4().hex}{destination.suffix}")
|
||||
size = 0
|
||||
try:
|
||||
with tmp.open("wb") as buffer:
|
||||
@@ -433,17 +494,16 @@ async def save_upload_file_chunked(upload_file: UploadFile, destination: Path):
|
||||
if size > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File exceeds {max_size / 1024 / 1024} MB limit")
|
||||
buffer.write(chunk)
|
||||
# atomic move into place
|
||||
tmp.replace(destination)
|
||||
return size
|
||||
except Exception:
|
||||
tmp.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def is_allowed_file(filename: str, allowed_extensions: set) -> bool:
|
||||
return Path(filename).suffix.lower() in allowed_extensions
|
||||
|
||||
# --- Routes (only transcription route is modified) ---
|
||||
# --- Routes (transcription route uses Huey task enqueuing) ---
|
||||
|
||||
@app.post("/transcribe-audio", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def submit_audio_transcription(
|
||||
@@ -453,7 +513,7 @@ async def submit_audio_transcription(
|
||||
):
|
||||
if not is_allowed_file(file.filename, {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus"}):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio file type.")
|
||||
|
||||
|
||||
whisper_config = APP_CONFIG.get("transcription_settings", {}).get("whisper", {})
|
||||
if model_size not in whisper_config.get("allowed_models", []):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model size: {model_size}.")
|
||||
@@ -461,24 +521,29 @@ async def submit_audio_transcription(
|
||||
job_id = uuid.uuid4().hex
|
||||
safe_basename = secure_filename(file.filename)
|
||||
stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix
|
||||
|
||||
|
||||
audio_filename = f"{stem}_{job_id}{suffix}"
|
||||
transcript_filename = f"{stem}_{job_id}.txt"
|
||||
upload_path = PATHS.UPLOADS_DIR / audio_filename
|
||||
processed_path = PATHS.PROCESSED_DIR / transcript_filename
|
||||
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
|
||||
job_data = JobCreate(id=job_id, task_type="transcription", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
input_size = await save_upload_file_chunked(file, upload_path)
|
||||
|
||||
job_data = JobCreate(
|
||||
id=job_id,
|
||||
task_type="transcription",
|
||||
original_filename=file.filename,
|
||||
input_filepath=str(upload_path),
|
||||
input_filesize=input_size,
|
||||
processed_filepath=str(processed_path)
|
||||
)
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
|
||||
# --- MODIFIED: Pass whisper_config to the task ---
|
||||
|
||||
# enqueue the Huey task (decorated function call enqueues when using huey)
|
||||
run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size=model_size, whisper_settings=whisper_config)
|
||||
|
||||
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
||||
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
|
||||
|
||||
# --- Other routes remain unchanged ---
|
||||
|
||||
@app.get("/")
|
||||
async def get_index(request: Request):
|
||||
@@ -493,23 +558,55 @@ async def get_index(request: Request):
|
||||
@app.get("/settings")
|
||||
async def get_settings_page(request: Request):
|
||||
try:
|
||||
with open(PATHS.SETTINGS_FILE, 'r') as f:
|
||||
current_config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load settings.yml for settings page: {e}")
|
||||
with open(PATHS.SETTINGS_FILE, 'r', encoding='utf8') as f:
|
||||
current_config = yaml.safe_load(f) or {}
|
||||
except Exception:
|
||||
logger.exception("Could not load settings.yml for settings page")
|
||||
current_config = {}
|
||||
return templates.TemplateResponse("settings.html", {"request": request, "config": current_config})
|
||||
|
||||
def deep_merge(base: dict, updates: dict) -> dict:
|
||||
"""
|
||||
Recursively merge `updates` into `base`. Lists and scalars are replaced.
|
||||
"""
|
||||
for key, value in updates.items():
|
||||
if (
|
||||
key in base
|
||||
and isinstance(base[key], dict)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
base[key] = deep_merge(base[key], value)
|
||||
else:
|
||||
base[key] = value
|
||||
return base
|
||||
|
||||
|
||||
@app.post("/settings/save")
|
||||
async def save_settings(new_config: Dict = Body(...)):
|
||||
tmp = PATHS.SETTINGS_FILE.with_suffix(".tmp")
|
||||
try:
|
||||
with open(PATHS.SETTINGS_FILE, 'w') as f:
|
||||
yaml.dump(new_config, f, default_flow_style=False, sort_keys=False)
|
||||
# load existing config if present
|
||||
try:
|
||||
with PATHS.SETTINGS_FILE.open("r", encoding="utf8") as f:
|
||||
current_config = yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
current_config = {}
|
||||
|
||||
# deep merge new values
|
||||
merged = deep_merge(current_config, new_config)
|
||||
|
||||
# atomic write back
|
||||
with tmp.open("w", encoding="utf8") as f:
|
||||
yaml.safe_dump(merged, f, default_flow_style=False, sort_keys=False)
|
||||
tmp.replace(PATHS.SETTINGS_FILE)
|
||||
|
||||
load_app_config()
|
||||
return JSONResponse({"message": "Settings saved successfully."})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save settings: {e}")
|
||||
raise HTTPException(status_code=500, detail="Could not write to settings.yml.")
|
||||
return JSONResponse({"message": "Settings updated successfully."})
|
||||
except Exception:
|
||||
logger.exception("Failed to update settings")
|
||||
tmp.unlink(missing_ok=True)
|
||||
raise HTTPException(status_code=500, detail="Could not update settings.yml.")
|
||||
|
||||
|
||||
@app.post("/settings/clear-history")
|
||||
async def clear_job_history(db: Session = Depends(get_db)):
|
||||
@@ -518,9 +615,9 @@ async def clear_job_history(db: Session = Depends(get_db)):
|
||||
db.commit()
|
||||
logger.info(f"Cleared {num_deleted} jobs from history.")
|
||||
return {"deleted_count": num_deleted}
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
db.rollback()
|
||||
logger.error(f"Failed to clear job history: {e}")
|
||||
logger.exception("Failed to clear job history")
|
||||
raise HTTPException(status_code=500, detail="Database error while clearing history.")
|
||||
|
||||
@app.post("/settings/delete-files")
|
||||
@@ -532,9 +629,9 @@ async def delete_processed_files():
|
||||
if f.is_file():
|
||||
f.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
errors.append(f.name)
|
||||
logger.error(f"Could not delete processed file {f.name}: {e}")
|
||||
logger.exception(f"Could not delete processed file {f.name}")
|
||||
if errors:
|
||||
raise HTTPException(status_code=500, detail=f"Could not delete some files: {', '.join(errors)}")
|
||||
logger.info(f"Deleted {deleted_count} files from processed directory.")
|
||||
@@ -562,12 +659,14 @@ async def submit_file_conversion(file: UploadFile = File(...), output_format: st
|
||||
processed_filename = f"{original_stem}_{job_id}.{target_ext}"
|
||||
upload_path = PATHS.UPLOADS_DIR / upload_filename
|
||||
processed_path = PATHS.PROCESSED_DIR / processed_filename
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
input_size = await save_upload_file_chunked(file, upload_path)
|
||||
job_data = JobCreate(id=job_id, task_type="conversion", original_filename=file.filename,
|
||||
input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
input_filepath=str(upload_path),
|
||||
input_filesize=input_size,
|
||||
processed_filepath=str(processed_path))
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
run_conversion_task(new_job.id, str(upload_path), str(processed_path), tool, task_key, conversion_tools)
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
||||
|
||||
@app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def submit_pdf_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
|
||||
@@ -578,12 +677,15 @@ async def submit_pdf_ocr(file: UploadFile = File(...), db: Session = Depends(get
|
||||
unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}"
|
||||
upload_path = PATHS.UPLOADS_DIR / unique_filename
|
||||
processed_path = PATHS.PROCESSED_DIR / unique_filename
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
job_data = JobCreate(id=job_id, task_type="ocr", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
input_size = await save_upload_file_chunked(file, upload_path)
|
||||
job_data = JobCreate(id=job_id, task_type="ocr", original_filename=file.filename,
|
||||
input_filepath=str(upload_path),
|
||||
input_filesize=input_size,
|
||||
processed_filepath=str(processed_path))
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
ocr_settings = APP_CONFIG.get("ocr_settings", {}).get("ocrmypdf", {})
|
||||
run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path), ocr_settings)
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
||||
|
||||
@app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def submit_image_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
|
||||
@@ -596,11 +698,14 @@ async def submit_image_ocr(file: UploadFile = File(...), db: Session = Depends(g
|
||||
unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}"
|
||||
upload_path = PATHS.UPLOADS_DIR / unique_filename
|
||||
processed_path = PATHS.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.txt"
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
job_data = JobCreate(id=job_id, task_type="ocr-image", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
input_size = await save_upload_file_chunked(file, upload_path)
|
||||
job_data = JobCreate(id=job_id, task_type="ocr-image", original_filename=file.filename,
|
||||
input_filepath=str(upload_path),
|
||||
input_filesize=input_size,
|
||||
processed_filepath=str(processed_path))
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
run_image_ocr_task(new_job.id, str(upload_path), str(processed_path))
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
return {"job_id": new_job.id, "status": new_job.status, "status_url": f"/job/{new_job.id}"}
|
||||
|
||||
@app.post("/job/{job_id}/cancel", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
||||
@@ -626,8 +731,7 @@ async def get_job_status(job_id: str, db: Session = Depends(get_db)):
|
||||
@app.get("/download/{filename}")
|
||||
async def download_file(filename: str):
|
||||
safe_filename = secure_filename(filename)
|
||||
file_path = PATHS.PROCESSED_DIR / safe_filename
|
||||
file_path = file_path.resolve()
|
||||
file_path = (PATHS.PROCESSED_DIR / safe_filename).resolve()
|
||||
base = PATHS.PROCESSED_DIR.resolve()
|
||||
try:
|
||||
file_path.relative_to(base)
|
||||
@@ -635,4 +739,15 @@ async def download_file(filename: str):
|
||||
raise HTTPException(status_code=403, detail="Access denied.")
|
||||
if not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="File not found.")
|
||||
return FileResponse(path=file_path, filename=safe_filename, media_type="application/octet-stream")
|
||||
return FileResponse(path=file_path, filename=safe_filename, media_type="application/octet-stream")
|
||||
|
||||
# Small health endpoint
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
except Exception:
|
||||
logger.exception("Health check failed")
|
||||
return JSONResponse({"ok": False}, status_code=500)
|
||||
return {"ok": True}
|
||||
Reference in New Issue
Block a user