SkyNait commited on
Commit
99cd3b7
·
verified ·
1 Parent(s): da9ad0b

fix long gemini calls (async)

Browse files
Files changed (1) hide show
  1. topic_extraction.py +223 -178
topic_extraction.py CHANGED
@@ -6,7 +6,8 @@ import json
6
  import logging
7
  import fitz
8
  import base64
9
- import concurrent.futures
 
10
  from io import BytesIO
11
  from typing import List, Dict, Any
12
 
@@ -19,19 +20,123 @@ from google.genai import types
19
 
20
  from magic_pdf.data.dataset import PymuDocDataset
21
  from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
22
-
23
  from table_row_extraction import TableExtractor
24
 
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
 
 
 
28
 
29
  _GEMINI_CLIENT = None
30
- #to store and reuse a single Gemini client instance instead of reinitializing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def unify_whitespace(text: str) -> str:
33
  return re.sub(r"\s+", " ", text).strip().lower()
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
36
  if not page_indices:
37
  raise ValueError("No page indices provided for subset creation.")
@@ -48,41 +153,23 @@ def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> byt
48
  doc.close()
49
  return subset_bytes
50
 
51
- def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
52
- """
53
- Return a sorted list of 0-based pages in which `search_text` (normalized) appears,
54
- scanning the entire PDF in RAW mode.
55
- """
56
- doc = fitz.open(stream=pdf_bytes, filetype="pdf")
57
- st_norm = unify_whitespace(search_text)
58
- found = []
59
- for i in range(doc.page_count):
60
- raw = doc[i].get_text("raw")
61
- norm = unify_whitespace(raw)
62
- if st_norm in norm:
63
- found.append(i)
64
- doc.close()
65
- return sorted(found)
66
-
67
  class GeminiTopicExtractor:
68
  def __init__(self, api_key: str = None, num_pages: int = 10):
69
  self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
70
  self.num_pages = num_pages
71
 
72
  def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
73
- """
74
- Return a dict of subtopics => [start_page, end_page].
75
- """
76
  first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
77
  if not first_pages_text.strip():
78
  logger.error("No text from first pages => cannot extract subtopics.")
79
  return {}
 
80
  prompt = f"""
81
  You have the first pages of a PDF specification, including a table of contents.
82
 
83
  Instructions:
84
  1. Identify the 'Contents' section listing all topics, subtopics, and their corresponding pages.
85
- 2. Identify the major academic subtopics (common desired topic names "Paper X", "Theme X", "Content of X").
86
  3. For each subtopic, give the range of pages [start_page, end_page] (1-based) from the table of contents.
87
  4. Output only valid JSON of the form:
88
  {{
@@ -92,7 +179,8 @@ Instructions:
92
  5. If you can't find any subtopics, return an empty JSON.
93
 
94
  Important notes:
95
- - the correct "end_page" must be the page number of the next topic or subtopic minus 1.
 
96
 
97
  Examples:
98
 
@@ -166,6 +254,54 @@ The correct output should be:
166
  "Theme 4: A global perspective": [29, 38]
167
  }}
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  Now, extract topics from this text:
170
  {first_pages_text}
171
  """
@@ -185,25 +321,15 @@ Now, extract topics from this text:
185
 
186
  raw_json = response.text.strip()
187
  cleaned = raw_json.replace("```json", "").replace("```", "")
188
-
189
- # Attempt to parse
190
- data = json.loads(cleaned)
191
- # data might be nested or flat
192
- # if nested, example {"2 Subject content": {"Paper 1...": [11,29]}}
193
- # if flat, example {"Paper 1...": [11,29]}
194
- # so we unify it to a single dict of subname => [start,end].
195
  final_dict = {}
196
-
197
- # If the top-level is a dict of dict
198
- # We look for a dict whose values are themselves subtopics
199
- # Or it might be a direct subtopic dict
200
- # We'll try a quick approach:
201
- # - If any top-level value is a dict with numeric arrays, use that
202
- # - else assume data is the direct subtopic dict
203
  found_sub_dict = None
204
  for k, v in data.items():
205
  if isinstance(v, dict):
206
- # might be the sub-sub dict
207
  found_sub_dict = v
208
  break
