init
This commit is contained in:
397
main.py
Normal file
397
main.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import logging
|
||||
import shutil
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Set
|
||||
|
||||
import ocrmypdf
|
||||
import pypdf
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
from faster_whisper import WhisperModel
|
||||
# MODIFICATION: Added Form for model selection
|
||||
from fastapi import (Depends, FastAPI, File, Form, HTTPException, Request,
|
||||
UploadFile, status)
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from huey import SqliteHuey
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
from sqlalchemy import (Column, DateTime, Integer, String, Text,
|
||||
create_engine)
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 1. CONFIGURATION
|
||||
# --------------------------------------------------------------------------------
|
||||
class Settings(BaseSettings):
|
||||
BASE_DIR: Path = Path(__file__).resolve().parent
|
||||
UPLOADS_DIR: Path = BASE_DIR / "uploads"
|
||||
PROCESSED_DIR: Path = BASE_DIR / "processed"
|
||||
DATABASE_URL: str = f"sqlite:///{BASE_DIR / 'jobs.db'}"
|
||||
HUEY_DB_PATH: str = str(BASE_DIR / "huey.db")
|
||||
# MODIFICATION: Removed hardcoded model size, added a set of allowed models
|
||||
WHISPER_COMPUTE_TYPE: str = "int8"
|
||||
ALLOWED_WHISPER_MODELS: Set[str] = {"tiny", "base", "small", "medium", "large-v3", "distil-large-v2"}
|
||||
MAX_FILE_SIZE_BYTES: int = 500 * 1024 * 1024 # 500 MB
|
||||
ALLOWED_PDF_EXTENSIONS: set = {".pdf"}
|
||||
ALLOWED_IMAGE_EXTENSIONS: set = {".png", ".jpg", ".jpeg", ".tiff", ".tif"}
|
||||
ALLOWED_AUDIO_EXTENSIONS: set = {".mp3", "m4a", ".ogg", ".flac", ".opus"}
|
||||
|
||||
settings = Settings()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings.UPLOADS_DIR.mkdir(exist_ok=True)
|
||||
settings.PROCESSED_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 2. DATABASE (for Job Tracking) - NO CHANGES
|
||||
# --------------------------------------------------------------------------------
|
||||
engine = create_engine(settings.DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
class Job(Base):
|
||||
__tablename__ = "jobs"
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
task_type = Column(String, index=True)
|
||||
status = Column(String, default="pending")
|
||||
progress = Column(Integer, default=0)
|
||||
original_filename = Column(String)
|
||||
input_filepath = Column(String)
|
||||
processed_filepath = Column(String, nullable=True)
|
||||
result_preview = Column(Text, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 3. PYDANTIC SCHEMAS (Data Validation) - NO CHANGES
|
||||
# --------------------------------------------------------------------------------
|
||||
class JobCreate(BaseModel):
|
||||
id: str
|
||||
task_type: str
|
||||
original_filename: str
|
||||
input_filepath: str
|
||||
processed_filepath: str | None = None
|
||||
|
||||
class JobSchema(BaseModel):
|
||||
id: str
|
||||
task_type: str
|
||||
status: str
|
||||
progress: int
|
||||
original_filename: str
|
||||
processed_filepath: str | None = None
|
||||
result_preview: str | None = None
|
||||
error_message: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 4. CRUD OPERATIONS (Database Interactions) - NO CHANGES
|
||||
# --------------------------------------------------------------------------------
|
||||
def get_job(db: Session, job_id: str):
|
||||
return db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
def get_jobs(db: Session, skip: int = 0, limit: int = 100):
|
||||
return db.query(Job).order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
def create_job(db: Session, job: JobCreate):
|
||||
db_job = Job(**job.model_dump())
|
||||
db.add(db_job)
|
||||
db.commit()
|
||||
db.refresh(db_job)
|
||||
return db_job
|
||||
|
||||
def update_job_status(db: Session, job_id: str, status: str, progress: int = None, error: str = None):
|
||||
db_job = get_job(db, job_id)
|
||||
if db_job:
|
||||
db_job.status = status
|
||||
if progress is not None:
|
||||
db_job.progress = progress
|
||||
if error:
|
||||
db_job.error_message = error
|
||||
db.commit()
|
||||
db.refresh(db_job)
|
||||
return db_job
|
||||
|
||||
def mark_job_as_completed(db: Session, job_id: str, preview: str | None = None):
|
||||
db_job = get_job(db, job_id)
|
||||
if db_job and db_job.status != 'cancelled':
|
||||
db_job.status = "completed"
|
||||
db_job.progress = 100
|
||||
if preview:
|
||||
db_job.result_preview = preview.strip()[:2000]
|
||||
db.commit()
|
||||
return db_job
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 5. BACKGROUND TASKS (Huey)
|
||||
# --------------------------------------------------------------------------------
|
||||
huey = SqliteHuey(filename=settings.HUEY_DB_PATH)
|
||||
|
||||
# MODIFICATION: Removed global whisper model and lazy loader.
|
||||
# The model will now be loaded inside the task itself based on user selection.
|
||||
|
||||
@huey.task()
|
||||
def run_pdf_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled':
|
||||
logger.info(f"Job {job_id} was cancelled before starting.")
|
||||
return
|
||||
|
||||
update_job_status(db, job_id, "processing")
|
||||
logger.info(f"Starting PDF OCR for job {job_id}")
|
||||
|
||||
ocrmypdf.ocr(input_path_str, output_path_str, deskew=True, force_ocr=True, clean=True, optimize=1, progress_bar=False)
|
||||
|
||||
with open(output_path_str, "rb") as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
preview = "\n".join(page.extract_text() or "" for page in reader.pages)
|
||||
mark_job_as_completed(db, job_id, preview=preview)
|
||||
logger.info(f"PDF OCR for job {job_id} completed.")
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR during PDF OCR for job {job_id}: {e}\n{traceback.format_exc()}")
|
||||
update_job_status(db, job_id, "failed", error=str(e))
|
||||
finally:
|
||||
Path(input_path_str).unlink(missing_ok=True)
|
||||
db.close()
|
||||
|
||||
@huey.task()
|
||||
def run_image_ocr_task(job_id: str, input_path_str: str, output_path_str: str):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled':
|
||||
logger.info(f"Job {job_id} was cancelled before starting.")
|
||||
return
|
||||
|
||||
update_job_status(db, job_id, "processing", progress=50)
|
||||
logger.info(f"Starting Image OCR for job {job_id}")
|
||||
text = pytesseract.image_to_string(Image.open(input_path_str))
|
||||
with open(output_path_str, "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
mark_job_as_completed(db, job_id, preview=text)
|
||||
logger.info(f"Image OCR for job {job_id} completed.")
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR during Image OCR for job {job_id}: {e}\n{traceback.format_exc()}")
|
||||
update_job_status(db, job_id, "failed", error=str(e))
|
||||
finally:
|
||||
Path(input_path_str).unlink(missing_ok=True)
|
||||
db.close()
|
||||
|
||||
# MODIFICATION: The task now accepts `model_size` and loads the model dynamically.
|
||||
@huey.task()
|
||||
def run_transcription_task(job_id: str, input_path_str: str, output_path_str: str, model_size: str):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = get_job(db, job_id)
|
||||
if not job or job.status == 'cancelled':
|
||||
logger.info(f"Job {job_id} was cancelled before starting.")
|
||||
return
|
||||
|
||||
update_job_status(db, job_id, "processing")
|
||||
|
||||
# Load the specified model for this task
|
||||
logger.info(f"Loading faster-whisper model: {model_size} for job {job_id}...")
|
||||
model = WhisperModel(
|
||||
model_size,
|
||||
device="cpu",
|
||||
compute_type=settings.WHISPER_COMPUTE_TYPE
|
||||
)
|
||||
logger.info(f"Whisper model '{model_size}' loaded successfully.")
|
||||
|
||||
logger.info(f"Starting transcription for job {job_id}")
|
||||
segments, info = model.transcribe(input_path_str, beam_size=5)
|
||||
|
||||
full_transcript = []
|
||||
total_duration = info.duration
|
||||
for segment in segments:
|
||||
job_check = get_job(db, job_id)
|
||||
if job_check.status == 'cancelled':
|
||||
logger.info(f"Job {job_id} cancelled during transcription.")
|
||||
return
|
||||
|
||||
# Update progress based on the segment's end time
|
||||
if total_duration > 0:
|
||||
progress = int((segment.end / total_duration) * 100)
|
||||
update_job_status(db, job_id, "processing", progress=progress)
|
||||
full_transcript.append(segment.text.strip())
|
||||
|
||||
transcript_text = "\n".join(full_transcript)
|
||||
with open(output_path_str, "w", encoding="utf-8") as f:
|
||||
f.write(transcript_text)
|
||||
mark_job_as_completed(db, job_id, preview=transcript_text)
|
||||
logger.info(f"Transcription for job {job_id} completed.")
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR during transcription for job {job_id}: {e}\n{traceback.format_exc()}")
|
||||
update_job_status(db, job_id, "failed", error=str(e))
|
||||
finally:
|
||||
Path(input_path_str).unlink(missing_ok=True)
|
||||
db.close()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# --- 6. FASTAPI APPLICATION
|
||||
# --------------------------------------------------------------------------------
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Application starting up...")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
yield
|
||||
logger.info("Application shutting down...")
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=settings.BASE_DIR / "static"), name="static")
|
||||
templates = Jinja2Templates(directory=settings.BASE_DIR / "templates")
|
||||
|
||||
# --- Helper Functions ---
|
||||
async def save_upload_file_chunked(upload_file: UploadFile, destination: Path):
|
||||
size = 0
|
||||
with open(destination, "wb") as buffer:
|
||||
while chunk := await upload_file.read(1024 * 1024): # 1MB chunks
|
||||
if size + len(chunk) > settings.MAX_FILE_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"File exceeds limit of {settings.MAX_FILE_SIZE_BYTES // 1024 // 1024} MB"
|
||||
)
|
||||
buffer.write(chunk)
|
||||
size += len(chunk)
|
||||
|
||||
def is_allowed_file(filename: str, allowed_extensions: set) -> bool:
|
||||
return Path(filename).suffix.lower() in allowed_extensions
|
||||
|
||||
# --- API Endpoints ---
|
||||
@app.get("/")
|
||||
async def get_index(request: Request):
|
||||
# MODIFICATION: Pass available models to the template
|
||||
return templates.TemplateResponse("index.html", {
|
||||
"request": request,
|
||||
"whisper_models": sorted(list(settings.ALLOWED_WHISPER_MODELS))
|
||||
})
|
||||
|
||||
@app.post("/ocr-pdf", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def submit_pdf_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
|
||||
if not is_allowed_file(file.filename, settings.ALLOWED_PDF_EXTENSIONS):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.")
|
||||
|
||||
job_id = uuid.uuid4().hex
|
||||
safe_basename = secure_filename(file.filename)
|
||||
unique_filename = f"{Path(safe_basename).stem}_{job_id}{Path(safe_basename).suffix}"
|
||||
upload_path = settings.UPLOADS_DIR / unique_filename
|
||||
processed_path = settings.PROCESSED_DIR / unique_filename
|
||||
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
|
||||
job_data = JobCreate(id=job_id, task_type="ocr", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
|
||||
run_pdf_ocr_task(new_job.id, str(upload_path), str(processed_path))
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
|
||||
@app.post("/ocr-image", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def submit_image_ocr(file: UploadFile = File(...), db: Session = Depends(get_db)):
|
||||
if not is_allowed_file(file.filename, settings.ALLOWED_IMAGE_EXTENSIONS):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PNG, JPG, or TIFF.")
|
||||
|
||||
job_id = uuid.uuid4().hex
|
||||
safe_basename = secure_filename(file.filename)
|
||||
file_ext = Path(safe_basename).suffix
|
||||
unique_filename = f"{Path(safe_basename).stem}_{job_id}{file_ext}"
|
||||
upload_path = settings.UPLOADS_DIR / unique_filename
|
||||
processed_path = settings.PROCESSED_DIR / f"{Path(safe_basename).stem}_{job_id}.txt"
|
||||
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
|
||||
job_data = JobCreate(id=job_id, task_type="ocr-image", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
|
||||
run_image_ocr_task(new_job.id, str(upload_path), str(processed_path))
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
|
||||
# MODIFICATION: Endpoint now accepts `model_size` as form data.
|
||||
@app.post("/transcribe-audio", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def submit_audio_transcription(
|
||||
file: UploadFile = File(...),
|
||||
model_size: str = Form("base"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
if not is_allowed_file(file.filename, settings.ALLOWED_AUDIO_EXTENSIONS):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio file type.")
|
||||
|
||||
# Validate the selected model size
|
||||
if model_size not in settings.ALLOWED_WHISPER_MODELS:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid model size: {model_size}.")
|
||||
|
||||
job_id = uuid.uuid4().hex
|
||||
safe_basename = secure_filename(file.filename)
|
||||
stem, suffix = Path(safe_basename).stem, Path(safe_basename).suffix
|
||||
|
||||
audio_filename = f"{stem}_{job_id}{suffix}"
|
||||
transcript_filename = f"{stem}_{job_id}.txt"
|
||||
upload_path = settings.UPLOADS_DIR / audio_filename
|
||||
processed_path = settings.PROCESSED_DIR / transcript_filename
|
||||
|
||||
await save_upload_file_chunked(file, upload_path)
|
||||
|
||||
job_data = JobCreate(id=job_id, task_type="transcription", original_filename=file.filename, input_filepath=str(upload_path), processed_filepath=str(processed_path))
|
||||
new_job = create_job(db=db, job=job_data)
|
||||
|
||||
# Pass the selected model size to the background task
|
||||
run_transcription_task(new_job.id, str(upload_path), str(processed_path), model_size=model_size)
|
||||
return {"job_id": new_job.id, "status": new_job.status}
|
||||
|
||||
@app.post("/job/{job_id}/cancel", status_code=status.HTTP_200_OK)
|
||||
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
||||
job = get_job(db, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found.")
|
||||
if job.status in ["pending", "processing"]:
|
||||
update_job_status(db, job_id, status="cancelled")
|
||||
return {"message": "Job cancellation requested."}
|
||||
raise HTTPException(status_code=400, detail=f"Job is already in a final state ({job.status}).")
|
||||
|
||||
@app.get("/jobs", response_model=List[JobSchema])
|
||||
async def get_all_jobs(db: Session = Depends(get_db)):
|
||||
return get_jobs(db)
|
||||
|
||||
@app.get("/job/{job_id}", response_model=JobSchema)
|
||||
async def get_job_status(job_id: str, db: Session = Depends(get_db)):
|
||||
job = get_job(db, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found.")
|
||||
return job
|
||||
|
||||
@app.get("/download/{filename}")
|
||||
async def download_file(filename: str):
|
||||
safe_filename = secure_filename(filename)
|
||||
file_path = settings.PROCESSED_DIR / safe_filename
|
||||
|
||||
if not file_path.resolve().is_relative_to(settings.PROCESSED_DIR.resolve()):
|
||||
raise HTTPException(status_code=403, detail="Access denied.")
|
||||
|
||||
if not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="File not found.")
|
||||
|
||||
return FileResponse(path=file_path, filename=safe_filename, media_type="application/octet-stream")
|
||||
Reference in New Issue
Block a user