|
|
|
|
|
import os |
|
|
import re |
|
|
import gc |
|
|
import json |
|
|
import logging |
|
|
import fitz |
|
|
import boto3 |
|
|
import base64 |
|
|
import time |
|
|
import asyncio |
|
|
import tempfile |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
from google import genai |
|
|
from google.genai import types |
|
|
|
|
|
from magic_pdf.data.dataset import PymuDocDataset |
|
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze |
|
|
from magic_pdf.data.data_reader_writer.base import DataWriter |
|
|
from table_row_extraction import TableExtractor |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(logging.INFO) |
|
|
file_handler = logging.FileHandler("topic_extraction.log") |
|
|
file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s")) |
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
_GEMINI_CLIENT = None |
|
|
|
|
|
def unify_whitespace(text: str) -> str: |
|
|
return re.sub(r"\s+", " ", text).strip() |
|
|
|
|
|
def find_all_occurrences(pdf_bytes: bytes, search_text: str) -> List[int]: |
|
|
doc = fitz.open(stream=pdf_bytes, filetype="pdf") |
|
|
st_norm = unify_whitespace(search_text) |
|
|
found = [] |
|
|
for i in range(doc.page_count): |
|
|
raw = doc[i].get_text("raw") |
|
|
norm = unify_whitespace(raw) |
|
|
if st_norm in norm: |
|
|
found.append(i) |
|
|
doc.close() |
|
|
return sorted(found) |
|
|
|
|
|
def create_subset_pdf(original_pdf_bytes: bytes, page_indices: List[int]) -> bytes: |
|
|
if not page_indices: |
|
|
raise ValueError("No page indices provided for subset creation.") |
|
|
doc = fitz.open(stream=original_pdf_bytes, filetype="pdf") |
|
|
new_doc = fitz.open() |
|
|
for p in sorted(set(page_indices)): |
|
|
if 0 <= p < doc.page_count: |
|
|
new_doc.insert_pdf(doc, from_page=p, to_page=p) |
|
|
else: |
|
|
logger.error(f"Page index {p} out of range (0..{doc.page_count - 1}).") |
|
|
raise ValueError(f"Page index {p} out of range.") |
|
|
subset_bytes = new_doc.tobytes() |
|
|
new_doc.close() |
|
|
doc.close() |
|
|
return subset_bytes |
|
|
|
|
|
class s3Writer: |
|
|
def __init__(self, ak: str, sk: str, bucket: str, endpoint_url: str): |
|
|
self.bucket = bucket |
|
|
self.client = boto3.client( |
|
|
's3', |
|
|
aws_access_key_id=ak, |
|
|
aws_secret_access_key=sk, |
|
|
endpoint_url=endpoint_url |
|
|
) |
|
|
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
|
try: |
|
|
file_obj = BytesIO(data) |
|
|
self.client.upload_fileobj( |
|
|
file_obj, |
|
|
self.bucket, |
|
|
path |
|
|
) |
|
|
logger.info(f"Uploaded to S3: {path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to upload to S3: {str(e)}") |
|
|
raise |
|
|
|
|
|
def preprocess_image(image_data: bytes, max_dim: int = 600, quality: int = 60) -> bytes: |
|
|
arr = np.frombuffer(image_data, np.uint8) |
|
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR) |
|
|
if img is not None: |
|
|
h, w, _ = img.shape |
|
|
if max(h, w) > max_dim: |
|
|
scale = max_dim / float(max(h, w)) |
|
|
new_w = int(w * scale) |
|
|
new_h = int(h * scale) |
|
|
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality] |
|
|
success, enc = cv2.imencode(".jpg", img, encode_params) |
|
|
if success: |
|
|
return enc.tobytes() |
|
|
return image_data |
|
|
|
|
|
def call_gemini_for_table_classification(image_data: bytes, api_key: str, max_retries: int = 1) -> str: |
|
|
""" |
|
|
Existing Gemini call to classify an image as TWO_COLUMN, THREE_COLUMN, or NO_TABLE. |
|
|
""" |
|
|
for attempt in range(max_retries + 1): |
|
|
try: |
|
|
prompt = """You are given an image. Determine if it shows a table that has exactly 2 or 3 columns. |
|
|
The three-column 'table' image includes such key features: |
|
|
- Three columns header |
|
|
- Headers like 'Topics', 'Content', 'Guidelines' |
|
|
- Possibly sections (e.g. 8.4, 9.1) |
|
|
The two-column 'table' image includes such key features: |
|
|
- Two columns |
|
|
- Headers like 'Subject content' and 'Additional information' |
|
|
- Possibly sections (e.g. 2.1, 3.4) |
|
|
If the image is a relevant table with 2 columns, respond with 'TWO_COLUMN'. |
|
|
If the image is a relevant table with 3 columns, respond with 'THREE_COLUMN'. |
|
|
If the image does not show a table at all, respond with 'NO_TABLE'. |
|
|
Return only one of these exact labels. |
|
|
""" |
|
|
global _GEMINI_CLIENT |
|
|
if _GEMINI_CLIENT is None: |
|
|
_GEMINI_CLIENT = genai.Client(api_key=api_key) |
|
|
client = _GEMINI_CLIENT |
|
|
|
|
|
resp = client.models.generate_content( |
|
|
model="gemini-2.0-flash", |
|
|
contents=[ |
|
|
{ |
|
|
"parts": [ |
|
|
{"text": prompt}, |
|
|
{ |
|
|
"inline_data": { |
|
|
"mime_type": "image/jpeg", |
|
|
"data": base64.b64encode(image_data).decode('utf-8') |
|
|
} |
|
|
} |
|
|
] |
|
|
} |
|
|
], |
|
|
config=types.GenerateContentConfig(temperature=0.0) |
|
|
) |
|
|
if resp and resp.text: |
|
|
classification = resp.text.strip().upper() |
|
|
if "THREE" in classification: |
|
|
return "THREE_COLUMN" |
|
|
elif "TWO" in classification: |
|
|
return "TWO_COLUMN" |
|
|
return "NO_TABLE" |
|
|
except Exception as e: |
|
|
logger.error(f"Gemini table classification error: {e}") |
|
|
if "503" in str(e): |
|
|
return "NO_TABLE" |
|
|
if attempt < max_retries: |
|
|
time.sleep(0.5) |
|
|
else: |
|
|
return "NO_TABLE" |
|
|
|
|
|
async def classify_image_async(image_data: bytes, api_key: str, max_retries: int = 1) -> str: |
|
|
loop = asyncio.get_event_loop() |
|
|
preprocessed = preprocess_image(image_data) |
|
|
return await loop.run_in_executor(None, call_gemini_for_table_classification, preprocessed, api_key, max_retries) |
|
|
|
|
|
def call_gemini_for_subtopic_identification_image(image_data: bytes, api_key: str, max_retries: int = 1) -> dict: |
|
|
for attempt in range(max_retries + 1): |
|
|
try: |
|
|
prompt = """ |
|
|
You are given an image from an educational curriculum specification. The image may contain either: |
|
|
1) A main topic heading in the format: "<number> <Topic Name>", for example "2 Algebra and functions continued". |
|
|
2) A subtopic heading in the format "<number>.<number>", for example "2.5", "2.6", or "3.4". |
|
|
3) Possibly no relevant text at all. |
|
|
|
|
|
Your task: |
|
|
1. If the cell shows a main topic, extract the topic name (e.g. "2 Algebra and functions") and place it in the JSON key "title". |
|
|
2. If the cell shows one or more subtopic numbers (e.g. "2.5", "2.6"), collect them in the JSON key "subtopics" as an array of strings. |
|
|
3. If neither a main topic nor subtopic is detected, return empty values. |
|
|
|
|
|
Output only valid JSON in this exact structure, with no extra text or explanation: |
|
|
|
|
|
Output only valid JSON in this exact structure, with no extra text or explanation: |
|
|
|
|
|
{ |
|
|
"title": "...", |
|
|
"subtopics": [...] |
|
|
} |
|
|
|
|
|
Where: |
|
|
- "title" is the recognized main topic (if any). Otherwise, an empty string. |
|
|
- "subtopics" is an array of recognized subtopic numbers (e.g. ["2.5", "2.6"]). Otherwise, an empty array. |
|
|
|
|
|
Examples: |
|
|
1. If the image text is "2 Algebra and functions continued", return: |
|
|
{ |
|
|
"title": "2 Algebra and functions continued", |
|
|
"subtopics": [] |
|
|
} |
|
|
|
|
|
2. If the image text is "2.5 Solve linear and quadratic inequalities ...", return: |
|
|
{ |
|
|
"title": "", |
|
|
"subtopics": ["2.5"] |
|
|
} |
|
|
|
|
|
3. If the image text is "2.6 Manipulate polynomials algebraically ...", return: |
|
|
{ |
|
|
"title": "", |
|
|
"subtopics": ["2.6"] |
|
|
} |
|
|
|
|
|
If you cannot recognize any text matching these patterns, or if nothing is found, return: |
|
|
{ |
|
|
"title": "", |
|
|
"subtopics": [] |
|
|
} |
|
|
""" |
|
|
global _GEMINI_CLIENT |
|
|
if _GEMINI_CLIENT is None: |
|
|
_GEMINI_CLIENT = genai.Client(api_key=api_key) |
|
|
client = _GEMINI_CLIENT |
|
|
|
|
|
resp = client.models.generate_content( |
|
|
model="gemini-2.0-flash", |
|
|
contents=[ |
|
|
{ |
|
|
"parts": [ |
|
|
{"text": prompt}, |
|
|
{ |
|
|
"inline_data": { |
|
|
"mime_type": "image/jpeg", |
|
|
"data": base64.b64encode(image_data).decode("utf-8") |
|
|
} |
|
|
} |
|
|
] |
|
|
} |
|
|
], |
|
|
config=types.GenerateContentConfig(temperature=0.0) |
|
|
) |
|
|
|
|
|
|
|
|
if not resp or not resp.text: |
|
|
logger.warning("Gemini returned an empty response for subtopic extraction.") |
|
|
return {"title": "", "subtopics": []} |
|
|
|
|
|
raw = resp.text.strip() |
|
|
raw = raw.replace("```json", "").replace("```", "").strip() |
|
|
data = json.loads(raw) |
|
|
|
|
|
title = data.get("title", "") |
|
|
subtopics = data.get("subtopics", []) |
|
|
if not isinstance(subtopics, list): |
|
|
subtopics = [] |
|
|
return {"title": title, "subtopics": subtopics} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Gemini subtopic identification error on attempt {attempt}: {e}") |
|
|
if attempt < max_retries: |
|
|
time.sleep(0.5) |
|
|
else: |
|
|
return {"title": "", "subtopics": []} |
|
|
|
|
|
return {"title": "", "subtopics": []} |
|
|
|
|
|
class S3ImageWriter(DataWriter): |
|
|
def __init__(self, s3_writer: s3Writer, base_path: str, gemini_api_key: str): |
|
|
self.s3_writer = s3_writer |
|
|
self.base_path = base_path if base_path.endswith("/") else base_path + "/" |
|
|
self.gemini_api_key = gemini_api_key |
|
|
self.descriptions = {} |
|
|
self._img_count = 0 |
|
|
self.extracted_tables = {} |
|
|
|
|
|
self.extracted_subtopics = {} |
|
|
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
|
self._img_count += 1 |
|
|
unique_id = f"img_{self._img_count}.jpg" |
|
|
s3_key = f"{self.base_path}{unique_id}" |
|
|
self.s3_writer.write(s3_key, data) |
|
|
self.descriptions[path] = { |
|
|
"data": data, |
|
|
"s3_path": s3_key, |
|
|
"table_classification": "NO_TABLE", |
|
|
"final_alt": "" |
|
|
} |
|
|
|
|
|
async def post_process_async(self, key: str, md_content: str) -> str: |
|
|
logger.info("Classifying images to detect tables.") |
|
|
tasks = { |
|
|
p: asyncio.create_task(classify_image_async(info["data"], self.gemini_api_key)) |
|
|
for p, info in self.descriptions.items() |
|
|
} |
|
|
results = await asyncio.gather(*tasks.values(), return_exceptions=True) |
|
|
for p, result in zip(tasks.keys(), results): |
|
|
if isinstance(result, Exception): |
|
|
logger.error(f"Table classification error for {p}: {result}") |
|
|
self.descriptions[p]['table_classification'] = "NO_TABLE" |
|
|
else: |
|
|
self.descriptions[p]['table_classification'] = result |
|
|
|
|
|
for p, info in self.descriptions.items(): |
|
|
cls = info['table_classification'] |
|
|
if cls == "TWO_COLUMN": |
|
|
info['final_alt'] = "HAS TO BE PROCESSED - two column table" |
|
|
elif cls == "THREE_COLUMN": |
|
|
info['final_alt'] = "HAS TO BE PROCESSED - three column table" |
|
|
else: |
|
|
info['final_alt'] = "NO_TABLE image" |
|
|
md_content = md_content.replace(f"", f"![{info['final_alt']}]({info['s3_path']})") |
|
|
|
|
|
md_content = await self._process_table_images_in_markdown(key, md_content) |
|
|
|
|
|
|
|
|
final_lines = [ |
|
|
line.strip() for line in md_content.split("\n") |
|
|
if re.match(r"^\!\[.*\]\(.*\)", line.strip()) |
|
|
] |
|
|
return "\n".join(final_lines) |
|
|
|
|
|
async def _process_table_images_in_markdown(self, key: str, md_content: str) -> str: |
|
|
pat = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)" |
|
|
matches = re.findall(pat, md_content, flags=re.IGNORECASE) |
|
|
if not matches: |
|
|
return md_content |
|
|
|
|
|
for (col_type, s3_key) in matches: |
|
|
logger.info(f"Processing table image: {s3_key}, columns={col_type}") |
|
|
img_data = None |
|
|
for desc in self.descriptions.values(): |
|
|
if desc.get("s3_path") == s3_key: |
|
|
img_data = desc.get("data") |
|
|
break |
|
|
if img_data is None: |
|
|
logger.warning(f"No image data found for S3 key {s3_key}. Skipping.") |
|
|
continue |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: |
|
|
temp_file.write(img_data) |
|
|
temp_path = temp_file.name |
|
|
|
|
|
try: |
|
|
if col_type.lower() == 'two': |
|
|
extractor = TableExtractor( |
|
|
skip_header=True, |
|
|
merge_two_col_rows=True, |
|
|
enable_subtopic_merge=True, |
|
|
subtopic_threshold=0.2 |
|
|
) |
|
|
else: |
|
|
extractor = TableExtractor( |
|
|
skip_header=True, |
|
|
merge_two_col_rows=False, |
|
|
enable_subtopic_merge=False, |
|
|
subtopic_threshold=0.2 |
|
|
) |
|
|
row_boxes = extractor.process_image(temp_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_folder = temp_path + "_rows" |
|
|
os.makedirs(out_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
extractor.save_extracted_cells(temp_path, row_boxes, out_folder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
recognized_main_topic = "" |
|
|
main_topic_image_key = None |
|
|
recognized_subtopics = [] |
|
|
|
|
|
|
|
|
for i, row in enumerate(row_boxes): |
|
|
row_dir = os.path.join(out_folder, f"row_{i}") |
|
|
for j, _ in enumerate(row): |
|
|
cell_path = os.path.join(row_dir, f"col_{j}.png") |
|
|
if not os.path.isfile(cell_path): |
|
|
alternative_path = os.path.join(row_dir, f"col_{j}.jpg") |
|
|
if os.path.isfile(alternative_path): |
|
|
cell_path = alternative_path |
|
|
else: |
|
|
logger.warning(f"Cell image not found: {cell_path}") |
|
|
continue |
|
|
|
|
|
with open(cell_path, "rb") as cf: |
|
|
cell_image_data = cf.read() |
|
|
|
|
|
|
|
|
cell_key = f"{self.base_path}cells/{os.path.basename(s3_key)}_r{i}_c{j}.png" |
|
|
self.s3_writer.write(cell_key, cell_image_data) |
|
|
|
|
|
info = call_gemini_for_subtopic_identification_image(cell_image_data, self.gemini_api_key) |
|
|
|
|
|
|
|
|
if info["title"] and not recognized_main_topic: |
|
|
recognized_main_topic = info["title"] |
|
|
main_topic_image_key = cell_key |
|
|
|
|
|
for st in info["subtopics"]: |
|
|
recognized_subtopics.append({ |
|
|
"title": st, |
|
|
"contents": [{"type": "image", "key": cell_key}], |
|
|
"children": [] |
|
|
}) |
|
|
|
|
|
final_json = { |
|
|
"title": recognized_main_topic, |
|
|
"contents": [], |
|
|
"children": recognized_subtopics |
|
|
} |
|
|
if main_topic_image_key: |
|
|
final_json["contents"].append({"type": "image", "key": main_topic_image_key}) |
|
|
|
|
|
|
|
|
self.extracted_subtopics[s3_key] = final_json |
|
|
|
|
|
|
|
|
snippet = ["**Extracted table cells:**"] |
|
|
for i, row in enumerate(row_boxes): |
|
|
for j, _ in enumerate(row): |
|
|
snippet.append(f"}_r{i}_c{j}.jpg)") |
|
|
new_snip = "\n".join(snippet) |
|
|
old_line = f"" |
|
|
md_content = md_content.replace(old_line, new_snip) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing table image {s3_key}: {e}") |
|
|
finally: |
|
|
os.remove(temp_path) |
|
|
|
|
|
return md_content |
|
|
|
|
|
def post_process(self, key: str, md_content: str) -> str: |
|
|
return asyncio.run(self.post_process_async(key, md_content)) |
|
|
|
|
|
class LocalImageWriter(DataWriter): |
|
|
def __init__(self, output_folder: str, gemini_api_key: str): |
|
|
self.output_folder = output_folder |
|
|
os.makedirs(self.output_folder, exist_ok=True) |
|
|
self.descriptions = {} |
|
|
self._img_count = 0 |
|
|
self.gemini_api_key = gemini_api_key |
|
|
self.extracted_tables = {} |
|
|
|
|
|
def write(self, path: str, data: bytes) -> None: |
|
|
self._img_count += 1 |
|
|
unique_id = f"img_{self._img_count}.jpg" |
|
|
self.descriptions[path] = { |
|
|
"data": data, |
|
|
"relative_path": unique_id, |
|
|
"table_classification": "NO_TABLE", |
|
|
"final_alt": "" |
|
|
} |
|
|
image_path = os.path.join(self.output_folder, unique_id) |
|
|
with open(image_path, "wb") as f: |
|
|
f.write(data) |
|
|
|
|
|
async def post_process_async(self, key: str, md_content: str) -> str: |
|
|
logger.info("Classifying images to detect tables.") |
|
|
tasks = [] |
|
|
for p, info in self.descriptions.items(): |
|
|
tasks.append((p, classify_image_async(info["data"], self.gemini_api_key))) |
|
|
for p, task in tasks: |
|
|
try: |
|
|
classification = await task |
|
|
self.descriptions[p]['table_classification'] = classification |
|
|
except Exception as e: |
|
|
logger.error(f"Table classification error: {e}") |
|
|
self.descriptions[p]['table_classification'] = "NO_TABLE" |
|
|
for p, info in self.descriptions.items(): |
|
|
cls = info['table_classification'] |
|
|
if cls == "TWO_COLUMN": |
|
|
info['final_alt'] = "HAS TO BE PROCESSED - two column table" |
|
|
elif cls == "THREE_COLUMN": |
|
|
info['final_alt'] = "HAS TO BE PROCESSED - three column table" |
|
|
else: |
|
|
info['final_alt'] = "NO_TABLE image" |
|
|
md_content = md_content.replace(f"", f"![{info['final_alt']}]({info['relative_path']})") |
|
|
md_content = self._process_table_images_in_markdown(md_content) |
|
|
final_lines = [] |
|
|
for line in md_content.split("\n"): |
|
|
if re.match(r"^\!\[.*\]\(.*\)", line.strip()): |
|
|
final_lines.append(line.strip()) |
|
|
return "\n".join(final_lines) |
|
|
|
|
|
def _process_table_images_in_markdown(self, md_content: str) -> str: |
|
|
pat = r"!\[HAS TO BE PROCESSED - (two|three) column table\]\(([^)]+)\)" |
|
|
matches = re.findall(pat, md_content, flags=re.IGNORECASE) |
|
|
if not matches: |
|
|
return md_content |
|
|
for (col_type, image_id) in matches: |
|
|
logger.info(f"Processing table image => {image_id}, columns={col_type}") |
|
|
temp_path = os.path.join(self.output_folder, image_id) |
|
|
desc_item = None |
|
|
for k, val in self.descriptions.items(): |
|
|
if val["relative_path"] == image_id: |
|
|
desc_item = val |
|
|
break |
|
|
if not desc_item: |
|
|
logger.warning(f"No matching image data for {image_id}, skipping extraction.") |
|
|
continue |
|
|
if not os.path.exists(temp_path): |
|
|
with open(temp_path, "wb") as f: |
|
|
f.write(desc_item["data"]) |
|
|
try: |
|
|
if col_type.lower() == 'two': |
|
|
extractor = TableExtractor( |
|
|
skip_header=True, |
|
|
merge_two_col_rows=True, |
|
|
enable_subtopic_merge=True, |
|
|
subtopic_threshold=0.2 |
|
|
) |
|
|
else: |
|
|
extractor = TableExtractor( |
|
|
skip_header=True, |
|
|
merge_two_col_rows=False, |
|
|
enable_subtopic_merge=False, |
|
|
subtopic_threshold=0.2 |
|
|
) |
|
|
row_boxes = extractor.process_image(temp_path) |
|
|
out_folder = temp_path + "_rows" |
|
|
os.makedirs(out_folder, exist_ok=True) |
|
|
extractor.save_extracted_cells(temp_path, row_boxes, out_folder) |
|
|
|
|
|
extracted_cells = [] |
|
|
for root, dirs, files in os.walk(out_folder): |
|
|
for file in files: |
|
|
rel_path = os.path.relpath(os.path.join(root, file), self.output_folder) |
|
|
extracted_cells.append(rel_path) |
|
|
|
|
|
self.extracted_tables[image_id] = extracted_cells |
|
|
snippet = ["**Extracted table cells:**"] |
|
|
for i, row in enumerate(row_boxes): |
|
|
row_dir = os.path.join(out_folder, f"row_{i}") |
|
|
for j, _ in enumerate(row): |
|
|
cell_file = f"col_{j}.jpg" |
|
|
cell_path = os.path.join(row_dir, cell_file) |
|
|
relp = os.path.relpath(cell_path, self.output_folder) |
|
|
snippet.append(f"") |
|
|
new_snip = "\n".join(snippet) |
|
|
old_line = f"" |
|
|
md_content = md_content.replace(old_line, new_snip) |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing table image {image_id}: {e}") |
|
|
finally: |
|
|
if os.path.exists(temp_path): |
|
|
os.remove(temp_path) |
|
|
return md_content |
|
|
|
|
|
def post_process(self, key: str, md_content: str) -> str: |
|
|
return asyncio.run(self.post_process_async(key, md_content)) |
|
|
|
|
|
class GeminiTopicExtractor: |
|
|
def __init__(self, api_key: str = None, num_pages: int = 14): |
|
|
self.api_key = api_key or os.getenv("GEMINI_API_KEY", "") |
|
|
self.num_pages = num_pages |
|
|
|
|
|
def extract_subtopics(self, pdf_path: str) -> Dict[str, List[int]]: |
|
|
first_pages_text = self._read_first_pages_raw(pdf_path, self.num_pages) |
|
|
if not first_pages_text.strip(): |
|
|
logger.error("No text from first pages => cannot extract subtopics.") |
|
|
return {} |
|
|
prompt = f""" |
|
|
You have the first pages of a PDF specification, including a table of contents. |
|
|
Instructions: |
|
|
1. Identify the 'Contents' section listing all topics, subtopics, and their corresponding pages. |
|
|
2. Identify the major academic subtopics (common desired topic names "Paper X", "Theme X", "Content of X", "AS Unit X", "A2 Unit X", or similar headings). |
|
|
3. For each subtopic, give the range of pages [start_page, end_page] (1-based) from the table of contents. |
|
|
4. Output only valid JSON of the form: |
|
|
{{ |
|
|
"Subtopic A": [start_page, end_page], |
|
|
"Subtopic B": [start_page, end_page] |
|
|
}} |
|
|
5. If you can't find any subtopics, return an empty JSON. |
|
|
Important notes: |
|
|
- The correct "end_page" must be the page number of the next topic or subtopic minus 1. |
|
|
- The final output must be valid JSON only, with no extra text or code blocks. |
|
|
Examples: |
|
|
1. Given this table of contents: |
|
|
1 Introduction – 2 |
|
|
Why choose Edexcel A Level Mathematics? - 2 |
|
|
Supporting you in planning and implementing this qualification - 3 |
|
|
Qualification at a glance - 5 |
|
|
2 Subject content and assessment information – 7 |
|
|
Paper 1 and Paper 2: Pure Mathematics - 11 |
|
|
Paper 3: Statistics and Mechanics - 30 |
|
|
Assessment Objectives - 40 |
|
|
3 Administration and general information – 42 |
|
|
Entries - 42 |
|
|
Access arrangements, reasonable adjustments, special consideration and malpractice - 42 |
|
|
Student recruitment and progression - 45 |
|
|
Appendix 1: Formulae – 49 |
|
|
Appendix 2: Notation – 53 |
|
|
Appendix 3: Use of calculators – 59 |
|
|
Appendix 4: Assessment Objectives – 60 |
|
|
Appendix 5: The context for the development of this qualification – 62 |
|
|
Appendix 6: Transferable skills – 64 |
|
|
Appendix 7: Level 3 Extended Project qualification – 65 |
|
|
Appendix 8: Codes – 67 |
|
|
The correct output should be: |
|
|
{{ |
|
|
"Paper 1 and Paper 2: Pure Mathematics": [11, 29], |
|
|
"Paper 3: Statistics and Mechanics": [30, 42] |
|
|
}} |
|
|
2. Given this table of contents: |
|
|
Qualification at a glance – 1 |
|
|
Assessment Objectives and weightings - 4 |
|
|
Knowledge, skills and understanding – 5 |
|
|
Theme 1: Introduction to markets and market failure - 5 |
|
|
Theme 2: The UK economy – performance and policies - 11 |
|
|
Theme 3: Business behaviour and the labour market - 21 |
|
|
Theme 4: A global perspective - 29 |
|
|
Assessment – 39 |
|
|
Assessment summary - 39 |
|
|
Assessment objectives - 41 |
|
|
Assessment overview - 42 |
|
|
Breakdown of assessment objectives - 42 |
|
|
Synoptic assessment - 43 |
|
|
Discount code and performance tables - 43 |
|
|
Access arrangements, reasonable adjustments and special consideration - 44 |
|
|
Malpractice - 45 |
|
|
Equality Act 2010 and Pearson equality policy - 45 |
|
|
Synoptic assessment - 46 |
|
|
Awarding and reporting - 47 |
|
|
Other information – 49 |
|
|
Student recruitment -49 |
|
|
Prior learning and other requirements -49 |
|
|
Progression - 49 |
|
|
Appendix 1: Transferable skills – 53 |
|
|
Appendix 2: Level 3 Extended Project qualification – 55 |
|
|
Appendix 3: Quantitative skills – 59 |
|
|
Appendix 4: Codes – 61 |
|
|
Appendix 5: Index – 63 |
|
|
The correct output should be: |
|
|
{{ |
|
|
"Theme 1: Introduction to markets and market failure": [5, 10], |
|
|
"Theme 2: The UK economy – performance and policies": [11, 20], |
|
|
"Theme 3: Business behaviour and the labour market": [21, 28], |
|
|
"Theme 4: A global perspective": [29, 38] |
|
|
}} |
|
|
3. You might also see sections like: |
|
|
2.1 AS Unit 1 11 |
|
|
2.2 AS Unit 2 18 |
|
|
2.3 A2 Unit 3 24 |
|
|
2.4 A2 Unit 4 31 |
|
|
In that scenario, your output might look like: |
|
|
{{ |
|
|
"2.1 AS Unit 1": [11, 17], |
|
|
"2.2 AS Unit 2": [18, 23], |
|
|
"2.3 A2 Unit 3": [24, 30], |
|
|
"2.4 A2 Unit 4": [31, 35] |
|
|
}} |
|
|
4. Another example might list subtopics: |
|
|
3.1 Overarching themes 11 |
|
|
3.2 A: Proof 12 |
|
|
3.3 B: Algebra and functions 13 |
|
|
3.4 C: Coordinate geometry in the ( x , y ) plane 14 |
|
|
3.5 D: Sequences and series 15 |
|
|
3.6 E: Trigonometry 16 |
|
|
3.7 F: Exponentials and logarithms 17 |
|
|
3.8 G: Differentiation 18 |
|
|
3.9 H: Integration 19 |
|
|
3.10 I: Numerical methods 20 |
|
|
3.11 J: Vectors 20 |
|
|
3.12 K: Statistical sampling 21 |
|
|
3.13 L: Data presentation and interpretation 21 |
|
|
3.14 M: Probability 22 |
|
|
3.15 N: Statistical distributions 23 |
|
|
3.16 O: Statistical hypothesis testing 23 |
|
|
3.17 P: Quantities and units in mechanics 24 |
|
|
3.18 Q: Kinematics 24 |
|
|
3.19 R: Forces and Newton’s laws 24 |
|
|
3.20 S: Moments 25 |
|
|
3.21 Use of data in statistics 26 |
|
|
Here the correct output might look like: |
|
|
{{ |
|
|
"A: Proof": [12, 12], |
|
|
"B: Algebra and functions": [13, 13], |
|
|
... |
|
|
}} |
|
|
Now, extract topics from this text: |
|
|
{first_pages_text} |
|
|
""" |
|
|
global _GEMINI_CLIENT |
|
|
if _GEMINI_CLIENT is None: |
|
|
_GEMINI_CLIENT = genai.Client(api_key=self.api_key) |
|
|
client = _GEMINI_CLIENT |
|
|
try: |
|
|
response = client.models.generate_content( |
|
|
model="gemini-2.0-flash", |
|
|
contents=[prompt], |
|
|
config=types.GenerateContentConfig(temperature=0.0) |
|
|
) |
|
|
if not response or not response.text: |
|
|
logger.warning("No text from LLM => returning empty subtopics.") |
|
|
return {} |
|
|
raw_json = response.text.strip() |
|
|
cleaned = raw_json.replace("```json", "").replace("```", "") |
|
|
try: |
|
|
data = json.loads(cleaned) |
|
|
except Exception as json_err: |
|
|
logger.error(f"JSON parsing error: {json_err}") |
|
|
return {} |
|
|
final_dict = {} |
|
|
found_sub_dict = None |
|
|
for k, v in data.items(): |
|
|
if isinstance(v, dict): |
|
|
found_sub_dict = v |
|
|
break |
|
|
if found_sub_dict is not None: |
|
|
for subk, rng in found_sub_dict.items(): |
|
|
if isinstance(rng, list) and len(rng) == 2: |
|
|
final_dict[subk] = rng |
|
|
else: |
|
|
for subk, rng in data.items(): |
|
|
if isinstance(rng, list) and len(rng) == 2: |
|
|
final_dict[subk] = rng |
|
|
return final_dict |
|
|
except Exception as e: |
|
|
logger.error(f"Gemini subtopic extraction error: {e}") |
|
|
return {} |
|
|
|
|
|
def _read_first_pages_raw(self, pdf_path: str, num_pages: int) -> str: |
|
|
text_parts = [] |
|
|
try: |
|
|
if pdf_path.startswith("http://") or pdf_path.startswith("https://"): |
|
|
response = requests.get(pdf_path) |
|
|
if response.status_code != 200: |
|
|
logger.error("Failed to download PDF from %s. Status code: %d", pdf_path, response.status_code) |
|
|
return "" |
|
|
pdf_bytes = response.content |
|
|
else: |
|
|
with open(pdf_path, "rb") as f: |
|
|
pdf_bytes = f.read() |
|
|
doc = fitz.open(stream=pdf_bytes, filetype="pdf") |
|
|
pages_to_read = min(num_pages, doc.page_count) |
|
|
for i in range(pages_to_read): |
|
|
raw_text = doc[i].get_text("raw") |
|
|
text_parts.append(raw_text) |
|
|
doc.close() |
|
|
except Exception as e: |
|
|
logger.error(f"Could not open PDF: {e}") |
|
|
return "\n".join(text_parts) |
|
|
|
|
|
class MineruNoTextProcessor: |
|
|
def __init__(self, output_folder: str, gemini_api_key: str): |
|
|
self.output_folder = output_folder |
|
|
os.makedirs(self.output_folder, exist_ok=True) |
|
|
self.layout_model = "doclayout_yolo" |
|
|
self.formula_enable = True |
|
|
self.table_enable = False |
|
|
self.language = "en" |
|
|
|
|
|
self.subtopic_extractor = GeminiTopicExtractor(api_key=gemini_api_key, num_pages=20) |
|
|
self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "") |
|
|
|
|
|
self.use_s3 = True |
|
|
self.s3_writer = s3Writer( |
|
|
ak=os.getenv("S3_ACCESS_KEY"), |
|
|
sk=os.getenv("S3_SECRET_KEY"), |
|
|
bucket="quextro-resources", |
|
|
endpoint_url=os.getenv("S3_ENDPOINT") |
|
|
) |
|
|
|
|
|
def cleanup_gpu(self): |
|
|
try: |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
logger.info("GPU memory cleaned up.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during GPU cleanup: {e}") |
|
|
|
|
|
def process(self, pdf_path: str) -> Dict[str, Any]: |
|
|
logger.info(f"Processing PDF: {pdf_path}") |
|
|
try: |
|
|
|
|
|
subtopics = self.subtopic_extractor.extract_subtopics(pdf_path) |
|
|
logger.info(f"Gemini returned subtopics: {subtopics}") |
|
|
|
|
|
if pdf_path.startswith("http://") or pdf_path.startswith("https://"): |
|
|
response = requests.get(pdf_path) |
|
|
if response.status_code != 200: |
|
|
logger.error("Failed to download PDF from %s. Status code: %d", pdf_path, response.status_code) |
|
|
raise Exception(f"Failed to download PDF: {pdf_path}") |
|
|
pdf_bytes = response.content |
|
|
logger.info("Downloaded %d bytes for pdf_url='%s'", len(pdf_bytes), pdf_path) |
|
|
else: |
|
|
with open(pdf_path, "rb") as f: |
|
|
pdf_bytes = f.read() |
|
|
logger.info("Loaded %d bytes from local file '%s'", len(pdf_bytes), pdf_path) |
|
|
|
|
|
doc = fitz.open(stream=pdf_bytes, filetype="pdf") |
|
|
total_pages = doc.page_count |
|
|
doc.close() |
|
|
|
|
|
|
|
|
final_pages = set() |
|
|
if not subtopics: |
|
|
|
|
|
final_pages = set(range(total_pages)) |
|
|
else: |
|
|
offset_candidates = [] |
|
|
for subname, rng in subtopics.items(): |
|
|
start_p, _ = rng |
|
|
occs = find_all_occurrences(pdf_bytes, subname) |
|
|
for p in occs: |
|
|
candidate = p - (start_p - 1) |
|
|
if candidate > 0: |
|
|
offset_candidates.append(candidate) |
|
|
if offset_candidates: |
|
|
try: |
|
|
from statistics import mode |
|
|
global_offset = mode(offset_candidates) |
|
|
except: |
|
|
from statistics import median |
|
|
global_offset = int(median(offset_candidates)) |
|
|
else: |
|
|
global_offset = 0 |
|
|
|
|
|
logger.info(f"Computed global offset: {global_offset}") |
|
|
for subname, rng in subtopics.items(): |
|
|
if not (isinstance(rng, list) and len(rng) == 2): |
|
|
continue |
|
|
start_p, end_p = rng |
|
|
if start_p > end_p: |
|
|
continue |
|
|
s0 = (start_p - 1) + global_offset |
|
|
e0 = (end_p - 1) + global_offset |
|
|
for pp in range(s0, e0 + 1): |
|
|
final_pages.add(pp) |
|
|
|
|
|
if not final_pages: |
|
|
final_pages = set(range(total_pages)) |
|
|
|
|
|
logger.info(f"Processing pages (0-based): {sorted(final_pages)}") |
|
|
subset_pdf_bytes = create_subset_pdf(pdf_bytes, sorted(final_pages)) |
|
|
|
|
|
|
|
|
dataset = PymuDocDataset(subset_pdf_bytes) |
|
|
inference = doc_analyze( |
|
|
dataset, |
|
|
ocr=True, |
|
|
lang=self.language, |
|
|
layout_model=self.layout_model, |
|
|
formula_enable=self.formula_enable, |
|
|
table_enable=self.table_enable |
|
|
) |
|
|
|
|
|
writer = S3ImageWriter(self.s3_writer, "/topic-extraction", self.gemini_api_key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
md_prefix = "/topic-extraction/" |
|
|
pipe_result = inference.pipe_ocr_mode(writer, lang=self.language) |
|
|
md_content = pipe_result.get_markdown(md_prefix) |
|
|
final_markdown = writer.post_process(md_prefix, md_content) |
|
|
|
|
|
subtopic_list = list(writer.extracted_subtopics.values()) |
|
|
|
|
|
out_path = os.path.join(self.output_folder, "final_subtopics.json") |
|
|
with open(out_path, "w", encoding="utf-8") as f: |
|
|
json.dump(subtopic_list, f, indent=2) |
|
|
logger.info(f"Final subtopics JSON saved locally at {out_path}") |
|
|
|
|
|
return { |
|
|
"final_markdown": final_markdown, |
|
|
"subtopics_extracted": subtopic_list |
|
|
} |
|
|
finally: |
|
|
self.cleanup_gpu() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
input_pdf = "/home/user/app/input_output/a-level-pearson-mathematics-specification.pdf" |
|
|
output_dir = "/home/user/app/pearson_json" |
|
|
gemini_key = os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU") |
|
|
try: |
|
|
processor = MineruNoTextProcessor(output_folder=output_dir, gemini_api_key=gemini_key) |
|
|
result = processor.process(input_pdf) |
|
|
logger.info("Processing completed successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Processing failed: {e}") |