|
|
|
|
|
import gradio as gr
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
import numpy as np
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "microsoft/deberta-v3-base"
|
|
|
model_path = "best_model.pt"
|
|
|
target_cols = ["anger", "fear", "joy", "sadness", "surprise"]
|
|
|
max_len = 256
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
def encode_texts(texts):
|
|
|
enc = tokenizer(
|
|
|
[texts],
|
|
|
padding="max_length",
|
|
|
truncation=True,
|
|
|
max_length=max_len,
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
return enc["input_ids"].to(device), enc["attention_mask"].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiLabelBERT(nn.Module):
|
|
|
def __init__(self, model_name, num_classes):
|
|
|
super().__init__()
|
|
|
self.bert = AutoModel.from_pretrained(model_name)
|
|
|
self.dropout = nn.Dropout(0.3)
|
|
|
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
pooled_output = outputs.last_hidden_state[:, 0, :]
|
|
|
x = self.dropout(pooled_output)
|
|
|
return self.fc(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MultiLabelBERT(model_name, len(target_cols)).to(device)
|
|
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_emotions(text):
|
|
|
input_ids, attention_mask = encode_texts(text)
|
|
|
with torch.no_grad():
|
|
|
logits = model(input_ids, attention_mask)
|
|
|
probs = torch.sigmoid(logits).cpu().numpy()[0]
|
|
|
|
|
|
preds = (probs > 0.5).astype(int)
|
|
|
return {col: int(preds[i]) for i, col in enumerate(target_cols)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface(
|
|
|
fn=predict_emotions,
|
|
|
inputs=gr.Textbox(lines=4, placeholder="Enter your text here..."),
|
|
|
outputs=[gr.Label(num_top_classes=5)],
|
|
|
title="Multi-Label Emotion Classifier",
|
|
|
description="Predicts emotions: anger, fear, joy, sadness, surprise"
|
|
|
)
|
|
|
|
|
|
iface.launch()
|
|
|
|