bengali-political-maf-v6 / model_architecture.py
lucius-40's picture
Upload model_architecture.py with huggingface_hub
c8860b0 verified
"""
Model architecture for Bengali Memes Classification using Multimodal Attention Fusion (MAF)
With Political Lexicon Boosting Support
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Model configuration (must match training)
MODEL_NAME = "xlm-roberta-large"
LEXICON_BOOST_STRENGTH = 1.5
LEXICON_MIN_MATCHES = 1
# Political Lexicon (same as training)
POLITICAL_LEXICON = {
# Bengali political terms
'বিজেপি', 'কংগ্রেস', 'নেতা', 'মোদী', 'সাধারণ মানুষ', 'তৃণমূল', 'সরকার', 'দুর্নীতি',
'রাজনীতি', 'অনশন', 'ভোটের', 'ভোট', 'ছাত্রলীগে', 'ছাত্রলীগ', 'অবরোধ', 'আওয়ামিলীগ',
'নেতাকর্মী', 'বঙ্গভবন', 'জয় শ্রী রাম', 'প্রধানমন্ত্রী', 'শেখ', 'হাসিনা', 'জননেত্রী',
'ওবায়দুল', 'কাদের', 'মির্জা', 'ফখরুল', 'কর্মীদের', 'কর্মী', 'সাম্রাজ্যবাদ', 'বঙ্গবন্ধু',
'মুজিবুর', 'মুজিব', 'দেশবাসী', 'মন্ত্রী', 'পাকিস্তান', 'ভারত', 'পদের', 'পদ', 'ফ্যাসিবাদী',
'ফ্যাসিস্ট', 'এনসিপি', 'হাসনাত', 'সার্জিস', 'নাহিদ', 'আসিফ', 'মাক্সবাদী', 'মমতা', 'জাতীয়',
'দল', 'প্রার্থীর', 'প্রার্থী', 'হেফাজত', 'স্বজন হারানোর বেদনা', 'জয় বাংলা', 'জয়', 'মুখ্যমন্ত্রী',
'কমিউনিস্টরা', 'কমিউনিস্ট', 'দালাল', 'ছাত্রদল', 'শিবির', 'ছাত্রশিবির', 'জুলাই এর চেতনা',
'খালেদা', 'জিয়া', 'তারেক', 'আমির', 'দেশদ্রোহী', 'আওয়ামী', 'বিএনপি', 'দলীয়', 'শিক্ষামন্ত্রী',
'পলক', 'পুজিবাদী', 'রহমান', 'রাষ্ট্র', 'জনতা', 'শাসন', 'মিছিল', 'মুখপাত্র', 'হিটলার', 'ট্রাম্প',
'সরকারি', 'উন্নয়ন', 'আন্দোলন', 'নির্বাচন', 'নির্বাচনের', 'ইলেকশন', 'জনগণ', 'পার্টির', 'পার্টি',
'চুপ্পু', 'প্রধানমন্ত্রী', 'বিরোধী', 'ওয়েস্টিন', 'জাতীয়তাবাদ', 'শিক্ষামন্ত্রী', 'সরকার',
'নির্বাচনী', 'প্রচারণা', 'প্রার্থী', 'ডাকসু', 'দেশনেতা', 'দশ পার্সেন্ট', 'ইন্টেরিম', 'ইউনুস',
'অন্তর্বর্তকালীন', 'বাবর', 'শত্রুজ', 'ইনকিলাব', 'রাজাকার', 'জিন্দাবাদ', 'এরশাদ', 'জামাত',
'পুঁজিবাদী', 'গদি', 'পদত্যাগ', 'কমিশন', 'কারচুপি', 'স্বৈরাচার', 'বামপন্থী', 'বাম', 'ডানপন্থী',
'আপা', 'নমিনেশন', 'তত্ত্বাবধায়ক',
# English transliterations
'bjp', 'congress', 'sheikh', 'hasina', 'mujib', 'mujibur', 'rahman', 'tareq','tarek', 'zia', 'khaleda',
'hasnat', 'sarjis', 'voter', 'vote', 'voting', 'nahid', 'asif', 'mirza', 'fakhrul', 'polok',
'jamaat', 'ncp', 'awami', 'league', 'bnp', 'election', 'bcl', 'chatro', 'league', 'chatroleague',
'shibir', 'chatrodol', 'dol', 'hefajot', 'trump', 'modi', 'govt', 'yunus', 'interim', 'bangabandhu',
'fascist', 'minister', 'president', 'communist', 'commie', 'leftist', 'rightist', 'marxist',
'hitler', 'putin', 'netanyahu', 'biden', 'july', 'joy', 'bangla', 'chuppu', 'politics', 'political',
'mamdani', 'nomination', 'apa', 'zohran', 'amlig', 'Hasina'
}
POLITICAL_LEXICON = {word.lower() for word in POLITICAL_LEXICON}
def contains_political_keywords(text, lexicon=POLITICAL_LEXICON, min_matches=LEXICON_MIN_MATCHES):
"""Count political keyword matches in text"""
import pandas as pd
if pd.isna(text) or not isinstance(text, str):
return 0
text_lower = text.lower()
matches = sum(1 for keyword in lexicon if keyword in text_lower)
return matches
def apply_lexicon_boost(logits, lexicon_matches, boost_strength=LEXICON_BOOST_STRENGTH):
"""Apply logit adjustment for political class"""
boost_mask = (lexicon_matches > 0).float().unsqueeze(1)
adjustment = torch.zeros_like(logits)
adjustment[:, 1] = boost_strength * lexicon_matches.float()
adjusted_logits = logits + (adjustment * boost_mask)
return adjusted_logits
class MultiheadAttention(nn.Module):
"""Multi-head attention mechanism for cross-modal fusion"""
def __init__(self, d_model, nhead, dropout=0.15):
super(MultiheadAttention, self).__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
def forward(self, query, key, value, mask=None):
output, _ = self.attention(query, key, value, attn_mask=mask)
return output
class MAF(nn.Module):
"""Multimodal Attention Fusion (MAF) Model"""
def __init__(self, clip_model, num_classes, num_heads, use_lexicon_boost=True):
super(MAF, self).__init__()
# Visual feature extractor (CLIP)
self.clip = clip_model
self.visual_linear = nn.Linear(512, 1024)
# Textual feature extractor (XLM)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.bert = AutoModel.from_pretrained(MODEL_NAME)
# All parameters are trainable by default
for param in self.bert.parameters():
param.requires_grad = True
print(f"✓ XLM-RoBERTa loaded with all layers trainable")
# Multihead attention
self.attention = MultiheadAttention(d_model=1024, nhead=num_heads)
# Fully connected layers
self.fc = nn.Sequential(
nn.Linear(1024 + 1024 + 1024, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes),
)
# Lexicon boosting flag
self.use_lexicon_boost = use_lexicon_boost
if use_lexicon_boost:
print(f"✓ Lexicon boosting ENABLED (strength: {LEXICON_BOOST_STRENGTH})")
def forward(self, image_input, input_ids, attention_mask, lexicon_matches=None):
# Extract visual features using CLIP
image_features = self.clip(image_input)
image_features = self.visual_linear(image_features)
image_features = image_features.unsqueeze(1)
# Apply average pooling to reduce the sequence length to 70
image_features = F.adaptive_avg_pool1d(image_features.permute(0, 2, 1), 70).permute(0, 2, 1)
# Extract BERT embeddings
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
bert_output = bert_outputs.last_hidden_state
# Apply multihead attention between visual_features and BERT embeddings
attention_output = self.attention(
query=image_features.permute(1, 0, 2),
key=bert_output.permute(1, 0, 2),
value=image_features.permute(1, 0, 2),
mask=None
)
# Swap back the dimensions to (batch_size, seq_length, feature_size)
attention_output = attention_output.permute(1, 0, 2)
# Concatenate the context vector, visual features, BERT embeddings, and attention output
fusion_input = torch.cat([attention_output, image_features, bert_output], dim=2)
# Pool over the sequence dimension and pass through FC layers
output = self.fc(fusion_input.mean(1))
if self.use_lexicon_boost and lexicon_matches is not None:
output = apply_lexicon_boost(output, lexicon_matches)
return output