SkyNait commited on
Commit
da9ad0b
·
verified ·
1 Parent(s): ee099e1

add comments for understanding

Browse files
Files changed (1) hide show
  1. topic_extraction.py +41 -66
topic_extraction.py CHANGED
@@ -26,16 +26,15 @@ logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
28
 
 
 
 
29
  def unify_whitespace(text: str) -> str:
30
  return re.sub(r"\s+", " ", text).strip().lower()
31
 
32
  def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
33
- """
34
- Creates a new PDF (in memory) containing only pages from page_indices (0-based).
35
- """
36
  if not page_indices:
37
  raise ValueError("No page indices provided for subset creation.")
38
-
39
  doc = fitz.open(stream=original_pdf_bytes, filetype="pdf")
40
  new_doc = fitz.open()
41
  for p in sorted(set(page_indices)):
@@ -78,7 +77,6 @@ class GeminiTopicExtractor:
78
  if not first_pages_text.strip():
79
  logger.error("No text from first pages => cannot extract subtopics.")
80
  return {}
81
-
82
  prompt = f"""
83
  You have the first pages of a PDF specification, including a table of contents.
84
 
@@ -162,17 +160,20 @@ Appendix 5: Index – 63
162
  The correct output should be:
163
 
164
  {{
165
- "Theme 1: Introduction to markets and market failure": [5, 10]
166
- "Theme 2: The UK economy – performance and policies": [11, 20]
167
- "Theme 3: Business behaviour and the labour market": [21, 28]
168
  "Theme 4: A global perspective": [29, 38]
169
  }}
170
 
171
  Now, extract topics from this text:
172
  {first_pages_text}
173
  """
 
 
 
 
174
  try:
