cesparzaf commited on
Commit
3c116f2
·
1 Parent(s): 3fac025

Descarga CSV y gráfica pro (estilo limpio, año corto, anotación)

Browse files
Files changed (1) hide show
  1. app.py +101 -35
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
@@ -6,15 +7,39 @@ import matplotlib.pyplot as plt
6
  import matplotlib.dates as mdates
7
  from chronos import ChronosPipeline
8
 
9
- # Modelo compatible con chronos-forecasting==1.5.3 (free tier)
10
- MODEL_ID = "amazon/chronos-t5-base"
 
 
11
 
12
  PIPELINE = ChronosPipeline.from_pretrained(
13
  MODEL_ID,
14
  device_map="auto",
15
- dtype=torch.float32, # usar dtype (no torch_dtype)
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def _prepare_series(df: pd.DataFrame, freq: str | None):
19
  """
20
  Espera columnas: date,value
@@ -24,6 +49,7 @@ def _prepare_series(df: pd.DataFrame, freq: str | None):
24
  """
25
  if "date" not in df.columns or "value" not in df.columns:
26
  raise gr.Error("El CSV debe tener columnas: date,value")
 
27
  df = df.copy()
28
  df["date"] = pd.to_datetime(df["date"])
29
  df = df.sort_values("date")
@@ -53,15 +79,54 @@ def _filter_by_sku(df: pd.DataFrame, sku: str | None):
53
  sdf = df[df["sku"].astype(str) == str(sku)].copy()
54
  if sdf.empty:
55
  raise gr.Error(f"No hay datos para el SKU: {sku}")
56
- # Mantén solo las columnas mínimas para el pipeline
57
- sdf = sdf[["date", "value"]]
58
- return sdf
59
  return df[["date", "value"]].copy()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def forecast_fn(file, sku: str, horizon: int = 12, freq: str = "MS"):
62
  """
63
- Función principal: lee CSV, filtra por SKU (si hay), prepara serie,
64
- ejecuta pronóstico y devuelve tabla + gráfica.
65
  """
66
  if file is None:
67
  raise gr.Error("Sube un CSV con columnas: (sku,) date, value")
@@ -93,28 +158,17 @@ def forecast_fn(file, sku: str, horizon: int = 12, freq: str = "MS"):
93
  "p90": np.round(p90, 4),
94
  })
95
 
96
- # Gráfica con AÑO CORTO (yy-mm)
97
- fig = plt.figure(figsize=(8, 4))
98
- ax = plt.gca()
99
- ax.plot(df["date"], df["value"], label="Histórico")
100
- ax.plot(out["date"], out["p50"], label="Pronóstico (P50)")
101
- ax.fill_between(out["date"], out["p10"], out["p90"], alpha=0.3, label="Banda P10–P90")
102
- ax.set_title("Pronóstico con Chronos-T5 (P10 / P50 / P90)")
103
- ax.set_xlabel("Fecha")
104
- ax.set_ylabel("Valor")
105
- # Formato de fechas: año corto + mes (yy-mm), p. ej. 24-07
106
- ax.xaxis.set_major_locator(mdates.AutoDateLocator())
107
- ax.xaxis.set_major_formatter(mdates.DateFormatter("%y-%m"))
108
- plt.xticks(rotation=45)
109
- ax.legend()
110
 
111
- return out, fig
 
112
 
113
  def list_skus(file):
114
- """
115
- Al cargar un archivo, detecta los SKUs disponibles (si la columna existe)
116
- y llena el dropdown dinámicamente.
117
- """
118
  if file is None:
119
  return gr.update(choices=[], value=None), None
120
  df = pd.read_csv(file.name)
@@ -123,28 +177,40 @@ def list_skus(file):
123
  if not skus:
124
  return gr.update(choices=[], value=None), df.head(10)
125
  return gr.update(choices=skus, value=skus[0]), df.head(10)
126
- # Sin columna 'sku'
127
  return gr.update(choices=[], value=None), df.head(10)
128
 
129
  with gr.Blocks(title="Pronóstico por SKU (Chronos-T5)") as demo:
130
- gr.Markdown("## Pronóstico de Demanda por **SKU** (Hugging Face + Chronos-T5)\nCSV con columnas: **sku (opcional)**, **date**, **value**. Selecciona el SKU y genera el forecast.")
 
 
 
 
131
  with gr.Row():
132
  file = gr.File(label="CSV (sku,date,value o date,value)", file_types=[".csv"])
133
  sku_dd = gr.Dropdown(choices=[], value=None, label="SKU (si el CSV tiene columna 'sku')")
134
  horizon = gr.Slider(1, 36, value=12, step=1, label="Horizonte (pasos)")
135
  freq = gr.Dropdown(choices=["", "D", "W", "MS", "M"], value="MS",
136
  label="Frecuencia. ''=inferir, MS=mensual")
137
- # Vista previa de datos
138
  preview = gr.Dataframe(label="Vista previa (primeras filas)")
139
-
140
- # Cuando cambia el archivo, llenar dropdown de SKUs y mostrar preview
141
  file.change(list_skus, inputs=file, outputs=[sku_dd, preview])
142
 
143
- btn = gr.Button("Generar pronóstico")
144
  out_table = gr.Dataframe(label="Tabla de pronóstico")
145
  out_plot = gr.Plot(label="Gráfica")
146
- btn.click(forecast_fn, inputs=[file, sku_dd, horizon, freq], outputs=[out_table, out_plot], api_name="/forecast")
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  if __name__ == "__main__":
149
  demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)
150
-
 
1
+ import io
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
 
7
  import matplotlib.dates as mdates
8
  from chronos import ChronosPipeline
9
 
10
+ # ----------------------------
11
+ # Modelo (ligero para Space free)
12
+ # ----------------------------
13
+ MODEL_ID = "amazon/chronos-t5-small" # o "amazon/chronos-t5-mini"
14
 
15
  PIPELINE = ChronosPipeline.from_pretrained(
16
  MODEL_ID,
17
  device_map="auto",
18
+ dtype=torch.float32, # usar dtype (no torch_dtype)
19
  )
20
 
21
+ # ----------------------------
22
+ # Estilo "más pro" para las gráficas
23
+ # ----------------------------
24
+ plt.rcParams.update({
25
+ "figure.figsize": (9, 4.8),
26
+ "figure.facecolor": "#ffffff",
27
+ "axes.facecolor": "#ffffff",
28
+ "axes.grid": True,
29
+ "grid.color": "#e6e6e6",
30
+ "grid.linestyle": "-",
31
+ "grid.linewidth": 0.6,
32
+ "axes.spines.top": False,
33
+ "axes.spines.right": False,
34
+ "axes.spines.left": False,
35
+ "axes.spines.bottom": False,
36
+ "axes.titlesize": 16,
37
+ "axes.titleweight": "semibold",
38
+ "axes.labelsize": 12,
39
+ "legend.frameon": True,
40
+ "legend.framealpha": 0.9,
41
+ })
42
+
43
  def _prepare_series(df: pd.DataFrame, freq: str | None):
44
  """
45
  Espera columnas: date,value
 
49
  """
50
  if "date" not in df.columns or "value" not in df.columns:
51
  raise gr.Error("El CSV debe tener columnas: date,value")
52
+
53
  df = df.copy()
54
  df["date"] = pd.to_datetime(df["date"])
55
  df = df.sort_values("date")
 
79
  sdf = df[df["sku"].astype(str) == str(sku)].copy()
80
  if sdf.empty:
81
  raise gr.Error(f"No hay datos para el SKU: {sku}")
82
+ return sdf[["date", "value"]]
 
 
83
  return df[["date", "value"]].copy()
84
 