209
  if found_sub_dict is not None:
@@ -211,8 +337,6 @@ Now, extract topics from this text:
211
  if isinstance(rng, list) and len(rng) == 2:
212
  final_dict[subk] = rng
213
  else:
214
- # maybe data is the direct subtopic dict
215
- # parse data
216
  for subk, rng in data.items():
217
  if isinstance(rng, list) and len(rng) == 2:
218
  final_dict[subk] = rng
@@ -234,124 +358,36 @@ Now, extract topics from this text:
234
  logger.error(f"Could not open PDF: {e}")
235
  return "\n".join(text_parts)
236
 
237
- def call_gemini_for_table_classification(image_data: bytes, api_key: str) -> str:
238
- """
239
- Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
240
- """
241
- #shrink image to reduce size
242
- try:
243
- arr = np.frombuffer(image_data, np.uint8)
244
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
245
- if img is not None:
246
- h, w, _ = img.shape
247
- max_dim = 800
248
- scale = 1.0
249
- if max(h, w) > max_dim:
250
- scale = max_dim / float(max(h, w))
251
- if scale < 1.0:
252
- new_w = int(w * scale)
253
- new_h = int(h * scale)
254
- img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
255
- encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), 70]
256
- success, enc = cv2.imencode(".jpg", img, encode_params)
257
- if success:
258
- image_data = enc.tobytes()
259
- except Exception as e:
260
- logger.warning(f"shrink_image_to_jpeg error: {e}")
261
-
262
- prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns.
263
- The three-column 'table' image include such key features:
264
- - Three columns header columns
265
- - Headers like 'Topics', 'Content', 'Guidelines'
266
- - Numbered sections (e.g., 8.4, 9.1)
267
- - Educational curriculum-style structure
268
- The two-column 'table' image include such key features:
269
- - Two columns header columns
270
- - Headers like 'Subject content' and 'Additional information'
271
- - Numbered sections (e.g., 2.1, 3.4)
272
- - Educational curriculum-style structure
273
- - Bullet description in 'Additional information'
274
- If the image is a relevant table with 2 columns, respond with 'TWO_COLUMN'.
275
- If the image is a relevant table with 3 columns, respond with 'THREE_COLUMN'.
276
- If the image does not show a table at all, respond with 'NO_TABLE'.
277
- Return only one of these exact labels as your entire response:
278
- TWO_COLUMN
279
- THREE_COLUMN
280
- NO_TABLE
281
- """
282
- global _GEMINI_CLIENT
283
- if _GEMINI_CLIENT is None:
284
- _GEMINI_CLIENT = genai.Client(api_key=api_key)
285
- client = _GEMINI_CLIENT
286
- try:
287
- resp = client.models.generate_content(
288
- model="gemini-2.0-flash",
289
- contents=[
290
- {
291
- "parts": [
292
- {"text": prompt},
293
- {
294
- "inline_data": {
295
- "mime_type": "image/jpeg",
296
- "data": base64.b64encode(image_data).decode('utf-8')
297
- }
298
- }
299
- ]
300
- }
301
- ],
302
- config=types.GenerateContentConfig(temperature=0.0)
303
- )
304
- if resp and resp.text:
305
- classification = resp.text.strip().upper()
306
- if "THREE" in classification:
307
- return "THREE_COLUMN"
308
- elif "TWO" in classification:
309
- return "TWO_COLUMN"
310
- return "NO_TABLE"
311
- except Exception as e:
312
- logger.error(f"Gemini table classification error: {e}")
313
- return "NO_TABLE"
314
-
315
- class LocalImageWriter:
316
- """
317
- Writes extracted images, then does concurrency-based table classification calls.
318
- """
319
  def __init__(self, output_folder: str, gemini_api_key: str):
320
  self.output_folder = output_folder
321
  os.makedirs(self.output_folder, exist_ok=True)
322
- self.images_dir = os.path.join(self.output_folder, "images")
323
- os.makedirs(self.images_dir, exist_ok=True)
324
  self.descriptions = {}
325
  self._img_count = 0
326
  self.gemini_api_key = gemini_api_key
327
 
328
  def write(self, path: str, data: bytes) -> None:
329
  self._img_count += 1
330
- fname = f"img_{self._img_count}.png"
331
- fpath = os.path.join(self.images_dir, fname)
332
- with open(fpath, "wb") as f:
333
- f.write(data)
334
- rel_path = os.path.relpath(fpath, self.output_folder)
335
  self.descriptions[path] = {
336
  "data": data,
337
- "relative_path": rel_path,
338
  "table_classification": "NO_TABLE",
339
  "final_alt": ""
340
  }
341
 
342
- def post_process(self, key: str, md_content: str) -> str:
343
- logger.info("Classifying images to detect tables (concurrent)...")
344
- with concurrent.futures.ThreadPoolExecutor(max_workers=6) as exe:
345
- fut_map = {exe.submit(call_gemini_for_table_classification, info["data"], self.gemini_api_key): p for p, info in self.descriptions.items()}
346
- for fut in concurrent.futures.as_completed(fut_map):
347
- path = fut_map[fut]
348
- try:
349
- classification = fut.result()
350
- self.descriptions[path]['table_classification'] = classification
351
- except Exception as e:
352
- logger.error(f"Table classification error: {e}")
353
- self.descriptions[path]['table_classification'] = "NO_TABLE"
354
-
355
  for p, info in self.descriptions.items():
356
  cls = info['table_classification']
357
  if cls == "TWO_COLUMN":
@@ -360,32 +396,43 @@ class LocalImageWriter:
360
  info['final_alt'] = "HAS TO BE PROCESSED - three column table"
361
  else:
362
  info['final_alt'] = "NO_TABLE image"
363
-
364
- #replace placeholders in the Markdown
365
  for p, info in self.descriptions.items():
366
  old_md = f"![]({key}{p})"
367
  new_md = f"![{info['final_alt']}]({info['relative_path']})"
368
  md_content = md_content.replace(old_md, new_md)
369
 
370
- # IF any table images => extract rows
371
  md_content = self._process_table_images_in_markdown(md_content)
372
-
373
- # Keep only lines that are image references
374
  final_lines = []
375
  for line in md_content.split("\n"):
376
  if re.match(r"^\!\[.*\]\(.*\)", line.strip()):
377
  final_lines.append(line.strip())
378
  return "\n".join(final_lines)
379
 
 
 
 
 
 
 
380
  def _process_table_images_in_markdown(self, md_content: str) -> str:
381
  pat = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)"
382
  matches = re.findall(pat, md_content, flags=re.IGNORECASE)
383
  if not matches:
384
  return md_content
385
-
386
- for (col_type, image_path) in matches:
387
- logger.info(f"Processing table image => {image_path}, columns={col_type}")
388
- abs_image_path = os.path.join(self.output_folder, image_path)
 
 
 
 
 
 
 
 
 
 
389
  try:
390
  if col_type.lower() == 'two':
391
  extractor = TableExtractor(
@@ -401,25 +448,31 @@ class LocalImageWriter:
401
  enable_subtopic_merge=False,
402
  subtopic_threshold=0.2
403
  )
404
- row_boxes = extractor.process_image(abs_image_path)
405
- out_folder = abs_image_path + "_rows"
 
406
  os.makedirs(out_folder, exist_ok=True)
407
- extractor.save_extracted_cells(abs_image_path, row_boxes, out_folder)
 
408
 
409
  snippet = ["**Extracted table cells:**"]
410
  for i, row in enumerate(row_boxes):
411
  row_dir = os.path.join(out_folder, f"row_{i}")
412
  for j, _ in enumerate(row):
413
- cell_file = f"col_{j}.png"
414
  cell_path = os.path.join(row_dir, cell_file)
415
  relp = os.path.relpath(cell_path, self.output_folder)
416
  snippet.append(f"![Row {i} Col {j}]({relp})")
417
-
418
  new_snip = "\n".join(snippet)
419
- old_line = f"![HAS TO BE PROCESSED - {col_type} column table]({image_path})"
 
 
420
  md_content = md_content.replace(old_line, new_snip)
421
  except Exception as e:
422
- logger.error(f"Error processing table image {image_path}: {e}")
 
 
 
423
  return md_content
424
 
425
  class MineruNoTextProcessor:
@@ -442,14 +495,14 @@ class MineruNoTextProcessor:
442
  logger.error(f"Error during GPU cleanup: {e}")
443
 
444
  def process(self, pdf_path: str) -> str:
 
445
  try:
446
- #Extract subtopics from Gemini
447
  subtopics = self.subtopic_extractor.extract_subtopics(pdf_path)
448
- logger.info(f"Gemini returned subtopics: {subtopics}")
449
 
450
- #read whole pdf
451
  with open(pdf_path, "rb") as f:
452
  pdf_bytes = f.read()
 
453
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
454
  total_pages = doc.page_count
455
  doc.close()
@@ -459,17 +512,14 @@ class MineruNoTextProcessor:
459
  logger.warning("No subtopics found. We'll process the entire PDF as fallback.")
460
  final_pages = set(range(total_pages))
461
  else:
462
- # For each subtopic, find occurrence >= (start_p-1)
463
  for subname, rng in subtopics.items():
464
  if not (isinstance(rng, list) and len(rng) == 2):
465
  logger.warning(f"Skipping subtopic '{subname}' => invalid range {rng}")
466
  continue
467
  start_p, end_p = rng
468
  if start_p > end_p:
469
- logger.warning(f"Skipping subtopic '{subname}' => start> end {rng}")
470
  continue
471
-
472
- # find occurrences
473
  occs = find_all_occurrences(pdf_bytes, subname)
474
  logger.info(f"Occurrences of subtopic '{subname}': {occs}")
475
  doc_start_0 = start_p - 1
@@ -479,7 +529,6 @@ class MineruNoTextProcessor:
479
  chosen_page = p
480
  break
481
  if chosen_page is None:
482
- # fallback to last or 0
483
  if occs:
484
  chosen_page = occs[-1]
485
  logger.warning(f"No occurrence >= {doc_start_0} for '{subname}'. Using last => {chosen_page}")
@@ -496,15 +545,13 @@ class MineruNoTextProcessor:
496
  e0 = max(0, min(total_pages - 1, e0))
497
  for pp in range(s0, e0 + 1):
498
  final_pages.add(pp)
499
-
500
  if not final_pages:
501
  logger.warning("No valid pages after offset. We'll process entire PDF.")
502
  final_pages = set(range(total_pages))
503
-
504
  logger.info(f"Processing pages (0-based): {sorted(final_pages)}")
 
505
  subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages))
506
-
507
- # doc_analyze => concurrency => final MD
508
  dataset = PymuDocDataset(subset_pdf_bytes)
509
  inference = doc_analyze(
510
  dataset,
@@ -515,18 +562,18 @@ class MineruNoTextProcessor:
515
  table_enable=self.table_enable
516
  )
517
  logger.info("doc_analyze complete. Extracting images.")
518
-
519
  writer = LocalImageWriter(self.output_folder, self.gemini_api_key)
520
  pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
521
  md_content = pipe_result.get_markdown("local-unique-prefix/")
522
-
523
  final_markdown = writer.post_process("local-unique-prefix/", md_content)
524
-
525
  out_path = os.path.join(self.output_folder, "final_output.md")
 
526
  with open(out_path, "w", encoding="utf-8") as f:
527
  f.write(final_markdown)
528
- logger.info(f"Markdown saved to: {out_path}")
529
  return final_markdown
 
530
  finally:
531
  self.cleanup_gpu()
532
 
@@ -537,7 +584,5 @@ if __name__ == "__main__":
537
  try:
538
  processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key)
539
  md_output = processor.process(input_pdf)
540
- print("Final Markdown Output:")
541
- print(md_output)
542
  except Exception as e:
543
  logger.error(f"Processing failed: {e}")
 
6
  import logging
7
  import fitz
8
  import base64
9
+ import time
10
+ import asyncio
11
  from io import BytesIO
12
  from typing import List, Dict, Any
13
 
 
20
 
21
  from magic_pdf.data.dataset import PymuDocDataset
22
  from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
23
+ from magic_pdf.data.data_reader_writer.base import DataWriter
24
  from table_row_extraction import TableExtractor
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
  logger.setLevel(logging.INFO)
29
+ file_handler = logging.FileHandler("topic_extraction.log")
30
+ file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s"))
31
+ logger.addHandler(file_handler)
32
 
33
  _GEMINI_CLIENT = None
34
+
35
+ def preprocess_image(image_data: bytes, max_dim: int = 600, quality: int = 60) -> bytes:
36
+ """
37
+ Downscale the image to reduce payload size.
38
+ """
39
+ arr = np.frombuffer(image_data, np.uint8)
40
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
41
+ if img is not None:
42
+ h, w, _ = img.shape
43
+ if max(h, w) > max_dim:
44
+ scale = max_dim / float(max(h, w))
45
+ new_w = int(w * scale)
46
+ new_h = int(h * scale)
47
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
48
+ encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
49
+ success, enc = cv2.imencode(".jpg", img, encode_params)
50
+ if success:
51
+ return enc.tobytes()
52
+ return image_data
53
+
54
+ def call_gemini_for_table_classification(image_data: bytes, api_key: str, max_retries: int = 1) -> str:
55
+ """
56
+ Synchronously call the Gemini API to classify a table image.
57
+ """
58
+ for attempt in range(max_retries + 1):
59
+ try:
60
+ prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns.
61
+ The three-column 'table' image include such key features:
62
+ - Three columns header columns
63
+ - Headers like 'Topics', 'Content', 'Guidelines'
64
+ - Numbered sections (e.g., 8.4, 9.1)
65
+ - Educational curriculum-style structure
66
+ The two-column 'table' image include such key features:
67
+ - Two columns header columns
68
+ - Headers like 'Subject content' and 'Additional information'
69
+ - Numbered sections (e.g., 2.1, 3.4)
70
+ - Educational curriculum-style structure
71
+ - Bullet description in 'Additional information'
72
+ If the image is a relevant table with 2 columns, respond with 'TWO_COLUMN'.
73
+ If the image is a relevant table with 3 columns, respond with 'THREE_COLUMN'.
74
+ If the image does not show a table at all, respond with 'NO_TABLE'.
75
+ Return only one of these exact labels as your entire response:
76
+ TWO_COLUMN
77
+ THREE_COLUMN
78
+ NO_TABLE
79
+ """
80
+ global _GEMINI_CLIENT
81
+ client = _GEMINI_CLIENT
82
+ resp = client.models.generate_content(
83
+ model="gemini-2.0-flash",
84
+ contents=[
85
+ {
86
+ "parts": [
87
+ {"text": prompt},
88
+ {
89
+ "inline_data": {
90
+ "mime_type": "image/jpeg",
91
+ "data": base64.b64encode(image_data).decode('utf-8')
92
+ }
93
+ }
94
+ ]
95
+ }
96
+ ],
97
+ config=types.GenerateContentConfig(temperature=0.0)
98
+ )
99
+ if resp and resp.text:
100
+ classification = resp.text.strip().upper()
101
+ if "THREE" in classification:
102
+ return "THREE_COLUMN"
103
+ elif "TWO" in classification:
104
+ return "TWO_COLUMN"
105
+ return "NO_TABLE"
106
+ except Exception as e:
107
+ error_msg = str(e)
108
+ logger.error(f"Gemini table classification error: {error_msg}")
109
+ if "503" in error_msg:
110
+ return "NO_TABLE"
111
+ if attempt < max_retries:
112
+ logger.warning("Retrying classification due to error... attempt %d", attempt + 1)
113
+ time.sleep(0.5)
114
+ else:
115
+ return "NO_TABLE"
116
+
117
+ async def classify_image_async(image_data: bytes, api_key: str, max_retries: int = 1) -> str:
118
+ """
119
+ Asynchronous wrapper for image classification.
120
+ """
121
+ loop = asyncio.get_event_loop()
122
+ preprocessed = preprocess_image(image_data)
123
+ return await loop.run_in_executor(None, call_gemini_for_table_classification, preprocessed, api_key, max_retries)
124
 
