Spaces:
Sleeping
Sleeping
| import os, io, base64, torch, logging | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from flask import Flask, request, jsonify | |
| from sqlalchemy import create_engine, inspect | |
| from model_loader import load_model | |
| # ------------------------------------------------------- | |
| # π§ Flask App Setup | |
| # ------------------------------------------------------- | |
| app = Flask("ChatBot-Backend") | |
| # ------------------------------------------------------- | |
| # π§Ύ Logging Configuration | |
| # ------------------------------------------------------- | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() | |
| logging.basicConfig( | |
| level=LOG_LEVEL, | |
| format="[%(asctime)s] [%(levelname)s] %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger("ChatBot") | |
| logger.info("π Starting ChatBot backend service...") | |
| # ------------------------------------------------------- | |
| # βοΈ Database Configuration | |
| # ------------------------------------------------------- | |
| DB_USER = os.getenv("DB_USER", "root") | |
| DB_PASSWORD = os.getenv("DB_PASSWORD", "root1234") | |
| DB_HOST = os.getenv("DB_HOST", "database-1.chks4awear3o.eu-north-1.rds.amazonaws.com") | |
| DB_PORT = os.getenv("DB_PORT", "3306") | |
| DB_NAME = os.getenv("DB_NAME", "chatbot_db") | |
| # ------------------------------------------------------- | |
| # π§© Database Engine Setup | |
| # ------------------------------------------------------- | |
| try: | |
| engine = create_engine(f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}") | |
| insp = inspect(engine) | |
| logger.info("β Connected to MySQL successfully.") | |
| except Exception as e: | |
| logger.error(f"β Database connection failed: {e}") | |
| engine = None | |
| # ------------------------------------------------------- | |
| # π§ Model and Schema | |
| # ------------------------------------------------------- | |
| tokenizer, model = None, None | |
| schema_description = "" | |
| def build_schema_description(): | |
| """Builds schema text dynamically from MySQL tables.""" | |
| global schema_description | |
| if not engine: | |
| schema_description = "β οΈ Database connection unavailable." | |
| return | |
| try: | |
| schema_description = "" | |
| for table in insp.get_table_names(): | |
| schema_description += f"Table: {table}\n" | |
| for col in insp.get_columns(table): | |
| schema_description += f" - {col['name']} ({col['type']})\n" | |
| schema_description += "\n" | |
| logger.info("π Schema description built successfully.") | |
| except Exception as e: | |
| logger.error(f"β οΈ Error while building schema: {e}") | |
| schema_description = f"β οΈ Schema fetch error: {e}" | |
| def generate_sql(question: str) -> str: | |
| """Generates SQL query from user question using the model.""" | |
| if tokenizer is None or model is None: | |
| raise RuntimeError("Model not loaded yet.") | |
| logger.info(f"π§© Generating SQL for: {question}") | |
| prompt = ( | |
| "You are a professional SQL generator.\n" | |
| "Convert the following question into a valid SQL query based on this schema:\n\n" | |
| f"{schema_description}\n" | |
| f"Question: {question}\n\nSQL:" | |
| ) | |
| logger.info(f"π§ Generated prompt: {prompt}") | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.2, do_sample=False) | |
| sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "SELECT" in sql.upper(): | |
| sql = sql[sql.upper().find("SELECT"):] | |
| sql = sql.strip() | |
| logger.info(f"π§ Generated SQL: {sql}") | |
| return sql | |
| def init_model(): | |
| """Loads the model and builds schema once before first API call.""" | |
| global tokenizer, model | |
| logger.info("πͺ Initializing model on first request...") | |
| tokenizer, model = load_model() | |
| model.eval() | |
| build_schema_description() | |
| logger.info("β Model loaded and schema ready.") | |
| # ------------------------------------------------------- | |
| # π Routes | |
| # ------------------------------------------------------- | |
| def home(): | |
| return jsonify({"message": "Chatbot backend is running!"}) | |
| def ask(): | |
| """Main API endpoint for answering user queries.""" | |
| try: | |
| data = request.get_json(force=True) | |
| except Exception as e: | |
| logger.error(f"β Invalid JSON received: {e}") | |
| return jsonify({"error": "Invalid JSON payload"}), 400 | |
| question = data.get("question", "").strip() | |
| if not question: | |
| return jsonify({"error": "Empty question"}), 400 | |
| logger.info(f"π¨οΈ Received question: {question}") | |
| try: | |
| sql = generate_sql(question) | |
| df = pd.read_sql(sql, engine) | |
| logger.info(f"β SQL executed successfully, {len(df)} rows fetched.") | |
| if df.empty: | |
| return jsonify({"answer": "No relevant data found in the database."}) | |
| html_table = df.to_html(index=False, classes="table table-striped") | |
| # Plot graph | |
| chart_base64 = None | |
| try: | |
| if len(df.columns) >= 2: | |
| plt.figure(figsize=(6, 4)) | |
| df.plot(x=df.columns[0], y=df.columns[1], kind="bar") | |
| plt.title(question) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| chart_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
| plt.close() | |
| logger.info("π Chart generated successfully.") | |
| except Exception as plot_err: | |
| logger.warning(f"β οΈ Chart generation failed: {plot_err}") | |
| return jsonify({ | |
| "answer": f"Hereβs what I found:<br>{html_table}", | |
| "chart": chart_base64 | |
| }) | |
| except Exception as e: | |
| logger.exception(f"β Error processing request: {e}") | |
| return jsonify({"answer": f"β οΈ Error: {str(e)}"}) | |
| # ----------------------- | |
| # Run Flask app | |
| # ----------------------- | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |