SkyNait commited on
Commit
bc4eaf5
·
1 Parent(s): a127a50

Correct page range handling

Browse files
Files changed (1) hide show
  1. topic_extraction.py +8 -75
topic_extraction.py CHANGED
@@ -14,41 +14,24 @@ import torch
14
  import cv2
15
  import numpy as np
16
 
17
- # Attempt top-level import of google.genai
18
- try:
19
- from google import genai
20
- from google.genai import types
21
- except ImportError:
22
- genai = None
23
- types = None
24
-
25
- # magic-pdf imports
26
  from magic_pdf.data.dataset import PymuDocDataset
27
  from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
28
 
29
- # table extraction logic
30
  from table_row_extraction import TableExtractor
31
 
32
- ###############################################################################
33
- # Logging Setup
34
- ###############################################################################
35
  logging.basicConfig(level=logging.INFO)
36
  logger = logging.getLogger(__name__)
37
  logger.setLevel(logging.INFO)
38
 
39
- ###############################################################################
40
- # PDF Utility Functions
41
- ###############################################################################
42
  def unify_whitespace(text: str) -> str:
43
- """
44
- Replace runs of whitespace with a single space, strip leading/trailing, then lowercase.
45
- """
46
  return re.sub(r"\s+", " ", text).strip().lower()
47
 
48
  def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
49
  """
50
- Creates a new PDF (in memory) containing only pages in page_indices (0-based).
51
- Raises ValueError if page_indices is empty or out of range.
52
  """
53
  if not page_indices:
54
  raise ValueError("No page indices provided for subset creation.")
@@ -66,9 +49,6 @@ def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> byt
66
  doc.close()
67
  return subset_bytes
68
 
69
- ###############################################################################
70
- # Searching in PDF
71
- ###############################################################################
72
  def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
73
  """
74
  Return a sorted list of 0-based pages in which `search_text` (normalized) appears,
@@ -85,46 +65,20 @@ def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
85
  doc.close()
86
  return sorted(found)
87
 
88
- ###############################################################################
89
- # Gemini LLM for Subtopic Extraction
90
- ###############################################################################
91
  class GeminiTopicExtractor:
92
- """
93
- Extract subtopics from the PDF by reading the first `num_pages` pages, calling Gemini.
94
- We expect a structure like:
95
- {
96
- "2 Subject content and assessment information": {
97
- "Paper 1 and Paper 2: Pure Mathematics": [11, 29],
98
- "Paper 3: Statistics and Mechanics": [30, 42]
99
- }
100
- }
101
- or sometimes just a flat dict:
102
- {
103
- "Paper 1 and Paper 2: Pure Mathematics": [15, 33],
104
- "Paper 3: Statistics and Mechanics": [34, 46]
105
- }
106
- We'll parse both forms.
107
- """
108
  def __init__(self, api_key: str = None, num_pages: int = 10):
109
  self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
110
- if not self.api_key:
111
- logger.warning("No Gemini API key for subtopic extraction.")
112
  self.num_pages = num_pages
113
 
114
  def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
115
  """
116
  Return a dict of subtopics => [start_page, end_page].
117
- Could be empty if parsing fails or the LLM can't find subtopics.
118
  """
119
  first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
120
  if not first_pages_text.strip():
121
  logger.error("No text from first pages => cannot extract subtopics.")
122
  return {}
123
 
124
- if genai is None or types is None:
125
- logger.warning("google.genai not installed. Returning empty subtopics.")
126
- return {}
127
-
128
  prompt = f"""
129
  You have the first pages of a PDF specification, including a table of contents.
130
 
@@ -229,15 +183,14 @@ Now, extract topics from this text:
229
  return {}
230
 
231
  raw_json = response.text.strip()
232
- # Clean up triple backticks
233
  cleaned = raw_json.replace("```json", "").replace("```", "")
234
 
235
  # Attempt to parse
236
  data = json.loads(cleaned)
237
  # data might be nested or flat
238
- # if nested, e.g. {"2 Subject content": {"Paper 1...": [11,29]}}
239
- # if flat, e.g. {"Paper 1...": [11,29]}
240
- # We'll unify it to a single dict of subname => [start,end].
241
  final_dict = {}
242
 
243
  # If the top-level is a dict of dict
@@ -254,7 +207,6 @@ Now, extract topics from this text:
254
  break
255
 
256
  if found_sub_dict is not None:
257
- # parse found_sub_dict
258
  for subk, rng in found_sub_dict.items():
259
  if isinstance(rng, list) and len(rng) == 2:
260
  final_dict[subk] = rng
@@ -283,12 +235,9 @@ Now, extract topics from this text:
283
  logger.error(f"Could not open PDF: {e}")
284
  return "\n".join(text_parts)
285
 
286
- ###############################################################################
287
- # Concurrency for Table Classification
288
- ###############################################################################
289
  def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str:
290
  """
291
- Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE using Gemini.
292
  """
293
  if not api_key:
294
  logger.warning("No Gemini API key => NO_TABLE.")
