|
|
""" |
|
|
Model architecture for Bengali Memes Classification using Multimodal Attention Fusion (MAF) |
|
|
With Political Lexicon Boosting Support |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
|
|
|
MODEL_NAME = "xlm-roberta-large" |
|
|
LEXICON_BOOST_STRENGTH = 1.5 |
|
|
LEXICON_MIN_MATCHES = 1 |
|
|
|
|
|
|
|
|
POLITICAL_LEXICON = { |
|
|
|
|
|
'বিজেপি', 'কংগ্রেস', 'নেতা', 'মোদী', 'সাধারণ মানুষ', 'তৃণমূল', 'সরকার', 'দুর্নীতি', |
|
|
'রাজনীতি', 'অনশন', 'ভোটের', 'ভোট', 'ছাত্রলীগে', 'ছাত্রলীগ', 'অবরোধ', 'আওয়ামিলীগ', |
|
|
'নেতাকর্মী', 'বঙ্গভবন', 'জয় শ্রী রাম', 'প্রধানমন্ত্রী', 'শেখ', 'হাসিনা', 'জননেত্রী', |
|
|
'ওবায়দুল', 'কাদের', 'মির্জা', 'ফখরুল', 'কর্মীদের', 'কর্মী', 'সাম্রাজ্যবাদ', 'বঙ্গবন্ধু', |
|
|
'মুজিবুর', 'মুজিব', 'দেশবাসী', 'মন্ত্রী', 'পাকিস্তান', 'ভারত', 'পদের', 'পদ', 'ফ্যাসিবাদী', |
|
|
'ফ্যাসিস্ট', 'এনসিপি', 'হাসনাত', 'সার্জিস', 'নাহিদ', 'আসিফ', 'মাক্সবাদী', 'মমতা', 'জাতীয়', |
|
|
'দল', 'প্রার্থীর', 'প্রার্থী', 'হেফাজত', 'স্বজন হারানোর বেদনা', 'জয় বাংলা', 'জয়', 'মুখ্যমন্ত্রী', |
|
|
'কমিউনিস্টরা', 'কমিউনিস্ট', 'দালাল', 'ছাত্রদল', 'শিবির', 'ছাত্রশিবির', 'জুলাই এর চেতনা', |
|
|
'খালেদা', 'জিয়া', 'তারেক', 'আমির', 'দেশদ্রোহী', 'আওয়ামী', 'বিএনপি', 'দলীয়', 'শিক্ষামন্ত্রী', |
|
|
'পলক', 'পুজিবাদী', 'রহমান', 'রাষ্ট্র', 'জনতা', 'শাসন', 'মিছিল', 'মুখপাত্র', 'হিটলার', 'ট্রাম্প', |
|
|
'সরকারি', 'উন্নয়ন', 'আন্দোলন', 'নির্বাচন', 'নির্বাচনের', 'ইলেকশন', 'জনগণ', 'পার্টির', 'পার্টি', |
|
|
'চুপ্পু', 'প্রধানমন্ত্রী', 'বিরোধী', 'ওয়েস্টিন', 'জাতীয়তাবাদ', 'শিক্ষামন্ত্রী', 'সরকার', |
|
|
'নির্বাচনী', 'প্রচারণা', 'প্রার্থী', 'ডাকসু', 'দেশনেতা', 'দশ পার্সেন্ট', 'ইন্টেরিম', 'ইউনুস', |
|
|
'অন্তর্বর্তকালীন', 'বাবর', 'শত্রুজ', 'ইনকিলাব', 'রাজাকার', 'জিন্দাবাদ', 'এরশাদ', 'জামাত', |
|
|
'পুঁজিবাদী', 'গদি', 'পদত্যাগ', 'কমিশন', 'কারচুপি', 'স্বৈরাচার', 'বামপন্থী', 'বাম', 'ডানপন্থী', |
|
|
'আপা', 'নমিনেশন', 'তত্ত্বাবধায়ক', |
|
|
|
|
|
'bjp', 'congress', 'sheikh', 'hasina', 'mujib', 'mujibur', 'rahman', 'tareq','tarek', 'zia', 'khaleda', |
|
|
'hasnat', 'sarjis', 'voter', 'vote', 'voting', 'nahid', 'asif', 'mirza', 'fakhrul', 'polok', |
|
|
'jamaat', 'ncp', 'awami', 'league', 'bnp', 'election', 'bcl', 'chatro', 'league', 'chatroleague', |
|
|
'shibir', 'chatrodol', 'dol', 'hefajot', 'trump', 'modi', 'govt', 'yunus', 'interim', 'bangabandhu', |
|
|
'fascist', 'minister', 'president', 'communist', 'commie', 'leftist', 'rightist', 'marxist', |
|
|
'hitler', 'putin', 'netanyahu', 'biden', 'july', 'joy', 'bangla', 'chuppu', 'politics', 'political', |
|
|
'mamdani', 'nomination', 'apa', 'zohran', 'amlig', 'Hasina' |
|
|
} |
|
|
POLITICAL_LEXICON = {word.lower() for word in POLITICAL_LEXICON} |
|
|
|
|
|
def contains_political_keywords(text, lexicon=POLITICAL_LEXICON, min_matches=LEXICON_MIN_MATCHES): |
|
|
"""Count political keyword matches in text""" |
|
|
import pandas as pd |
|
|
if pd.isna(text) or not isinstance(text, str): |
|
|
return 0 |
|
|
text_lower = text.lower() |
|
|
matches = sum(1 for keyword in lexicon if keyword in text_lower) |
|
|
return matches |
|
|
|
|
|
def apply_lexicon_boost(logits, lexicon_matches, boost_strength=LEXICON_BOOST_STRENGTH): |
|
|
"""Apply logit adjustment for political class""" |
|
|
boost_mask = (lexicon_matches > 0).float().unsqueeze(1) |
|
|
adjustment = torch.zeros_like(logits) |
|
|
adjustment[:, 1] = boost_strength * lexicon_matches.float() |
|
|
adjusted_logits = logits + (adjustment * boost_mask) |
|
|
return adjusted_logits |
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
|
"""Multi-head attention mechanism for cross-modal fusion""" |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.15): |
|
|
super(MultiheadAttention, self).__init__() |
|
|
self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
output, _ = self.attention(query, key, value, attn_mask=mask) |
|
|
return output |
|
|
|
|
|
|
|
|
class MAF(nn.Module): |
|
|
"""Multimodal Attention Fusion (MAF) Model""" |
|
|
|
|
|
def __init__(self, clip_model, num_classes, num_heads, use_lexicon_boost=True): |
|
|
super(MAF, self).__init__() |
|
|
|
|
|
|
|
|
self.clip = clip_model |
|
|
self.visual_linear = nn.Linear(512, 1024) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
self.bert = AutoModel.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
for param in self.bert.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
print(f"✓ XLM-RoBERTa loaded with all layers trainable") |
|
|
|
|
|
|
|
|
self.attention = MultiheadAttention(d_model=1024, nhead=num_heads) |
|
|
|
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(1024 + 1024 + 1024, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(128, num_classes), |
|
|
) |
|
|
|
|
|
|
|
|
self.use_lexicon_boost = use_lexicon_boost |
|
|
if use_lexicon_boost: |
|
|
print(f"✓ Lexicon boosting ENABLED (strength: {LEXICON_BOOST_STRENGTH})") |
|
|
|
|
|
def forward(self, image_input, input_ids, attention_mask, lexicon_matches=None): |
|
|
|
|
|
image_features = self.clip(image_input) |
|
|
image_features = self.visual_linear(image_features) |
|
|
image_features = image_features.unsqueeze(1) |
|
|
|
|
|
image_features = F.adaptive_avg_pool1d(image_features.permute(0, 2, 1), 70).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
bert_output = bert_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
attention_output = self.attention( |
|
|
query=image_features.permute(1, 0, 2), |
|
|
key=bert_output.permute(1, 0, 2), |
|
|
value=image_features.permute(1, 0, 2), |
|
|
mask=None |
|
|
) |
|
|
|
|
|
|
|
|
attention_output = attention_output.permute(1, 0, 2) |
|
|
|
|
|
|
|
|
fusion_input = torch.cat([attention_output, image_features, bert_output], dim=2) |
|
|
|
|
|
|
|
|
output = self.fc(fusion_input.mean(1)) |
|
|
|
|
|
if self.use_lexicon_boost and lexicon_matches is not None: |
|
|
output = apply_lexicon_boost(output, lexicon_matches) |
|
|
|
|
|
return output |
|
|
|