Spaces:
Running
Running
| import os | |
| import re | |
| import csv | |
| import secrets | |
| import unicodedata | |
| from datetime import datetime, timedelta | |
| from io import BytesIO | |
| from flask import ( | |
| Flask, | |
| render_template, | |
| request, | |
| redirect, | |
| url_for, | |
| send_file, | |
| flash, | |
| session, | |
| abort, | |
| ) | |
| from flask_sqlalchemy import SQLAlchemy | |
| from flask_login import ( | |
| LoginManager, | |
| UserMixin, | |
| login_user, | |
| login_required, | |
| current_user, | |
| logout_user, | |
| ) | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.pagesizes import letter | |
| from sendgrid import SendGridAPIClient | |
| from sendgrid.helpers.mail import Mail | |
| # your HF-based classifier | |
| from model import classify_tone_rich | |
| # ========================================================= | |
| # APP CONFIG | |
| # ========================================================= | |
| app = Flask(__name__) | |
| app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", "change-this-in-prod") | |
| app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///data.db" | |
| app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False | |
| # Detect if running on Hugging Face Spaces (they set SPACE_ID) | |
| IS_HF = os.getenv("SPACE_ID") is not None | |
| # cookie security | |
| app.config["SESSION_COOKIE_HTTPONLY"] = True | |
| if IS_HF: | |
| # Needed because app runs inside an iframe on huggingface.co | |
| app.config["SESSION_COOKIE_SAMESITE"] = "None" | |
| app.config["SESSION_COOKIE_SECURE"] = True | |
| else: | |
| # Local dev / normal hosting | |
| app.config["SESSION_COOKIE_SAMESITE"] = "Lax" | |
| app.config["SESSION_COOKIE_SECURE"] = False | |
| db = SQLAlchemy(app) | |
| login_manager = LoginManager(app) | |
| login_manager.login_view = "login" | |
| # Email (SendGrid) | |
| SENDGRID_API_KEY = os.getenv("SENDGRID_API_KEY") | |
| SENDER_EMAIL = os.getenv("SENDER_EMAIL", "[email protected]") | |
| # simple email regex | |
| EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$") | |
| os.makedirs("exports", exist_ok=True) | |
| # ========================================================= | |
| # HELPER FUNCTIONS β SANITIZATION, CSRF, PASSWORDS | |
| # ========================================================= | |
| def normalize_text(value: str) -> str: | |
| if not value: | |
| return "" | |
| value = unicodedata.normalize("NFKC", value) | |
| value = value.replace("\u200b", "").replace("\u200c", "").replace("\u200d", "") | |
| return value.strip() | |
| def sanitize_string(value: str, max_len: int = 255) -> str: | |
| value = normalize_text(value) | |
| if len(value) > max_len: | |
| value = value[:max_len] | |
| return value | |
| def sanitize_long_text(value: str, max_len: int = 4000) -> str: | |
| value = normalize_text(value) | |
| if len(value) > max_len: | |
| value = value[:max_len] | |
| return value | |
| def is_valid_email(email: str) -> bool: | |
| return bool(email and EMAIL_RE.match(email)) | |
| def is_strong_password(pw: str) -> bool: | |
| if not pw or len(pw) < 8: | |
| return False | |
| has_letter = any(c.isalpha() for c in pw) | |
| has_digit = any(c.isdigit() for c in pw) | |
| return has_letter and has_digit | |
| def generate_code() -> str: | |
| """6-digit numeric code used for verify + reset.""" | |
| return f"{secrets.randbelow(1000000):06d}" | |
| def generate_csrf_token() -> str: | |
| token = session.get("csrf_token") | |
| if not token: | |
| token = secrets.token_hex(16) | |
| session["csrf_token"] = token | |
| return token | |
| def csrf_protect(): | |
| # ensure CSRF token exists | |
| generate_csrf_token() | |
| if request.method == "POST": | |
| form_token = request.form.get("csrf_token", "") | |
| sess_token = session.get("csrf_token", "") | |
| if not form_token or form_token != sess_token: | |
| abort(400, description="Invalid CSRF token") | |
| def inject_csrf(): | |
| return {"csrf_token": session.get("csrf_token", "")} | |
| # ========================================================= | |
| # MODELS | |
| # ========================================================= | |
| class User(UserMixin, db.Model): | |
| id = db.Column(db.Integer, primary_key=True) | |
| email = db.Column(db.String(255), unique=True, nullable=False) | |
| password_hash = db.Column(db.String(255), nullable=False) | |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) | |
| # login security | |
| failed_logins = db.Column(db.Integer, default=0) | |
| lock_until = db.Column(db.DateTime, nullable=True) | |
| # email verification | |
| is_verified = db.Column(db.Boolean, default=False) | |
| verification_code = db.Column(db.String(6), nullable=True) | |
| verification_expires = db.Column(db.DateTime, nullable=True) | |
| # password reset | |
| reset_code = db.Column(db.String(6), nullable=True) | |
| reset_expires = db.Column(db.DateTime, nullable=True) | |
| # activity (for possible retention rules) | |
| last_active_at = db.Column(db.DateTime, nullable=True) | |
| class Entry(db.Model): | |
| id = db.Column(db.Integer, primary_key=True) | |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) | |
| text = db.Column(db.Text, nullable=False) | |
| label = db.Column(db.String(32)) | |
| confidence = db.Column(db.Float) | |
| severity = db.Column(db.Integer) | |
| threat_score = db.Column(db.Integer) | |
| politeness_score = db.Column(db.Integer) | |
| friendly_score = db.Column(db.Integer) | |
| has_threat = db.Column(db.Boolean, default=False) | |
| has_profanity = db.Column(db.Boolean, default=False) | |
| has_sarcasm = db.Column(db.Boolean, default=False) | |
| user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) | |
| user = db.relationship("User", backref="entries") | |
| def load_user(user_id): | |
| try: | |
| return User.query.get(int(user_id)) | |
| except Exception: | |
| return None | |
| # ========================================================= | |
| # EMAIL HELPERS | |
| # ========================================================= | |
| def send_email(to_email: str, subject: str, html: str): | |
| if not SENDGRID_API_KEY: | |
| print("[WARN] SENDGRID_API_KEY not set. Skipping email send.") | |
| print(f"Subject: {subject}\nTo: {to_email}\n{html}") | |
| return | |
| message = Mail( | |
| from_email=SENDER_EMAIL, | |
| to_emails=to_email, | |
| subject=subject, | |
| html_content=html, | |
| ) | |
| try: | |
| sg = SendGridAPIClient(SENDGRID_API_KEY) | |
| sg.send(message) | |
| print(f"[INFO] Sent email to {to_email}: {subject}") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to send email to {to_email}: {e}") | |
| def send_verification_email(to_email: str, code: str): | |
| html = f""" | |
| <p>Thanks for signing up for the AI Email Tone Classifier.</p> | |
| <p>Your verification code is: <strong>{code}</strong></p> | |
| <p>This code will expire in 15 minutes.</p> | |
| """ | |
| send_email(to_email, "Verify your email", html) | |
| def send_password_reset_email(to_email: str, code: str): | |
| html = f""" | |
| <p>You requested to reset your password for the AI Email Tone Classifier.</p> | |
| <p>Your password reset code is: <strong>{code}</strong></p> | |
| <p>This code will expire in 15 minutes.</p> | |
| <p>If you did not request this, you can ignore this email.</p> | |
| """ | |
| send_email(to_email, "Password reset code", html) | |
| # ========================================================= | |
| # AUTH ROUTES: REGISTER / LOGIN / LOGOUT / VERIFY | |
| # ========================================================= | |
| def register(): | |
| if current_user.is_authenticated: | |
| return redirect(url_for("index")) | |
| if request.method == "POST": | |
| email = sanitize_string(request.form.get("email", ""), 255).lower() | |
| password = normalize_text(request.form.get("password", "")) | |
| consent = request.form.get("consent_privacy") == "on" | |
| if not email or not password: | |
| flash("Email and password are required.", "error") | |
| return redirect(url_for("register")) | |
| if not is_valid_email(email): | |
| flash("Please enter a valid email address.", "error") | |
| return redirect(url_for("register")) | |
| if not is_strong_password(password): | |
| flash("Password must be at least 8 characters and contain letters and numbers.", "error") | |
| return redirect(url_for("register")) | |
| if not consent: | |
| flash("You must agree to the Privacy Policy to create an account.", "error") | |
| return redirect(url_for("register")) | |
| existing = User.query.filter_by(email=email).first() | |
| if existing: | |
| flash("An account with that email already exists.", "error") | |
| return redirect(url_for("register")) | |
| user = User( | |
| email=email, | |
| password_hash=generate_password_hash(password), | |
| last_active_at=datetime.utcnow(), | |
| ) | |
| code = generate_code() | |
| user.verification_code = code | |
| user.verification_expires = datetime.utcnow() + timedelta(minutes=15) | |
| user.is_verified = False | |
| db.session.add(user) | |
| db.session.commit() | |
| send_verification_email(email, code) | |
| session["pending_email"] = email | |
| flash("Account created. Check your email for the verification code.", "success") | |
| return redirect(url_for("verify")) | |
| return render_template("login.html", mode="register", title="Register") | |
| def login(): | |
| if current_user.is_authenticated: | |
| return redirect(url_for("index")) | |
| if request.method == "POST": | |
| email = sanitize_string(request.form.get("email", ""), 255).lower() | |
| password = normalize_text(request.form.get("password", "")) | |
| if not email or not password: | |
| flash("Email and password are required.", "error") | |
| return redirect(url_for("login")) | |
| user = User.query.filter_by(email=email).first() | |
| if not user: | |
| flash("Invalid email or password.", "error") | |
| return redirect(url_for("login")) | |
| now = datetime.utcnow() | |
| # lockout check | |
| if user.lock_until and user.lock_until > now: | |
| remaining = int((user.lock_until - now).total_seconds() // 60) + 1 | |
| flash(f"Account locked due to too many failed attempts. Try again in ~{remaining} minutes.", "error") | |
| return redirect(url_for("login")) | |
| if not check_password_hash(user.password_hash, password): | |
| user.failed_logins = (user.failed_logins or 0) + 1 | |
| if user.failed_logins >= 5: | |
| user.lock_until = now + timedelta(minutes=10) | |
| user.failed_logins = 0 | |
| db.session.commit() | |
| flash("Invalid email or password.", "error") | |
| return redirect(url_for("login")) | |
| # reset counters | |
| user.failed_logins = 0 | |
| user.lock_until = None | |
| user.last_active_at = now | |
| db.session.commit() | |
| if not user.is_verified: | |
| session["pending_email"] = user.email | |
| flash("Please verify your email before logging in.", "error") | |
| return redirect(url_for("verify")) | |
| login_user(user) | |
| flash("Logged in successfully.", "success") | |
| return redirect(url_for("index")) | |
| return render_template("login.html", mode="login", title="Login") | |
| def logout(): | |
| logout_user() | |
| flash("You have been logged out.", "success") | |
| return redirect(url_for("login")) | |
| def verify(): | |
| email = sanitize_string( | |
| request.args.get("email", "") or session.get("pending_email", ""), 255 | |
| ).lower() | |
| if not email: | |
| flash("No email specified for verification. Please register or log in again.", "error") | |
| return redirect(url_for("register")) | |
| user = User.query.filter_by(email=email).first() | |
| if not user: | |
| flash("Account not found. Please register again.", "error") | |
| return redirect(url_for("register")) | |
| if user.is_verified: | |
| flash("Your email is already verified. You can log in.", "success") | |
| return redirect(url_for("login")) | |
| if request.method == "POST": | |
| action = request.form.get("action", "verify") | |
| if action == "resend": | |
| code = generate_code() | |
| user.verification_code = code | |
| user.verification_expires = datetime.utcnow() + timedelta(minutes=15) | |
| db.session.commit() | |
| send_verification_email(user.email, code) | |
| flash("A new verification code has been sent.", "success") | |
| return redirect(url_for("verify", email=user.email)) | |
| code_input = sanitize_string(request.form.get("code", ""), 6) | |
| if not code_input: | |
| flash("Please enter the verification code.", "error") | |
| return redirect(url_for("verify", email=user.email)) | |
| if not user.verification_code or not user.verification_expires: | |
| flash("No active verification code. Please resend.", "error") | |
| return redirect(url_for("verify", email=user.email)) | |
| if datetime.utcnow() > user.verification_expires: | |
| flash("Verification code expired. Please request a new one.", "error") | |
| return redirect(url_for("verify", email=user.email)) | |
| if code_input != user.verification_code: | |
| flash("Invalid verification code.", "error") | |
| return redirect(url_for("verify", email=user.email)) | |
| user.is_verified = True | |
| user.verification_code = None | |
| user.verification_expires = None | |
| user.last_active_at = datetime.utcnow() | |
| db.session.commit() | |
| flash("Email verified successfully. You can now log in.", "success") | |
| return redirect(url_for("login")) | |
| session["pending_email"] = email | |
| return render_template("verify.html", email=email, title="Verify Email") | |
| # ========================================================= | |
| # FORGOT PASSWORD + RESET | |
| # ========================================================= | |
| def forgot_password(): | |
| if request.method == "POST": | |
| email = sanitize_string(request.form.get("email", ""), 255).lower() | |
| if not is_valid_email(email): | |
| flash("If that email exists, a reset code has been sent. Check your email, then enter the code below.", "success") | |
| return redirect(url_for("reset_password")) | |
| user = User.query.filter_by(email=email).first() | |
| if user: | |
| code = generate_code() | |
| user.reset_code = code | |
| user.reset_expires = datetime.utcnow() + timedelta(minutes=15) | |
| db.session.commit() | |
| send_password_reset_email(user.email, code) | |
| flash("If that email exists, a reset code has been sent. Check your email, then enter the code below.", "success") | |
| return redirect(url_for("reset_password")) | |
| return render_template("forgot.html", title="Forgot Password") | |
| def reset_password(): | |
| if request.method == "POST": | |
| email = sanitize_string(request.form.get("email", ""), 255).lower() | |
| code_input = sanitize_string(request.form.get("code", ""), 6) | |
| new_pw = normalize_text(request.form.get("password", "")) | |
| confirm_pw = normalize_text(request.form.get("confirm_password", "")) | |
| if not email or not code_input or not new_pw or not confirm_pw: | |
| flash("All fields are required.", "error") | |
| return redirect(url_for("reset_password")) | |
| if new_pw != confirm_pw: | |
| flash("Passwords do not match.", "error") | |
| return redirect(url_for("reset_password")) | |
| if not is_strong_password(new_pw): | |
| flash("Password must be at least 8 characters and contain letters and numbers.", "error") | |
| return redirect(url_for("reset_password")) | |
| user = User.query.filter_by(email=email).first() | |
| if ( | |
| not user | |
| or not user.reset_code | |
| or not user.reset_expires | |
| or datetime.utcnow() > user.reset_expires | |
| or code_input != user.reset_code | |
| ): | |
| flash("Invalid or expired reset code.", "error") | |
| return redirect(url_for("reset_password")) | |
| user.password_hash = generate_password_hash(new_pw) | |
| user.reset_code = None | |
| user.reset_expires = None | |
| db.session.commit() | |
| flash("Password reset successfully. You can now log in.", "success") | |
| return redirect(url_for("login")) | |
| return render_template("reset_password.html", title="Reset Password") | |
| # ========================================================= | |
| # MAIN CLASSIFIER | |
| # ========================================================= | |
| def index(): | |
| if not current_user.is_verified: | |
| flash("Please verify your email to use the classifier.", "error") | |
| return redirect(url_for("verify", email=current_user.email)) | |
| result = None # dict from classify_tone_rich | |
| text = "" | |
| if request.method == "POST": | |
| text = sanitize_long_text(request.form.get("email_text", "")) | |
| if text: | |
| result = classify_tone_rich(text) | |
| entry = Entry( | |
| text=text, | |
| label=result["label"], | |
| confidence=float(result["confidence"]), | |
| severity=int(result["severity"]), | |
| threat_score=int(result["threat_score"]), | |
| politeness_score=int(result["politeness_score"]), | |
| friendly_score=int(result["friendly_score"]), | |
| has_threat=bool(result["has_threat"]), | |
| has_profanity=bool(result["has_profanity"]), | |
| has_sarcasm=bool(result["has_sarcasm"]), | |
| user_id=current_user.id, | |
| ) | |
| db.session.add(entry) | |
| db.session.commit() | |
| return render_template( | |
| "index.html", | |
| title="Analyze Email", | |
| email_text=text, | |
| result=result, | |
| ) | |
| # ========================================================= | |
| # HISTORY + EXPORTS | |
| # ========================================================= | |
| def history_view(): | |
| if not current_user.is_verified: | |
| flash("Please verify your email to view history.", "error") | |
| return redirect(url_for("verify", email=current_user.email)) | |
| q = sanitize_string(request.args.get("q", ""), 255).lower() | |
| filter_label = sanitize_string(request.args.get("label", ""), 32).lower() | |
| query = Entry.query.filter_by(user_id=current_user.id) | |
| if q: | |
| query = query.filter(Entry.text.ilike(f"%{q}%")) | |
| if filter_label: | |
| query = query.filter(Entry.label.ilike(filter_label)) | |
| entries = query.order_by(Entry.created_at.desc()).all() | |
| return render_template( | |
| "history.html", | |
| title="History", | |
| history=entries, | |
| search=q, | |
| active_filter=filter_label, | |
| ) | |
| def export_csv(): | |
| if not current_user.is_verified: | |
| flash("Please verify your email to export data.", "error") | |
| return redirect(url_for("verify", email=current_user.email)) | |
| filepath = os.path.join("exports", f"history_{current_user.id}.csv") | |
| entries = Entry.query.filter_by(user_id=current_user.id).order_by(Entry.created_at.asc()) | |
| with open(filepath, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow( | |
| [ | |
| "Time UTC", | |
| "Label", | |
| "Confidence", | |
| "Severity", | |
| "ThreatScore", | |
| "PolitenessScore", | |
| "FriendlyScore", | |
| "HasThreat", | |
| "HasProfanity", | |
| "HasSarcasm", | |
| "Text", | |
| ] | |
| ) | |
| for e in entries: | |
| writer.writerow( | |
| [ | |
| e.created_at.isoformat(), | |
| e.label, | |
| f"{e.confidence:.1f}", | |
| e.severity, | |
| e.threat_score, | |
| e.politeness_score, | |
| e.friendly_score, | |
| int(e.has_threat), | |
| int(e.has_profanity), | |
| int(e.has_sarcasm), | |
| e.text, | |
| ] | |
| ) | |
| return send_file(filepath, as_attachment=True) | |
| def export_pdf(): | |
| if not current_user.is_verified: | |
| flash("Please verify your email to export data.", "error") | |
| return redirect(url_for("verify", email=current_user.email)) | |
| buffer = BytesIO() | |
| c = canvas.Canvas(buffer, pagesize=letter) | |
| width, height = letter | |
| c.setFillColorRGB(0.12, 0.15, 0.20) | |
| c.rect(0, height - 60, width, 60, fill=1) | |
| c.setFillColorRGB(1, 1, 1) | |
| c.setFont("Helvetica-Bold", 18) | |
| c.drawString(40, height - 35, "Tone Classifier β History Report") | |
| entries = ( | |
| Entry.query.filter_by(user_id=current_user.id) | |
| .order_by(Entry.created_at.desc()) | |
| .all() | |
| ) | |
| y = height - 80 | |
| for e in entries: | |
| if y < 90: | |
| c.showPage() | |
| y = height - 60 | |
| c.setFont("Helvetica-Bold", 10) | |
| c.setFillColorRGB(0, 0, 0) | |
| c.drawString( | |
| 40, | |
| y, | |
| f"{e.created_at.isoformat()} | {e.label} | Severity {e.severity}", | |
| ) | |
| y -= 12 | |
| meta = f"Threat:{e.threat_score} Polite:{e.politeness_score} Friendly:{e.friendly_score}" | |
| c.setFont("Helvetica", 9) | |
| c.drawString(40, y, meta) | |
| y -= 12 | |
| text = e.text | |
| while len(text) > 90: | |
| idx = text.rfind(" ", 0, 90) | |
| if idx == -1: | |
| idx = 90 | |
| c.drawString(50, y, text[:idx]) | |
| text = text[idx:].strip() | |
| y -= 11 | |
| c.drawString(50, y, text) | |
| y -= 20 | |
| c.showPage() | |
| c.save() | |
| buffer.seek(0) | |
| filepath = os.path.join("exports", f"history_{current_user.id}.pdf") | |
| with open(filepath, "wb") as f: | |
| f.write(buffer.getvalue()) | |
| return send_file(filepath, as_attachment=True) | |
| def clear_history(): | |
| if not current_user.is_verified: | |
| flash("Please verify your email to clear history.", "error") | |
| return redirect(url_for("verify", email=current_user.email)) | |
| Entry.query.filter_by(user_id=current_user.id).delete() | |
| db.session.commit() | |
| flash("History cleared.", "success") | |
| return redirect(url_for("history_view")) | |
| # ========================================================= | |
| # DELETE ACCOUNT + GDPR PAGES | |
| # ========================================================= | |
| def delete_account(): | |
| if request.method == "POST": | |
| password = normalize_text(request.form.get("password", "")) | |
| if not check_password_hash(current_user.password_hash, password): | |
| flash("Incorrect password. Account not deleted.", "error") | |
| return redirect(url_for("delete_account")) | |
| try: | |
| uid = current_user.id | |
| Entry.query.filter_by(user_id=uid).delete() | |
| user = User.query.get(uid) | |
| logout_user() | |
| db.session.delete(user) | |
| db.session.commit() | |
| flash("Your account and all data have been deleted.", "success") | |
| except Exception as e: | |
| db.session.rollback() | |
| flash("Error deleting account. Please try again.", "error") | |
| print(f"[ERROR] delete_account failed: {e}") | |
| return redirect(url_for("delete_account")) | |
| return redirect(url_for("register")) | |
| return render_template("delete_account.html", title="Delete Account") | |
| def privacy(): | |
| from datetime import datetime as dt | |
| return render_template("privacy.html", title="Privacy Policy", datetime=dt) | |
| def do_not_sell(): | |
| return render_template("do_not_sell.html", title="Do Not Sell My Info") | |
| # ========================================================= | |
| # INIT DB & RUN (LOCAL) | |
| # ========================================================= | |
| with app.app_context(): | |
| db.create_all() | |
| if __name__ == "__main__": | |
| # local dev | |
| app.run(debug=True, host="0.0.0.0", port=7860) | |