# app.py import gradio as gr import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel import numpy as np import os # ------------------------------- # Config # ------------------------------- model_name = "microsoft/deberta-v3-base" model_path = "best_model.pt" # Upload this to your HF Space folder target_cols = ["anger", "fear", "joy", "sadness", "surprise"] max_len = 256 device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------------- # Tokenizer & Dataset # ------------------------------- tokenizer = AutoTokenizer.from_pretrained(model_name) def encode_texts(texts): enc = tokenizer( [texts], padding="max_length", truncation=True, max_length=max_len, return_tensors="pt" ) return enc["input_ids"].to(device), enc["attention_mask"].to(device) # ------------------------------- # Model Definition # ------------------------------- class MultiLabelBERT(nn.Module): def __init__(self, model_name, num_classes): super().__init__() self.bert = AutoModel.from_pretrained(model_name) self.dropout = nn.Dropout(0.3) self.fc = nn.Linear(self.bert.config.hidden_size, num_classes) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0, :] # CLS token x = self.dropout(pooled_output) return self.fc(x) # ------------------------------- # Load Model # ------------------------------- model = MultiLabelBERT(model_name, len(target_cols)).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # ------------------------------- # Prediction Function # ------------------------------- def predict_emotions(text): input_ids, attention_mask = encode_texts(text) with torch.no_grad(): logits = model(input_ids, attention_mask) probs = torch.sigmoid(logits).cpu().numpy()[0] # Threshold 0.5 preds = (probs > 0.5).astype(int) return {col: int(preds[i]) for i, col in enumerate(target_cols)} # ------------------------------- # Gradio Interface # ------------------------------- iface = gr.Interface( fn=predict_emotions, inputs=gr.Textbox(lines=4, placeholder="Enter your text here..."), outputs=[gr.Label(num_top_classes=5)], title="Multi-Label Emotion Classifier", description="Predicts emotions: anger, fear, joy, sadness, surprise" ) iface.launch()