import os
import cv2
import numpy as np
import torch
import albumentations as album
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import warnings
warnings.filterwarnings("ignore")
class Pipeline:
def __init__(self, model_path, device=None):
self.img_size = (384, 288)
self.classes = ['background', 'polyp']
self.class_rgb_values = [
[0, 0, 0],
[255, 0, 0]
]
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Using device: {self.device}")
try:
self.model = torch.load(model_path, map_location=self.device)
self.model.eval()
except Exception as e:
print(f"Failed to load model: {e}")
raise
encoder_name = 'efficientnet-b3'
self.preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder_name, 'imagenet')
def preprocess_image(self, image):
image = cv2.resize(image, self.img_size, interpolation=cv2.INTER_AREA)
preprocessing = self._get_preprocessing(self.preprocessing_fn)
sample = preprocessing(image=image)
image = sample['image']
return torch.from_numpy(image).unsqueeze(0).to(self.device)
def _get_preprocessing(self, preprocessing_fn=None):
_transform = []
if preprocessing_fn:
_transform.append(album.Lambda(image=preprocessing_fn))
_transform.append(album.Lambda(image=self._to_tensor))
return album.Compose(_transform)
def _to_tensor(self, x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def _reverse_one_hot(self, image):
return np.argmax(image, axis=-1)
def _colour_code_segmentation(self, image):
colour_codes = np.array(self.class_rgb_values)
return colour_codes[image.astype(int)]
def predict(self, image):
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = image[:, :, :3]
original_h, original_w = image.shape[:2]
x_tensor = self.preprocess_image(image)
with torch.no_grad():
pred_mask = self.model(x_tensor)
pred_mask = pred_mask.detach().squeeze().cpu().numpy()
pred_mask = np.transpose(pred_mask, (1, 2, 0))
polyp_heatmap = pred_mask[:, :, self.classes.index('polyp')]
binary_mask = (polyp_heatmap > 0.5).astype(np.uint8)
colored_mask = self._colour_code_segmentation(self._reverse_one_hot(pred_mask))
if (original_h, original_w) != self.img_size[::-1]:
binary_mask = cv2.resize(binary_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
colored_mask = cv2.resize(colored_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
polyp_heatmap = cv2.resize(polyp_heatmap, (original_w, original_h), interpolation=cv2.INTER_LINEAR)
return {
'binary_mask': binary_mask,
'colored_mask': colored_mask,
'heatmap': polyp_heatmap
}
def visualize_prediction(self, image, prediction, save_path=None):
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
axs[0].imshow(image)
axs[0].set_title('Original Image')
axs[0].axis('off')
axs[1].imshow(prediction['heatmap'], cmap='jet')
axs[1].set_title('Polyp Probability Heatmap')
axs[1].axis('off')
axs[2].imshow(prediction['binary_mask'], cmap='gray')
axs[2].set_title('Binary Polyp Mask')
axs[2].axis('off')
overlay = image.copy()
colored_mask = prediction['colored_mask']
mask_condition = prediction['binary_mask'] > 0
overlay[mask_condition] = overlay[mask_condition] * 0.5 + colored_mask[mask_condition] * 0.5
axs[3].imshow(overlay)
axs[3].set_title('Polyp Overlay')
axs[3].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
plt.show()
return overlay
def main():
model_path = './best_model.pth'
pipeline = Pipeline(model_path)
image_path = 'test.png'
if os.path.exists(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
prediction = pipeline.predict(image)
pipeline.visualize_prediction(image, prediction, save_path='prediction_result.png')
polyp_percentage = np.mean(prediction['binary_mask']) * 100
print(f"Polyp covers approximately {polyp_percentage:.2f}% of the image")
else:
print(f"Image not found at {image_path}")
if __name__ == "__main__":
main()