dlgenai / app.py
ShreyaAgr's picture
Upload 2 files
3a27687 verified
# app.py
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import numpy as np
import os
# -------------------------------
# Config
# -------------------------------
model_name = "microsoft/deberta-v3-base"
model_path = "best_model.pt" # Upload this to your HF Space folder
target_cols = ["anger", "fear", "joy", "sadness", "surprise"]
max_len = 256
device = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------------------
# Tokenizer & Dataset
# -------------------------------
tokenizer = AutoTokenizer.from_pretrained(model_name)
def encode_texts(texts):
enc = tokenizer(
[texts],
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt"
)
return enc["input_ids"].to(device), enc["attention_mask"].to(device)
# -------------------------------
# Model Definition
# -------------------------------
class MultiLabelBERT(nn.Module):
def __init__(self, model_name, num_classes):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.3)
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :] # CLS token
x = self.dropout(pooled_output)
return self.fc(x)
# -------------------------------
# Load Model
# -------------------------------
model = MultiLabelBERT(model_name, len(target_cols)).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# -------------------------------
# Prediction Function
# -------------------------------
def predict_emotions(text):
input_ids, attention_mask = encode_texts(text)
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = torch.sigmoid(logits).cpu().numpy()[0]
# Threshold 0.5
preds = (probs > 0.5).astype(int)
return {col: int(preds[i]) for i, col in enumerate(target_cols)}
# -------------------------------
# Gradio Interface
# -------------------------------
iface = gr.Interface(
fn=predict_emotions,
inputs=gr.Textbox(lines=4, placeholder="Enter your text here..."),
outputs=[gr.Label(num_top_classes=5)],
title="Multi-Label Emotion Classifier",
description="Predicts emotions: anger, fear, joy, sadness, surprise"
)
iface.launch()