175
- client = genai.Client(api_key=self.api_key)
176
  response = client.models.generate_content(
177
  model="gemini-2.0-flash",
178
  contents=[prompt],
@@ -181,7 +182,7 @@ Now, extract topics from this text:
181
  if not response or not response.text:
182
  logger.warning("No text from LLM => returning empty subtopics.")
183
  return {}
184
-
185
  raw_json = response.text.strip()
186
  cleaned = raw_json.replace("```json", "").replace("```", "")
187
 
@@ -205,7 +206,6 @@ Now, extract topics from this text:
205
  # might be the sub-sub dict
206
  found_sub_dict = v
207
  break
208
-
209
  if found_sub_dict is not None:
210
  for subk, rng in found_sub_dict.items():
211
  if isinstance(rng, list) and len(rng) == 2:
@@ -216,7 +216,6 @@ Now, extract topics from this text:
216
  for subk, rng in data.items():
217
  if isinstance(rng, list) and len(rng) == 2:
218
  final_dict[subk] = rng
219
-
220
  return final_dict
221
  except Exception as e:
222
  logger.error(f"Gemini subtopic extraction error: {e}")
@@ -239,14 +238,7 @@ def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str
239
  """
240
  Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
241
  """
242
- if not api_key:
243
- logger.warning("No Gemini API key => NO_TABLE.")
244
- return "NO_TABLE"
245
- if genai is None or types is None:
246
- logger.warning("google.genai not installed => NO_TABLE.")
247
- return "NO_TABLE"
248
-
249
- # Attempt to shrink
250
  try:
251
  arr = np.frombuffer(image_data, np.uint8)
252
  img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
@@ -266,7 +258,7 @@ def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str
266
  image_data = enc.tobytes()
267
  except Exception as e:
268
  logger.warning(f"shrink_image_to_jpeg error: {e}")
269
-
270
  prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns.
271
  The three-column 'table' image include such key features:
272
  - Three columns header columns
@@ -287,8 +279,11 @@ TWO_COLUMN
287
  THREE_COLUMN
288
  NO_TABLE
289
  """
 
 
 
 
290
  try:
291
- client = genai.Client(api_key=api_key)
292
  resp = client.models.generate_content(
293
  model="gemini-2.0-flash",
294
  contents=[
@@ -324,10 +319,8 @@ class LocalImageWriter:
324
  def __init__(self, output_folder: str, gemini_api_key: str):
325
  self.output_folder = output_folder
326
  os.makedirs(self.output_folder, exist_ok=True)
327
-
328
  self.images_dir = os.path.join(self.output_folder, "images")
329
  os.makedirs(self.images_dir, exist_ok=True)
330
-
331
  self.descriptions = {}
332
  self._img_count = 0
333
  self.gemini_api_key = gemini_api_key
@@ -349,11 +342,7 @@ class LocalImageWriter:
349
  def post_process(self, key: str, md_content: str) -> str:
350
  logger.info("Classifying images to detect tables (concurrent)...")
351
  with concurrent.futures.ThreadPoolExecutor(max_workers=6) as exe:
352
- fut_map = {}
353
- for p, info in self.descriptions.items():
354
- fut = exe.submit(call_gemini_for_table_classification, info["data"], self.gemini_api_key)
355
- fut_map[fut] = p
356
-
357
  for fut in concurrent.futures.as_completed(fut_map):
358
  path = fut_map[fut]
359
  try:
@@ -362,8 +351,7 @@ class LocalImageWriter:
362
  except Exception as e:
363
  logger.error(f"Table classification error: {e}")
364
  self.descriptions[path]['table_classification'] = "NO_TABLE"
365
-
366
- # 2) Set final alt text
367
  for p, info in self.descriptions.items():
368
  cls = info['table_classification']
369
  if cls == "TWO_COLUMN":
@@ -372,22 +360,21 @@ class LocalImageWriter:
372
  info['final_alt'] = "HAS TO BE PROCESSED - three column table"
373
  else:
374
  info['final_alt'] = "NO_TABLE image"
375
-
376
- # 3) Replace placeholders in the Markdown
377
  for p, info in self.descriptions.items():
378
  old_md = f"![]({key}{p})"
379
  new_md = f"![{info['final_alt']}]({info['relative_path']})"
380
  md_content = md_content.replace(old_md, new_md)
381
-
382
- # 4) If any table images => extract rows
383
  md_content = self._process_table_images_in_markdown(md_content)
384
-
385
- # 5) Keep only lines that are image references
386
  final_lines = []
387
  for line in md_content.split("\n"):
388
  if re.match(r"^\!\[.*\]\(.*\)", line.strip()):
389
  final_lines.append(line.strip())
390
-
391
  return "\n".join(final_lines)
392
 
393
  def _process_table_images_in_markdown(self, md_content: str) -> str:
@@ -395,7 +382,7 @@ class LocalImageWriter:
395
  matches = re.findall(pat, md_content, flags=re.IGNORECASE)
396
  if not matches:
397
  return md_content
398
-
399
  for (col_type, image_path) in matches:
400
  logger.info(f"Processing table image => {image_path}, columns={col_type}")
401
  abs_image_path = os.path.join(self.output_folder, image_path)
@@ -418,7 +405,7 @@ class LocalImageWriter:
418
  out_folder = abs_image_path + "_rows"
419
  os.makedirs(out_folder, exist_ok=True)
420
  extractor.save_extracted_cells(abs_image_path, row_boxes, out_folder)
421
-
422
  snippet = ["**Extracted table cells:**"]
423
  for i, row in enumerate(row_boxes):
424
  row_dir = os.path.join(out_folder, f"row_{i}")
@@ -427,26 +414,22 @@ class LocalImageWriter:
427
  cell_path = os.path.join(row_dir, cell_file)
428
  relp = os.path.relpath(cell_path, self.output_folder)
429
  snippet.append(f"![Row {i} Col {j}]({relp})")
430
-
431
  new_snip = "\n".join(snippet)
432
  old_line = f"![HAS TO BE PROCESSED - {col_type} column table]({image_path})"
433
  md_content = md_content.replace(old_line, new_snip)
434
  except Exception as e:
435
  logger.error(f"Error processing table image {image_path}: {e}")
436
-
437
  return md_content
438
 
439
  class MineruNoTextProcessor:
440
  def __init__(self, output_folder: str, gemini_api_key: str = None):
441
  self.output_folder = output_folder
442
  os.makedirs(self.output_folder, exist_ok=True)
443
-
444
  self.layout_model = "doclayout_yolo"
445
  self.formula_enable = True
446
  self.table_enable = False
447
  self.language = "en"
448
-
449
- # Use our new flexible approach
450
  self.subtopic_extractor = GeminiTopicExtractor(api_key=gemini_api_key, num_pages=10)
451
  self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "")
452
 
@@ -459,19 +442,18 @@ class MineruNoTextProcessor:
459
  logger.error(f"Error during GPU cleanup: {e}")
460
 
461
  def process(self, pdf_path: str) -> str:
462
- logger.info(f"Processing PDF: {pdf_path}")
463
  try:
464
- # 1) Extract subtopics from Gemini
465
  subtopics = self.subtopic_extractor.extract_subtopics(pdf_path)
466
  logger.info(f"Gemini returned subtopics: {subtopics}")
467
-
468
- # 2) Read entire PDF
469
  with open(pdf_path, "rb") as f:
470
  pdf_bytes = f.read()
471
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
472
  total_pages = doc.page_count
473
  doc.close()
474
-
475
  final_pages = set()
476
  if not subtopics:
477
  logger.warning("No subtopics found. We'll process the entire PDF as fallback.")
@@ -490,7 +472,6 @@ class MineruNoTextProcessor:
490
  # find occurrences
491
  occs = find_all_occurrences(pdf_bytes, subname)
492
  logger.info(f"Occurrences of subtopic '{subname}': {occs}")
493
-
494
  doc_start_0 = start_p - 1
495
  chosen_page = None
496
  for p in occs:
@@ -505,27 +486,25 @@ class MineruNoTextProcessor:
505
  else:
506
  chosen_page = 0
507
  logger.warning(f"No occurrences for '{subname}'. Using page 0.")
508
-
509
  raw_offset = chosen_page - doc_start_0
510
  offset = max(0, raw_offset)
511
  logger.info(f"Subtopic '{subname}': doc_start={start_p}, chosen_page={chosen_page}, raw_offset={raw_offset}, offset={offset}")
512
-
513
  s0 = (start_p - 1) + offset
514
  e0 = (end_p - 1) + offset
515
  s0 = max(0, min(total_pages - 1, s0))
516
  e0 = max(0, min(total_pages - 1, e0))
517
  for pp in range(s0, e0 + 1):
518
  final_pages.add(pp)
519
-
520
- # 3) If final_pages is empty => fallback entire PDF
521
  if not final_pages:
522
  logger.warning("No valid pages after offset. We'll process entire PDF.")
523
  final_pages = set(range(total_pages))
524
-
525
  logger.info(f"Processing pages (0-based): {sorted(final_pages)}")
526
  subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages))
527
-
528
- # 4) doc_analyze => concurrency => final MD
529
  dataset = PymuDocDataset(subset_pdf_bytes)
530
  inference = doc_analyze(
531
  dataset,
@@ -535,22 +514,19 @@ class MineruNoTextProcessor:
535
  formula_enable=self.formula_enable,
536
  table_enable=self.table_enable
537
  )
538
- logger.info("doc_analyze complete. Extracting images...")
539
-
540
  writer = LocalImageWriter(self.output_folder, self.gemini_api_key)
541
  pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
542
  md_content = pipe_result.get_markdown("local-unique-prefix/")
543
-
544
  final_markdown = writer.post_process("local-unique-prefix/", md_content)
545
-
546
- # 5) Save
547
  out_path = os.path.join(self.output_folder, "final_output.md")
548
  with open(out_path, "w", encoding="utf-8") as f:
549
  f.write(final_markdown)
550
-
551
  logger.info(f"Markdown saved to: {out_path}")
552
  return final_markdown
553
-
554
  finally:
555
  self.cleanup_gpu()
556
 
@@ -558,7 +534,6 @@ if __name__ == "__main__":
558
  input_pdf = "/home/user/app/input_output/ocr-specification-economics.pdf"
559
  output_dir = "/home/user/app/outputs"
560
  gemini_key = os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU")
561
-
562
  try:
563
  processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key)
564
  md_output = processor.process(input_pdf)
 
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
28
 
29
+ _GEMINI_CLIENT = None
30
+ #to store and reuse a single Gemini client instance instead of reinitializing
31
+
32
  def unify_whitespace(text: str) -> str:
33
  return re.sub(r"\s+", " ", text).strip().lower()
34
 
35
  def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
 
 
 
36
  if not page_indices:
37
  raise ValueError("No page indices provided for subset creation.")
 
38
  doc = fitz.open(stream=original_pdf_bytes, filetype="pdf")
39
  new_doc = fitz.open()
40
  for p in sorted(set(page_indices)):
 
77
  if not first_pages_text.strip():
78
  logger.error("No text from first pages => cannot extract subtopics.")
79
  return {}
 
80
  prompt = f"""
81
  You have the first pages of a PDF specification, including a table of contents.
82
 
 
160
  The correct output should be:
161
 
162
  {{
163
+ "Theme 1: Introduction to markets and market failure": [5, 10],
164
+ "Theme 2: The UK economy – performance and policies": [11, 20],
165
+ "Theme 3: Business behaviour and the labour market": [21, 28],
166
  "Theme 4: A global perspective": [29, 38]
167
  }}
168
 
169
  Now, extract topics from this text:
170
  {first_pages_text}
171
  """
172
+ global _GEMINI_CLIENT
173
+ if _GEMINI_CLIENT is None:
174
+ _GEMINI_CLIENT = genai.Client(api_key=self.api_key)
175
+ client = _GEMINI_CLIENT
176
  try:
 
177
  response = client.models.generate_content(
178
  model="gemini-2.0-flash",
179
  contents=[prompt],
 
182
  if not response or not response.text:
183
  logger.warning("No text from LLM => returning empty subtopics.")
184
  return {}
185
+
186
  raw_json = response.text.strip()
187
  cleaned = raw_json.replace("```json", "").replace("```", "")
188
 
 
206
  # might be the sub-sub dict
207
  found_sub_dict = v
208
  break
 
209
  if found_sub_dict is not None:
210
  for subk, rng in found_sub_dict.items():
211
  if isinstance(rng, list) and len(rng) == 2:
 
216
  for subk, rng in data.items():
217
  if isinstance(rng, list) and len(rng) == 2:
218
  final_dict[subk] = rng
 
219
  return final_dict
220
  except Exception as e:
221
  logger.error(f"Gemini subtopic extraction error: {e}")
 
238
  """
239
  Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
240
  """
241
+ #shrink image to reduce size
 
 
 
 
 
 
 
242
  try:
243
  arr = np.frombuffer(image_data, np.uint8)
244
  img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
 
258
  image_data = enc.tobytes()
259
  except Exception as e:
260
  logger.warning(f"shrink_image_to_jpeg error: {e}")
261
+
262
  prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns.
263
  The three-column 'table' image include such key features:
264
  - Three columns header columns
 
279
  THREE_COLUMN
280
  NO_TABLE
281
  """
282
+ global _GEMINI_CLIENT
283
+ if _GEMINI_CLIENT is None:
284
+ _GEMINI_CLIENT = genai.Client(api_key=api_key)
285
+ client = _GEMINI_CLIENT
286
  try:
 
287
  resp = client.models.generate_content(
288
  model="gemini-2.0-flash",
289
  contents=[
 
319
  def __init__(self, output_folder: str, gemini_api_key: str):
320
  self.output_folder = output_folder
321
  os.makedirs(self.output_folder, exist_ok=True)
 
322
  self.images_dir = os.path.join(self.output_folder, "images")
323
  os.makedirs(self.images_dir, exist_ok=True)
 
324
  self.descriptions = {}
325
  self._img_count = 0
326
  self.gemini_api_key = gemini_api_key
 
342
  def post_process(self, key: str, md_content: str) -> str:
343
  logger.info("Classifying images to detect tables (concurrent)...")
344
  with concurrent.futures.ThreadPoolExecutor(max_workers=6) as exe:
345
+ fut_map = {exe.submit(call_gemini_for_table_classification, info["data"], self.gemini_api_key): p for p, info in self.descriptions.items()}
 
 
 
 
346
  for fut in concurrent.futures.as_completed(fut_map):
347
  path = fut_map[fut]
348
  try:
 
351
  except Exception as e:
352
  logger.error(f"Table classification error: {e}")
353
  self.descriptions[path]['table_classification'] = "NO_TABLE"
354
+
 
355
  for p, info in self.descriptions.items():
356
  cls = info['table_classification']
357
  if cls == "TWO_COLUMN":
 
360
  info['final_alt'] = "HAS TO BE PROCESSED - three column table"
361
  else:
362
  info['final_alt'] = "NO_TABLE image"
363
+
364
+ #replace placeholders in the Markdown
365
  for p, info in self.descriptions.items():
366
  old_md = f"![]({key}{p})"
367
  new_md = f"![{info['final_alt']}]({info['relative_path']})"
368
  md_content = md_content.replace(old_md, new_md)
369
+
370
+ # IF any table images => extract rows
371
  md_content = self._process_table_images_in_markdown(md_content)
372
+
373
+ # Keep only lines that are image references
374
  final_lines = []
375
  for line in md_content.split("\n"):
376
  if re.match(r"^\!\[.*\]\(.*\)", line.strip()):
377
  final_lines.append(line.strip())
 
378
  return "\n".join(final_lines)
379
 
380
  def _process_table_images_in_markdown(self, md_content: str) -> str:
 
382
  matches = re.findall(pat, md_content, flags=re.IGNORECASE)
383
  if not matches:
384
  return md_content
385
+
386
  for (col_type, image_path) in matches:
387
  logger.info(f"Processing table image => {image_path}, columns={col_type}")
388
  abs_image_path = os.path.join(self.output_folder, image_path)
 
405
  out_folder = abs_image_path + "_rows"
406
  os.makedirs(out_folder, exist_ok=True)
407
  extractor.save_extracted_cells(abs_image_path, row_boxes, out_folder)
408
+
409
  snippet = ["**Extracted table cells:**"]
410
  for i, row in enumerate(row_boxes):
411
  row_dir = os.path.join(out_folder, f"row_{i}")
 
414
  cell_path = os.path.join(row_dir, cell_file)
415
  relp = os.path.relpath(cell_path, self.output_folder)
416
  snippet.append(f"![Row {i} Col {j}]({relp})")
417
+
418
  new_snip = "\n".join(snippet)
419
  old_line = f"![HAS TO BE PROCESSED - {col_type} column table]({image_path})"
420
  md_content = md_content.replace(old_line, new_snip)
421
  except Exception as e:
422
  logger.error(f"Error processing table image {image_path}: {e}")
 
423
  return md_content
424
 
425
  class MineruNoTextProcessor:
426
  def __init__(self, output_folder: str, gemini_api_key: str = None):
427
  self.output_folder = output_folder
428
  os.makedirs(self.output_folder, exist_ok=True)
 
429
  self.layout_model = "doclayout_yolo"
430
  self.formula_enable = True
431
  self.table_enable = False
432
  self.language = "en"
 
 
433
  self.subtopic_extractor = GeminiTopicExtractor(api_key=gemini_api_key, num_pages=10)
434
  self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "")
435
 
 
442
  logger.error(f"Error during GPU cleanup: {e}")
443
 
444
  def process(self, pdf_path: str) -> str:
 
445
  try:
446
+ #Extract subtopics from Gemini
447
  subtopics = self.subtopic_extractor.extract_subtopics(pdf_path)
448
  logger.info(f"Gemini returned subtopics: {subtopics}")
449
+
450
+ #read whole pdf
451
  with open(pdf_path, "rb") as f:
452
  pdf_bytes = f.read()
453
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
454
  total_pages = doc.page_count
455
  doc.close()
456
+
457
  final_pages = set()
458
  if not subtopics:
459
  logger.warning("No subtopics found. We'll process the entire PDF as fallback.")
 
472
  # find occurrences
473
  occs = find_all_occurrences(pdf_bytes, subname)
474
  logger.info(f"Occurrences of subtopic '{subname}': {occs}")
 
475
  doc_start_0 = start_p - 1
476
  chosen_page = None
477
  for p in occs:
 
486
  else:
487
  chosen_page = 0
488
  logger.warning(f"No occurrences for '{subname}'. Using page 0.")
489
+
490
  raw_offset = chosen_page - doc_start_0
491
  offset = max(0, raw_offset)
492
  logger.info(f"Subtopic '{subname}': doc_start={start_p}, chosen_page={chosen_page}, raw_offset={raw_offset}, offset={offset}")
 
493
  s0 = (start_p - 1) + offset
494
  e0 = (end_p - 1) + offset
495
  s0 = max(0, min(total_pages - 1, s0))
496
  e0 = max(0, min(total_pages - 1, e0))
497
  for pp in range(s0, e0 + 1):
498
  final_pages.add(pp)
499
+
 
500
  if not final_pages:
501
  logger.warning("No valid pages after offset. We'll process entire PDF.")
502
  final_pages = set(range(total_pages))
503
+
504
  logger.info(f"Processing pages (0-based): {sorted(final_pages)}")
505
  subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages))
506
+
507
+ # doc_analyze => concurrency => final MD
508
  dataset = PymuDocDataset(subset_pdf_bytes)
509
  inference = doc_analyze(
510
  dataset,
 
514
  formula_enable=self.formula_enable,
515
  table_enable=self.table_enable
516
  )
517
+ logger.info("doc_analyze complete. Extracting images.")
518
+
519
  writer = LocalImageWriter(self.output_folder, self.gemini_api_key)
520
  pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
521
  md_content = pipe_result.get_markdown("local-unique-prefix/")
522
+
523
  final_markdown = writer.post_process("local-unique-prefix/", md_content)
524
+
 
525
  out_path = os.path.join(self.output_folder, "final_output.md")
526
  with open(out_path, "w", encoding="utf-8") as f:
527
  f.write(final_markdown)
 
528
  logger.info(f"Markdown saved to: {out_path}")
529
  return final_markdown
 
530
  finally:
531
  self.cleanup_gpu()
532
 
 
534
  input_pdf = "/home/user/app/input_output/ocr-specification-economics.pdf"
535
  output_dir = "/home/user/app/outputs"
536
  gemini_key = os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU")
 
537
  try:
538
  processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key)
539
  md_output = processor.process(input_pdf)