125
  def unify_whitespace(text: str) -> str:
126
  return re.sub(r"\s+", " ", text).strip().lower()
127
 
128
+ def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
129
+ doc = fitz.open(stream=pdf_bytes, filetype="pdf")
130
+ st_norm = unify_whitespace(search_text)
131
+ found = []
132
+ for i in range(doc.page_count):
133
+ raw = doc[i].get_text("raw")
134
+ norm = unify_whitespace(raw)
135
+ if st_norm in norm:
136
+ found.append(i)
137
+ doc.close()
138
+ return sorted(found)
139
+
140
  def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
141
  if not page_indices:
142
  raise ValueError("No page indices provided for subset creation.")
 
153
  doc.close()
154
  return subset_bytes
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  class GeminiTopicExtractor:
157
  def __init__(self, api_key: str = None, num_pages: int = 10):
158
  self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
159
  self.num_pages = num_pages
160
 
161
  def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
 
 
 
162
  first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
163
  if not first_pages_text.strip():
164
  logger.error("No text from first pages => cannot extract subtopics.")
165
  return {}
166
+
167
  prompt = f"""
168
  You have the first pages of a PDF specification, including a table of contents.
169
 
170
  Instructions:
171
  1. Identify the 'Contents' section listing all topics, subtopics, and their corresponding pages.
172
+ 2. Identify the major academic subtopics (common desired topic names "Paper X", "Theme X", "Content of X", "AS Unit X", "A2 Unit X", or similar headings).
173
  3. For each subtopic, give the range of pages [start_page, end_page] (1-based) from the table of contents.
174
  4. Output only valid JSON of the form:
175
  {{
 
179
  5. If you can't find any subtopics, return an empty JSON.
180
 
181
  Important notes:
182
+ - The correct "end_page" must be the page number of the next topic or subtopic minus 1.
183
+ - The final output must be valid JSON only, with no extra text or code blocks.
184
 
185
  Examples:
186
 
 
254
  "Theme 4: A global perspective": [29, 38]
255
  }}
256
 
257
+ 3. You might also see sections like:
258
+
259
+ 2.1 AS Unit 1 11
260
+ 2.2 AS Unit 2 18
261
+ 2.3 A2 Unit 3 24
262
+ 2.4 A2 Unit 4 31
263
+
264
+ In that scenario, your output might look like:
265
+
266
+ {{
267
+ "AS Unit 1": [11, 17],
268
+ "AS Unit 2": [18, 23],
269
+ "A2 Unit 3": [24, 30],
270
+ "A2 Unit 4": [31, 35]
271
+ }}
272
+
273
+ 4. Another example might list subtopics:
274
+
275
+ 3.1 Overarching themes 11
276
+ 3.2 A: Proof 12
277
+ 3.3 B: Algebra and functions 13
278
+ 3.4 C: Coordinate geometry in the ( x , y ) plane 14
279
+ 3.5 D: Sequences and series 15
280
+ 3.6 E: Trigonometry 16
281
+ 3.7 F: Exponentials and logarithms 17
282
+ 3.8 G: Differentiation 18
283
+ 3.9 H: Integration 19
284
+ 3.10 I: Numerical methods 20
285
+ 3.11 J: Vectors 20
286
+ 3.12 K: Statistical sampling 21
287
+ 3.13 L: Data presentation and interpretation 21
288
+ 3.14 M: Probability 22
289
+ 3.15 N: Statistical distributions 23
290
+ 3.16 O: Statistical hypothesis testing 23
291
+ 3.17 P: Quantities and units in mechanics 24
292
+ 3.18 Q: Kinematics 24
293
+ 3.19 R: Forces and Newton’s laws 24
294
+ 3.20 S: Moments 25
295
+ 3.21 Use of data in statistics 26
296
+
297
+ Here the correct output might look like:
298
+
299
+ {{
300
+ "A: Proof": [12, 12],
301
+ "B: Algebra and functions": [13, 13],
302
+ ...
303
+ }}
304
+
305
  Now, extract topics from this text:
306
  {first_pages_text}
307
  """
 
321
 
322
  raw_json = response.text.strip()
323
  cleaned = raw_json.replace("```json", "").replace("```", "")
324
+ try:
325
+ data = json.loads(cleaned)
326
+ except Exception as json_err:
327
+ logger.error(f"JSON parsing error: {json_err}")
328
+ return {}
 
 
329
  final_dict = {}
 
 
 
 
 
 
 
330
  found_sub_dict = None
331
  for k, v in data.items():
332
  if isinstance(v, dict):
 
333
  found_sub_dict = v
334
  break
335
  if found_sub_dict is not None:
 
337
  if isinstance(rng, list) and len(rng) == 2:
338
  final_dict[subk] = rng
339
  else:
 
 
340
  for subk, rng in data.items():
341
  if isinstance(rng, list) and len(rng) == 2:
342
  final_dict[subk] = rng
 
358
  logger.error(f"Could not open PDF: {e}")
359
  return "\n".join(text_parts)
360
 
361
+ class LocalImageWriter(DataWriter):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def __init__(self, output_folder: str, gemini_api_key: str):
363
  self.output_folder = output_folder
364
  os.makedirs(self.output_folder, exist_ok=True)
 
 
365
  self.descriptions = {}
366
  self._img_count = 0
367
  self.gemini_api_key = gemini_api_key
368
 
369
  def write(self, path: str, data: bytes) -> None:
370
  self._img_count += 1
371
+ unique_id = f"img_{self._img_count}.jpg"
 
 
 
 
372
  self.descriptions[path] = {
373
  "data": data,
374
+ "relative_path": unique_id,
375
  "table_classification": "NO_TABLE",
376
  "final_alt": ""
377
  }
378
 
379
+ async def post_process_async(self, key: str, md_content: str) -> str:
380
+ logger.info("Classifying images to detect tables (async).")
381
+ tasks = []
382
+ for p, info in self.descriptions.items():
383
+ tasks.append((p, classify_image_async(info["data"], self.gemini_api_key)))
384
+ for p, task in tasks:
385
+ try:
386
+ classification = await task
387
+ self.descriptions[p]['table_classification'] = classification
388
+ except Exception as e:
389
+ logger.error(f"Table classification error: {e}")
390
+ self.descriptions[p]['table_classification'] = "NO_TABLE"
 
391
  for p, info in self.descriptions.items():
392
  cls = info['table_classification']
393
  if cls == "TWO_COLUMN":
 
396
  info['final_alt'] = "HAS TO BE PROCESSED - three column table"
397
  else:
398
  info['final_alt'] = "NO_TABLE image"
 
 
399
  for p, info in self.descriptions.items():
400
  old_md = f"![]({key}{p})"
401
  new_md = f"![{info['final_alt']}]({info['relative_path']})"
402
  md_content = md_content.replace(old_md, new_md)
403
 
 
404
  md_content = self._process_table_images_in_markdown(md_content)
 
 
405
  final_lines = []
406
  for line in md_content.split("\n"):
407
  if re.match(r"^\!\[.*\]\(.*\)", line.strip()):
408
  final_lines.append(line.strip())
409
  return "\n".join(final_lines)
410
 
411
+ def post_process(self, key: str, md_content: str) -> str:
412
+ """
413
+ Synchronous wrapper around the asynchronous post_process_async.
414
+ """
415
+ return asyncio.run(self.post_process_async(key, md_content))
416
+
417
  def _process_table_images_in_markdown(self, md_content: str) -> str:
418
  pat = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)"
419
  matches = re.findall(pat, md_content, flags=re.IGNORECASE)
420
  if not matches:
421
  return md_content
422
+ for (col_type, image_id) in matches:
423
+ logger.info(f"Processing table image => {image_id}, columns={col_type}")
424
+ temp_path = os.path.join(self.output_folder, image_id)
425
+ desc_item = None
426
+ for k, val in self.descriptions.items():
427
+ if val["relative_path"] == image_id:
428
+ desc_item = val
429
+ break
430
+ if not desc_item:
431
+ logger.warning(f"No matching image data for {image_id}, skipping extraction.")
432
+ continue
433
+ if not os.path.exists(temp_path):
434
+ with open(temp_path, "wb") as f:
435
+ f.write(desc_item["data"])
436
  try:
