| | |
| | import warnings |
| |
|
| | |
| | warnings.filterwarnings("ignore", category=FutureWarning, module="spaces") |
| |
|
| | import base64 |
| | import os |
| | import re |
| | import subprocess |
| | import sys |
| | import threading |
| | import time |
| | from collections import OrderedDict |
| | from io import BytesIO |
| |
|
| | import gradio as gr |
| | import pypdfium2 as pdfium |
| | import spaces |
| | import torch |
| | from openai import OpenAI |
| | from PIL import Image |
| | from transformers import ( |
| | LightOnOcrForConditionalGeneration, |
| | LightOnOcrProcessor, |
| | TextIteratorStreamer, |
| | ) |
| |
|
| | |
| | VLLM_ENDPOINT_OCR = os.environ.get("VLLM_ENDPOINT_OCR") |
| | VLLM_ENDPOINT_BBOX = os.environ.get("VLLM_ENDPOINT_BBOX") |
| |
|
| | |
| | STREAM_YIELD_INTERVAL = 0.5 |
| |
|
| | |
| | MODEL_REGISTRY = { |
| | "LightOnOCR-2-1B (Best OCR)": { |
| | "model_id": "lightonai/LightOnOCR-2-1B", |
| | "has_bbox": False, |
| | "description": "Best overall OCR performance", |
| | "vllm_endpoint": VLLM_ENDPOINT_OCR, |
| | }, |
| | "LightOnOCR-2-1B-bbox (Best Bbox)": { |
| | "model_id": "lightonai/LightOnOCR-2-1B-bbox", |
| | "has_bbox": True, |
| | "description": "Best bounding box detection", |
| | "vllm_endpoint": VLLM_ENDPOINT_BBOX, |
| | }, |
| | "LightOnOCR-2-1B-base": { |
| | "model_id": "lightonai/LightOnOCR-2-1B-base", |
| | "has_bbox": False, |
| | "description": "Base OCR model", |
| | }, |
| | "LightOnOCR-2-1B-bbox-base": { |
| | "model_id": "lightonai/LightOnOCR-2-1B-bbox-base", |
| | "has_bbox": True, |
| | "description": "Base bounding box model", |
| | }, |
| | "LightOnOCR-2-1B-ocr-soup": { |
| | "model_id": "lightonai/LightOnOCR-2-1B-ocr-soup", |
| | "has_bbox": False, |
| | "description": "OCR soup variant", |
| | }, |
| | "LightOnOCR-2-1B-bbox-soup": { |
| | "model_id": "lightonai/LightOnOCR-2-1B-bbox-soup", |
| | "has_bbox": True, |
| | "description": "Bounding box soup variant", |
| | }, |
| | } |
| |
|
| | DEFAULT_MODEL = "LightOnOCR-2-1B (Best OCR)" |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | if device == "cuda": |
| | attn_implementation = "sdpa" |
| | dtype = torch.bfloat16 |
| | print("Using sdpa for GPU") |
| | else: |
| | attn_implementation = "eager" |
| | dtype = torch.float32 |
| | print("Using eager attention for CPU") |
| |
|
| |
|
| | class ModelManager: |
| | """Manages model loading with LRU caching and GPU memory management.""" |
| |
|
| | def __init__(self, max_cached=2): |
| | self._cache = OrderedDict() |
| | self._max_cached = max_cached |
| |
|
| | def get_model(self, model_name): |
| | """Get model and processor, loading if necessary.""" |
| | config = MODEL_REGISTRY.get(model_name) |
| | if config is None: |
| | raise ValueError(f"Unknown model: {model_name}") |
| |
|
| | model_id = config["model_id"] |
| |
|
| | |
| | if model_id in self._cache: |
| | |
| | self._cache.move_to_end(model_id) |
| | print(f"Using cached model: {model_name}") |
| | return self._cache[model_id] |
| |
|
| | |
| | while len(self._cache) >= self._max_cached: |
| | evicted_id, (evicted_model, _) = self._cache.popitem(last=False) |
| | print(f"Evicting model from cache: {evicted_id}") |
| | del evicted_model |
| | if device == "cuda": |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | print(f"Loading model: {model_name} ({model_id})...") |
| | model = ( |
| | LightOnOcrForConditionalGeneration.from_pretrained( |
| | model_id, |
| | attn_implementation=attn_implementation, |
| | torch_dtype=dtype, |
| | trust_remote_code=True, |
| | ) |
| | .to(device) |
| | .eval() |
| | ) |
| |
|
| | processor = LightOnOcrProcessor.from_pretrained( |
| | model_id, trust_remote_code=True |
| | ) |
| |
|
| | |
| | self._cache[model_id] = (model, processor) |
| | print(f"Model loaded successfully: {model_name}") |
| |
|
| | return model, processor |
| |
|
| | def get_model_info(self, model_name): |
| | """Get model info without loading.""" |
| | return MODEL_REGISTRY.get(model_name) |
| |
|
| |
|
| | |
| | model_manager = ModelManager(max_cached=2) |
| | print("Model manager initialized. Models will be loaded on first use.") |
| |
|
| |
|
| | def render_pdf_page(page, max_resolution=1540, scale=2.77): |
| | """Render a PDF page to PIL Image.""" |
| | width, height = page.get_size() |
| | pixel_width = width * scale |
| | pixel_height = height * scale |
| | resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height) |
| | target_scale = scale * resize_factor |
| | return page.render(scale=target_scale, rev_byteorder=True).to_pil() |
| |
|
| |
|
| | def process_pdf(pdf_path, page_num=1): |
| | """Extract a specific page from PDF.""" |
| | pdf = pdfium.PdfDocument(pdf_path) |
| | total_pages = len(pdf) |
| | page_idx = min(max(int(page_num) - 1, 0), total_pages - 1) |
| |
|
| | page = pdf[page_idx] |
| | img = render_pdf_page(page) |
| |
|
| | pdf.close() |
| | return img, total_pages, page_idx + 1 |
| |
|
| |
|
| | def clean_output_text(text): |
| | """Remove chat template artifacts from output.""" |
| | |
| | markers_to_remove = ["system", "user", "assistant"] |
| |
|
| | |
| | lines = text.split("\n") |
| | cleaned_lines = [] |
| |
|
| | for line in lines: |
| | stripped = line.strip() |
| | |
| | if stripped.lower() not in markers_to_remove: |
| | cleaned_lines.append(line) |
| |
|
| | |
| | cleaned = "\n".join(cleaned_lines).strip() |
| |
|
| | |
| | if "assistant" in text.lower(): |
| | parts = text.split("assistant", 1) |
| | if len(parts) > 1: |
| | cleaned = parts[1].strip() |
| |
|
| | return cleaned |
| |
|
| |
|
| | |
| | BBOX_PATTERN = r"!\[image\]\((image_\d+\.png)\)\s*(\d+),(\d+),(\d+),(\d+)" |
| |
|
| |
|
| | def parse_bbox_output(text): |
| | """Parse bbox output and return cleaned text with list of detections.""" |
| | detections = [] |
| | for match in re.finditer(BBOX_PATTERN, text): |
| | image_ref, x1, y1, x2, y2 = match.groups() |
| | detections.append( |
| | {"ref": image_ref, "coords": (int(x1), int(y1), int(x2), int(y2))} |
| | ) |
| | |
| | cleaned = re.sub(BBOX_PATTERN, r"", text) |
| | return cleaned, detections |
| |
|
| |
|
| | def crop_from_bbox(source_image, bbox, padding=5): |
| | """Crop region from image based on normalized [0,1000] coords.""" |
| | w, h = source_image.size |
| | x1, y1, x2, y2 = bbox["coords"] |
| |
|
| | |
| | px1 = int(x1 * w / 1000) |
| | py1 = int(y1 * h / 1000) |
| | px2 = int(x2 * w / 1000) |
| | py2 = int(y2 * h / 1000) |
| |
|
| | |
| | px1, py1 = max(0, px1 - padding), max(0, py1 - padding) |
| | px2, py2 = min(w, px2 + padding), min(h, py2 + padding) |
| |
|
| | return source_image.crop((px1, py1, px2, py2)) |
| |
|
| |
|
| | def image_to_data_uri(image): |
| | """Convert PIL image to base64 data URI for markdown embedding.""" |
| | buffer = BytesIO() |
| | image.save(buffer, format="PNG") |
| | b64 = base64.b64encode(buffer.getvalue()).decode() |
| | return f"data:image/png;base64,{b64}" |
| |
|
| |
|
| | def extract_text_via_vllm(image, model_name, temperature=0.2, stream=False, max_tokens=2048): |
| | """Extract text from image using vLLM endpoint.""" |
| | config = MODEL_REGISTRY.get(model_name) |
| | if config is None: |
| | raise ValueError(f"Unknown model: {model_name}") |
| |
|
| | endpoint = config.get("vllm_endpoint") |
| | if endpoint is None: |
| | raise ValueError(f"Model {model_name} does not have a vLLM endpoint") |
| |
|
| | model_id = config["model_id"] |
| |
|
| | |
| | if isinstance(image, Image.Image): |
| | image_uri = image_to_data_uri(image) |
| | else: |
| | |
| | image_uri = image |
| |
|
| | |
| | client = OpenAI(base_url=endpoint, api_key="not-needed") |
| |
|
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image_url", "image_url": {"url": image_uri}}, |
| | ], |
| | } |
| | ] |
| |
|
| | if stream: |
| | |
| | response = client.chat.completions.create( |
| | model=model_id, |
| | messages=messages, |
| | max_tokens=max_tokens, |
| | temperature=temperature if temperature > 0 else 0.0, |
| | top_p=0.9, |
| | stream=True, |
| | ) |
| |
|
| | full_text = "" |
| | last_yield_time = time.time() |
| | for chunk in response: |
| | if chunk.choices and chunk.choices[0].delta.content: |
| | full_text += chunk.choices[0].delta.content |
| | |
| | if time.time() - last_yield_time > STREAM_YIELD_INTERVAL: |
| | yield clean_output_text(full_text) |
| | last_yield_time = time.time() |
| | |
| | yield clean_output_text(full_text) |
| | else: |
| | |
| | response = client.chat.completions.create( |
| | model=model_id, |
| | messages=messages, |
| | max_tokens=max_tokens, |
| | temperature=temperature if temperature > 0 else 0.0, |
| | top_p=0.9, |
| | stream=False, |
| | ) |
| |
|
| | output_text = response.choices[0].message.content |
| | cleaned_text = clean_output_text(output_text) |
| | yield cleaned_text |
| |
|
| |
|
| | def render_bbox_with_crops(raw_output, source_image): |
| | """Replace markdown image placeholders with actual cropped images.""" |
| | cleaned, detections = parse_bbox_output(raw_output) |
| |
|
| | for bbox in detections: |
| | try: |
| | cropped = crop_from_bbox(source_image, bbox) |
| | data_uri = image_to_data_uri(cropped) |
| | |
| | cleaned = cleaned.replace( |
| | f"", f"" |
| | ) |
| | except Exception as e: |
| | print(f"Error cropping bbox {bbox}: {e}") |
| | |
| | continue |
| |
|
| | return cleaned |
| |
|
| |
|
| | @spaces.GPU |
| | def extract_text_from_image(image, model_name, temperature=0.2, stream=False, max_tokens=2048): |
| | """Extract text from image using LightOnOCR model.""" |
| | |
| | config = MODEL_REGISTRY.get(model_name, {}) |
| | if config.get("vllm_endpoint"): |
| | |
| | yield from extract_text_via_vllm(image, model_name, temperature, stream, max_tokens) |
| | return |
| |
|
| | |
| | model, processor = model_manager.get_model(model_name) |
| |
|
| | |
| | chat = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "url": image}, |
| | ], |
| | } |
| | ] |
| |
|
| | |
| | inputs = processor.apply_chat_template( |
| | chat, |
| | add_generation_prompt=True, |
| | tokenize=True, |
| | return_dict=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | |
| | inputs = { |
| | k: v.to(device=device, dtype=dtype) |
| | if isinstance(v, torch.Tensor) |
| | and v.dtype in [torch.float32, torch.float16, torch.bfloat16] |
| | else v.to(device) |
| | if isinstance(v, torch.Tensor) |
| | else v |
| | for k, v in inputs.items() |
| | } |
| |
|
| | generation_kwargs = dict( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature if temperature > 0 else 0.0, |
| | top_p=0.9, |
| | top_k=0, |
| | use_cache=True, |
| | do_sample=temperature > 0, |
| | ) |
| |
|
| | if stream: |
| | |
| | streamer = TextIteratorStreamer( |
| | processor.tokenizer, skip_prompt=True, skip_special_tokens=True |
| | ) |
| | generation_kwargs["streamer"] = streamer |
| |
|
| | |
| | thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
| | thread.start() |
| |
|
| | |
| | full_text = "" |
| | last_yield_time = time.time() |
| | for new_text in streamer: |
| | full_text += new_text |
| | |
| | if time.time() - last_yield_time > STREAM_YIELD_INTERVAL: |
| | yield clean_output_text(full_text) |
| | last_yield_time = time.time() |
| |
|
| | thread.join() |
| | |
| | yield clean_output_text(full_text) |
| | else: |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate(**generation_kwargs) |
| |
|
| | |
| | output_text = processor.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | cleaned_text = clean_output_text(output_text) |
| |
|
| | yield cleaned_text |
| |
|
| |
|
| | def process_input(file_input, model_name, temperature, page_num, enable_streaming, max_output_tokens): |
| | """Process uploaded file (image or PDF) and extract text with optional streaming.""" |
| | if file_input is None: |
| | yield "Please upload an image or PDF first.", "", "", None, gr.update() |
| | return |
| |
|
| | image_to_process = None |
| | page_info = "" |
| |
|
| | file_path = file_input if isinstance(file_input, str) else file_input.name |
| |
|
| | |
| | if file_path.lower().endswith(".pdf"): |
| | try: |
| | image_to_process, total_pages, actual_page = process_pdf( |
| | file_path, int(page_num) |
| | ) |
| | page_info = f"Processing page {actual_page} of {total_pages}" |
| | except Exception as e: |
| | yield f"Error processing PDF: {str(e)}", "", "", None, gr.update() |
| | return |
| | |
| | else: |
| | try: |
| | image_to_process = Image.open(file_path) |
| | page_info = "Processing image" |
| | except Exception as e: |
| | yield f"Error opening image: {str(e)}", "", "", None, gr.update() |
| | return |
| |
|
| | |
| | model_info = MODEL_REGISTRY.get(model_name, {}) |
| | has_bbox = model_info.get("has_bbox", False) |
| |
|
| | try: |
| | |
| | for extracted_text in extract_text_from_image( |
| | image_to_process, model_name, temperature, stream=enable_streaming, max_tokens=max_output_tokens |
| | ): |
| | |
| | if has_bbox: |
| | rendered_text = render_bbox_with_crops(extracted_text, image_to_process) |
| | else: |
| | rendered_text = extracted_text |
| | yield ( |
| | rendered_text, |
| | extracted_text, |
| | page_info, |
| | image_to_process, |
| | gr.update(), |
| | ) |
| |
|
| | except Exception as e: |
| | error_msg = f"Error during text extraction: {str(e)}" |
| | yield error_msg, error_msg, page_info, image_to_process, gr.update() |
| |
|
| |
|
| | def update_slider_and_preview(file_input): |
| | """Update page slider and preview image based on uploaded file.""" |
| | if file_input is None: |
| | return gr.update(maximum=20, value=1), None |
| |
|
| | file_path = file_input if isinstance(file_input, str) else file_input.name |
| |
|
| | if file_path.lower().endswith(".pdf"): |
| | try: |
| | pdf = pdfium.PdfDocument(file_path) |
| | total_pages = len(pdf) |
| | |
| | page = pdf[0] |
| | preview_image = page.render(scale=2).to_pil() |
| | pdf.close() |
| | return gr.update(maximum=total_pages, value=1), preview_image |
| | except: |
| | return gr.update(maximum=20, value=1), None |
| | else: |
| | |
| | try: |
| | preview_image = Image.open(file_path) |
| | return gr.update(maximum=1, value=1), preview_image |
| | except: |
| | return gr.update(maximum=1, value=1), None |
| |
|
| |
|
| | |
| | def get_model_info_text(model_name): |
| | """Return formatted model info string.""" |
| | info = MODEL_REGISTRY.get(model_name, {}) |
| | has_bbox = ( |
| | "Yes - will show cropped regions inline" |
| | if info.get("has_bbox", False) |
| | else "No" |
| | ) |
| | return f"**Description:** {info.get('description', 'N/A')}\n**Bounding Box Detection:** {has_bbox}" |
| |
|
| |
|
| | |
| | with gr.Blocks(title="LightOnOCR-2 Multi-Model OCR") as demo: |
| | gr.Markdown(f""" |
| | # LightOnOCR-2 — Efficient 1B VLM for OCR |
| | |
| | State-of-the-art OCR on OlmOCR-Bench, ~9× smaller and faster than competitors. Handles tables, forms, math, multi-column layouts. |
| | |
| | ⚡ **3.3× faster** than Chandra, **1.7× faster** than OlmOCR | 💸 **<$0.01/1k pages** | 🧠 End-to-end differentiable | 📍 Bbox variants for image detection |
| | |
| | 📄 [Paper](https://huggingface.co/papers/lightonocr-2) | 📝 [Blog](https://huggingface.co/blog/lightonai/lightonocr-2) | 📊 [Dataset](https://huggingface.co/datasets/lightonai/LightOnOCR-mix-0126) | 📓 [Finetuning](https://colab.research.google.com/drive/1WjbsFJZ4vOAAlKtcCauFLn_evo5UBRNa?usp=sharing) |
| | |
| | --- |
| | |
| | **How to use:** Select a model → Upload image/PDF → Click "Extract Text" | **Device:** {device.upper()} | **Attention:** {attn_implementation} |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | model_selector = gr.Dropdown( |
| | choices=list(MODEL_REGISTRY.keys()), |
| | value=DEFAULT_MODEL, |
| | label="Model", |
| | info="Select OCR model variant", |
| | ) |
| | model_info = gr.Markdown( |
| | value=get_model_info_text(DEFAULT_MODEL), label="Model Info" |
| | ) |
| | file_input = gr.File( |
| | label="Upload Image or PDF", |
| | file_types=[".pdf", ".png", ".jpg", ".jpeg"], |
| | type="filepath", |
| | ) |
| | rendered_image = gr.Image( |
| | label="Preview", type="pil", height=400, interactive=False |
| | ) |
| | num_pages = gr.Slider( |
| | minimum=1, |
| | maximum=20, |
| | value=1, |
| | step=1, |
| | label="PDF: Page Number", |
| | info="Select which page to extract", |
| | ) |
| | page_info = gr.Textbox(label="Processing Info", value="", interactive=False) |
| | temperature = gr.Slider( |
| | minimum=0.0, |
| | maximum=1.0, |
| | value=0.2, |
| | step=0.05, |
| | label="Temperature", |
| | info="0.0 = deterministic, Higher = more varied", |
| | ) |
| | enable_streaming = gr.Checkbox( |
| | label="Enable Streaming", |
| | value=True, |
| | info="Show text progressively as it's generated", |
| | ) |
| | max_output_tokens = gr.Slider( |
| | minimum=256, |
| | maximum=8192, |
| | value=2048, |
| | step=256, |
| | label="Max Output Tokens", |
| | info="Maximum number of tokens to generate", |
| | ) |
| | submit_btn = gr.Button("Extract Text", variant="primary") |
| | clear_btn = gr.Button("Clear", variant="secondary") |
| |
|
| | with gr.Column(scale=2): |
| | output_text = gr.Markdown( |
| | label="📄 Extracted Text (Rendered)", |
| | value="*Extracted text will appear here...*", |
| | latex_delimiters=[ |
| | {"left": "$$", "right": "$$", "display": True}, |
| | {"left": "$", "right": "$", "display": False}, |
| | ], |
| | ) |
| |
|
| | |
| | EXAMPLE_IMAGES = [ |
| | "examples/example_1.png", |
| | "examples/example_2.png", |
| | "examples/example_3.png", |
| | "examples/example_4.png", |
| | "examples/example_5.png", |
| | "examples/example_6.png", |
| | "examples/example_7.png", |
| | "examples/example_8.png", |
| | "examples/example_9.png", |
| | ] |
| |
|
| | with gr.Accordion("📁 Example Documents (click an image to load)", open=True): |
| | example_gallery = gr.Gallery( |
| | value=EXAMPLE_IMAGES, |
| | columns=5, |
| | rows=2, |
| | height="auto", |
| | object_fit="contain", |
| | show_label=False, |
| | allow_preview=False, |
| | ) |
| |
|
| | def load_example_image(evt: gr.SelectData): |
| | """Load selected example image into file input.""" |
| | return EXAMPLE_IMAGES[evt.index] |
| |
|
| | example_gallery.select( |
| | fn=load_example_image, |
| | outputs=[file_input], |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | raw_output = gr.Textbox( |
| | label="Raw Markdown Output", |
| | placeholder="Raw text will appear here...", |
| | lines=20, |
| | max_lines=30, |
| | ) |
| |
|
| | |
| | submit_btn.click( |
| | fn=process_input, |
| | inputs=[file_input, model_selector, temperature, num_pages, enable_streaming, max_output_tokens], |
| | outputs=[output_text, raw_output, page_info, rendered_image, num_pages], |
| | ) |
| |
|
| | file_input.change( |
| | fn=update_slider_and_preview, |
| | inputs=[file_input], |
| | outputs=[num_pages, rendered_image], |
| | ) |
| |
|
| | model_selector.change( |
| | fn=get_model_info_text, inputs=[model_selector], outputs=[model_info] |
| | ) |
| |
|
| | clear_btn.click( |
| | fn=lambda: ( |
| | None, |
| | DEFAULT_MODEL, |
| | get_model_info_text(DEFAULT_MODEL), |
| | "*Extracted text will appear here...*", |
| | "", |
| | "", |
| | None, |
| | 1, |
| | 2048, |
| | ), |
| | outputs=[ |
| | file_input, |
| | model_selector, |
| | model_info, |
| | output_text, |
| | raw_output, |
| | page_info, |
| | rendered_image, |
| | num_pages, |
| | max_output_tokens, |
| | ], |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.launch(theme=gr.themes.Soft(), ssr_mode=False, share = True) |
| |
|