Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| import subprocess | |
| import sys | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import pypdfium2 as pdfium | |
| from transformers import ( | |
| LightOnOCRForConditionalGeneration, | |
| LightOnOCRProcessor, | |
| ) | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if device == "cuda": | |
| attn_implementation = "sdpa" | |
| dtype = torch.bfloat16 | |
| else: | |
| attn_implementation = "eager" | |
| dtype = torch.float32 | |
| ocr_model = LightOnOCRForConditionalGeneration.from_pretrained( | |
| "lightonai/LightOnOCR-1B-1025", | |
| attn_implementation=attn_implementation, | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| ).to(device).eval() | |
| processor = LightOnOCRProcessor.from_pretrained( | |
| "lightonai/LightOnOCR-1B-1025", | |
| trust_remote_code=True, | |
| ) | |
| ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner") | |
| ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner") | |
| ner_pipeline = pipeline( | |
| "ner", | |
| model=ner_model, | |
| tokenizer=ner_tokenizer, | |
| aggregation_strategy="simple", | |
| ) | |
| def render_pdf_page(page, max_resolution=1540, scale=2.77): | |
| width, height = page.get_size() | |
| pixel_width = width * scale | |
| pixel_height = height * scale | |
| resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height) | |
| target_scale = scale * resize_factor | |
| return page.render(scale=target_scale, rev_byteorder=True).to_pil() | |
| def process_pdf(pdf_path, page_num=1): | |
| pdf = pdfium.PdfDocument(pdf_path) | |
| total_pages = len(pdf) | |
| page_idx = min(max(int(page_num) - 1, 0), total_pages - 1) | |
| page = pdf[page_idx] | |
| img = render_pdf_page(page) | |
| pdf.close() | |
| return img, total_pages, page_idx + 1 | |
| def clean_output_text(text): | |
| markers_to_remove = ["system", "user", "assistant"] | |
| lines = text.split('\n') | |
| cleaned_lines = [] | |
| for line in lines: | |
| stripped = line.strip() | |
| if stripped.lower() not in markers_to_remove: | |
| cleaned_lines.append(line) | |
| cleaned = '\n'.join(cleaned_lines).strip() | |
| if "assistant" in text.lower(): | |
| parts = text.split("assistant", 1) | |
| if len(parts) > 1: | |
| cleaned = parts[1].strip() | |
| return cleaned | |
| def preprocess_image_for_ocr(image): | |
| """Convert PIL.Image to adaptive thresholded image for OCR.""" | |
| image_rgb = image.convert("RGB") | |
| img_np = np.array(image_rgb) | |
| gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) | |
| adaptive_threshold = cv2.adaptiveThreshold( | |
| gray, | |
| 255, | |
| cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
| cv2.THRESH_BINARY, | |
| 85, | |
| 11, | |
| ) | |
| preprocessed_pil = Image.fromarray(adaptive_threshold) | |
| return preprocessed_pil | |
| def extract_text_from_image(image, temperature=0.2): | |
| """OCR + clinical NER, with preprocessing.""" | |
| processed_img = preprocess_image_for_ocr(image) | |
| chat = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": processed_img} | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| chat, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| # Move inputs to device | |
| inputs = { | |
| k: ( | |
| v.to(device=device, dtype=dtype) | |
| if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16] | |
| else v.to(device) | |
| if isinstance(v, torch.Tensor) | |
| else v | |
| ) | |
| for k, v in inputs.items() | |
| } | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=2048, | |
| temperature=temperature if temperature > 0 else 0.0, | |
| use_cache=True, | |
| do_sample=temperature > 0, | |
| ) | |
| with torch.no_grad(): | |
| outputs = ocr_model.generate(**generation_kwargs) | |
| output_text = processor.decode(outputs[0], skip_special_tokens=True) | |
| cleaned_text = clean_output_text(output_text) | |
| entities = ner_pipeline(cleaned_text) | |
| medications = [] | |
| for ent in entities: | |
| if ent["entity_group"] == "treatment": | |
| word = ent["word"] | |
| if word.startswith("##") and medications: | |
| medications[-1] += word[2:] | |
| else: | |
| medications.append(word) | |
| medications_str = ", ".join(set(medications)) if medications else "None detected" | |
| yield cleaned_text, medications_str, output_text, processed_img | |
| def process_input(file_input, temperature, page_num): | |
| if file_input is None: | |
| yield "Please upload an image or PDF first.", "", "", "", "No file!", 1 | |
| return | |
| image_to_process = None | |
| page_info = "" | |
| slider_value = page_num | |
| file_path = file_input if isinstance(file_input, str) else file_input.name | |
| if file_path.lower().endswith(".pdf"): | |
| try: | |
| image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num)) | |
| page_info = f"Processing page {actual_page} of {total_pages}" | |
| slider_value = actual_page | |
| except Exception as e: | |
| msg = f"Error processing PDF: {str(e)}" | |
| yield msg, "", msg, "", None, slider_value | |
| return | |
| else: | |
| try: | |
| image_to_process = Image.open(file_path) | |
| page_info = "Processing image" | |
| except Exception as e: | |
| msg = f"Error opening image: {str(e)}" | |
| yield msg, "", msg, "", None, slider_value | |
| return | |
| try: | |
| for cleaned_text, medications, raw_md, processed_img in extract_text_from_image( | |
| image_to_process, temperature | |
| ): | |
| yield cleaned_text, medications, raw_md, page_info, processed_img, slider_value | |
| except Exception as e: | |
| error_msg = f"Error during text extraction: {str(e)}" | |
| yield error_msg, "", error_msg, page_info, image_to_process, slider_value | |
| def update_slider(file_input): | |
| if file_input is None: | |
| return gr.update(maximum=20, value=1) | |
| file_path = file_input if isinstance(file_input, str) else file_input.name | |
| if file_path.lower().endswith('.pdf'): | |
| try: | |
| pdf = pdfium.PdfDocument(file_path) | |
| total_pages = len(pdf) | |
| pdf.close() | |
| return gr.update(maximum=total_pages, value=1) | |
| except: | |
| return gr.update(maximum=20, value=1) | |
| else: | |
| return gr.update(maximum=1, value=1) | |
| with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo: | |
| file_input = gr.File( | |
| label="🖼️ Upload Image or PDF", | |
| file_types=[".pdf", ".png", ".jpg", ".jpeg"], | |
| type="filepath" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Temperature" | |
| ) | |
| page_slider = gr.Slider( | |
| minimum=1, maximum=20, value=1, step=1, | |
| label="Page Number (PDF only)", | |
| interactive=True | |
| ) | |
| output_text = gr.Textbox( | |
| label="📝 Extracted Text", | |
| lines=4, | |
| max_lines=10, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| medicines_output = gr.Textbox( | |
| label="💊 Extracted Medicines/Drugs", | |
| placeholder="Medicine/drug names will appear here...", | |
| lines=2, | |
| max_lines=5, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| raw_output = gr.Textbox( | |
| label="Raw Model Output", | |
| lines=2, | |
| max_lines=5, | |
| interactive=False | |
| ) | |
| page_info = gr.Markdown( | |
| value="", # Info of PDF page | |
| interactive=False | |
| ) | |
| rendered_image = gr.Image( | |
| label="Processed Image (Thresholded for OCR)", | |
| interactive=False | |
| ) | |
| num_pages = gr.Number( | |
| value=1, label="Current Page (slider)", visible=False | |
| ) | |
| submit_btn = gr.Button("Extract Medicines", variant="primary") | |
| submit_btn.click( | |
| fn=process_input, | |
| inputs=[file_input, temperature, page_slider], | |
| outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages] | |
| ) | |
| file_input.change( | |
| fn=update_slider, | |
| inputs=[file_input], | |
| outputs=[page_slider] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |
| # Create Gradio interface | |
| # with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo: | |
| # gr.Markdown(f""" | |
| # # 📖 Image/PDF to Text Extraction with LightOnOCR | |
| # **💡 How to use:** | |
| # 1. Upload an image or PDF | |
| # 2. For PDFs: select which page to extract (1-20) | |
| # 3. Adjust temperature if needed | |
| # 4. Click "Extract Text" | |
| # **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables! | |
| # **Model:** LightOnOCR-1B-1025 by LightOn AI | |
| # **Device:** {device.upper()} | |
| # **Attention:** {attn_implementation} | |
| # """) | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # file_input = gr.File( | |
| # label="🖼️ Upload Image or PDF", | |
| # file_types=[".pdf", ".png", ".jpg", ".jpeg"], | |
| # type="filepath" | |
| # ) | |
| # rendered_image = gr.Image( | |
| # label="📄 Preview", | |
| # type="pil", | |
| # height=400, | |
| # interactive=False | |
| # ) | |
| # num_pages = gr.Slider( | |
| # minimum=1, | |
| # maximum=20, | |
| # value=1, | |
| # step=1, | |
| # label="PDF: Page Number", | |
| # info="Select which page to extract" | |
| # ) | |
| # page_info = gr.Textbox( | |
| # label="Processing Info", | |
| # value="", | |
| # interactive=False | |
| # ) | |
| # temperature = gr.Slider( | |
| # minimum=0.0, | |
| # maximum=1.0, | |
| # value=0.2, | |
| # step=0.05, | |
| # label="Temperature", | |
| # info="0.0 = deterministic, Higher = more varied" | |
| # ) | |
| # submit_btn = gr.Button("Extract Text", variant="primary") | |
| # clear_btn = gr.Button("Clear", variant="secondary") | |
| # with gr.Column(scale=2): | |
| # output_text = gr.Markdown( | |
| # label="📄 Extracted Text (Rendered)", | |
| # value="*Extracted text will appear here...*" | |
| # ) | |
| # medications_output = gr.Textbox( | |
| # label="💊 Extracted Medicines/Drugs", | |
| # placeholder="Medicine/drug names will appear here...", | |
| # lines=2, | |
| # max_lines=5, | |
| # interactive=False, | |
| # show_copy_button=True | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # raw_output = gr.Textbox( | |
| # label="Raw Markdown Output", | |
| # placeholder="Raw text will appear here...", | |
| # lines=20, | |
| # max_lines=30, | |
| # show_copy_button=True | |
| # ) | |
| # # Event handlers | |
| # submit_btn.click( | |
| # fn=process_input, | |
| # inputs=[file_input, temperature, num_pages, ], | |
| # outputs=[output_text, medications_output, raw_output, page_info, rendered_image, num_pages] | |
| # ) | |
| #################################### old code to be checked ############################################# | |
| # import sys | |
| # import threading | |
| # import spaces | |
| # import torch | |
| # import gradio as gr | |
| # from PIL import Image | |
| # from io import BytesIO | |
| # import pypdfium2 as pdfium | |
| # from transformers import ( | |
| # LightOnOCRForConditionalGeneration, | |
| # LightOnOCRProcessor, | |
| # TextIteratorStreamer, | |
| # ) | |
| # # ---- CLINICAL NER IMPORTS ---- | |
| # import spacy | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # Choose best attention implementation based on device | |
| # if device == "cuda": | |
| # attn_implementation = "sdpa" | |
| # dtype = torch.bfloat16 | |
| # print("Using sdpa for GPU") | |
| # else: | |
| # attn_implementation = "eager" # Best for CPU | |
| # dtype = torch.float32 | |
| # print("Using eager attention for CPU") | |
| # # Initialize the LightOnOCR model and processor | |
| # print(f"Loading model on {device} with {attn_implementation} attention...") | |
| # model = LightOnOCRForConditionalGeneration.from_pretrained( | |
| # "lightonai/LightOnOCR-1B-1025", | |
| # attn_implementation=attn_implementation, | |
| # torch_dtype=dtype, | |
| # trust_remote_code=True | |
| # ).to(device).eval() | |
| # processor = LightOnOCRProcessor.from_pretrained( | |
| # "lightonai/LightOnOCR-1B-1025", | |
| # trust_remote_code=True | |
| # ) | |
| # print("Model loaded successfully!") | |
| # # ---- LOAD CLINICAL NER MODEL (BC5CDR) ---- | |
| # print("Loading clinical NER model (bc5cdr)...") | |
| # nlp_ner = spacy.load("en_ner_bc5cdr_md") | |
| # print("Clinical NER loaded.") | |
| # def render_pdf_page(page, max_resolution=1540, scale=2.77): | |
| # """Render a PDF page to PIL Image.""" | |
| # width, height = page.get_size() | |
| # pixel_width = width * scale | |
| # pixel_height = height * scale | |
| # resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height) | |
| # target_scale = scale * resize_factor | |
| # return page.render(scale=target_scale, rev_byteorder=True).to_pil() | |
| # def process_pdf(pdf_path, page_num=1): | |
| # """Extract a specific page from PDF.""" | |
| # pdf = pdfium.PdfDocument(pdf_path) | |
| # total_pages = len(pdf) | |
| # page_idx = min(max(int(page_num) - 1, 0), total_pages - 1) | |
| # page = pdf[page_idx] | |
| # img = render_pdf_page(page) | |
| # pdf.close() | |
| # return img, total_pages, page_idx + 1 | |
| # def clean_output_text(text): | |
| # """Remove chat template artifacts from output.""" | |
| # markers_to_remove = ["system", "user", "assistant"] | |
| # lines = text.split('\n') | |
| # cleaned_lines = [] | |
| # for line in lines: | |
| # stripped = line.strip() | |
| # # Skip lines that are just template markers | |
| # if stripped.lower() not in markers_to_remove: | |
| # cleaned_lines.append(line) | |
| # cleaned = '\n'.join(cleaned_lines).strip() | |
| # if "assistant" in text.lower(): | |
| # parts = text.split("assistant", 1) | |
| # if len(parts) > 1: | |
| # cleaned = parts[1].strip() | |
| # return cleaned | |
| # def extract_medication_names(text): | |
| # """Extract medication names using clinical NER (spacy: bc5cdr CHEMICAL).""" | |
| # doc = nlp_ner(text) | |
| # meds = [ent.text for ent in doc.ents if ent.label_ == "CHEMICAL"] | |
| # meds_unique = list(dict.fromkeys(meds)) | |
| # return meds_unique | |
| # @spaces.GPU | |
| # def extract_text_from_image(image, temperature=0.2, stream=False): | |
| # """Extract text from image using LightOnOCR model.""" | |
| # chat = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image", "url": image}, | |
| # ], | |
| # } | |
| # ] | |
| # inputs = processor.apply_chat_template( | |
| # chat, | |
| # add_generation_prompt=True, | |
| # tokenize=True, | |
| # return_dict=True, | |
| # return_tensors="pt" | |
| # ) | |
| # inputs = { | |
| # k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16] | |
| # else v.to(device) if isinstance(v, torch.Tensor) | |
| # else v | |
| # for k, v in inputs.items() | |
| # } | |
| # generation_kwargs = dict( | |
| # **inputs, | |
| # max_new_tokens=2048, | |
| # temperature=temperature if temperature > 0 else 0.0, | |
| # use_cache=True, | |
| # do_sample=temperature > 0, | |
| # ) | |
| # if stream: | |
| # # Streaming generation | |
| # streamer = TextIteratorStreamer( | |
| # processor.tokenizer, | |
| # skip_prompt=True, | |
| # skip_special_tokens=True | |
| # ) | |
| # generation_kwargs["streamer"] = streamer | |
| # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| # thread.start() | |
| # full_text = "" | |
| # for new_text in streamer: | |
| # full_text += new_text | |
| # cleaned_text = clean_output_text(full_text) | |
| # yield cleaned_text | |
| # thread.join() | |
| # else: | |
| # # Non-streaming generation | |
| # with torch.no_grad(): | |
| # outputs = model.generate(**generation_kwargs) | |
| # output_text = processor.decode(outputs[0], skip_special_tokens=True) | |
| # cleaned_text = clean_output_text(output_text) | |
| # yield cleaned_text | |
| # def process_input(file_input, temperature, page_num, enable_streaming): | |
| # """Process uploaded file (image or PDF) and extract medication names via OCR+NER.""" | |
| # if file_input is None: | |
| # yield "Please upload an image or PDF first.", "", "", None, gr.update() | |
| # return | |
| # image_to_process = None | |
| # page_info = "" | |
| # file_path = file_input if isinstance(file_input, str) else file_input.name | |
| # # Handle PDF files | |
| # if file_path.lower().endswith('.pdf'): | |
| # try: | |
| # image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num)) | |
| # page_info = f"Processing page {actual_page} of {total_pages}" | |
| # except Exception as e: | |
| # yield f"Error processing PDF: {str(e)}", "", "", None, gr.update() | |
| # return | |
| # # Handle image files | |
| # else: | |
| # try: | |
| # image_to_process = Image.open(file_path) | |
| # page_info = "Processing image" | |
| # except Exception as e: | |
| # yield f"Error opening image: {str(e)}", "", "", None, gr.update() | |
| # return | |
| # try: | |
| # for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming): | |
| # meds = extract_medication_names(extracted_text) | |
| # meds_str = "\n".join(meds) if meds else "No medications found." | |
| # yield meds_str, meds_str, page_info, image_to_process, gr.update() | |
| # except Exception as e: | |
| # error_msg = f"Error during text extraction: {str(e)}" | |
| # yield error_msg, error_msg, page_info, image_to_process, gr.update() | |
| # def update_slider(file_input): | |
| # """Update page slider based on PDF page count.""" | |
| # if file_input is None: | |
| # return gr.update(maximum=20, value=1) | |
| # file_path = file_input if isinstance(file_input, str) else file_input.name | |
| # if file_path.lower().endswith('.pdf'): | |
| # try: | |
| # pdf = pdfium.PdfDocument(file_path) | |
| # total_pages = len(pdf) | |
| # pdf.close() | |
| # return gr.update(maximum=total_pages, value=1) | |
| # except: | |
| # return gr.update(maximum=20, value=1) | |
| # else: | |
| # return gr.update(maximum=1, value=1) | |
| # # ----- GRADIO UI ----- | |
| # with gr.Blocks(title="📖 Image/PDF OCR + Clinical NER", theme=gr.themes.Soft()) as demo: | |
| # gr.Markdown(f""" | |
| # # 📖 Medication Extraction from Image/PDF with LightOnOCR + Clinical NER | |
| # **💡 How to use:** | |
| # 1. Upload an image or PDF | |
| # 2. For PDFs: select which page to extract | |
| # 3. Adjust temperature if needed | |
| # 4. Click "Extract Medications" | |
| # **Output:** Only medication names found in text (via NER) | |
| # **Model:** LightOnOCR-1B-1025 by LightOn AI | |
| # **Device:** {device.upper()} | |
| # **Attention:** {attn_implementation} | |
| # """) | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # file_input = gr.File( | |
| # label="🖼️ Upload Image or PDF", | |
| # file_types=[".pdf", ".png", ".jpg", ".jpeg"], | |
| # type="filepath" | |
| # ) | |
| # rendered_image = gr.Image( | |
| # label="📄 Preview", | |
| # type="pil", | |
| # height=400, | |
| # interactive=False | |
| # ) | |
| # num_pages = gr.Slider( | |
| # minimum=1, | |
| # maximum=20, | |
| # value=1, | |
| # step=1, | |
| # label="PDF: Page Number", | |
| # info="Select which page to extract" | |
| # ) | |
| # page_info = gr.Textbox( | |
| # label="Processing Info", | |
| # value="", | |
| # interactive=False | |
| # ) | |
| # temperature = gr.Slider( | |
| # minimum=0.0, | |
| # maximum=1.0, | |
| # value=0.2, | |
| # step=0.05, | |
| # label="Temperature", | |
| # info="0.0 = deterministic, Higher = more varied" | |
| # ) | |
| # enable_streaming = gr.Checkbox( | |
| # label="Enable Streaming", | |
| # value=True, | |
| # info="Show text progressively as it's generated" | |
| # ) | |
| # submit_btn = gr.Button("Extract Medications", variant="primary") | |
| # clear_btn = gr.Button("Clear", variant="secondary") | |
| # with gr.Column(scale=2): | |
| # output_text = gr.Markdown( | |
| # label="🩺 Extracted Medication Names", | |
| # value="*Medication names will appear here...*" | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # raw_output = gr.Textbox( | |
| # label="Extracted Medication Names (Raw)", | |
| # placeholder="Medication list will appear here...", | |
| # lines=20, | |
| # max_lines=30, | |
| # show_copy_button=True | |
| # ) | |
| # # Event handlers | |
| # submit_btn.click( | |
| # fn=process_input, | |
| # inputs=[file_input, temperature, num_pages, enable_streaming], | |
| # outputs=[output_text, raw_output, page_info, rendered_image, num_pages] | |
| # ) | |
| # file_input.change( | |
| # fn=update_slider, | |
| # inputs=[file_input], | |
| # outputs=[num_pages] | |
| # ) | |
| # clear_btn.click( | |
| # fn=lambda: (None, "*Medication names will appear here...*", "", "", None, 1), | |
| # outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages] | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |