gauravsahu1990's picture
Upload folder using huggingface_hub
a7a7bbb verified
raw
history blame
6.1 kB
import os, io, base64, torch, logging
import pandas as pd
import matplotlib.pyplot as plt
from flask import Flask, request, jsonify
from sqlalchemy import create_engine, inspect
from model_loader import load_model
# -------------------------------------------------------
# 🧠 Flask App Setup
# -------------------------------------------------------
app = Flask("ChatBot-Backend")
# -------------------------------------------------------
# 🧾 Logging Configuration
# -------------------------------------------------------
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=LOG_LEVEL,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("ChatBot")
logger.info("πŸš€ Starting ChatBot backend service...")
# -------------------------------------------------------
# βš™οΈ Database Configuration
# -------------------------------------------------------
DB_USER = os.getenv("DB_USER", "root")
DB_PASSWORD = os.getenv("DB_PASSWORD", "root1234")
DB_HOST = os.getenv("DB_HOST", "database-1.chks4awear3o.eu-north-1.rds.amazonaws.com")
DB_PORT = os.getenv("DB_PORT", "3306")
DB_NAME = os.getenv("DB_NAME", "chatbot_db")
# -------------------------------------------------------
# 🧩 Database Engine Setup
# -------------------------------------------------------
try:
engine = create_engine(f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}")
insp = inspect(engine)
logger.info("βœ… Connected to MySQL successfully.")
except Exception as e:
logger.error(f"❌ Database connection failed: {e}")
engine = None
# -------------------------------------------------------
# 🧠 Model and Schema
# -------------------------------------------------------
tokenizer, model = None, None
schema_description = ""
def build_schema_description():
"""Builds schema text dynamically from MySQL tables."""
global schema_description
if not engine:
schema_description = "⚠️ Database connection unavailable."
return
try:
schema_description = ""
for table in insp.get_table_names():
schema_description += f"Table: {table}\n"
for col in insp.get_columns(table):
schema_description += f" - {col['name']} ({col['type']})\n"
schema_description += "\n"
logger.info("πŸ“˜ Schema description built successfully.")
except Exception as e:
logger.error(f"⚠️ Error while building schema: {e}")
schema_description = f"⚠️ Schema fetch error: {e}"
def generate_sql(question: str) -> str:
"""Generates SQL query from user question using the model."""
if tokenizer is None or model is None:
raise RuntimeError("Model not loaded yet.")
logger.info(f"🧩 Generating SQL for: {question}")
prompt = (
"You are a professional SQL generator.\n"
"Convert the following question into a valid SQL query based on this schema:\n\n"
f"{schema_description}\n"
f"Question: {question}\n\nSQL:"
)
logger.info(f"🧠 Generated prompt: {prompt}")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.2, do_sample=False)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "SELECT" in sql.upper():
sql = sql[sql.upper().find("SELECT"):]
sql = sql.strip()
logger.info(f"🧠 Generated SQL: {sql}")
return sql
@app.before_first_request
def init_model():
"""Loads the model and builds schema once before first API call."""
global tokenizer, model
logger.info("πŸͺ„ Initializing model on first request...")
tokenizer, model = load_model()
model.eval()
build_schema_description()
logger.info("βœ… Model loaded and schema ready.")
# -------------------------------------------------------
# 🌐 Routes
# -------------------------------------------------------
@app.route("/")
def home():
return jsonify({"message": "Chatbot backend is running!"})
@app.route("/api/ask", methods=["POST"])
def ask():
"""Main API endpoint for answering user queries."""
try:
data = request.get_json(force=True)
except Exception as e:
logger.error(f"❌ Invalid JSON received: {e}")
return jsonify({"error": "Invalid JSON payload"}), 400
question = data.get("question", "").strip()
if not question:
return jsonify({"error": "Empty question"}), 400
logger.info(f"πŸ—¨οΈ Received question: {question}")
try:
sql = generate_sql(question)
df = pd.read_sql(sql, engine)
logger.info(f"βœ… SQL executed successfully, {len(df)} rows fetched.")
if df.empty:
return jsonify({"answer": "No relevant data found in the database."})
html_table = df.to_html(index=False, classes="table table-striped")
# Plot graph
chart_base64 = None
try:
if len(df.columns) >= 2:
plt.figure(figsize=(6, 4))
df.plot(x=df.columns[0], y=df.columns[1], kind="bar")
plt.title(question)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
chart_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
logger.info("πŸ“ˆ Chart generated successfully.")
except Exception as plot_err:
logger.warning(f"⚠️ Chart generation failed: {plot_err}")
return jsonify({
"answer": f"Here’s what I found:<br>{html_table}",
"chart": chart_base64
})
except Exception as e:
logger.exception(f"❌ Error processing request: {e}")
return jsonify({"answer": f"⚠️ Error: {str(e)}"})
# -----------------------
# Run Flask app
# -----------------------
if __name__ == '__main__':
app.run(debug=True)