Spaces:
Running
Running
| import os, re, tempfile | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import matplotlib.dates as mdates | |
| from chronos import ChronosPipeline | |
| import io, base64 | |
| from types import SimpleNamespace | |
| # ---------------------------- | |
| # Modelo (ligero para Space free) | |
| # ---------------------------- | |
| MODEL_ID = "amazon/chronos-t5-small" # o "amazon/chronos-t5-mini" | |
| PIPELINE = ChronosPipeline.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| dtype=torch.float32, # usar dtype (no torch_dtype) | |
| ) | |
| # ---------------------------- | |
| # Estilo "pro" para las gr谩ficas | |
| # ---------------------------- | |
| plt.rcParams.update({ | |
| "figure.figsize": (9, 4.8), | |
| "figure.facecolor": "#ffffff", | |
| "axes.facecolor": "#ffffff", | |
| "axes.grid": True, | |
| "grid.color": "#e6e6e6", | |
| "grid.linestyle": "-", | |
| "grid.linewidth": 0.6, | |
| "axes.spines.top": False, | |
| "axes.spines.right": False, | |
| "axes.spines.left": False, | |
| "axes.spines.bottom": False, | |
| "axes.titlesize": 16, | |
| "axes.titleweight": "semibold", | |
| "axes.labelsize": 12, | |
| "legend.frameon": True, | |
| "legend.framealpha": 0.9, | |
| }) | |
| def _prepare_series(df: pd.DataFrame, freq: str | None): | |
| if "date" not in df.columns or "value" not in df.columns: | |
| raise gr.Error("El CSV debe tener columnas: date,value") | |
| df = df.copy() | |
| df["date"] = pd.to_datetime(df["date"]) | |
| df = df.sort_values("date") | |
| if freq and freq.strip(): | |
| df = df.set_index("date").asfreq(freq).reset_index() | |
| else: | |
| inferred = pd.infer_freq(df["date"]) | |
| if inferred is None: | |
| step = max(int((df["date"].diff().median() / pd.Timedelta(days=1)) or 1), 1) | |
| df = df.set_index("date").asfreq(f"{step}D").reset_index() | |
| else: | |
| df = df.set_index("date").asfreq(inferred).reset_index() | |
| df["value"] = pd.to_numeric(df["value"], errors="coerce") | |
| df["value"] = df["value"].interpolate("linear").bfill().ffill() | |
| return df | |
| def _filter_by_sku(df: pd.DataFrame, sku: str | None): | |
| if "sku" in df.columns: | |
| if not sku: | |
| raise gr.Error("Selecciona un SKU del listado.") | |
| sdf = df[df["sku"].astype(str) == str(sku)].copy() | |
| if sdf.empty: | |
| raise gr.Error(f"No hay datos para el SKU: {sku}") | |
| return sdf[["date", "value"]] | |
| return df[["date", "value"]].copy() | |
| def _nice_plot(df_hist: pd.DataFrame, df_fc: pd.DataFrame, std: np.ndarray) -> plt.Figure: | |
| """ | |
| Gr谩fica con: | |
| - l铆neas P10, P50, P90 | |
| - banda 卤1蟽 alrededor de P50 | |
| - eje X con a帽o corto | |
| - conexi贸n visual hist鈫扨50 (t a t+1) con l铆nea punteada | |
| """ | |
| fig, ax = plt.subplots() | |
| # Hist贸rico | |
| ax.plot(df_hist["date"], df_hist["value"], label="Hist贸rico", linewidth=2.2, color="C0") | |
| # L铆neas de cuantiles | |
| ax.plot(df_fc["date"], df_fc["p10"], label="P10", linewidth=1.8, color="C2") | |
| ax.plot(df_fc["date"], df_fc["p50"], label="P50 (mediana)", linewidth=2.4, color="C1") | |
| ax.plot(df_fc["date"], df_fc["p90"], label="P90", linewidth=1.8, color="C3") | |
| # Banda 卤1蟽 alrededor de P50 | |
| lo = df_fc["p50"] - std | |
| hi = df_fc["p50"] + std | |
| ax.fill_between(df_fc["date"], lo, hi, alpha=0.18, label="卤1蟽 alrededor de P50", | |
| color="C1", edgecolor="none") | |
| # Conexi贸n visual del 煤ltimo real al primer P50 (t -> t+1) | |
| last_date = df_hist["date"].iloc[-1] | |
| last_val = df_hist["value"].iloc[-1] | |
| first_date = df_fc["date"].iloc[0] | |
| first_p50 = df_fc["p50"].iloc[0] | |
| ax.plot([last_date, first_date], [last_val, first_p50], | |
| linestyle="--", linewidth=1.6, color="C1", alpha=0.9, | |
| label="Conexi贸n hist鈫扨50") | |
| # Formato fechas (a帽o corto) | |
| ax.xaxis.set_major_locator(mdates.AutoDateLocator()) | |
| ax.xaxis.set_major_formatter(mdates.DateFormatter("%y-%m")) | |
| ax.set_title("Pron贸stico con Chronos-T5 (P10 / P50 / P90 + 卤1蟽)") | |
| ax.set_xlabel("Fecha"); ax.set_ylabel("Valor") | |
| ax.legend(loc="upper left") | |
| fig.tight_layout(pad=1.2) | |
| return fig | |
| def _safe_name(text: str) -> str: | |
| import re | |
| return re.sub(r"[^A-Za-z0-9._-]+", "-", str(text))[:50] | |
| def forecast_fn(file, sku: str, horizon: int = 12, freq: str = "MS"): | |
| if file is None: | |
| raise gr.Error("Sube un CSV con columnas: (sku,) date, value") | |
| raw = pd.read_csv(file.name) | |
| df = _filter_by_sku(raw, sku) | |
| df = _prepare_series(df, freq.strip() or None) | |
| # Serie a tensor | |
| y = torch.tensor(df["value"].values, dtype=torch.float32) | |
| # Predicci贸n probabil铆stica (m煤ltiples trayectorias) | |
| samples = PIPELINE.predict(y, prediction_length=horizon, num_samples=200) # [1, N, H] | |
| samples = samples[0].numpy() # [N, H] | |
| p10, p50, p90 = np.quantile(samples, [0.10, 0.50, 0.90], axis=0) | |
| std = samples.std(axis=0) # desviaci贸n est谩ndar de la distribuci贸n predictiva en cada paso | |
| # Fechas futuras | |
| inferred = pd.infer_freq(df["date"]) | |
| if inferred is None: | |
| step = max(int((df["date"].diff().median() / pd.Timedelta(days=1)) or 1), 1) | |
| future_index = pd.date_range(df["date"].iloc[-1], periods=horizon+1, freq=f"{step}D")[1:] | |
| else: | |
| future_index = pd.date_range(df["date"].iloc[-1], periods=horizon+1, freq=inferred)[1:] | |
| out = pd.DataFrame({ | |
| "date": future_index, | |
| "p10": np.round(p10, 4), | |
| "p50": np.round(p50, 4), | |
| "p90": np.round(p90, 4), | |
| "std": np.round(std, 4) # <-- a帽adimos 蟽 al resultado | |
| }) | |
| # Gr谩fica | |
| fig = _nice_plot(df, out, std) | |
| import tempfile, os | |
| # Guardar la figura como PNG temporal | |
| tmp_plot_dir = tempfile.mkdtemp(prefix="plot_") | |
| plot_path = os.path.join(tmp_plot_dir, "forecast.png") | |
| fig.savefig(plot_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) # libera memoria | |
| # Archivo para descargar | |
| sku_name = _safe_name(sku) if sku else "serie" | |
| fname = f"forecast_{sku_name}.csv" | |
| tmp_dir = tempfile.mkdtemp(prefix="fcst_") | |
| tmp_path = os.path.join(tmp_dir, fname) | |
| out.to_csv(tmp_path, index=False, encoding="utf-8") | |
| # Resumen en Markdown (primer paso + 蟽 promedio del horizonte) | |
| md = ( | |
| f"**Resumen (primer paso):** \n" | |
| f"- P10: **{out['p10'].iloc[0]:.2f}** \n" | |
| f"- P50: **{out['p50'].iloc[0]:.2f}** \n" | |
| f"- P90: **{out['p90'].iloc[0]:.2f}** \n" | |
| f"- 蟽 (desv. est谩ndar): **{out['std'].iloc[0]:.2f}** \n\n" | |
| f"**蟽 promedio en el horizonte:** **{out['std'].mean():.2f}**" | |
| ) | |
| return out, plot_path, tmp_path, md | |
| def forecast_from_text(csv_text_or_b64: str, sku: str, horizon: int = 12, freq: str = "MS"): | |
| """ | |
| Recibe CSV como texto o base64 (sin 'data:'), lo escribe a un archivo temporal, | |
| y llama a forecast_fn reutilizando toda tu l贸gica actual. | |
| Devuelve: (tabla, image filepath, csv filepath, markdown) en el mismo orden. | |
| """ | |
| txt = (csv_text_or_b64 or "").strip() | |
| # 驴Parece base64? Intentar decodificar; si falla, tratar como texto plano. | |
| try: | |
| if not ("\n" in txt) and all(c.isalnum() or c in "+/=\n\r" for c in txt): | |
| raw = base64.b64decode(txt) | |
| csv_text = raw.decode("utf-8", errors="replace") | |
| else: | |
| csv_text = txt | |
| except Exception: | |
| csv_text = txt | |
| # Escribir a un archivo temporal .csv | |
| tmp = tempfile.NamedTemporaryFile("w+", suffix=".csv", delete=False) | |
| tmp.write(csv_text) | |
| tmp.flush() | |
| tmp.close() | |
| # Crear un objeto con atributo .name para que forecast_fn lo lea como gr.File | |
| dummy_file = SimpleNamespace(name=tmp.name) | |
| # Reusar tu pipeline original | |
| return forecast_fn(dummy_file, sku, horizon, freq) | |
| def list_skus(file): | |
| if file is None: | |
| return gr.update(choices=[], value=None), None | |
| df = pd.read_csv(file.name) | |
| if "sku" in df.columns: | |
| skus = sorted(df["sku"].dropna().astype(str).unique().tolist()) | |
| if not skus: | |
| return gr.update(choices=[], value=None), df.head(10) | |
| return gr.update(choices=skus, value=skus[0]), df.head(10) | |
| return gr.update(choices=[], value=None), df.head(10) | |
| with gr.Blocks(title="Pron贸stico por SKU (Chronos-T5)") as demo: | |
| gr.Markdown( | |
| "## Pron贸stico de Demanda por **SKU** \n" | |
| "CSV con columnas: **sku (opcional)**, **date**, **value**. " | |
| "Selecciona el SKU y genera el forecast." | |
| ) | |
| with gr.Row(): | |
| file = gr.File(label="CSV (sku,date,value o date,value)", file_types=[".csv"]) | |
| sku_dd = gr.Dropdown( | |
| choices=[], value=None, | |
| label="SKU (si el CSV tiene columna 'sku')", | |
| allow_custom_value=True | |
| ) | |
| horizon = gr.Slider(1, 36, value=12, step=1, label="Horizonte (pasos)") | |
| freq = gr.Dropdown(choices=["", "D", "W", "MS", "M"], value="MS", | |
| label="Frecuencia. ''=inferir, MS=mensual") | |
| preview = gr.Dataframe(label="Vista previa (primeras filas)") | |
| file.change(list_skus, inputs=file, outputs=[sku_dd, preview]) | |
| btn = gr.Button("Generar pron贸stico", variant="primary") | |
| out_table = gr.Dataframe(label="Tabla de pron贸stico") | |
| out_plot = gr.Image(type="filepath", label="Gr谩fica") # sin height (compat) | |
| download_file = gr.File(label="猬囷笍 Descargar pron贸stico (CSV)", interactive=False) | |
| stats_md = gr.Markdown(label="Resumen") | |
| # Endpoint solo-API que recibe CSV como texto/base64 (sin archivo) | |
| api_csv_text = gr.Textbox(visible=False) | |
| gr.Button(visible=False).click( | |
| fn=forecast_from_text, | |
| inputs=[api_csv_text, sku_dd, horizon, freq], # MISMO orden que usar谩 Netlify | |
| outputs=[out_table, out_plot, download_file, stats_md], | |
| api_name="/forecast_text" | |
| ) | |
| btn.click( | |
| forecast_fn, | |
| inputs=[file, sku_dd, horizon, freq], | |
| outputs=[out_table, out_plot, download_file, stats_md], | |
| api_name="/forecast" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |