Correct page range handling
Browse files- topic_extraction.py +8 -75
topic_extraction.py
CHANGED
|
@@ -14,41 +14,24 @@ import torch
|
|
| 14 |
import cv2
|
| 15 |
import numpy as np
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
from google.genai import types
|
| 21 |
-
except ImportError:
|
| 22 |
-
genai = None
|
| 23 |
-
types = None
|
| 24 |
-
|
| 25 |
-
# magic-pdf imports
|
| 26 |
from magic_pdf.data.dataset import PymuDocDataset
|
| 27 |
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
| 28 |
|
| 29 |
-
# table extraction logic
|
| 30 |
from table_row_extraction import TableExtractor
|
| 31 |
|
| 32 |
-
###############################################################################
|
| 33 |
-
# Logging Setup
|
| 34 |
-
###############################################################################
|
| 35 |
logging.basicConfig(level=logging.INFO)
|
| 36 |
logger = logging.getLogger(__name__)
|
| 37 |
logger.setLevel(logging.INFO)
|
| 38 |
|
| 39 |
-
###############################################################################
|
| 40 |
-
# PDF Utility Functions
|
| 41 |
-
###############################################################################
|
| 42 |
def unify_whitespace(text: str) -> str:
|
| 43 |
-
"""
|
| 44 |
-
Replace runs of whitespace with a single space, strip leading/trailing, then lowercase.
|
| 45 |
-
"""
|
| 46 |
return re.sub(r"\s+", " ", text).strip().lower()
|
| 47 |
|
| 48 |
def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
|
| 49 |
"""
|
| 50 |
-
Creates a new PDF (in memory) containing only pages
|
| 51 |
-
Raises ValueError if page_indices is empty or out of range.
|
| 52 |
"""
|
| 53 |
if not page_indices:
|
| 54 |
raise ValueError("No page indices provided for subset creation.")
|
|
@@ -66,9 +49,6 @@ def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> byt
|
|
| 66 |
doc.close()
|
| 67 |
return subset_bytes
|
| 68 |
|
| 69 |
-
###############################################################################
|
| 70 |
-
# Searching in PDF
|
| 71 |
-
###############################################################################
|
| 72 |
def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
|
| 73 |
"""
|
| 74 |
Return a sorted list of 0-based pages in which `search_text` (normalized) appears,
|
|
@@ -85,46 +65,20 @@ def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
|
|
| 85 |
doc.close()
|
| 86 |
return sorted(found)
|
| 87 |
|
| 88 |
-
###############################################################################
|
| 89 |
-
# Gemini LLM for Subtopic Extraction
|
| 90 |
-
###############################################################################
|
| 91 |
class GeminiTopicExtractor:
|
| 92 |
-
"""
|
| 93 |
-
Extract subtopics from the PDF by reading the first `num_pages` pages, calling Gemini.
|
| 94 |
-
We expect a structure like:
|
| 95 |
-
{
|
| 96 |
-
"2 Subject content and assessment information": {
|
| 97 |
-
"Paper 1 and Paper 2: Pure Mathematics": [11, 29],
|
| 98 |
-
"Paper 3: Statistics and Mechanics": [30, 42]
|
| 99 |
-
}
|
| 100 |
-
}
|
| 101 |
-
or sometimes just a flat dict:
|
| 102 |
-
{
|
| 103 |
-
"Paper 1 and Paper 2: Pure Mathematics": [15, 33],
|
| 104 |
-
"Paper 3: Statistics and Mechanics": [34, 46]
|
| 105 |
-
}
|
| 106 |
-
We'll parse both forms.
|
| 107 |
-
"""
|
| 108 |
def __init__(self, api_key: str = None, num_pages: int = 10):
|
| 109 |
self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
|
| 110 |
-
if not self.api_key:
|
| 111 |
-
logger.warning("No Gemini API key for subtopic extraction.")
|
| 112 |
self.num_pages = num_pages
|
| 113 |
|
| 114 |
def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
|
| 115 |
"""
|
| 116 |
Return a dict of subtopics => [start_page, end_page].
|
| 117 |
-
Could be empty if parsing fails or the LLM can't find subtopics.
|
| 118 |
"""
|
| 119 |
first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
|
| 120 |
if not first_pages_text.strip():
|
| 121 |
logger.error("No text from first pages => cannot extract subtopics.")
|
| 122 |
return {}
|
| 123 |
|
| 124 |
-
if genai is None or types is None:
|
| 125 |
-
logger.warning("google.genai not installed. Returning empty subtopics.")
|
| 126 |
-
return {}
|
| 127 |
-
|
| 128 |
prompt = f"""
|
| 129 |
You have the first pages of a PDF specification, including a table of contents.
|
| 130 |
|
|
@@ -229,15 +183,14 @@ Now, extract topics from this text:
|
|
| 229 |
return {}
|
| 230 |
|
| 231 |
raw_json = response.text.strip()
|
| 232 |
-
# Clean up triple backticks
|
| 233 |
cleaned = raw_json.replace("```json", "").replace("```", "")
|
| 234 |
|
| 235 |
# Attempt to parse
|
| 236 |
data = json.loads(cleaned)
|
| 237 |
# data might be nested or flat
|
| 238 |
-
# if nested,
|
| 239 |
-
# if flat,
|
| 240 |
-
#
|
| 241 |
final_dict = {}
|
| 242 |
|
| 243 |
# If the top-level is a dict of dict
|
|
@@ -254,7 +207,6 @@ Now, extract topics from this text:
|
|
| 254 |
break
|
| 255 |
|
| 256 |
if found_sub_dict is not None:
|
| 257 |
-
# parse found_sub_dict
|
| 258 |
for subk, rng in found_sub_dict.items():
|
| 259 |
if isinstance(rng, list) and len(rng) == 2:
|
| 260 |
final_dict[subk] = rng
|
|
@@ -283,12 +235,9 @@ Now, extract topics from this text:
|
|
| 283 |
logger.error(f"Could not open PDF: {e}")
|
| 284 |
return "\n".join(text_parts)
|
| 285 |
|
| 286 |
-
###############################################################################
|
| 287 |
-
# Concurrency for Table Classification
|
| 288 |
-
###############################################################################
|
| 289 |
def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str:
|
| 290 |
"""
|
| 291 |
-
Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
|
| 292 |
"""
|
| 293 |
if not api_key:
|
| 294 |
logger.warning("No Gemini API key => NO_TABLE.")
|
|
@@ -354,9 +303,6 @@ NO_TABLE
|
|
| 354 |
logger.error(f"Gemini table classification error: {e}")
|
| 355 |
return "NO_TABLE"
|
| 356 |
|
| 357 |
-
###############################################################################
|
| 358 |
-
# LocalImageWriter
|
| 359 |
-
###############################################################################
|
| 360 |
class LocalImageWriter:
|
| 361 |
"""
|
| 362 |
Writes extracted images, then does concurrency-based table classification calls.
|
|
@@ -476,17 +422,7 @@ class LocalImageWriter:
|
|
| 476 |
|
| 477 |
return md_content
|
| 478 |
|
| 479 |
-
###############################################################################
|
| 480 |
-
# MineruNoTextProcessor
|
| 481 |
-
###############################################################################
|
| 482 |
class MineruNoTextProcessor:
|
| 483 |
-
"""
|
| 484 |
-
1) Use Gemini to get subtopics => e.g. {"Paper 1 and Paper 2: Pure Mathematics": [11,29], ...}
|
| 485 |
-
2) For each subtopic name => find real occurrence in PDF at or after (start_page-1).
|
| 486 |
-
3) offset = occurrence_page - (start_page-1). clamp offset >= 0
|
| 487 |
-
4) Flatten final pages, subset PDF, run magic-pdf => concurrency => final MD
|
| 488 |
-
5) If no subtopics found, process entire PDF as fallback.
|
| 489 |
-
"""
|
| 490 |
def __init__(self, output_folder: str, gemini_api_key: str = None):
|
| 491 |
self.output_folder = output_folder
|
| 492 |
os.makedirs(self.output_folder, exist_ok=True)
|
|
@@ -604,9 +540,6 @@ class MineruNoTextProcessor:
|
|
| 604 |
finally:
|
| 605 |
self.cleanup_gpu()
|
| 606 |
|
| 607 |
-
###############################################################################
|
| 608 |
-
# Example Main
|
| 609 |
-
###############################################################################
|
| 610 |
if __name__ == "__main__":
|
| 611 |
input_pdf = "/home/user/app/input_output/ocr-specification-economics.pdf"
|
| 612 |
output_dir = "/home/user/app/outputs"
|
|
|
|
| 14 |
import cv2
|
| 15 |
import numpy as np
|
| 16 |
|
| 17 |
+
from google import genai
|
| 18 |
+
from google.genai import types
|
| 19 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from magic_pdf.data.dataset import PymuDocDataset
|
| 21 |
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
| 22 |
|
|
|
|
| 23 |
from table_row_extraction import TableExtractor
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
logging.basicConfig(level=logging.INFO)
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
logger.setLevel(logging.INFO)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
def unify_whitespace(text: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 30 |
return re.sub(r"\s+", " ", text).strip().lower()
|
| 31 |
|
| 32 |
def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
|
| 33 |
"""
|
| 34 |
+
Creates a new PDF (in memory) containing only pages from page_indices (0-based).
|
|
|
|
| 35 |
"""
|
| 36 |
if not page_indices:
|
| 37 |
raise ValueError("No page indices provided for subset creation.")
|
|
|
|
| 49 |
doc.close()
|
| 50 |
return subset_bytes
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
|
| 53 |
"""
|
| 54 |
Return a sorted list of 0-based pages in which `search_text` (normalized) appears,
|
|
|
|
| 65 |
doc.close()
|
| 66 |
return sorted(found)
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
class GeminiTopicExtractor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def __init__(self, api_key: str = None, num_pages: int = 10):
|
| 70 |
self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
|
|
|
|
|
|
|
| 71 |
self.num_pages = num_pages
|
| 72 |
|
| 73 |
def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
|
| 74 |
"""
|
| 75 |
Return a dict of subtopics => [start_page, end_page].
|
|
|
|
| 76 |
"""
|
| 77 |
first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
|
| 78 |
if not first_pages_text.strip():
|
| 79 |
logger.error("No text from first pages => cannot extract subtopics.")
|
| 80 |
return {}
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
prompt = f"""
|
| 83 |
You have the first pages of a PDF specification, including a table of contents.
|
| 84 |
|
|
|
|
| 183 |
return {}
|
| 184 |
|
| 185 |
raw_json = response.text.strip()
|
|
|
|
| 186 |
cleaned = raw_json.replace("```json", "").replace("```", "")
|
| 187 |
|
| 188 |
# Attempt to parse
|
| 189 |
data = json.loads(cleaned)
|
| 190 |
# data might be nested or flat
|
| 191 |
+
# if nested, example {"2 Subject content": {"Paper 1...": [11,29]}}
|
| 192 |
+
# if flat, example {"Paper 1...": [11,29]}
|
| 193 |
+
# so we unify it to a single dict of subname => [start,end].
|
| 194 |
final_dict = {}
|
| 195 |
|
| 196 |
# If the top-level is a dict of dict
|
|
|
|
| 207 |
break
|
| 208 |
|
| 209 |
if found_sub_dict is not None:
|
|
|
|
| 210 |
for subk, rng in found_sub_dict.items():
|
| 211 |
if isinstance(rng, list) and len(rng) == 2:
|
| 212 |
final_dict[subk] = rng
|
|
|
|
| 235 |
logger.error(f"Could not open PDF: {e}")
|
| 236 |
return "\n".join(text_parts)
|
| 237 |
|
|
|
|
|
|
|
|
|
|
| 238 |
def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str:
|
| 239 |
"""
|
| 240 |
+
Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
|
| 241 |
"""
|
| 242 |
if not api_key:
|
| 243 |
logger.warning("No Gemini API key => NO_TABLE.")
|
|
|
|
| 303 |
logger.error(f"Gemini table classification error: {e}")
|
| 304 |
return "NO_TABLE"
|
| 305 |
|
|
|
|
|
|
|
|
|
|
| 306 |
class LocalImageWriter:
|
| 307 |
"""
|
| 308 |
Writes extracted images, then does concurrency-based table classification calls.
|
|
|
|
| 422 |
|
| 423 |
return md_content
|
| 424 |
|
|
|
|
|
|
|
|
|
|
| 425 |
class MineruNoTextProcessor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
def __init__(self, output_folder: str, gemini_api_key: str = None):
|
| 427 |
self.output_folder = output_folder
|
| 428 |
os.makedirs(self.output_folder, exist_ok=True)
|
|
|
|
| 540 |
finally:
|
| 541 |
self.cleanup_gpu()
|
| 542 |
|
|
|
|
|
|
|
|
|
|
| 543 |
if __name__ == "__main__":
|
| 544 |
input_pdf = "/home/user/app/input_output/ocr-specification-economics.pdf"
|
| 545 |
output_dir = "/home/user/app/outputs"
|