Spaces:
Sleeping
Sleeping
File size: 6,097 Bytes
9bee602 a7a7bbb 9bee602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os, io, base64, torch, logging
import pandas as pd
import matplotlib.pyplot as plt
from flask import Flask, request, jsonify
from sqlalchemy import create_engine, inspect
from model_loader import load_model
# -------------------------------------------------------
# π§ Flask App Setup
# -------------------------------------------------------
app = Flask("ChatBot-Backend")
# -------------------------------------------------------
# π§Ύ Logging Configuration
# -------------------------------------------------------
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=LOG_LEVEL,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("ChatBot")
logger.info("π Starting ChatBot backend service...")
# -------------------------------------------------------
# βοΈ Database Configuration
# -------------------------------------------------------
DB_USER = os.getenv("DB_USER", "root")
DB_PASSWORD = os.getenv("DB_PASSWORD", "root1234")
DB_HOST = os.getenv("DB_HOST", "database-1.chks4awear3o.eu-north-1.rds.amazonaws.com")
DB_PORT = os.getenv("DB_PORT", "3306")
DB_NAME = os.getenv("DB_NAME", "chatbot_db")
# -------------------------------------------------------
# π§© Database Engine Setup
# -------------------------------------------------------
try:
engine = create_engine(f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}")
insp = inspect(engine)
logger.info("β
Connected to MySQL successfully.")
except Exception as e:
logger.error(f"β Database connection failed: {e}")
engine = None
# -------------------------------------------------------
# π§ Model and Schema
# -------------------------------------------------------
tokenizer, model = None, None
schema_description = ""
def build_schema_description():
"""Builds schema text dynamically from MySQL tables."""
global schema_description
if not engine:
schema_description = "β οΈ Database connection unavailable."
return
try:
schema_description = ""
for table in insp.get_table_names():
schema_description += f"Table: {table}\n"
for col in insp.get_columns(table):
schema_description += f" - {col['name']} ({col['type']})\n"
schema_description += "\n"
logger.info("π Schema description built successfully.")
except Exception as e:
logger.error(f"β οΈ Error while building schema: {e}")
schema_description = f"β οΈ Schema fetch error: {e}"
def generate_sql(question: str) -> str:
"""Generates SQL query from user question using the model."""
if tokenizer is None or model is None:
raise RuntimeError("Model not loaded yet.")
logger.info(f"π§© Generating SQL for: {question}")
prompt = (
"You are a professional SQL generator.\n"
"Convert the following question into a valid SQL query based on this schema:\n\n"
f"{schema_description}\n"
f"Question: {question}\n\nSQL:"
)
logger.info(f"π§ Generated prompt: {prompt}")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.2, do_sample=False)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "SELECT" in sql.upper():
sql = sql[sql.upper().find("SELECT"):]
sql = sql.strip()
logger.info(f"π§ Generated SQL: {sql}")
return sql
@app.before_first_request
def init_model():
"""Loads the model and builds schema once before first API call."""
global tokenizer, model
logger.info("πͺ Initializing model on first request...")
tokenizer, model = load_model()
model.eval()
build_schema_description()
logger.info("β
Model loaded and schema ready.")
# -------------------------------------------------------
# π Routes
# -------------------------------------------------------
@app.route("/")
def home():
return jsonify({"message": "Chatbot backend is running!"})
@app.route("/api/ask", methods=["POST"])
def ask():
"""Main API endpoint for answering user queries."""
try:
data = request.get_json(force=True)
except Exception as e:
logger.error(f"β Invalid JSON received: {e}")
return jsonify({"error": "Invalid JSON payload"}), 400
question = data.get("question", "").strip()
if not question:
return jsonify({"error": "Empty question"}), 400
logger.info(f"π¨οΈ Received question: {question}")
try:
sql = generate_sql(question)
df = pd.read_sql(sql, engine)
logger.info(f"β
SQL executed successfully, {len(df)} rows fetched.")
if df.empty:
return jsonify({"answer": "No relevant data found in the database."})
html_table = df.to_html(index=False, classes="table table-striped")
# Plot graph
chart_base64 = None
try:
if len(df.columns) >= 2:
plt.figure(figsize=(6, 4))
df.plot(x=df.columns[0], y=df.columns[1], kind="bar")
plt.title(question)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
chart_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close()
logger.info("π Chart generated successfully.")
except Exception as plot_err:
logger.warning(f"β οΈ Chart generation failed: {plot_err}")
return jsonify({
"answer": f"Hereβs what I found:<br>{html_table}",
"chart": chart_base64
})
except Exception as e:
logger.exception(f"β Error processing request: {e}")
return jsonify({"answer": f"β οΈ Error: {str(e)}"})
# -----------------------
# Run Flask app
# -----------------------
if __name__ == '__main__':
app.run(debug=True)
|