Update common/models.py
Browse files- common/models.py +74 -58
common/models.py
CHANGED
|
@@ -1,58 +1,74 @@
|
|
| 1 |
-
# common/models.py
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from transformers import AutoModel
|
| 5 |
-
|
| 6 |
-
# ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
|
| 7 |
-
BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
|
| 8 |
-
POOLING_AFTER_LSTM = "masked_mean"
|
| 9 |
-
|
| 10 |
-
class BaseHead(nn.Module):
|
| 11 |
-
def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
|
| 14 |
-
self.dropout = nn.Dropout(dropout)
|
| 15 |
-
self.fc = nn.Linear(hidden_lstm*2, num_classes)
|
| 16 |
-
assert pooling in ['cls','masked_mean','masked_max']
|
| 17 |
-
self.pooling = pooling
|
| 18 |
-
def pool(self, x, mask):
|
| 19 |
-
if self.pooling=='cls': return x[:,0,:]
|
| 20 |
-
mask = mask.unsqueeze(-1)
|
| 21 |
-
if self.pooling=='masked_mean':
|
| 22 |
-
s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
|
| 23 |
-
x=x.masked_fill(mask==0,-1e9); return x.max(1).values
|
| 24 |
-
def forward_after_bert(self, seq, mask):
|
| 25 |
-
x, _ = self.lstm(seq)
|
| 26 |
-
x = self.pool(x, mask)
|
| 27 |
-
return self.fc(self.dropout(x))
|
| 28 |
-
|
| 29 |
-
class Model1Baseline(nn.Module):
|
| 30 |
-
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 31 |
-
super().__init__()
|
| 32 |
-
self.bert = AutoModel.from_pretrained(name)
|
| 33 |
-
self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
|
| 34 |
-
def forward(self, ids, mask):
|
| 35 |
-
out = self.bert(input_ids=ids, attention_mask=mask)
|
| 36 |
-
return self.head.forward_after_bert(out.last_hidden_state, mask)
|
| 37 |
-
|
| 38 |
-
class Model2CNNBiLSTM(nn.Module):
|
| 39 |
-
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 40 |
-
super().__init__()
|
| 41 |
-
self.bert = AutoModel.from_pretrained(name)
|
| 42 |
-
H = self.bert.config.hidden_size
|
| 43 |
-
self.c1 = nn.Conv1d(H,128,3,padding=1)
|
| 44 |
-
self.c2 = nn.Conv1d(128,128,5,padding=2)
|
| 45 |
-
self.head = BaseHead(128, hidden, classes, dropout, pooling)
|
| 46 |
-
def forward(self, ids, mask):
|
| 47 |
-
out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 48 |
-
x = F.relu(self.c1(out.transpose(1,2)))
|
| 49 |
-
x = F.relu(self.c2(x)).transpose(1,2)
|
| 50 |
-
return self.head.forward_after_bert(x, mask)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# common/models.py
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import AutoModel
|
| 5 |
+
|
| 6 |
+
# ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
|
| 7 |
+
BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
|
| 8 |
+
POOLING_AFTER_LSTM = "masked_mean"
|
| 9 |
+
|
| 10 |
+
class BaseHead(nn.Module):
|
| 11 |
+
def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
|
| 14 |
+
self.dropout = nn.Dropout(dropout)
|
| 15 |
+
self.fc = nn.Linear(hidden_lstm*2, num_classes)
|
| 16 |
+
assert pooling in ['cls','masked_mean','masked_max']
|
| 17 |
+
self.pooling = pooling
|
| 18 |
+
def pool(self, x, mask):
|
| 19 |
+
if self.pooling=='cls': return x[:,0,:]
|
| 20 |
+
mask = mask.unsqueeze(-1)
|
| 21 |
+
if self.pooling=='masked_mean':
|
| 22 |
+
s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
|
| 23 |
+
x=x.masked_fill(mask==0,-1e9); return x.max(1).values
|
| 24 |
+
def forward_after_bert(self, seq, mask):
|
| 25 |
+
x, _ = self.lstm(seq)
|
| 26 |
+
x = self.pool(x, mask)
|
| 27 |
+
return self.fc(self.dropout(x))
|
| 28 |
+
|
| 29 |
+
class Model1Baseline(nn.Module):
|
| 30 |
+
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.bert = AutoModel.from_pretrained(name)
|
| 33 |
+
self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
|
| 34 |
+
def forward(self, ids, mask):
|
| 35 |
+
out = self.bert(input_ids=ids, attention_mask=mask)
|
| 36 |
+
return self.head.forward_after_bert(out.last_hidden_state, mask)
|
| 37 |
+
|
| 38 |
+
class Model2CNNBiLSTM(nn.Module):
|
| 39 |
+
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.bert = AutoModel.from_pretrained(name)
|
| 42 |
+
H = self.bert.config.hidden_size
|
| 43 |
+
self.c1 = nn.Conv1d(H,128,3,padding=1)
|
| 44 |
+
self.c2 = nn.Conv1d(128,128,5,padding=2)
|
| 45 |
+
self.head = BaseHead(128, hidden, classes, dropout, pooling)
|
| 46 |
+
def forward(self, ids, mask):
|
| 47 |
+
out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 48 |
+
x = F.relu(self.c1(out.transpose(1,2)))
|
| 49 |
+
x = F.relu(self.c2(x)).transpose(1,2)
|
| 50 |
+
return self.head.forward_after_bert(x, mask)
|
| 51 |
+
|
| 52 |
+
class Model3PureLast4(nn.Module):
|
| 53 |
+
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 54 |
+
super().__init__()
|
| 55 |
+
from transformers import AutoModel
|
| 56 |
+
import torch.nn.functional as F
|
| 57 |
+
self.bert = AutoModel.from_pretrained(name)
|
| 58 |
+
self.w = nn.Parameter(torch.ones(4))
|
| 59 |
+
H = self.bert.config.hidden_size
|
| 60 |
+
self.head = BaseHead(H, hidden, classes, dropout, pooling)
|
| 61 |
+
def forward(self, ids, mask):
|
| 62 |
+
out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
|
| 63 |
+
last4 = out.hidden_states[-4:]; w = F.softmax(self.w, dim=0)
|
| 64 |
+
seq = sum(w[i]*last4[i] for i in range(4))
|
| 65 |
+
return self.head.forward_after_bert(seq, mask)
|
| 66 |
+
|
| 67 |
+
def create_model_by_name(model_name):
|
| 68 |
+
if model_name == "Model1_Baseline": return Model1Baseline()
|
| 69 |
+
elif model_name == "Model2_CNN_BiLSTM": return Model2CNNBiLSTM()
|
| 70 |
+
elif model_name == "Model3_Pure_Last4Weighted": #in ["Model3_Pure_Last4Weighted","last4weighted_pure","last4_pure"]:
|
| 71 |
+
return Model3PureLast4()
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unknown model name: {model_name}")
|