cesparzaf's picture
Actualizo app.py y requirements
2828dc3
import os, re, tempfile
import gradio as gr
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from chronos import ChronosPipeline
import io, base64
from types import SimpleNamespace
# ----------------------------
# Modelo (ligero para Space free)
# ----------------------------
MODEL_ID = "amazon/chronos-t5-small" # o "amazon/chronos-t5-mini"
PIPELINE = ChronosPipeline.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.float32, # usar dtype (no torch_dtype)
)
# ----------------------------
# Estilo "pro" para las gr谩ficas
# ----------------------------
plt.rcParams.update({
"figure.figsize": (9, 4.8),
"figure.facecolor": "#ffffff",
"axes.facecolor": "#ffffff",
"axes.grid": True,
"grid.color": "#e6e6e6",
"grid.linestyle": "-",
"grid.linewidth": 0.6,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.spines.left": False,
"axes.spines.bottom": False,
"axes.titlesize": 16,
"axes.titleweight": "semibold",
"axes.labelsize": 12,
"legend.frameon": True,
"legend.framealpha": 0.9,
})
def _prepare_series(df: pd.DataFrame, freq: str | None):
if "date" not in df.columns or "value" not in df.columns:
raise gr.Error("El CSV debe tener columnas: date,value")
df = df.copy()
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date")
if freq and freq.strip():
df = df.set_index("date").asfreq(freq).reset_index()
else:
inferred = pd.infer_freq(df["date"])
if inferred is None:
step = max(int((df["date"].diff().median() / pd.Timedelta(days=1)) or 1), 1)
df = df.set_index("date").asfreq(f"{step}D").reset_index()
else:
df = df.set_index("date").asfreq(inferred).reset_index()
df["value"] = pd.to_numeric(df["value"], errors="coerce")
df["value"] = df["value"].interpolate("linear").bfill().ffill()
return df
def _filter_by_sku(df: pd.DataFrame, sku: str | None):
if "sku" in df.columns:
if not sku:
raise gr.Error("Selecciona un SKU del listado.")
sdf = df[df["sku"].astype(str) == str(sku)].copy()
if sdf.empty:
raise gr.Error(f"No hay datos para el SKU: {sku}")
return sdf[["date", "value"]]
return df[["date", "value"]].copy()
def _nice_plot(df_hist: pd.DataFrame, df_fc: pd.DataFrame, std: np.ndarray) -> plt.Figure:
"""
Gr谩fica con:
- l铆neas P10, P50, P90
- banda 卤1蟽 alrededor de P50
- eje X con a帽o corto
- conexi贸n visual hist鈫扨50 (t a t+1) con l铆nea punteada
"""
fig, ax = plt.subplots()
# Hist贸rico
ax.plot(df_hist["date"], df_hist["value"], label="Hist贸rico", linewidth=2.2, color="C0")
# L铆neas de cuantiles
ax.plot(df_fc["date"], df_fc["p10"], label="P10", linewidth=1.8, color="C2")
ax.plot(df_fc["date"], df_fc["p50"], label="P50 (mediana)", linewidth=2.4, color="C1")
ax.plot(df_fc["date"], df_fc["p90"], label="P90", linewidth=1.8, color="C3")
# Banda 卤1蟽 alrededor de P50
lo = df_fc["p50"] - std
hi = df_fc["p50"] + std
ax.fill_between(df_fc["date"], lo, hi, alpha=0.18, label="卤1蟽 alrededor de P50",
color="C1", edgecolor="none")
# Conexi贸n visual del 煤ltimo real al primer P50 (t -> t+1)
last_date = df_hist["date"].iloc[-1]
last_val = df_hist["value"].iloc[-1]
first_date = df_fc["date"].iloc[0]
first_p50 = df_fc["p50"].iloc[0]
ax.plot([last_date, first_date], [last_val, first_p50],
linestyle="--", linewidth=1.6, color="C1", alpha=0.9,
label="Conexi贸n hist鈫扨50")
# Formato fechas (a帽o corto)
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%y-%m"))
ax.set_title("Pron贸stico con Chronos-T5 (P10 / P50 / P90 + 卤1蟽)")
ax.set_xlabel("Fecha"); ax.set_ylabel("Valor")
ax.legend(loc="upper left")
fig.tight_layout(pad=1.2)
return fig
def _safe_name(text: str) -> str:
import re
return re.sub(r"[^A-Za-z0-9._-]+", "-", str(text))[:50]
def forecast_fn(file, sku: str, horizon: int = 12, freq: str = "MS"):
if file is None:
raise gr.Error("Sube un CSV con columnas: (sku,) date, value")
raw = pd.read_csv(file.name)
df = _filter_by_sku(raw, sku)
df = _prepare_series(df, freq.strip() or None)
# Serie a tensor
y = torch.tensor(df["value"].values, dtype=torch.float32)
# Predicci贸n probabil铆stica (m煤ltiples trayectorias)
samples = PIPELINE.predict(y, prediction_length=horizon, num_samples=200) # [1, N, H]
samples = samples[0].numpy() # [N, H]
p10, p50, p90 = np.quantile(samples, [0.10, 0.50, 0.90], axis=0)
std = samples.std(axis=0) # desviaci贸n est谩ndar de la distribuci贸n predictiva en cada paso
# Fechas futuras
inferred = pd.infer_freq(df["date"])
if inferred is None:
step = max(int((df["date"].diff().median() / pd.Timedelta(days=1)) or 1), 1)
future_index = pd.date_range(df["date"].iloc[-1], periods=horizon+1, freq=f"{step}D")[1:]
else:
future_index = pd.date_range(df["date"].iloc[-1], periods=horizon+1, freq=inferred)[1:]
out = pd.DataFrame({
"date": future_index,
"p10": np.round(p10, 4),
"p50": np.round(p50, 4),
"p90": np.round(p90, 4),
"std": np.round(std, 4) # <-- a帽adimos 蟽 al resultado
})
# Gr谩fica
fig = _nice_plot(df, out, std)
import tempfile, os
# Guardar la figura como PNG temporal
tmp_plot_dir = tempfile.mkdtemp(prefix="plot_")
plot_path = os.path.join(tmp_plot_dir, "forecast.png")
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
plt.close(fig) # libera memoria
# Archivo para descargar
sku_name = _safe_name(sku) if sku else "serie"
fname = f"forecast_{sku_name}.csv"
tmp_dir = tempfile.mkdtemp(prefix="fcst_")
tmp_path = os.path.join(tmp_dir, fname)
out.to_csv(tmp_path, index=False, encoding="utf-8")
# Resumen en Markdown (primer paso + 蟽 promedio del horizonte)
md = (
f"**Resumen (primer paso):** \n"
f"- P10: **{out['p10'].iloc[0]:.2f}** \n"
f"- P50: **{out['p50'].iloc[0]:.2f}** \n"
f"- P90: **{out['p90'].iloc[0]:.2f}** \n"
f"- 蟽 (desv. est谩ndar): **{out['std'].iloc[0]:.2f}** \n\n"
f"**蟽 promedio en el horizonte:** **{out['std'].mean():.2f}**"
)
return out, plot_path, tmp_path, md
def forecast_from_text(csv_text_or_b64: str, sku: str, horizon: int = 12, freq: str = "MS"):
"""
Recibe CSV como texto o base64 (sin 'data:'), lo escribe a un archivo temporal,
y llama a forecast_fn reutilizando toda tu l贸gica actual.
Devuelve: (tabla, image filepath, csv filepath, markdown) en el mismo orden.
"""
txt = (csv_text_or_b64 or "").strip()
# 驴Parece base64? Intentar decodificar; si falla, tratar como texto plano.
try:
if not ("\n" in txt) and all(c.isalnum() or c in "+/=\n\r" for c in txt):
raw = base64.b64decode(txt)
csv_text = raw.decode("utf-8", errors="replace")
else:
csv_text = txt
except Exception:
csv_text = txt
# Escribir a un archivo temporal .csv
tmp = tempfile.NamedTemporaryFile("w+", suffix=".csv", delete=False)
tmp.write(csv_text)
tmp.flush()
tmp.close()
# Crear un objeto con atributo .name para que forecast_fn lo lea como gr.File
dummy_file = SimpleNamespace(name=tmp.name)
# Reusar tu pipeline original
return forecast_fn(dummy_file, sku, horizon, freq)
def list_skus(file):
if file is None:
return gr.update(choices=[], value=None), None
df = pd.read_csv(file.name)
if "sku" in df.columns:
skus = sorted(df["sku"].dropna().astype(str).unique().tolist())
if not skus:
return gr.update(choices=[], value=None), df.head(10)
return gr.update(choices=skus, value=skus[0]), df.head(10)
return gr.update(choices=[], value=None), df.head(10)
with gr.Blocks(title="Pron贸stico por SKU (Chronos-T5)") as demo:
gr.Markdown(
"## Pron贸stico de Demanda por **SKU** \n"
"CSV con columnas: **sku (opcional)**, **date**, **value**. "
"Selecciona el SKU y genera el forecast."
)
with gr.Row():
file = gr.File(label="CSV (sku,date,value o date,value)", file_types=[".csv"])
sku_dd = gr.Dropdown(
choices=[], value=None,
label="SKU (si el CSV tiene columna 'sku')",
allow_custom_value=True
)
horizon = gr.Slider(1, 36, value=12, step=1, label="Horizonte (pasos)")
freq = gr.Dropdown(choices=["", "D", "W", "MS", "M"], value="MS",
label="Frecuencia. ''=inferir, MS=mensual")
preview = gr.Dataframe(label="Vista previa (primeras filas)")
file.change(list_skus, inputs=file, outputs=[sku_dd, preview])
btn = gr.Button("Generar pron贸stico", variant="primary")
out_table = gr.Dataframe(label="Tabla de pron贸stico")
out_plot = gr.Image(type="filepath", label="Gr谩fica") # sin height (compat)
download_file = gr.File(label="猬囷笍 Descargar pron贸stico (CSV)", interactive=False)
stats_md = gr.Markdown(label="Resumen")
# Endpoint solo-API que recibe CSV como texto/base64 (sin archivo)
api_csv_text = gr.Textbox(visible=False)
gr.Button(visible=False).click(
fn=forecast_from_text,
inputs=[api_csv_text, sku_dd, horizon, freq], # MISMO orden que usar谩 Netlify
outputs=[out_table, out_plot, download_file, stats_md],
api_name="/forecast_text"
)
btn.click(
forecast_fn,
inputs=[file, sku_dd, horizon, freq],
outputs=[out_table, out_plot, download_file, stats_md],
api_name="/forecast"
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)