437
  if col_type.lower() == 'two':
438
  extractor = TableExtractor(
 
448
  enable_subtopic_merge=False,
449
  subtopic_threshold=0.2
450
  )
451
+ row_boxes = extractor.process_image(temp_path)
452
+
453
+ out_folder = temp_path + "_rows"
454
  os.makedirs(out_folder, exist_ok=True)
455
+
456
+ extractor.save_extracted_cells(temp_path, row_boxes, out_folder)
457
 
458
  snippet = ["**Extracted table cells:**"]
459
  for i, row in enumerate(row_boxes):
460
  row_dir = os.path.join(out_folder, f"row_{i}")
461
  for j, _ in enumerate(row):
462
+ cell_file = f"col_{j}.jpg"
463
  cell_path = os.path.join(row_dir, cell_file)
464
  relp = os.path.relpath(cell_path, self.output_folder)
465
  snippet.append(f"![Row {i} Col {j}]({relp})")
 
466
  new_snip = "\n".join(snippet)
467
+
468
+ old_line = f"![HAS TO BE PROCESSED - {col_type} column table]({image_id})"
469
+
470
  md_content = md_content.replace(old_line, new_snip)
471
  except Exception as e:
472
+ logger.error(f"Error processing table image {image_id}: {e}")
473
+ finally:
474
+ if os.path.exists(temp_path):
475
+ os.remove(temp_path)
476
  return md_content
