import os, io, base64, torch, logging import pandas as pd import matplotlib.pyplot as plt from flask import Flask, request, jsonify from sqlalchemy import create_engine, inspect from model_loader import load_model # ------------------------------------------------------- # 🧠 Flask App Setup # ------------------------------------------------------- app = Flask("ChatBot-Backend") # ------------------------------------------------------- # 🧾 Logging Configuration # ------------------------------------------------------- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() logging.basicConfig( level=LOG_LEVEL, format="[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger("ChatBot") logger.info("πŸš€ Starting ChatBot backend service...") # ------------------------------------------------------- # βš™οΈ Database Configuration # ------------------------------------------------------- DB_USER = os.getenv("DB_USER", "root") DB_PASSWORD = os.getenv("DB_PASSWORD", "root1234") DB_HOST = os.getenv("DB_HOST", "database-1.chks4awear3o.eu-north-1.rds.amazonaws.com") DB_PORT = os.getenv("DB_PORT", "3306") DB_NAME = os.getenv("DB_NAME", "chatbot_db") # ------------------------------------------------------- # 🧩 Database Engine Setup # ------------------------------------------------------- try: engine = create_engine(f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}") insp = inspect(engine) logger.info("βœ… Connected to MySQL successfully.") except Exception as e: logger.error(f"❌ Database connection failed: {e}") engine = None # ------------------------------------------------------- # 🧠 Model and Schema # ------------------------------------------------------- tokenizer, model = None, None schema_description = "" def build_schema_description(): """Builds schema text dynamically from MySQL tables.""" global schema_description if not engine: schema_description = "⚠️ Database connection unavailable." return try: schema_description = "" for table in insp.get_table_names(): schema_description += f"Table: {table}\n" for col in insp.get_columns(table): schema_description += f" - {col['name']} ({col['type']})\n" schema_description += "\n" logger.info("πŸ“˜ Schema description built successfully.") except Exception as e: logger.error(f"⚠️ Error while building schema: {e}") schema_description = f"⚠️ Schema fetch error: {e}" def generate_sql(question: str) -> str: """Generates SQL query from user question using the model.""" if tokenizer is None or model is None: raise RuntimeError("Model not loaded yet.") logger.info(f"🧩 Generating SQL for: {question}") prompt = ( "You are a professional SQL generator.\n" "Convert the following question into a valid SQL query based on this schema:\n\n" f"{schema_description}\n" f"Question: {question}\n\nSQL:" ) logger.info(f"🧠 Generated prompt: {prompt}") inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.2, do_sample=False) sql = tokenizer.decode(outputs[0], skip_special_tokens=True) if "SELECT" in sql.upper(): sql = sql[sql.upper().find("SELECT"):] sql = sql.strip() logger.info(f"🧠 Generated SQL: {sql}") return sql @app.before_first_request def init_model(): """Loads the model and builds schema once before first API call.""" global tokenizer, model logger.info("πŸͺ„ Initializing model on first request...") tokenizer, model = load_model() model.eval() build_schema_description() logger.info("βœ… Model loaded and schema ready.") # ------------------------------------------------------- # 🌐 Routes # ------------------------------------------------------- @app.route("/") def home(): return jsonify({"message": "Chatbot backend is running!"}) @app.route("/api/ask", methods=["POST"]) def ask(): """Main API endpoint for answering user queries.""" try: data = request.get_json(force=True) except Exception as e: logger.error(f"❌ Invalid JSON received: {e}") return jsonify({"error": "Invalid JSON payload"}), 400 question = data.get("question", "").strip() if not question: return jsonify({"error": "Empty question"}), 400 logger.info(f"πŸ—¨οΈ Received question: {question}") try: sql = generate_sql(question) df = pd.read_sql(sql, engine) logger.info(f"βœ… SQL executed successfully, {len(df)} rows fetched.") if df.empty: return jsonify({"answer": "No relevant data found in the database."}) html_table = df.to_html(index=False, classes="table table-striped") # Plot graph chart_base64 = None try: if len(df.columns) >= 2: plt.figure(figsize=(6, 4)) df.plot(x=df.columns[0], y=df.columns[1], kind="bar") plt.title(question) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) chart_base64 = base64.b64encode(buf.read()).decode("utf-8") plt.close() logger.info("πŸ“ˆ Chart generated successfully.") except Exception as plot_err: logger.warning(f"⚠️ Chart generation failed: {plot_err}") return jsonify({ "answer": f"Here’s what I found:
{html_table}", "chart": chart_base64 }) except Exception as e: logger.exception(f"❌ Error processing request: {e}") return jsonify({"answer": f"⚠️ Error: {str(e)}"}) # ----------------------- # Run Flask app # ----------------------- if __name__ == '__main__': app.run(debug=True)