gauravsahu1990's picture
Upload folder using huggingface_hub
a7a7bbb verified
import os, io, base64, torch, logging
import pandas as pd
import matplotlib.pyplot as plt
from flask import Flask, request, jsonify
from sqlalchemy import create_engine, inspect
from model_loader import load_model
# -------------------------------------------------------
# 🧠 Flask App Setup
# -------------------------------------------------------
app = Flask("ChatBot-Backend")
# -------------------------------------------------------
# 🧾 Logging Configuration
# -------------------------------------------------------
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=LOG_LEVEL,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("ChatBot")
logger.info("πŸš€ Starting ChatBot backend service...")
# -------------------------------------------------------
# βš™οΈ Database Configuration
# -------------------------------------------------------
DB_USER = os.getenv("DB_USER", "root")
DB_PASSWORD = os.getenv("DB_PASSWORD", "root1234")
DB_HOST = os.getenv("DB_HOST", "database-1.chks4awear3o.eu-north-1.rds.amazonaws.com")
DB_PORT = os.getenv("DB_PORT", "3306")
DB_NAME = os.getenv("DB_NAME", "chatbot_db")
# -------------------------------------------------------
# 🧩 Database Engine Setup
# -------------------------------------------------------
try:
engine = create_engine(f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}")
insp = inspect(engine)
logger.info("βœ… Connected to MySQL successfully.")
except Exception as e:
logger.error(f"❌ Database connection failed: {e}")
engine = None
# -------------------------------------------------------
# 🧠 Model and Schema
# -------------------------------------------------------
tokenizer, model = None, None
schema_description = ""
def build_schema_description():
"""Builds schema text dynamically from MySQL tables."""
global schema_description
if not engine:
schema_description = "⚠️ Database connection unavailable."
return
try:
schema_description = ""
for table in insp.get_table_names():
schema_description += f"Table: {table}\n"
for col in insp.get_columns(table):
schema_description += f" - {col['name']} ({col['type']})\n"
schema_description += "\n"
logger.info("πŸ“˜ Schema description built successfully.")
except Exception as e:
logger.error(f"⚠️ Error while building schema: {e}")
schema_description = f"⚠️ Schema fetch error: {e}"
def generate_sql(question: str) -> str:
"""Generates SQL query from user question using the model."""
if tokenizer is None or model is None:
raise RuntimeError("Model not loaded yet.")
logger.info(f"🧩 Generating SQL for: {question}")
prompt = (
"You are a professional SQL generator.\n"
"Convert the following question into a valid SQL query based on this schema:\n\n"
f"{schema_description}\n"
f"Question: {question}\n\nSQL:"
)
logger.info(f"🧠 Generated prompt: {prompt}")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.2, do_sample=False)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "SELECT" in sql.upper():
sql = sql[sql.upper().find("SELECT"):]
sql = sql.strip()
logger.info(f"🧠 Generated SQL: {sql}")
return sql
@app.before_first_request
def init_model():
"""Loads the model and builds schema once before first API call."""
global tokenizer, model
logger.info("πŸͺ„ Initializing model on first request...")
tokenizer, model = load_model()
model.eval()
build_schema_description()
logger.info("βœ… Model loaded and schema ready.")
# -------------------------------------------------------
# 🌐 Routes
# -------------------------------------------------------
@app.route("/")
def home():
return jsonify({"message": "Chatbot backend is running!"})
@app.route("/api/ask", methods=["POST"])
def ask():
"""Main API endpoint for answering user queries."""
try:
data = request.get_json(force=True)
except Exception as e:
logger.error(f"❌ Invalid JSON received: {e}")
return jsonify({"error": "Invalid JSON payload"}), 400
question = data.get("question", "").strip()
if not question:
return jsonify({"error": "Empty question"}), 400
logger.info(f"πŸ—¨οΈ Received question: {question}")
try:
sql = generate_sql(question)
df = pd.read_sql(sql, engine)
logger.info(f"βœ… SQL executed successfully, {len(df)} rows fetched.")
if df.empty:
return jsonify({"answer": "No relevant data found in the database."})
html_table = df.to_html(index=False, classes="table table-striped")
# Plot graph
chart_base64 = None
try:
if len(df.columns) >= 2:
plt.figure(figsize=(6, 4))
df.plot(x=df.columns[0], y=df.columns[1], kind="bar")
plt.title(question)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
chart_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
logger.info("πŸ“ˆ Chart generated successfully.")
except Exception as plot_err:
logger.warning(f"⚠️ Chart generation failed: {plot_err}")
return jsonify({
"answer": f"Here’s what I found:<br>{html_table}",
"chart": chart_base64
})
except Exception as e:
logger.exception(f"❌ Error processing request: {e}")
return jsonify({"answer": f"⚠️ Error: {str(e)}"})
# -----------------------
# Run Flask app
# -----------------------
if __name__ == '__main__':
app.run(debug=True)