85
+ def _nice_plot(df_hist: pd.DataFrame, df_fc: pd.DataFrame) -> plt.Figure:
86
+ """Grafica con estilo pro + año corto en eje X y anotación del último P50."""
87
+ fig, ax = plt.subplots()
88
+
89
+ # Líneas
90
+ ax.plot(df_hist["date"], df_hist["value"], label="Histórico", linewidth=2.2, color="C0")
91
+ ax.plot(df_fc["date"], df_fc["p50"], label="Pronóstico (P50)", linewidth=2.4, color="C1")
92
+
93
+ # Banda de incertidumbre
94
+ ax.fill_between(df_fc["date"], df_fc["p10"], df_fc["p90"],
95
+ alpha=0.18, label="Banda P10–P90", color="C1", edgecolor="none")
96
+
97
+ # Formato de fechas: año corto (yy-mm)
98
+ ax.xaxis.set_major_locator(mdates.AutoDateLocator())
99
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%y-%m"))
100
+ plt.setp(ax.get_xticklabels(), rotation=0, ha="center")
101
+
102
+ # Etiquetas y título
103
+ ax.set_title("Pronóstico con Chronos-T5 (P10 / P50 / P90)")
104
+ ax.set_xlabel("Fecha")
105
+ ax.set_ylabel("Valor")
106
+
107
+ # Anotar último P50
108
+ x_last = df_fc["date"].iloc[-1]
109
+ y_last = df_fc["p50"].iloc[-1]
110
+ ax.scatter([x_last], [y_last], color="C1", s=30, zorder=3)
111
+ ax.annotate(f"P50={y_last:.1f}", xy=(x_last, y_last),
112
+ xytext=(10, 10), textcoords="offset points",
113
+ fontsize=10, bbox=dict(boxstyle="round,pad=0.25", fc="#f5f5f5", ec="#cccccc"))
114
+
115
+ # Leyenda fuera de la serie para no tapar
116
+ leg = ax.legend(loc="upper left")
117
+ for lh in leg.legendHandles:
118
+ try:
119
+ lh.set_alpha(1.0)
120
+ except Exception:
121
+ pass
122
+
123
+ fig.tight_layout(pad=1.2)
124
+ return fig
125
+
126
  def forecast_fn(file, sku: str, horizon: int = 12, freq: str = "MS"):
127
  """
128
+ Lee CSV, filtra por SKU (si hay), prepara serie,
129
+ ejecuta pronóstico y devuelve tabla + gráfica + archivo descargable.
130
  """
131
  if file is None:
132
  raise gr.Error("Sube un CSV con columnas: (sku,) date, value")
 
158
  "p90": np.round(p90, 4),
159
  })
160
 
161
+ # Gráfica pulida
162
+ fig = _nice_plot(df, out)
163
+
164
+ # --- Descarga: generar CSV en memoria y devolver bytes ---
165
+ csv_bytes = out.to_csv(index=False).encode("utf-8")
 
 
 
 
 
 
 
 
 
166
 
167
+ # Retornamos: tabla, gráfica, y contenido para DownloadButton
168
+ return out, fig, csv_bytes
169
 
170
  def list_skus(file):
171
+ """Detecta SKUs (si existe la columna) y llena el dropdown dinámicamente."""
 
 
 
172
  if file is None:
173
  return gr.update(choices=[], value=None), None
174
  df = pd.read_csv(file.name)
 
177
  if not skus:
178
  return gr.update(choices=[], value=None), df.head(10)
179
  return gr.update(choices=skus, value=skus[0]), df.head(10)
 
180
  return gr.update(choices=[], value=None), df.head(10)
181
 
182
  with gr.Blocks(title="Pronóstico por SKU (Chronos-T5)") as demo:
183
+ gr.Markdown(
184
+ "## Pronóstico de Demanda por **SKU**\n"
185
+ "CSV con columnas: **sku (opcional)**, **date**, **value**. "
186
+ "Selecciona el SKU y genera el forecast."
187
+ )
188
  with gr.Row():
189
  file = gr.File(label="CSV (sku,date,value o date,value)", file_types=[".csv"])
190
  sku_dd = gr.Dropdown(choices=[], value=None, label="SKU (si el CSV tiene columna 'sku')")
191
  horizon = gr.Slider(1, 36, value=12, step=1, label="Horizonte (pasos)")
192
  freq = gr.Dropdown(choices=["", "D", "W", "MS", "M"], value="MS",
193
  label="Frecuencia. ''=inferir, MS=mensual")
 
194
  preview = gr.Dataframe(label="Vista previa (primeras filas)")
 
 
195
  file.change(list_skus, inputs=file, outputs=[sku_dd, preview])
196
 
197
+ btn = gr.Button("Generar pronóstico", variant="primary")
198
  out_table = gr.Dataframe(label="Tabla de pronóstico")
199
  out_plot = gr.Plot(label="Gráfica")
200
+
201
+ # Botón de descarga (entrega bytes -> archivo "forecast.csv")
202
+ download_btn = gr.DownloadButton(
203
+ label="⬇️ Descargar pronóstico (CSV)",
204
+ value=None, # se asigna en tiempo de ejecución
205
+ file_name="forecast.csv" # nombre sugerido
206
+ )
207
+
208
+ btn.click(
209
+ forecast_fn,
210
+ inputs=[file, sku_dd, horizon, freq],
211
+ outputs=[out_table, out_plot, download_btn],
212
+ api_name="/forecast"
213
+ )
214
 
215
  if __name__ == "__main__":
216
  demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)