Spaces:
Running
Running
| import os | |
| import re | |
| from typing import Dict, List | |
| from huggingface_hub import InferenceClient | |
| # ========================================================= | |
| # HUGGING FACE INFERENCE CLIENT | |
| # ========================================================= | |
| HF_API_TOKEN = os.getenv("HF_API_TOKEN") # optional, set in HF Space secrets | |
| if HF_API_TOKEN: | |
| client = InferenceClient(token=HF_API_TOKEN) | |
| else: | |
| client = InferenceClient() # anonymous for public models (rate-limited) | |
| # Model IDs | |
| TOX_MODEL_ID = "unitary/toxic-bert" | |
| OFF_MODEL_ID = "cardiffnlp/twitter-roberta-base-offensive" | |
| EMO_MODEL_ID = "j-hartmann/emotion-english-distilroberta-base" | |
| SENT_MODEL_ID = "distilbert-base-uncased-finetuned-sst-2-english" | |
| # ========================================================= | |
| # RULE KEYWORDS / PATTERNS | |
| # ========================================================= | |
| AGGRESSION_KEYWORDS = [ | |
| "stupid", "idiot", "dumb", "incompetent", "useless", | |
| "trash", "garbage", "worthless", "pathetic", "clown", | |
| "moron", "failure", "shut up", "hate you" | |
| ] | |
| THREAT_PHRASES = [ | |
| "you will regret", "there will be consequences", "watch your back", | |
| "this is your last warning", "i'm coming for you", | |
| "or else", "i'll ruin you", "i'll make you pay", | |
| "i am gonna hurt you", "i'm going to hurt you", | |
| "im gonna hurt you", # <-- added for your exact example | |
| ] | |
| PROFANITY = [ | |
| "fuck", "shit", "bitch", "asshole", "bastard", | |
| "motherfucker", "prick", "dickhead" | |
| ] | |
| POLITE_KEYWORDS = [ | |
| "please", "thank you", "thanks", "would you mind", | |
| "if possible", "kindly", "when you have a chance", | |
| "if you don't mind" | |
| ] | |
| FRIENDLY_KEYWORDS = [ | |
| "awesome", "amazing", "great job", "fantastic", | |
| "love this", "appreciate you", "good vibes", | |
| "wonderful", "you're the best", "you are the best", | |
| ] | |
| SARCASM_PATTERNS = [ | |
| r"yeah right", | |
| r"sure you did", | |
| r"great job (idiot|genius)", | |
| r"nice work (moron|buddy)", | |
| r"well done.*not", | |
| r"nice job.*not", | |
| ] | |
| # Generic threat regex: “gonna/going to/will hurt you” | |
| THREAT_REGEX = re.compile(r"\b(gonna|going to|will)\s+hurt you\b") | |
| # ========================================================= | |
| # HF INFERENCE HELPERS | |
| # ========================================================= | |
| def _safe_text_classification(model_id: str, text: str) -> List[Dict]: | |
| """ | |
| Wrapper around HF Inference API text classification. | |
| Returns a list of dicts like: | |
| [ | |
| {"label": "POSITIVE", "score": 0.95}, | |
| ... | |
| ] | |
| or [] on error. | |
| """ | |
| try: | |
| out = client.text_classification(text, model=model_id) | |
| # Some clients may return a single dict; normalize to list | |
| if isinstance(out, dict): | |
| return [out] | |
| return out or [] | |
| except Exception as e: | |
| print(f"[WARN] HF Inference error for {model_id}: {e}") | |
| return [] | |
| def _get_sentiment(text: str): | |
| """ | |
| Returns (pos, neg) based on distilbert sentiment. | |
| """ | |
| results = _safe_text_classification(SENT_MODEL_ID, text) | |
| pos = 0.5 | |
| neg = 0.5 | |
| if results: | |
| scores = {r["label"].upper(): float(r["score"]) for r in results} | |
| # typical labels: POSITIVE / NEGATIVE | |
| if "POSITIVE" in scores: | |
| pos = scores["POSITIVE"] | |
| neg = 1.0 - pos | |
| elif "NEGATIVE" in scores: | |
| neg = scores["NEGATIVE"] | |
| pos = 1.0 - neg | |
| return pos, neg | |
| def _get_toxicity(text: str) -> float: | |
| """ | |
| Return a toxicity-like score in [0, 1]. | |
| For unitary/toxic-bert, we consider any 'toxic-like' label as signal. | |
| """ | |
| results = _safe_text_classification(TOX_MODEL_ID, text) | |
| if not results: | |
| return 0.0 | |
| toxic_score = 0.0 | |
| for r in results: | |
| label = r["label"].lower() | |
| if any(key in label for key in ["toxic", "obscene", "insult", "hate", "threat"]): | |
| toxic_score = max(toxic_score, float(r["score"])) | |
| return toxic_score | |
| def _get_offensive(text: str) -> float: | |
| """ | |
| Return an offensive score in [0, 1]. | |
| For cardiffnlp/twitter-roberta-base-offensive, look for OFFENSE-like labels. | |
| """ | |
| results = _safe_text_classification(OFF_MODEL_ID, text) | |
| if not results: | |
| return 0.0 | |
| off_score = 0.0 | |
| for r in results: | |
| label = r["label"].lower() | |
| if "offense" in label or "offensive" in label: | |
| off_score = max(off_score, float(r["score"])) | |
| return off_score | |
| def _get_emotions(text: str): | |
| """ | |
| Returns a dict like {"anger": 0.3, "joy": 0.6}. | |
| """ | |
| results = _safe_text_classification(EMO_MODEL_ID, text) | |
| if not results: | |
| return {"anger": 0.0, "joy": 0.0} | |
| emo = {} | |
| for r in results: | |
| emo[r["label"].lower()] = float(r["score"]) | |
| anger = emo.get("anger", 0.0) | |
| joy = emo.get("joy", 0.0) | |
| return {"anger": anger, "joy": joy} | |
| # ========================================================= | |
| # MAIN CLASSIFIER (STRICT OPTION A) | |
| # ========================================================= | |
| def classify_tone_rich(text: str): | |
| lowered = text.lower() | |
| explanation = [] | |
| # --- Model signals --- | |
| pos, neg = _get_sentiment(text) | |
| tox_score = _get_toxicity(text) | |
| off_score = _get_offensive(text) | |
| emo = _get_emotions(text) | |
| anger = emo.get("anger", 0.0) | |
| joy = emo.get("joy", 0.0) | |
| explanation.append(f"Sentiment pos={pos:.2f}, neg={neg:.2f}") | |
| explanation.append(f"Toxicity={tox_score:.2f}, Offensive={off_score:.2f}") | |
| explanation.append(f"Emotion anger={anger:.2f}, joy={joy:.2f}") | |
| # --- Rule flags --- | |
| has_insult = any(w in lowered for w in AGGRESSION_KEYWORDS) | |
| # THREATS: list OR generic regex | |
| has_threat_phrase = any(p in lowered for p in THREAT_PHRASES) | |
| has_threat_regex = bool(THREAT_REGEX.search(lowered)) | |
| has_threat = has_threat_phrase or has_threat_regex | |
| has_profanity = any(bad in lowered for bad in PROFANITY) | |
| has_polite = any(w in lowered for w in POLITE_KEYWORDS) | |
| has_friendly = any(w in lowered for w in FRIENDLY_KEYWORDS) | |
| has_sarcasm = any(re.search(p, lowered) for p in SARCASM_PATTERNS) | |
| if has_insult: | |
| explanation.append("Detected explicit insult keyword.") | |
| if has_threat_phrase: | |
| explanation.append("Detected explicit threat phrase.") | |
| if has_threat_regex: | |
| explanation.append("Matched generic threat pattern (gonna/going to/will hurt you).") | |
| if has_profanity: | |
| explanation.append("Detected profanity.") | |
| if has_polite: | |
| explanation.append("Detected polite phrasing.") | |
| if has_friendly: | |
| explanation.append("Detected friendly / appreciative wording.") | |
| if has_sarcasm: | |
| explanation.append("Matched a sarcasm pattern.") | |
| # ===================================================== | |
| # STRICT AGGRESSIVE RULES | |
| # ===================================================== | |
| # 1) Threats override everything | |
| if has_threat: | |
| return { | |
| "label": "Aggressive", | |
| "confidence": 95, | |
| "severity": 95, | |
| "threat_score": 95, | |
| "politeness_score": 0, | |
| "friendly_score": 0, | |
| "has_threat": True, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| # 2) Profanity → aggressive | |
| if has_profanity: | |
| sev = max(85, int((tox_score + off_score) / 2 * 100)) | |
| return { | |
| "label": "Aggressive", | |
| "confidence": 90, | |
| "severity": sev, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": 0, | |
| "friendly_score": 0, | |
| "has_threat": has_threat, | |
| "has_profanity": True, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| # 3) Direct insults → aggressive | |
| if has_insult: | |
| sev = max(80, int((tox_score + off_score) / 2 * 100)) | |
| return { | |
| "label": "Aggressive", | |
| "confidence": 88, | |
| "severity": sev, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": 0, | |
| "friendly_score": 0, | |
| "has_threat": has_threat, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| # 4) Sarcasm + negative sentiment → aggressive | |
| if has_sarcasm and neg > 0.55: | |
| return { | |
| "label": "Aggressive", | |
| "confidence": 85, | |
| "severity": 85, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": 0, | |
| "friendly_score": 0, | |
| "has_threat": has_threat, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": True, | |
| "explanation": explanation, | |
| } | |
| # 5) High anger + toxicity | |
| if anger + tox_score > 1.1: | |
| return { | |
| "label": "Aggressive", | |
| "confidence": 80, | |
| "severity": 80, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": 0, | |
| "friendly_score": 0, | |
| "has_threat": has_threat, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| # ===================================================== | |
| # POSITIVE LABELS – FRIENDLY / POLITE | |
| # ===================================================== | |
| if has_friendly and pos > 0.60: | |
| return { | |
| "label": "Friendly", | |
| "confidence": int(pos * 100), | |
| "severity": 0, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": int(pos * 100), | |
| "friendly_score": int(pos * 100), | |
| "has_threat": has_threat, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| if has_polite and pos > 0.50: | |
| return { | |
| "label": "Polite", | |
| "confidence": int(pos * 100), | |
| "severity": 0, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": int(pos * 100), | |
| "friendly_score": 0, | |
| "has_threat": has_threat, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| # ===================================================== | |
| # NEUTRAL FALLBACK | |
| # ===================================================== | |
| return { | |
| "label": "Neutral", | |
| "confidence": int((1 - neg) * 100), | |
| "severity": 0, | |
| "threat_score": int(tox_score * 100), | |
| "politeness_score": int(pos * 100), | |
| "friendly_score": int(pos * 100), | |
| "has_threat": has_threat, | |
| "has_profanity": has_profanity, | |
| "has_sarcasm": has_sarcasm, | |
| "explanation": explanation, | |
| } | |
| # Optional wrapper for backwards compatibility | |
| def classify_tone(text: str): | |
| r = classify_tone_rich(text) | |
| aggressive_prob = r["severity"] / 100.0 | |
| positive_prob = r["friendly_score"] / 100.0 | |
| return r["label"], r["confidence"], aggressive_prob, positive_prob | |