import gradio as gr import torch import numpy as np from PIL import Image from torch import nn from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor import os # --- 1. CONFIGURACIÓN --- MODEL_PATH = "modelo_mejorado.pth" LABELS = [ "fondo", "wheat leaf rust", "wheat powdery mildew", "wheat septoria blotch", "wheat stem rust", "wheat stripe rust" ] # Paleta (R, G, B) PALETA_COLORES = [ [0, 0, 0], [220, 38, 38], [22, 163, 74], [37, 99, 235], [234, 179, 8], [219, 39, 119] ] # En Docker CPU (Free Tier) forzamos CPU para evitar errores de memoria device = torch.device("cpu") print(f"Usando dispositivo: {device}") # --- 2. CARGAR MODELO --- checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512" try: model_inference = SegformerForSemanticSegmentation.from_pretrained( checkpoint_name, num_labels=len(LABELS), id2label={i: label for i, label in enumerate(LABELS)}, label2id={label: i for i, label in enumerate(LABELS)}, ignore_mismatched_sizes=True ) # Cargar pesos if os.path.exists(MODEL_PATH): state_dict = torch.load(MODEL_PATH, map_location=device) model_inference.load_state_dict(state_dict) print("✅ Modelo cargado correctamente.") else: print("⚠️ NO se encontró el archivo .pth") model_inference.to(device) model_inference.eval() image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name) image_processor.do_resize = False image_processor.do_rescale = True except Exception as e: print(f"Error fatal cargando modelo: {e}") # --- 3. PREDICCIÓN --- def predecir_enfermedad(image): if image is None: return None, "⚠️ Sube una imagen." original_size = image.size img_resized = image.resize((512, 512)) inputs = image_processor(images=img_resized, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model_inference(**inputs) logits = outputs.logits logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False) pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy() color_mask = np.zeros((512, 512, 3), dtype=np.uint8) classes_found = [] unique_classes = np.unique(pred_mask) for class_id in unique_classes: if class_id == 0: continue classes_found.append(LABELS[class_id]) color_mask[pred_mask == class_id] = PALETA_COLORES[class_id] mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST) final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45) if len(classes_found) > 0: diagnosis = "Enfermedades:\n" + "\n".join(f"- {c}" for c in classes_found) else: diagnosis = "Planta Sana (Solo fondo)." return final_image, diagnosis # --- 4. INTERFAZ --- # Usamos Blocks simples compatibles con versiones estables with gr.Blocks(title="Wheat AI") as demo: gr.Markdown("# 🌾 Detector de Enfermedades en Trigo") with gr.Row(): img_in = gr.Image(type="pil", label="Imagen") img_out = gr.Image(type="pil", label="Resultado") txt_out = gr.Textbox(label="Diagnóstico") btn = gr.Button("Analizar", variant="primary") btn.click(predecir_enfermedad, img_in, [img_out, txt_out]) # Configuración crítica para Docker en Hugging Face if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)