import os
import cv2
import numpy as np
import torch
import albumentations as album
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import warnings
warnings.filterwarnings("ignore")

class Pipeline:
    def __init__(self, model_path, device=None):
        self.img_size = (384, 288)
        self.classes = ['background', 'polyp']
        self.class_rgb_values = [
            [0, 0, 0],      
            [255, 0, 0]     
        ]

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
            
        print(f"Using device: {self.device}")
        
        try:
            self.model = torch.load(model_path, map_location=self.device)
            self.model.eval()
        except Exception as e:
            print(f"Failed to load model: {e}")
            raise
        
        encoder_name = 'efficientnet-b3'  
        self.preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder_name, 'imagenet')
        
    def preprocess_image(self, image):
        image = cv2.resize(image, self.img_size, interpolation=cv2.INTER_AREA)
        
        preprocessing = self._get_preprocessing(self.preprocessing_fn)
        sample = preprocessing(image=image)
        image = sample['image']
        
        return torch.from_numpy(image).unsqueeze(0).to(self.device)
    
    def _get_preprocessing(self, preprocessing_fn=None):
        _transform = []
        if preprocessing_fn:
            _transform.append(album.Lambda(image=preprocessing_fn))
        _transform.append(album.Lambda(image=self._to_tensor))
        return album.Compose(_transform)
    
    def _to_tensor(self, x, **kwargs):
        return x.transpose(2, 0, 1).astype('float32')
    
    def _reverse_one_hot(self, image):
        return np.argmax(image, axis=-1)
    
    def _colour_code_segmentation(self, image):
        colour_codes = np.array(self.class_rgb_values)
        return colour_codes[image.astype(int)]
    
    def predict(self, image):
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = image[:, :, :3]  
            
        original_h, original_w = image.shape[:2]
        
        x_tensor = self.preprocess_image(image)
        
        with torch.no_grad():
            pred_mask = self.model(x_tensor)
            pred_mask = pred_mask.detach().squeeze().cpu().numpy()
            pred_mask = np.transpose(pred_mask, (1, 2, 0))
        
        polyp_heatmap = pred_mask[:, :, self.classes.index('polyp')]
        
        binary_mask = (polyp_heatmap > 0.5).astype(np.uint8)

        colored_mask = self._colour_code_segmentation(self._reverse_one_hot(pred_mask))
        
        if (original_h, original_w) != self.img_size[::-1]:
            binary_mask = cv2.resize(binary_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
            colored_mask = cv2.resize(colored_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
            polyp_heatmap = cv2.resize(polyp_heatmap, (original_w, original_h), interpolation=cv2.INTER_LINEAR)
        
        return {
            'binary_mask': binary_mask,
            'colored_mask': colored_mask,
            'heatmap': polyp_heatmap
        }
    
    def visualize_prediction(self, image, prediction, save_path=None):
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))

        axs[0].imshow(image)
        axs[0].set_title('Original Image')
        axs[0].axis('off')
        
        axs[1].imshow(prediction['heatmap'], cmap='jet')
        axs[1].set_title('Polyp Probability Heatmap')
        axs[1].axis('off')
        
        axs[2].imshow(prediction['binary_mask'], cmap='gray')
        axs[2].set_title('Binary Polyp Mask')
        axs[2].axis('off')
        
        overlay = image.copy()
        colored_mask = prediction['colored_mask']
        mask_condition = prediction['binary_mask'] > 0
        overlay[mask_condition] = overlay[mask_condition] * 0.5 + colored_mask[mask_condition] * 0.5
        
        axs[3].imshow(overlay)
        axs[3].set_title('Polyp Overlay')
        axs[3].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Visualization saved to {save_path}")
        
        plt.show()
        
        return overlay

def main():
    model_path = './best_model.pth'  # Path to model
    pipeline = Pipeline(model_path)
    
    image_path = 'test.png' 
    
    if os.path.exists(image_path):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        prediction = pipeline.predict(image)
        
        pipeline.visualize_prediction(image, prediction, save_path='prediction_result.png')
        
        polyp_percentage = np.mean(prediction['binary_mask']) * 100
        print(f"Polyp covers approximately {polyp_percentage:.2f}% of the image")
    else:
        print(f"Image not found at {image_path}")

if __name__ == "__main__":
    main()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support