@@ -354,9 +303,6 @@ NO_TABLE
354
  logger.error(f"Gemini table classification error: {e}")
355
  return "NO_TABLE"
356
 
357
- ###############################################################################
358
- # LocalImageWriter
359
- ###############################################################################
360
  class LocalImageWriter:
361
  """
362
  Writes extracted images, then does concurrency-based table classification calls.
@@ -476,17 +422,7 @@ class LocalImageWriter:
476
 
477
  return md_content
478
 
479
- ###############################################################################
480
- # MineruNoTextProcessor
481
- ###############################################################################
482
  class MineruNoTextProcessor:
483
- """
484
- 1) Use Gemini to get subtopics => e.g. {"Paper 1 and Paper 2: Pure Mathematics": [11,29], ...}
485
- 2) For each subtopic name => find real occurrence in PDF at or after (start_page-1).
486
- 3) offset = occurrence_page - (start_page-1). clamp offset >= 0
487
- 4) Flatten final pages, subset PDF, run magic-pdf => concurrency => final MD
488
- 5) If no subtopics found, process entire PDF as fallback.
489
- """
490
  def __init__(self, output_folder: str, gemini_api_key: str = None):
491
  self.output_folder = output_folder
492
  os.makedirs(self.output_folder, exist_ok=True)
@@ -604,9 +540,6 @@ class MineruNoTextProcessor:
604
  finally:
605
  self.cleanup_gpu()
606
 
607
- ###############################################################################
608
- # Example Main
609
- ###############################################################################
610
  if __name__ == "__main__":
611
  input_pdf = "/home/user/app/input_output/ocr-specification-economics.pdf"
612
  output_dir = "/home/user/app/outputs"
 
14
  import cv2
15
  import numpy as np
16
 
17
+ from google import genai
18
+ from google.genai import types
19
+
 
 
 
 
 
 
20
  from magic_pdf.data.dataset import PymuDocDataset
21
  from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
22
 
 
23
  from table_row_extraction import TableExtractor
24
 
 
 
 
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
28
 
 
 
 
29
  def unify_whitespace(text: str) -> str:
 
 
 
30
  return re.sub(r"\s+", " ", text).strip().lower()
31
 
32
  def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
33
  """
34
+ Creates a new PDF (in memory) containing only pages from page_indices (0-based).
 
35
  """
36
  if not page_indices:
37
  raise ValueError("No page indices provided for subset creation.")
 
49
  doc.close()
50
  return subset_bytes
51
 
 
 
 
52
  def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
53
  """
54
  Return a sorted list of 0-based pages in which `search_text` (normalized) appears,
 
65
  doc.close()
66
  return sorted(found)
67
 
 
 
 
68
  class GeminiTopicExtractor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def __init__(self, api_key: str = None, num_pages: int = 10):
70
  self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
 
 
71
  self.num_pages = num_pages
72
 
73
  def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
74
  """
75
  Return a dict of subtopics => [start_page, end_page].
 
76
  """
77
  first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
78
  if not first_pages_text.strip():
79
  logger.error("No text from first pages => cannot extract subtopics.")
80
  return {}
81
 
 
 
 
 
82
  prompt = f"""
83
  You have the first pages of a PDF specification, including a table of contents.
84
 
 
183
  return {}
184
 
185
  raw_json = response.text.strip()
 
186
  cleaned = raw_json.replace("```json", "").replace("```", "")
187
 
188
  # Attempt to parse
189
  data = json.loads(cleaned)
190
  # data might be nested or flat
191
+ # if nested, example {"2 Subject content": {"Paper 1...": [11,29]}}
192
+ # if flat, example {"Paper 1...": [11,29]}
193
+ # so we unify it to a single dict of subname => [start,end].
194
  final_dict = {}
195
 
196
  # If the top-level is a dict of dict
 
207
  break
208
 
209
  if found_sub_dict is not None:
 
210
  for subk, rng in found_sub_dict.items():
211
  if isinstance(rng, list) and len(rng) == 2:
212
  final_dict[subk] = rng
 
235
  logger.error(f"Could not open PDF: {e}")
236
  return "\n".join(text_parts)
237
 
 
 
 
238
  def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str:
239
  """
240
+ Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
241
  """
242
  if not api_key:
243
  logger.warning("No Gemini API key => NO_TABLE.")
 
303
  logger.error(f"Gemini table classification error: {e}")
304
  return "NO_TABLE"
305
 
 
 
 
306
  class LocalImageWriter:
307
  """
308
  Writes extracted images, then does concurrency-based table classification calls.
 
422
 
423
  return md_content
424
 
 
 
 
425
  class MineruNoTextProcessor:
 
 
 
 
 
 
 
426
  def __init__(self, output_folder: str, gemini_api_key: str = None):
427
  self.output_folder = output_folder
428
  os.makedirs(self.output_folder, exist_ok=True)
 
540
  finally:
541
  self.cleanup_gpu()
542
 
 
 
 
543
  if __name__ == "__main__":
544
  input_pdf = "/home/user/app/input_output/ocr-specification-economics.pdf"
545
  output_dir = "/home/user/app/outputs"