MinerU / table_row_extraction.py
SkyNait's picture
test
8966134
raw
history blame
17 kB
import cv2
import numpy as np
import logging
from pathlib import Path
from typing import List, Tuple
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# if you are working with 3-column tables, change `merge_two_col_rows` and `enable_subtopic_merge` to False
# otherwise set them to True if you are working with 2-column tables (currently hardcoded, just test)
class TableExtractor:
def __init__(
self,
#preprocessing parameters
denoise_h: int = 10,
clahe_clip: float = 3.0,
clahe_grid: int = 8,
sharpen_kernel: np.ndarray = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]]),
thresh_block_size: int = 21,
thresh_C: int = 7,
# Row detection parameters
horizontal_scale: int = 20,
row_morph_iterations: int = 2,
min_row_height: int = 30,
min_row_density: float = 0.01,
# Column detection parameters
vertical_scale: int = 20,
col_morph_iterations: int = 2,
min_col_height_ratio: float = 0.5,
min_col_density: float = 0.01,
# Bounding box extraction
padding: int = 0,
skip_header: bool = True,
# Two-column & subtopic merges
merge_two_col_rows: bool = False,
enable_subtopic_merge: bool = False,
subtopic_threshold: float = 0.2,
#gray artifact filter
std_threshold_for_artifacts: float = 5.0,
#parameters for line removal check
line_removal_scale: int = 15,
line_removal_iterations: int = 1,
min_text_ratio_after_line_removal: float = 0.001
):
"""
:param merge_two_col_rows: If True, a row with exactly 1 vertical line => merges into 1 bounding box.
:param enable_subtopic_merge: If True, a row with 2 vertical lines => 3 columns can become 2 if left is narrow.
:param subtopic_threshold: Fraction of row width for subtopic detection.
:param std_threshold_for_artifacts: Grayscale std dev < this => skip as artifact.
:param line_removal_scale: Larger => more aggressive line detection inside the cell.
:param line_removal_iterations: Morphological iterations for line removal.
:param min_text_ratio_after_line_removal: If fraction of text after removing lines < this => skip cell.
"""
# Preprocessing
self.denoise_h = denoise_h
self.clahe_clip = clahe_clip
self.clahe_grid = clahe_grid
self.sharpen_kernel = sharpen_kernel
self.thresh_block_size = thresh_block_size
self.thresh_C = thresh_C
# Row detection
self.horizontal_scale = horizontal_scale
self.row_morph_iterations = row_morph_iterations
self.min_row_height = min_row_height
self.min_row_density = min_row_density
# Column detection
self.vertical_scale = vertical_scale
self.col_morph_iterations = col_morph_iterations
self.min_col_height_ratio = min_col_height_ratio
self.min_col_density = min_col_density
# Bbox extraction
self.padding = padding
self.skip_header = skip_header
# Two-column / subtopic merges
self.merge_two_col_rows = merge_two_col_rows
self.enable_subtopic_merge = enable_subtopic_merge
self.subtopic_threshold = subtopic_threshold
#artifact filtering (gray headers, purple, etc) / currenty not working well
self.std_threshold_for_artifacts = std_threshold_for_artifacts
#line removal inside cell
self.line_removal_scale = line_removal_scale
self.line_removal_iterations = line_removal_iterations
self.min_text_ratio_after_line_removal = min_text_ratio_after_line_removal
def preprocess(self, img: np.ndarray) -> np.ndarray:
"""Grayscale, denoise, CLAHE, sharpen, adaptive threshold (binary_inv)."""
if img.ndim == 3:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
gray = img.copy()
denoised = cv2.fastNlMeansDenoising(gray, h=self.denoise_h)
clahe = cv2.createCLAHE(clipLimit=self.clahe_clip, tileGridSize=(self.clahe_grid, self.clahe_grid))
enhanced = clahe.apply(denoised)
sharpened = cv2.filter2D(enhanced, -1, self.sharpen_kernel)
binarized = cv2.adaptiveThreshold(
sharpened, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
self.thresh_block_size,
self.thresh_C
)
return binarized
def detect_full_rows(self, bin_img: np.ndarray) -> List[Tuple[int, int]]:
"""Find horizontal row boundaries in the binarized image."""
h_kernel_size = max(1, bin_img.shape[1] // self.horizontal_scale)
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (h_kernel_size, 1))
horizontal_lines = cv2.morphologyEx(bin_img, cv2.MORPH_OPEN, horizontal_kernel,
iterations=self.row_morph_iterations)
row_projection = np.sum(horizontal_lines, axis=1)
max_val = np.max(row_projection) if len(row_projection) else 0
# If no lines, treat entire image as one row (opt)
if max_val < 1e-5:
return [(0, bin_img.shape[0])]
threshold_val = 0.3 * max_val
line_indices = np.where(row_projection > threshold_val)[0]
if len(line_indices) < 2:
return [(0, bin_img.shape[0])]
# Group consecutive indices
lines = []
current = [line_indices[0]]
for i in range(1, len(line_indices)):
if line_indices[i] - line_indices[i - 1] <= 2:
current.append(line_indices[i])
else:
lines.append(int(np.mean(current)))
current = [line_indices[i]]
if current:
lines.append(int(np.mean(current)))
row_bounds = []
for i in range(len(lines) - 1):
y1 = lines[i]
y2 = lines[i + 1]
if (y2 - y1) >= self.min_row_height:
row_bounds.append((y1, y2))
return row_bounds if row_bounds else [(0, bin_img.shape[0])]
def detect_columns_in_row(self, row_img: np.ndarray, y1: int, y2: int) -> List[Tuple[int, int, int, int]]:
"""
Detect up to two vertical lines => up to 3 bounding boxes.
- 0 lines => 1 bounding box
- 1 line => 2 bounding boxes (unless merge_two_col_rows => 1)
- 2 lines => 3 bounding boxes by default
if enable_subtopic_merge => check left box < subtopic_threshold => 2 boxes
"""
row_height = (y2 - y1)
row_width = row_img.shape[1]
# Morph kernel for vertical lines
v_kernel_size = max(1, row_height // self.vertical_scale)
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, v_kernel_size))
vertical_lines = cv2.morphologyEx(row_img, cv2.MORPH_OPEN, vertical_kernel,
iterations=self.col_morph_iterations)
vertical_lines = cv2.dilate(vertical_lines, np.ones((3, 3), np.uint8), iterations=1)
# Find contours => x positions
contours, _ = cv2.findContours(vertical_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x_positions = []
for c in contours:
x, y, w, h = cv2.boundingRect(c)
# Must be at least half the row height to be considered a real column divider
if h >= self.min_col_height_ratio * row_height:
x_positions.append(x)
x_positions = sorted(set(x_positions))
# Keep at most 2 vertical lines
if len(x_positions) > 2:
x_positions = x_positions[:2]
# Build bounding boxes
if len(x_positions) == 0:
# 0 lines => single bounding box
boxes = [(0, y1, row_width, row_height)]
elif len(x_positions) == 1:
# 1 line => 2 bounding boxes by default
x1 = x_positions[0]
if self.merge_two_col_rows:
# Merge => single bounding box
boxes = [(0, y1, row_width, row_height)]
else:
boxes = [
(0, y1, x1, row_height),
(x1, y1, row_width - x1, row_height)
]
else:
# 2 lines => normally 3 bounding boxes
x1, x2 = sorted(x_positions)
if self.enable_subtopic_merge:
# If left bounding box is very narrow => treat as subtopic => 2 bounding boxes
left_box_width = x1
if left_box_width < (self.subtopic_threshold * row_width):
boxes = [
(0, y1, x1, row_height),
(x1, y1, row_width - x1, row_height)
]
else:
boxes = [
(0, y1, x1, row_height),
(x1, y1, x2 - x1, row_height),
(x2, y1, row_width - x2, row_height)
]
else:
boxes = [
(0, y1, x1, row_height),
(x1, y1, x2 - x1, row_height),
(x2, y1, row_width - x2, row_height)
]
# Filter out columns with insufficient density
filtered = []
for (x, y, w, h) in boxes:
if w <= 0:
continue
subregion = row_img[:, x : x + w]
white_pixels = np.sum(subregion == 255)
total_pixels = subregion.size
if total_pixels == 0:
continue
density = white_pixels / total_pixels
if density >= self.min_col_density:
filtered.append((x, y, w, h))
return filtered
def process_image(self, image_path: str) -> List[List[Tuple[int, int, int, int]]]:
"""
1) Preprocess => bin_img
2) Detect row segments
3) Filter out rows by density
- optionally skip first row (header)
5) For each row => detect columns => bounding boxes
"""
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Could not read image: {image_path}")
bin_img = self.preprocess(img)
row_segments = self.detect_full_rows(bin_img)
# Filter out rows with insufficient density
valid_rows = []
for (y1, y2) in row_segments:
row_region = bin_img[y1:y2, :]
area = row_region.size
if area == 0:
continue
white_pixels = np.sum(row_region == 255)
density = white_pixels / area
if density >= self.min_row_density:
valid_rows.append((y1, y2))
# Possibly skip header row
if self.skip_header and len(valid_rows) > 1:
valid_rows = valid_rows[1:]
# Detect columns in each row
all_rows_boxes = []
for (y1, y2) in valid_rows:
row_img = bin_img[y1:y2, :]
col_boxes = self.detect_columns_in_row(row_img, y1, y2)
if col_boxes:
all_rows_boxes.append(col_boxes)
return all_rows_boxes
def extract_box_image(self, original: np.ndarray, box: Tuple[int, int, int, int]) -> np.ndarray:
"""Crop bounding box from original with optional padding."""
x, y, w, h = box
Y1 = max(0, y - self.padding)
Y2 = min(original.shape[0], y + h + self.padding)
X1 = max(0, x - self.padding)
X2 = min(original.shape[1], x + w + self.padding)
return original[Y1:Y2, X1:X2]
def _remove_lines_in_cell(self, gray_bin: np.ndarray) -> np.ndarray:
"""
Remove horizontal + vertical lines from a binarized subregion
and return the 'text-only' mask.
"""
# 1) horizontal line detection
horiz_kernel_size = max(1, gray_bin.shape[1] // self.line_removal_scale)
horiz_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horiz_kernel_size, 1))
horizontal = cv2.morphologyEx(gray_bin, cv2.MORPH_OPEN, horiz_kernel, iterations=self.line_removal_iterations)
# 2) vertical line detection
vert_kernel_size = max(1, gray_bin.shape[0] // self.line_removal_scale)
vert_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_kernel_size))
vertical = cv2.morphologyEx(gray_bin, cv2.MORPH_OPEN, vert_kernel, iterations=self.line_removal_iterations)
# Combine lines
lines = cv2.bitwise_or(horizontal, vertical)
# Subtract from the original => text-only
text_only = cv2.bitwise_and(gray_bin, cv2.bitwise_not(lines))
return text_only
def is_grey_artifact(self, cell_img: np.ndarray) -> bool:
"""
1) If grayscale std dev < std_threshold_for_artifacts => skip as uniform.
2) Otherwise, remove lines from an Otsu-binarized version of the cell
and check if there's enough text left. If not, skip as artifact.
"""
if cell_img.size == 0:
return True
gray = cv2.cvtColor(cell_img, cv2.COLOR_BGR2GRAY)
std_val = np.std(gray)
if std_val < self.std_threshold_for_artifacts:
return True
# 2) Binarize => remove lines => check leftover text
# Use Otsu threshold for the local cell
_, cell_bin = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
text_only = self._remove_lines_in_cell(cell_bin)
nonzero_text = cv2.countNonZero(text_only)
ratio = nonzero_text / float(cell_bin.size)
if ratio < self.min_text_ratio_after_line_removal:
# Hardly any text remains => artifact
return True
return False
def save_extracted_cells(
self, image_path: str, row_boxes: List[List[Tuple[int, int, int, int]]], output_dir: str
):
"""Save each cell from the original image, skipping uniform/gray artifacts."""
out_path = Path(output_dir)
out_path.mkdir(exist_ok=True, parents=True)
original = cv2.imread(image_path)
if original is None:
raise ValueError(f"Could not read original image: {image_path}")
for i, row in enumerate(row_boxes):
row_dir = out_path / f"row_{i}"
row_dir.mkdir(exist_ok=True)
for j, box in enumerate(row):
cell_img = self.extract_box_image(original, box)
# Skip if uniform or line-only artifact
if self.is_grey_artifact(cell_img):
logger.info(f"Skipping artifact cell at row={i}, col={j}. (uniform/grey/line-only)")
continue
out_file = row_dir / f"col_{j}.png"
cv2.imwrite(str(out_file), cell_img)
logger.info(f"Saved cell image row={i}, col={j} -> {out_file}")
class TableExtractorApp:
def __init__(self, extractor: TableExtractor):
self.extractor = extractor
def run(self, input_image: str, output_folder: str):
row_boxes = self.extractor.process_image(input_image)
logger.info(f"Detected {len(row_boxes)} row(s).")
self.extractor.save_extracted_cells(input_image, row_boxes, output_folder)
logger.info("Done. Check the output folder for results.")
if __name__ == "__main__":
input_image = "images/test/img_2.png"
output_folder = "refined_outp"
extractor = TableExtractor(
denoise_h=10,
clahe_clip=3.0,
clahe_grid=8,
thresh_block_size=21,
thresh_C=7,
horizontal_scale=20,
row_morph_iterations=2,
min_row_height=30,
min_row_density=0.01,
vertical_scale=20,
col_morph_iterations=2,
min_col_height_ratio=0.5,
min_col_density=0.01,
padding=1,
skip_header=True,
merge_two_col_rows=True,
enable_subtopic_merge=True,
subtopic_threshold=0.2,
std_threshold_for_artifacts=10.0,
line_removal_scale=20,
line_removal_iterations=1,
min_text_ratio_after_line_removal=0.001
)
app = TableExtractorApp(extractor)
app.run(input_image, output_folder)