fix long gemini calls (async)
Browse files- topic_extraction.py +223 -178
topic_extraction.py
CHANGED
|
@@ -6,7 +6,8 @@ import json
|
|
| 6 |
import logging
|
| 7 |
import fitz
|
| 8 |
import base64
|
| 9 |
-
import
|
|
|
|
| 10 |
from io import BytesIO
|
| 11 |
from typing import List, Dict, Any
|
| 12 |
|
|
@@ -19,19 +20,123 @@ from google.genai import types
|
|
| 19 |
|
| 20 |
from magic_pdf.data.dataset import PymuDocDataset
|
| 21 |
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
| 22 |
-
|
| 23 |
from table_row_extraction import TableExtractor
|
| 24 |
|
| 25 |
logging.basicConfig(level=logging.INFO)
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
_GEMINI_CLIENT = None
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def unify_whitespace(text: str) -> str:
|
| 33 |
return re.sub(r"\s+", " ", text).strip().lower()
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
|
| 36 |
if not page_indices:
|
| 37 |
raise ValueError("No page indices provided for subset creation.")
|
|
@@ -48,41 +153,23 @@ def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> byt
|
|
| 48 |
doc.close()
|
| 49 |
return subset_bytes
|
| 50 |
|
| 51 |
-
def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
|
| 52 |
-
"""
|
| 53 |
-
Return a sorted list of 0-based pages in which `search_text` (normalized) appears,
|
| 54 |
-
scanning the entire PDF in RAW mode.
|
| 55 |
-
"""
|
| 56 |
-
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
| 57 |
-
st_norm = unify_whitespace(search_text)
|
| 58 |
-
found = []
|
| 59 |
-
for i in range(doc.page_count):
|
| 60 |
-
raw = doc[i].get_text("raw")
|
| 61 |
-
norm = unify_whitespace(raw)
|
| 62 |
-
if st_norm in norm:
|
| 63 |
-
found.append(i)
|
| 64 |
-
doc.close()
|
| 65 |
-
return sorted(found)
|
| 66 |
-
|
| 67 |
class GeminiTopicExtractor:
|
| 68 |
def __init__(self, api_key: str = None, num_pages: int = 10):
|
| 69 |
self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
|
| 70 |
self.num_pages = num_pages
|
| 71 |
|
| 72 |
def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
|
| 73 |
-
"""
|
| 74 |
-
Return a dict of subtopics => [start_page, end_page].
|
| 75 |
-
"""
|
| 76 |
first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
|
| 77 |
if not first_pages_text.strip():
|
| 78 |
logger.error("No text from first pages => cannot extract subtopics.")
|
| 79 |
return {}
|
|
|
|
| 80 |
prompt = f"""
|
| 81 |
You have the first pages of a PDF specification, including a table of contents.
|
| 82 |
|
| 83 |
Instructions:
|
| 84 |
1. Identify the 'Contents' section listing all topics, subtopics, and their corresponding pages.
|
| 85 |
-
2. Identify the major academic subtopics (common desired topic names "Paper X", "Theme X", "Content of X").
|
| 86 |
3. For each subtopic, give the range of pages [start_page, end_page] (1-based) from the table of contents.
|
| 87 |
4. Output only valid JSON of the form:
|
| 88 |
{{
|
|
@@ -92,7 +179,8 @@ Instructions:
|
|
| 92 |
5. If you can't find any subtopics, return an empty JSON.
|
| 93 |
|
| 94 |
Important notes:
|
| 95 |
-
-
|
|
|
|
| 96 |
|
| 97 |
Examples:
|
| 98 |
|
|
@@ -166,6 +254,54 @@ The correct output should be:
|
|
| 166 |
"Theme 4: A global perspective": [29, 38]
|
| 167 |
}}
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
Now, extract topics from this text:
|
| 170 |
{first_pages_text}
|
| 171 |
"""
|
|
@@ -185,25 +321,15 @@ Now, extract topics from this text:
|
|
| 185 |
|
| 186 |
raw_json = response.text.strip()
|
| 187 |
cleaned = raw_json.replace("```json", "").replace("```", "")
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
# if flat, example {"Paper 1...": [11,29]}
|
| 194 |
-
# so we unify it to a single dict of subname => [start,end].
|
| 195 |
final_dict = {}
|
| 196 |
-
|
| 197 |
-
# If the top-level is a dict of dict
|
| 198 |
-
# We look for a dict whose values are themselves subtopics
|
| 199 |
-
# Or it might be a direct subtopic dict
|
| 200 |
-
# We'll try a quick approach:
|
| 201 |
-
# - If any top-level value is a dict with numeric arrays, use that
|
| 202 |
-
# - else assume data is the direct subtopic dict
|
| 203 |
found_sub_dict = None
|
| 204 |
for k, v in data.items():
|
| 205 |
if isinstance(v, dict):
|
| 206 |
-
# might be the sub-sub dict
|
| 207 |
found_sub_dict = v
|
| 208 |
break
|
| 209 |
if found_sub_dict is not None:
|
|
@@ -211,8 +337,6 @@ Now, extract topics from this text:
|
|
| 211 |
if isinstance(rng, list) and len(rng) == 2:
|
| 212 |
final_dict[subk] = rng
|
| 213 |
else:
|
| 214 |
-
# maybe data is the direct subtopic dict
|
| 215 |
-
# parse data
|
| 216 |
for subk, rng in data.items():
|
| 217 |
if isinstance(rng, list) and len(rng) == 2:
|
| 218 |
final_dict[subk] = rng
|
|
@@ -234,124 +358,36 @@ Now, extract topics from this text:
|
|
| 234 |
logger.error(f"Could not open PDF: {e}")
|
| 235 |
return "\n".join(text_parts)
|
| 236 |
|
| 237 |
-
|
| 238 |
-
"""
|
| 239 |
-
Classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE
|
| 240 |
-
"""
|
| 241 |
-
#shrink image to reduce size
|
| 242 |
-
try:
|
| 243 |
-
arr = np.frombuffer(image_data, np.uint8)
|
| 244 |
-
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
| 245 |
-
if img is not None:
|
| 246 |
-
h, w, _ = img.shape
|
| 247 |
-
max_dim = 800
|
| 248 |
-
scale = 1.0
|
| 249 |
-
if max(h, w) > max_dim:
|
| 250 |
-
scale = max_dim / float(max(h, w))
|
| 251 |
-
if scale < 1.0:
|
| 252 |
-
new_w = int(w * scale)
|
| 253 |
-
new_h = int(h * scale)
|
| 254 |
-
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 255 |
-
encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), 70]
|
| 256 |
-
success, enc = cv2.imencode(".jpg", img, encode_params)
|
| 257 |
-
if success:
|
| 258 |
-
image_data = enc.tobytes()
|
| 259 |
-
except Exception as e:
|
| 260 |
-
logger.warning(f"shrink_image_to_jpeg error: {e}")
|
| 261 |
-
|
| 262 |
-
prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns.
|
| 263 |
-
The three-column 'table' image include such key features:
|
| 264 |
-
- Three columns header columns
|
| 265 |
-
- Headers like 'Topics', 'Content', 'Guidelines'
|
| 266 |
-
- Numbered sections (e.g., 8.4, 9.1)
|
| 267 |
-
- Educational curriculum-style structure
|
| 268 |
-
The two-column 'table' image include such key features:
|
| 269 |
-
- Two columns header columns
|
| 270 |
-
- Headers like 'Subject content' and 'Additional information'
|
| 271 |
-
- Numbered sections (e.g., 2.1, 3.4)
|
| 272 |
-
- Educational curriculum-style structure
|
| 273 |
-
- Bullet description in 'Additional information'
|
| 274 |
-
If the image is a relevant table with 2 columns, respond with 'TWO_COLUMN'.
|
| 275 |
-
If the image is a relevant table with 3 columns, respond with 'THREE_COLUMN'.
|
| 276 |
-
If the image does not show a table at all, respond with 'NO_TABLE'.
|
| 277 |
-
Return only one of these exact labels as your entire response:
|
| 278 |
-
TWO_COLUMN
|
| 279 |
-
THREE_COLUMN
|
| 280 |
-
NO_TABLE
|
| 281 |
-
"""
|
| 282 |
-
global _GEMINI_CLIENT
|
| 283 |
-
if _GEMINI_CLIENT is None:
|
| 284 |
-
_GEMINI_CLIENT = genai.Client(api_key=api_key)
|
| 285 |
-
client = _GEMINI_CLIENT
|
| 286 |
-
try:
|
| 287 |
-
resp = client.models.generate_content(
|
| 288 |
-
model="gemini-2.0-flash",
|
| 289 |
-
contents=[
|
| 290 |
-
{
|
| 291 |
-
"parts": [
|
| 292 |
-
{"text": prompt},
|
| 293 |
-
{
|
| 294 |
-
"inline_data": {
|
| 295 |
-
"mime_type": "image/jpeg",
|
| 296 |
-
"data": base64.b64encode(image_data).decode('utf-8')
|
| 297 |
-
}
|
| 298 |
-
}
|
| 299 |
-
]
|
| 300 |
-
}
|
| 301 |
-
],
|
| 302 |
-
config=types.GenerateContentConfig(temperature=0.0)
|
| 303 |
-
)
|
| 304 |
-
if resp and resp.text:
|
| 305 |
-
classification = resp.text.strip().upper()
|
| 306 |
-
if "THREE" in classification:
|
| 307 |
-
return "THREE_COLUMN"
|
| 308 |
-
elif "TWO" in classification:
|
| 309 |
-
return "TWO_COLUMN"
|
| 310 |
-
return "NO_TABLE"
|
| 311 |
-
except Exception as e:
|
| 312 |
-
logger.error(f"Gemini table classification error: {e}")
|
| 313 |
-
return "NO_TABLE"
|
| 314 |
-
|
| 315 |
-
class LocalImageWriter:
|
| 316 |
-
"""
|
| 317 |
-
Writes extracted images, then does concurrency-based table classification calls.
|
| 318 |
-
"""
|
| 319 |
def __init__(self, output_folder: str, gemini_api_key: str):
|
| 320 |
self.output_folder = output_folder
|
| 321 |
os.makedirs(self.output_folder, exist_ok=True)
|
| 322 |
-
self.images_dir = os.path.join(self.output_folder, "images")
|
| 323 |
-
os.makedirs(self.images_dir, exist_ok=True)
|
| 324 |
self.descriptions = {}
|
| 325 |
self._img_count = 0
|
| 326 |
self.gemini_api_key = gemini_api_key
|
| 327 |
|
| 328 |
def write(self, path: str, data: bytes) -> None:
|
| 329 |
self._img_count += 1
|
| 330 |
-
|
| 331 |
-
fpath = os.path.join(self.images_dir, fname)
|
| 332 |
-
with open(fpath, "wb") as f:
|
| 333 |
-
f.write(data)
|
| 334 |
-
rel_path = os.path.relpath(fpath, self.output_folder)
|
| 335 |
self.descriptions[path] = {
|
| 336 |
"data": data,
|
| 337 |
-
"relative_path":
|
| 338 |
"table_classification": "NO_TABLE",
|
| 339 |
"final_alt": ""
|
| 340 |
}
|
| 341 |
|
| 342 |
-
def
|
| 343 |
-
logger.info("Classifying images to detect tables (
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
for p, info in self.descriptions.items():
|
| 356 |
cls = info['table_classification']
|
| 357 |
if cls == "TWO_COLUMN":
|
|
@@ -360,32 +396,43 @@ class LocalImageWriter:
|
|
| 360 |
info['final_alt'] = "HAS TO BE PROCESSED - three column table"
|
| 361 |
else:
|
| 362 |
info['final_alt'] = "NO_TABLE image"
|
| 363 |
-
|
| 364 |
-
#replace placeholders in the Markdown
|
| 365 |
for p, info in self.descriptions.items():
|
| 366 |
old_md = f""
|
| 367 |
new_md = f"![{info['final_alt']}]({info['relative_path']})"
|
| 368 |
md_content = md_content.replace(old_md, new_md)
|
| 369 |
|
| 370 |
-
# IF any table images => extract rows
|
| 371 |
md_content = self._process_table_images_in_markdown(md_content)
|
| 372 |
-
|
| 373 |
-
# Keep only lines that are image references
|
| 374 |
final_lines = []
|
| 375 |
for line in md_content.split("\n"):
|
| 376 |
if re.match(r"^\!\[.*\]\(.*\)", line.strip()):
|
| 377 |
final_lines.append(line.strip())
|
| 378 |
return "\n".join(final_lines)
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
def _process_table_images_in_markdown(self, md_content: str) -> str:
|
| 381 |
pat = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)"
|
| 382 |
matches = re.findall(pat, md_content, flags=re.IGNORECASE)
|
| 383 |
if not matches:
|
| 384 |
return md_content
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
try:
|
| 390 |
if col_type.lower() == 'two':
|
| 391 |
extractor = TableExtractor(
|
|
@@ -401,25 +448,31 @@ class LocalImageWriter:
|
|
| 401 |
enable_subtopic_merge=False,
|
| 402 |
subtopic_threshold=0.2
|
| 403 |
)
|
| 404 |
-
row_boxes = extractor.process_image(
|
| 405 |
-
|
|
|
|
| 406 |
os.makedirs(out_folder, exist_ok=True)
|
| 407 |
-
|
|
|
|
| 408 |
|
| 409 |
snippet = ["**Extracted table cells:**"]
|
| 410 |
for i, row in enumerate(row_boxes):
|
| 411 |
row_dir = os.path.join(out_folder, f"row_{i}")
|
| 412 |
for j, _ in enumerate(row):
|
| 413 |
-
cell_file = f"col_{j}.
|
| 414 |
cell_path = os.path.join(row_dir, cell_file)
|
| 415 |
relp = os.path.relpath(cell_path, self.output_folder)
|
| 416 |
snippet.append(f"")
|
| 417 |
-
|
| 418 |
new_snip = "\n".join(snippet)
|
| 419 |
-
|
|
|
|
|
|
|
| 420 |
md_content = md_content.replace(old_line, new_snip)
|
| 421 |
except Exception as e:
|
| 422 |
-
logger.error(f"Error processing table image {
|
|
|
|
|
|
|
|
|
|
| 423 |
return md_content
|
| 424 |
|
| 425 |
class MineruNoTextProcessor:
|
|
@@ -442,14 +495,14 @@ class MineruNoTextProcessor:
|
|
| 442 |
logger.error(f"Error during GPU cleanup: {e}")
|
| 443 |
|
| 444 |
def process(self, pdf_path: str) -> str:
|
|
|
|
| 445 |
try:
|
| 446 |
-
#Extract subtopics from Gemini
|
| 447 |
subtopics = self.subtopic_extractor.extract_subtopics(pdf_path)
|
| 448 |
-
logger.info(f"Gemini returned subtopics: {subtopics}")
|
| 449 |
|
| 450 |
-
|
| 451 |
with open(pdf_path, "rb") as f:
|
| 452 |
pdf_bytes = f.read()
|
|
|
|
| 453 |
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
| 454 |
total_pages = doc.page_count
|
| 455 |
doc.close()
|
|
@@ -459,17 +512,14 @@ class MineruNoTextProcessor:
|
|
| 459 |
logger.warning("No subtopics found. We'll process the entire PDF as fallback.")
|
| 460 |
final_pages = set(range(total_pages))
|
| 461 |
else:
|
| 462 |
-
# For each subtopic, find occurrence >= (start_p-1)
|
| 463 |
for subname, rng in subtopics.items():
|
| 464 |
if not (isinstance(rng, list) and len(rng) == 2):
|
| 465 |
logger.warning(f"Skipping subtopic '{subname}' => invalid range {rng}")
|
| 466 |
continue
|
| 467 |
start_p, end_p = rng
|
| 468 |
if start_p > end_p:
|
| 469 |
-
logger.warning(f"Skipping subtopic '{subname}' => start> end {rng}")
|
| 470 |
continue
|
| 471 |
-
|
| 472 |
-
# find occurrences
|
| 473 |
occs = find_all_occurrences(pdf_bytes, subname)
|
| 474 |
logger.info(f"Occurrences of subtopic '{subname}': {occs}")
|
| 475 |
doc_start_0 = start_p - 1
|
|
@@ -479,7 +529,6 @@ class MineruNoTextProcessor:
|
|
| 479 |
chosen_page = p
|
| 480 |
break
|
| 481 |
if chosen_page is None:
|
| 482 |
-
# fallback to last or 0
|
| 483 |
if occs:
|
| 484 |
chosen_page = occs[-1]
|
| 485 |
logger.warning(f"No occurrence >= {doc_start_0} for '{subname}'. Using last => {chosen_page}")
|
|
@@ -496,15 +545,13 @@ class MineruNoTextProcessor:
|
|
| 496 |
e0 = max(0, min(total_pages - 1, e0))
|
| 497 |
for pp in range(s0, e0 + 1):
|
| 498 |
final_pages.add(pp)
|
| 499 |
-
|
| 500 |
if not final_pages:
|
| 501 |
logger.warning("No valid pages after offset. We'll process entire PDF.")
|
| 502 |
final_pages = set(range(total_pages))
|
| 503 |
-
|
| 504 |
logger.info(f"Processing pages (0-based): {sorted(final_pages)}")
|
|
|
|
| 505 |
subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages))
|
| 506 |
-
|
| 507 |
-
# doc_analyze => concurrency => final MD
|
| 508 |
dataset = PymuDocDataset(subset_pdf_bytes)
|
| 509 |
inference = doc_analyze(
|
| 510 |
dataset,
|
|
@@ -515,18 +562,18 @@ class MineruNoTextProcessor:
|
|
| 515 |
table_enable=self.table_enable
|
| 516 |
)
|
| 517 |
logger.info("doc_analyze complete. Extracting images.")
|
| 518 |
-
|
| 519 |
writer = LocalImageWriter(self.output_folder, self.gemini_api_key)
|
| 520 |
pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
|
| 521 |
md_content = pipe_result.get_markdown("local-unique-prefix/")
|
| 522 |
-
|
| 523 |
final_markdown = writer.post_process("local-unique-prefix/", md_content)
|
| 524 |
-
|
| 525 |
out_path = os.path.join(self.output_folder, "final_output.md")
|
|
|
|
| 526 |
with open(out_path, "w", encoding="utf-8") as f:
|
| 527 |
f.write(final_markdown)
|
| 528 |
-
|
| 529 |
return final_markdown
|
|
|
|
| 530 |
finally:
|
| 531 |
self.cleanup_gpu()
|
| 532 |
|
|
@@ -537,7 +584,5 @@ if __name__ == "__main__":
|
|
| 537 |
try:
|
| 538 |
processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key)
|
| 539 |
md_output = processor.process(input_pdf)
|
| 540 |
-
print("Final Markdown Output:")
|
| 541 |
-
print(md_output)
|
| 542 |
except Exception as e:
|
| 543 |
logger.error(f"Processing failed: {e}")
|
|
|
|
| 6 |
import logging
|
| 7 |
import fitz
|
| 8 |
import base64
|
| 9 |
+
import time
|
| 10 |
+
import asyncio
|
| 11 |
from io import BytesIO
|
| 12 |
from typing import List, Dict, Any
|
| 13 |
|
|
|
|
| 20 |
|
| 21 |
from magic_pdf.data.dataset import PymuDocDataset
|
| 22 |
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
| 23 |
+
from magic_pdf.data.data_reader_writer.base import DataWriter
|
| 24 |
from table_row_extraction import TableExtractor
|
| 25 |
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
logger.setLevel(logging.INFO)
|
| 29 |
+
file_handler = logging.FileHandler("topic_extraction.log")
|
| 30 |
+
file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s"))
|
| 31 |
+
logger.addHandler(file_handler)
|
| 32 |
|
| 33 |
_GEMINI_CLIENT = None
|
| 34 |
+
|
| 35 |
+
def preprocess_image(image_data: bytes, max_dim: int = 600, quality: int = 60) -> bytes:
|
| 36 |
+
"""
|
| 37 |
+
Downscale the image to reduce payload size.
|
| 38 |
+
"""
|
| 39 |
+
arr = np.frombuffer(image_data, np.uint8)
|
| 40 |
+
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
| 41 |
+
if img is not None:
|
| 42 |
+
h, w, _ = img.shape
|
| 43 |
+
if max(h, w) > max_dim:
|
| 44 |
+
scale = max_dim / float(max(h, w))
|
| 45 |
+
new_w = int(w * scale)
|
| 46 |
+
new_h = int(h * scale)
|
| 47 |
+
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 48 |
+
encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
| 49 |
+
success, enc = cv2.imencode(".jpg", img, encode_params)
|
| 50 |
+
if success:
|
| 51 |
+
return enc.tobytes()
|
| 52 |
+
return image_data
|
| 53 |
+
|
| 54 |
+
def call_gemini_for_table_classification(image_data: bytes, api_key: str, max_retries: int = 1) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Synchronously call the Gemini API to classify a table image.
|
| 57 |
+
"""
|
| 58 |
+
for attempt in range(max_retries + 1):
|
| 59 |
+
try:
|
| 60 |
+
prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns.
|
| 61 |
+
The three-column 'table' image include such key features:
|
| 62 |
+
- Three columns header columns
|
| 63 |
+
- Headers like 'Topics', 'Content', 'Guidelines'
|
| 64 |
+
- Numbered sections (e.g., 8.4, 9.1)
|
| 65 |
+
- Educational curriculum-style structure
|
| 66 |
+
The two-column 'table' image include such key features:
|
| 67 |
+
- Two columns header columns
|
| 68 |
+
- Headers like 'Subject content' and 'Additional information'
|
| 69 |
+
- Numbered sections (e.g., 2.1, 3.4)
|
| 70 |
+
- Educational curriculum-style structure
|
| 71 |
+
- Bullet description in 'Additional information'
|
| 72 |
+
If the image is a relevant table with 2 columns, respond with 'TWO_COLUMN'.
|
| 73 |
+
If the image is a relevant table with 3 columns, respond with 'THREE_COLUMN'.
|
| 74 |
+
If the image does not show a table at all, respond with 'NO_TABLE'.
|
| 75 |
+
Return only one of these exact labels as your entire response:
|
| 76 |
+
TWO_COLUMN
|
| 77 |
+
THREE_COLUMN
|
| 78 |
+
NO_TABLE
|
| 79 |
+
"""
|
| 80 |
+
global _GEMINI_CLIENT
|
| 81 |
+
client = _GEMINI_CLIENT
|
| 82 |
+
resp = client.models.generate_content(
|
| 83 |
+
model="gemini-2.0-flash",
|
| 84 |
+
contents=[
|
| 85 |
+
{
|
| 86 |
+
"parts": [
|
| 87 |
+
{"text": prompt},
|
| 88 |
+
{
|
| 89 |
+
"inline_data": {
|
| 90 |
+
"mime_type": "image/jpeg",
|
| 91 |
+
"data": base64.b64encode(image_data).decode('utf-8')
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
]
|
| 95 |
+
}
|
| 96 |
+
],
|
| 97 |
+
config=types.GenerateContentConfig(temperature=0.0)
|
| 98 |
+
)
|
| 99 |
+
if resp and resp.text:
|
| 100 |
+
classification = resp.text.strip().upper()
|
| 101 |
+
if "THREE" in classification:
|
| 102 |
+
return "THREE_COLUMN"
|
| 103 |
+
elif "TWO" in classification:
|
| 104 |
+
return "TWO_COLUMN"
|
| 105 |
+
return "NO_TABLE"
|
| 106 |
+
except Exception as e:
|
| 107 |
+
error_msg = str(e)
|
| 108 |
+
logger.error(f"Gemini table classification error: {error_msg}")
|
| 109 |
+
if "503" in error_msg:
|
| 110 |
+
return "NO_TABLE"
|
| 111 |
+
if attempt < max_retries:
|
| 112 |
+
logger.warning("Retrying classification due to error... attempt %d", attempt + 1)
|
| 113 |
+
time.sleep(0.5)
|
| 114 |
+
else:
|
| 115 |
+
return "NO_TABLE"
|
| 116 |
+
|
| 117 |
+
async def classify_image_async(image_data: bytes, api_key: str, max_retries: int = 1) -> str:
|
| 118 |
+
"""
|
| 119 |
+
Asynchronous wrapper for image classification.
|
| 120 |
+
"""
|
| 121 |
+
loop = asyncio.get_event_loop()
|
| 122 |
+
preprocessed = preprocess_image(image_data)
|
| 123 |
+
return await loop.run_in_executor(None, call_gemini_for_table_classification, preprocessed, api_key, max_retries)
|
| 124 |
|
| 125 |
def unify_whitespace(text: str) -> str:
|
| 126 |
return re.sub(r"\s+", " ", text).strip().lower()
|
| 127 |
|
| 128 |
+
def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]:
|
| 129 |
+
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
| 130 |
+
st_norm = unify_whitespace(search_text)
|
| 131 |
+
found = []
|
| 132 |
+
for i in range(doc.page_count):
|
| 133 |
+
raw = doc[i].get_text("raw")
|
| 134 |
+
norm = unify_whitespace(raw)
|
| 135 |
+
if st_norm in norm:
|
| 136 |
+
found.append(i)
|
| 137 |
+
doc.close()
|
| 138 |
+
return sorted(found)
|
| 139 |
+
|
| 140 |
def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes:
|
| 141 |
if not page_indices:
|
| 142 |
raise ValueError("No page indices provided for subset creation.")
|
|
|
|
| 153 |
doc.close()
|
| 154 |
return subset_bytes
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
class GeminiTopicExtractor:
|
| 157 |
def __init__(self, api_key: str = None, num_pages: int = 10):
|
| 158 |
self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
|
| 159 |
self.num_pages = num_pages
|
| 160 |
|
| 161 |
def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]:
|
|
|
|
|
|
|
|
|
|
| 162 |
first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages)
|
| 163 |
if not first_pages_text.strip():
|
| 164 |
logger.error("No text from first pages => cannot extract subtopics.")
|
| 165 |
return {}
|
| 166 |
+
|
| 167 |
prompt = f"""
|
| 168 |
You have the first pages of a PDF specification, including a table of contents.
|
| 169 |
|
| 170 |
Instructions:
|
| 171 |
1. Identify the 'Contents' section listing all topics, subtopics, and their corresponding pages.
|
| 172 |
+
2. Identify the major academic subtopics (common desired topic names "Paper X", "Theme X", "Content of X", "AS Unit X", "A2 Unit X", or similar headings).
|
| 173 |
3. For each subtopic, give the range of pages [start_page, end_page] (1-based) from the table of contents.
|
| 174 |
4. Output only valid JSON of the form:
|
| 175 |
{{
|
|
|
|
| 179 |
5. If you can't find any subtopics, return an empty JSON.
|
| 180 |
|
| 181 |
Important notes:
|
| 182 |
+
- The correct "end_page" must be the page number of the next topic or subtopic minus 1.
|
| 183 |
+
- The final output must be valid JSON only, with no extra text or code blocks.
|
| 184 |
|
| 185 |
Examples:
|
| 186 |
|
|
|
|
| 254 |
"Theme 4: A global perspective": [29, 38]
|
| 255 |
}}
|
| 256 |
|
| 257 |
+
3. You might also see sections like:
|
| 258 |
+
|
| 259 |
+
2.1 AS Unit 1 11
|
| 260 |
+
2.2 AS Unit 2 18
|
| 261 |
+
2.3 A2 Unit 3 24
|
| 262 |
+
2.4 A2 Unit 4 31
|
| 263 |
+
|
| 264 |
+
In that scenario, your output might look like:
|
| 265 |
+
|
| 266 |
+
{{
|
| 267 |
+
"AS Unit 1": [11, 17],
|
| 268 |
+
"AS Unit 2": [18, 23],
|
| 269 |
+
"A2 Unit 3": [24, 30],
|
| 270 |
+
"A2 Unit 4": [31, 35]
|
| 271 |
+
}}
|
| 272 |
+
|
| 273 |
+
4. Another example might list subtopics:
|
| 274 |
+
|
| 275 |
+
3.1 Overarching themes 11
|
| 276 |
+
3.2 A: Proof 12
|
| 277 |
+
3.3 B: Algebra and functions 13
|
| 278 |
+
3.4 C: Coordinate geometry in the ( x , y ) plane 14
|
| 279 |
+
3.5 D: Sequences and series 15
|
| 280 |
+
3.6 E: Trigonometry 16
|
| 281 |
+
3.7 F: Exponentials and logarithms 17
|
| 282 |
+
3.8 G: Differentiation 18
|
| 283 |
+
3.9 H: Integration 19
|
| 284 |
+
3.10 I: Numerical methods 20
|
| 285 |
+
3.11 J: Vectors 20
|
| 286 |
+
3.12 K: Statistical sampling 21
|
| 287 |
+
3.13 L: Data presentation and interpretation 21
|
| 288 |
+
3.14 M: Probability 22
|
| 289 |
+
3.15 N: Statistical distributions 23
|
| 290 |
+
3.16 O: Statistical hypothesis testing 23
|
| 291 |
+
3.17 P: Quantities and units in mechanics 24
|
| 292 |
+
3.18 Q: Kinematics 24
|
| 293 |
+
3.19 R: Forces and Newton’s laws 24
|
| 294 |
+
3.20 S: Moments 25
|
| 295 |
+
3.21 Use of data in statistics 26
|
| 296 |
+
|
| 297 |
+
Here the correct output might look like:
|
| 298 |
+
|
| 299 |
+
{{
|
| 300 |
+
"A: Proof": [12, 12],
|
| 301 |
+
"B: Algebra and functions": [13, 13],
|
| 302 |
+
...
|
| 303 |
+
}}
|
| 304 |
+
|
| 305 |
Now, extract topics from this text:
|
| 306 |
{first_pages_text}
|
| 307 |
"""
|
|
|
|
| 321 |
|
| 322 |
raw_json = response.text.strip()
|
| 323 |
cleaned = raw_json.replace("```json", "").replace("```", "")
|
| 324 |
+
try:
|
| 325 |
+
data = json.loads(cleaned)
|
| 326 |
+
except Exception as json_err:
|
| 327 |
+
logger.error(f"JSON parsing error: {json_err}")
|
| 328 |
+
return {}
|
|
|
|
|
|
|
| 329 |
final_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
found_sub_dict = None
|
| 331 |
for k, v in data.items():
|
| 332 |
if isinstance(v, dict):
|
|
|
|
| 333 |
found_sub_dict = v
|
| 334 |
break
|
| 335 |
if found_sub_dict is not None:
|
|
|
|
| 337 |
if isinstance(rng, list) and len(rng) == 2:
|
| 338 |
final_dict[subk] = rng
|
| 339 |
else:
|
|
|
|
|
|
|
| 340 |
for subk, rng in data.items():
|
| 341 |
if isinstance(rng, list) and len(rng) == 2:
|
| 342 |
final_dict[subk] = rng
|
|
|
|
| 358 |
logger.error(f"Could not open PDF: {e}")
|
| 359 |
return "\n".join(text_parts)
|
| 360 |
|
| 361 |
+
class LocalImageWriter(DataWriter):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
def __init__(self, output_folder: str, gemini_api_key: str):
|
| 363 |
self.output_folder = output_folder
|
| 364 |
os.makedirs(self.output_folder, exist_ok=True)
|
|
|
|
|
|
|
| 365 |
self.descriptions = {}
|
| 366 |
self._img_count = 0
|
| 367 |
self.gemini_api_key = gemini_api_key
|
| 368 |
|
| 369 |
def write(self, path: str, data: bytes) -> None:
|
| 370 |
self._img_count += 1
|
| 371 |
+
unique_id = f"img_{self._img_count}.jpg"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
self.descriptions[path] = {
|
| 373 |
"data": data,
|
| 374 |
+
"relative_path": unique_id,
|
| 375 |
"table_classification": "NO_TABLE",
|
| 376 |
"final_alt": ""
|
| 377 |
}
|
| 378 |
|
| 379 |
+
async def post_process_async(self, key: str, md_content: str) -> str:
|
| 380 |
+
logger.info("Classifying images to detect tables (async).")
|
| 381 |
+
tasks = []
|
| 382 |
+
for p, info in self.descriptions.items():
|
| 383 |
+
tasks.append((p, classify_image_async(info["data"], self.gemini_api_key)))
|
| 384 |
+
for p, task in tasks:
|
| 385 |
+
try:
|
| 386 |
+
classification = await task
|
| 387 |
+
self.descriptions[p]['table_classification'] = classification
|
| 388 |
+
except Exception as e:
|
| 389 |
+
logger.error(f"Table classification error: {e}")
|
| 390 |
+
self.descriptions[p]['table_classification'] = "NO_TABLE"
|
|
|
|
| 391 |
for p, info in self.descriptions.items():
|
| 392 |
cls = info['table_classification']
|
| 393 |
if cls == "TWO_COLUMN":
|
|
|
|
| 396 |
info['final_alt'] = "HAS TO BE PROCESSED - three column table"
|
| 397 |
else:
|
| 398 |
info['final_alt'] = "NO_TABLE image"
|
|
|
|
|
|
|
| 399 |
for p, info in self.descriptions.items():
|
| 400 |
old_md = f""
|
| 401 |
new_md = f"![{info['final_alt']}]({info['relative_path']})"
|
| 402 |
md_content = md_content.replace(old_md, new_md)
|
| 403 |
|
|
|
|
| 404 |
md_content = self._process_table_images_in_markdown(md_content)
|
|
|
|
|
|
|
| 405 |
final_lines = []
|
| 406 |
for line in md_content.split("\n"):
|
| 407 |
if re.match(r"^\!\[.*\]\(.*\)", line.strip()):
|
| 408 |
final_lines.append(line.strip())
|
| 409 |
return "\n".join(final_lines)
|
| 410 |
|
| 411 |
+
def post_process(self, key: str, md_content: str) -> str:
|
| 412 |
+
"""
|
| 413 |
+
Synchronous wrapper around the asynchronous post_process_async.
|
| 414 |
+
"""
|
| 415 |
+
return asyncio.run(self.post_process_async(key, md_content))
|
| 416 |
+
|
| 417 |
def _process_table_images_in_markdown(self, md_content: str) -> str:
|
| 418 |
pat = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)"
|
| 419 |
matches = re.findall(pat, md_content, flags=re.IGNORECASE)
|
| 420 |
if not matches:
|
| 421 |
return md_content
|
| 422 |
+
for (col_type, image_id) in matches:
|
| 423 |
+
logger.info(f"Processing table image => {image_id}, columns={col_type}")
|
| 424 |
+
temp_path = os.path.join(self.output_folder, image_id)
|
| 425 |
+
desc_item = None
|
| 426 |
+
for k, val in self.descriptions.items():
|
| 427 |
+
if val["relative_path"] == image_id:
|
| 428 |
+
desc_item = val
|
| 429 |
+
break
|
| 430 |
+
if not desc_item:
|
| 431 |
+
logger.warning(f"No matching image data for {image_id}, skipping extraction.")
|
| 432 |
+
continue
|
| 433 |
+
if not os.path.exists(temp_path):
|
| 434 |
+
with open(temp_path, "wb") as f:
|
| 435 |
+
f.write(desc_item["data"])
|
| 436 |
try:
|
| 437 |
if col_type.lower() == 'two':
|
| 438 |
extractor = TableExtractor(
|
|
|
|
| 448 |
enable_subtopic_merge=False,
|
| 449 |
subtopic_threshold=0.2
|
| 450 |
)
|
| 451 |
+
row_boxes = extractor.process_image(temp_path)
|
| 452 |
+
|
| 453 |
+
out_folder = temp_path + "_rows"
|
| 454 |
os.makedirs(out_folder, exist_ok=True)
|
| 455 |
+
|
| 456 |
+
extractor.save_extracted_cells(temp_path, row_boxes, out_folder)
|
| 457 |
|
| 458 |
snippet = ["**Extracted table cells:**"]
|
| 459 |
for i, row in enumerate(row_boxes):
|
| 460 |
row_dir = os.path.join(out_folder, f"row_{i}")
|
| 461 |
for j, _ in enumerate(row):
|
| 462 |
+
cell_file = f"col_{j}.jpg"
|
| 463 |
cell_path = os.path.join(row_dir, cell_file)
|
| 464 |
relp = os.path.relpath(cell_path, self.output_folder)
|
| 465 |
snippet.append(f"")
|
|
|
|
| 466 |
new_snip = "\n".join(snippet)
|
| 467 |
+
|
| 468 |
+
old_line = f""
|
| 469 |
+
|
| 470 |
md_content = md_content.replace(old_line, new_snip)
|
| 471 |
except Exception as e:
|
| 472 |
+
logger.error(f"Error processing table image {image_id}: {e}")
|
| 473 |
+
finally:
|
| 474 |
+
if os.path.exists(temp_path):
|
| 475 |
+
os.remove(temp_path)
|
| 476 |
return md_content
|
| 477 |
|
| 478 |
class MineruNoTextProcessor:
|
|
|
|
| 495 |
logger.error(f"Error during GPU cleanup: {e}")
|
| 496 |
|
| 497 |
def process(self, pdf_path: str) -> str:
|
| 498 |
+
logger.info(f"Processing PDF: {pdf_path}")
|
| 499 |
try:
|
|
|
|
| 500 |
subtopics = self.subtopic_extractor.extract_subtopics(pdf_path)
|
|
|
|
| 501 |
|
| 502 |
+
logger.info(f"Gemini returned subtopics: {subtopics}")
|
| 503 |
with open(pdf_path, "rb") as f:
|
| 504 |
pdf_bytes = f.read()
|
| 505 |
+
|
| 506 |
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
| 507 |
total_pages = doc.page_count
|
| 508 |
doc.close()
|
|
|
|
| 512 |
logger.warning("No subtopics found. We'll process the entire PDF as fallback.")
|
| 513 |
final_pages = set(range(total_pages))
|
| 514 |
else:
|
|
|
|
| 515 |
for subname, rng in subtopics.items():
|
| 516 |
if not (isinstance(rng, list) and len(rng) == 2):
|
| 517 |
logger.warning(f"Skipping subtopic '{subname}' => invalid range {rng}")
|
| 518 |
continue
|
| 519 |
start_p, end_p = rng
|
| 520 |
if start_p > end_p:
|
| 521 |
+
logger.warning(f"Skipping subtopic '{subname}' => start > end {rng}")
|
| 522 |
continue
|
|
|
|
|
|
|
| 523 |
occs = find_all_occurrences(pdf_bytes, subname)
|
| 524 |
logger.info(f"Occurrences of subtopic '{subname}': {occs}")
|
| 525 |
doc_start_0 = start_p - 1
|
|
|
|
| 529 |
chosen_page = p
|
| 530 |
break
|
| 531 |
if chosen_page is None:
|
|
|
|
| 532 |
if occs:
|
| 533 |
chosen_page = occs[-1]
|
| 534 |
logger.warning(f"No occurrence >= {doc_start_0} for '{subname}'. Using last => {chosen_page}")
|
|
|
|
| 545 |
e0 = max(0, min(total_pages - 1, e0))
|
| 546 |
for pp in range(s0, e0 + 1):
|
| 547 |
final_pages.add(pp)
|
|
|
|
| 548 |
if not final_pages:
|
| 549 |
logger.warning("No valid pages after offset. We'll process entire PDF.")
|
| 550 |
final_pages = set(range(total_pages))
|
|
|
|
| 551 |
logger.info(f"Processing pages (0-based): {sorted(final_pages)}")
|
| 552 |
+
|
| 553 |
subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages))
|
| 554 |
+
|
|
|
|
| 555 |
dataset = PymuDocDataset(subset_pdf_bytes)
|
| 556 |
inference = doc_analyze(
|
| 557 |
dataset,
|
|
|
|
| 562 |
table_enable=self.table_enable
|
| 563 |
)
|
| 564 |
logger.info("doc_analyze complete. Extracting images.")
|
|
|
|
| 565 |
writer = LocalImageWriter(self.output_folder, self.gemini_api_key)
|
| 566 |
pipe_result = inference.pipe_ocr_mode(writer, lang=self.language)
|
| 567 |
md_content = pipe_result.get_markdown("local-unique-prefix/")
|
| 568 |
+
|
| 569 |
final_markdown = writer.post_process("local-unique-prefix/", md_content)
|
|
|
|
| 570 |
out_path = os.path.join(self.output_folder, "final_output.md")
|
| 571 |
+
|
| 572 |
with open(out_path, "w", encoding="utf-8") as f:
|
| 573 |
f.write(final_markdown)
|
| 574 |
+
|
| 575 |
return final_markdown
|
| 576 |
+
|
| 577 |
finally:
|
| 578 |
self.cleanup_gpu()
|
| 579 |
|
|
|
|
| 584 |
try:
|
| 585 |
processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key)
|
| 586 |
md_output = processor.process(input_pdf)
|
|
|
|
|
|
|
| 587 |
except Exception as e:
|
| 588 |
logger.error(f"Processing failed: {e}")
|