juancho2112 commited on
Commit
fbff75c
·
verified ·
1 Parent(s): 5678216

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -107
app.py DELETED
@@ -1,107 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
- from torch import nn
6
- from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
7
- import os
8
-
9
- # --- 1. CONFIGURACIÓN ---
10
- MODEL_PATH = "modelo_mejorado.pth"
11
-
12
- LABELS = [
13
- "fondo", "wheat leaf rust", "wheat powdery mildew",
14
- "wheat septoria blotch", "wheat stem rust", "wheat stripe rust"
15
- ]
16
-
17
- # Paleta (R, G, B)
18
- PALETA_COLORES = [
19
- [0, 0, 0], [220, 38, 38], [22, 163, 74],
20
- [37, 99, 235], [234, 179, 8], [219, 39, 119]
21
- ]
22
-
23
- # En Docker CPU (Free Tier) forzamos CPU para evitar errores de memoria
24
- device = torch.device("cpu")
25
- print(f"Usando dispositivo: {device}")
26
-
27
- # --- 2. CARGAR MODELO ---
28
- checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512"
29
-
30
- try:
31
- model_inference = SegformerForSemanticSegmentation.from_pretrained(
32
- checkpoint_name,
33
- num_labels=len(LABELS),
34
- id2label={i: label for i, label in enumerate(LABELS)},
35
- label2id={label: i for i, label in enumerate(LABELS)},
36
- ignore_mismatched_sizes=True
37
- )
38
- # Cargar pesos
39
- if os.path.exists(MODEL_PATH):
40
- state_dict = torch.load(MODEL_PATH, map_location=device)
41
- model_inference.load_state_dict(state_dict)
42
- print("✅ Modelo cargado correctamente.")
43
- else:
44
- print("⚠️ NO se encontró el archivo .pth")
45
-
46
- model_inference.to(device)
47
- model_inference.eval()
48
-
49
- image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name)
50
- image_processor.do_resize = False
51
- image_processor.do_rescale = True
52
-
53
- except Exception as e:
54
- print(f"Error fatal cargando modelo: {e}")
55
-
56
- # --- 3. PREDICCIÓN ---
57
- def predecir_enfermedad(image):
58
- if image is None: return None, "⚠️ Sube una imagen."
59
-
60
- original_size = image.size
61
- img_resized = image.resize((512, 512))
62
-
63
- inputs = image_processor(images=img_resized, return_tensors="pt")
64
- inputs = {k: v.to(device) for k, v in inputs.items()}
65
-
66
- with torch.no_grad():
67
- outputs = model_inference(**inputs)
68
- logits = outputs.logits
69
- logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)
70
- pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy()
71
-
72
- color_mask = np.zeros((512, 512, 3), dtype=np.uint8)
73
- classes_found = []
74
- unique_classes = np.unique(pred_mask)
75
-
76
- for class_id in unique_classes:
77
- if class_id == 0: continue
78
- classes_found.append(LABELS[class_id])
79
- color_mask[pred_mask == class_id] = PALETA_COLORES[class_id]
80
-
81
- mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST)
82
- final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45)
83
-
84
- if len(classes_found) > 0:
85
- diagnosis = "Enfermedades:\n" + "\n".join(f"- {c}" for c in classes_found)
86
- else:
87
- diagnosis = "Planta Sana (Solo fondo)."
88
-
89
- return final_image, diagnosis
90
-
91
- # --- 4. INTERFAZ ---
92
- # Usamos Blocks simples compatibles con versiones estables
93
- with gr.Blocks(title="Wheat AI") as demo:
94
- gr.Markdown("# 🌾 Detector de Enfermedades en Trigo")
95
-
96
- with gr.Row():
97
- img_in = gr.Image(type="pil", label="Imagen")
98
- img_out = gr.Image(type="pil", label="Resultado")
99
-
100
- txt_out = gr.Textbox(label="Diagnóstico")
101
- btn = gr.Button("Analizar", variant="primary")
102
-
103
- btn.click(predecir_enfermedad, img_in, [img_out, txt_out])
104
-
105
- # Configuración crítica para Docker en Hugging Face
106
- if __name__ == "__main__":
107
- demo.launch(server_name="0.0.0.0", server_port=7860)