477
 
478
  class MineruNoTextProcessor:
 
495
  logger.error(f"Error during GPU cleanup: {e}")
496
 
497
  def process(self, pdf_path: str) -> str:
498
+ logger.info(f"Processing PDF: {pdf_path}")
499
  try:
 
500
  subtopics = self.subtopic_extractor.extract_subtopics(pdf_path)
 
501
 
502
+ logger.info(f"Gemini returned subtopics: {subtopics}")
503
  with open(pdf_path, "rb") as f:
504
  pdf_bytes = f.read()
505
+
506
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
507
  total_pages = doc.page_count
508
  doc.close()
 
512
  logger.warning("No subtopics found. We'll process the entire PDF as fallback.")
513
  final_pages = set(range(total_pages))
514
  else:
 
515
  for subname, rng in subtopics.items():
516
  if not (isinstance(rng, list) and len(rng) == 2):
517
  logger.warning(f"Skipping subtopic '{subname}' => invalid range {rng}")
518
  continue
519
  start_p, end_p = rng
520
  if start_p > end_p:
521
+ logger.warning(f"Skipping subtopic '{subname}' => start > end {rng}")
522
  continue
 
 
523
  occs = find_all_occurrences(pdf_bytes, subname)
524
  logger.info(f"Occurrences of subtopic '{subname}': {occs}")
525
  doc_start_0 = start_p - 1
 
529
  chosen_page = p
530
  break
531
  if chosen_page is None:
 
532
  if occs:
533
  chosen_page = occs[-1]
534
  logger.warning(f"No occurrence >= {doc_start_0} for '{subname}'. Using last => {chosen_page}")
 
545
  e0 = max(0, min(total_pages - 1, e0))
546
  for pp in range(s0, e0 + 1):
547
  final_pages.add(pp)
 
548
  if not final_pages:
549
  logger.warning("No valid pages after offset. We'll process entire PDF.")
550
  final_pages = set(range(total_pages))
 
551
  logger.info(f"Processing pages (0-based): {sorted(final_pages)}")
552
+
553
  subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages))
554
+
 
555
  dataset = PymuDocDataset(subset_pdf_bytes)
556
  inference = doc_analyze(
557
  dataset,
 
562
  table_enable=self.table_enable
563
  )
564
  logger.info("doc_analyze complete. Extracting images.")
 
565
  writer = LocalImageWriter(self.output_folder, self.gemini_api_key)
566
  pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
567
  md_content = pipe_result.get_markdown("local-unique-prefix/")
568
+
569
  final_markdown = writer.post_process("local-unique-prefix/", md_content)
 
570
  out_path = os.path.join(self.output_folder, "final_output.md")
571
+
572
  with open(out_path, "w", encoding="utf-8") as f:
573
  f.write(final_markdown)
574
+
575
  return final_markdown
576
+
577
  finally:
578
  self.cleanup_gpu()
579
 
 
584
  try:
585
  processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key)
586
  md_output = processor.process(input_pdf)
 
 
587
  except Exception as e:
588
  logger.error(f"Processing failed: {e}")