Dusit-P commited on
Commit
cf3fcd2
·
verified ·
1 Parent(s): c179922

Update common/models.py

Browse files
Files changed (1) hide show
  1. common/models.py +74 -58
common/models.py CHANGED
@@ -1,58 +1,74 @@
1
- # common/models.py
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import AutoModel
5
-
6
- # ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
7
- BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
8
- POOLING_AFTER_LSTM = "masked_mean"
9
-
10
- class BaseHead(nn.Module):
11
- def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
12
- super().__init__()
13
- self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
14
- self.dropout = nn.Dropout(dropout)
15
- self.fc = nn.Linear(hidden_lstm*2, num_classes)
16
- assert pooling in ['cls','masked_mean','masked_max']
17
- self.pooling = pooling
18
- def pool(self, x, mask):
19
- if self.pooling=='cls': return x[:,0,:]
20
- mask = mask.unsqueeze(-1)
21
- if self.pooling=='masked_mean':
22
- s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
23
- x=x.masked_fill(mask==0,-1e9); return x.max(1).values
24
- def forward_after_bert(self, seq, mask):
25
- x, _ = self.lstm(seq)
26
- x = self.pool(x, mask)
27
- return self.fc(self.dropout(x))
28
-
29
- class Model1Baseline(nn.Module):
30
- def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
31
- super().__init__()
32
- self.bert = AutoModel.from_pretrained(name)
33
- self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
34
- def forward(self, ids, mask):
35
- out = self.bert(input_ids=ids, attention_mask=mask)
36
- return self.head.forward_after_bert(out.last_hidden_state, mask)
37
-
38
- class Model2CNNBiLSTM(nn.Module):
39
- def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
40
- super().__init__()
41
- self.bert = AutoModel.from_pretrained(name)
42
- H = self.bert.config.hidden_size
43
- self.c1 = nn.Conv1d(H,128,3,padding=1)
44
- self.c2 = nn.Conv1d(128,128,5,padding=2)
45
- self.head = BaseHead(128, hidden, classes, dropout, pooling)
46
- def forward(self, ids, mask):
47
- out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
48
- x = F.relu(self.c1(out.transpose(1,2)))
49
- x = F.relu(self.c2(x)).transpose(1,2)
50
- return self.head.forward_after_bert(x, mask)
51
-
52
- def create_model_by_name(model_name):
53
- if model_name == "Model1_Baseline":
54
- return Model1Baseline()
55
- elif model_name == "Model2_CNN_BiLSTM":
56
- return Model2CNNBiLSTM()
57
- else:
58
- raise ValueError(f"Unknown model name: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common/models.py
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel
5
+
6
+ # ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
7
+ BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
8
+ POOLING_AFTER_LSTM = "masked_mean"
9
+
10
+ class BaseHead(nn.Module):
11
+ def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
12
+ super().__init__()
13
+ self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
14
+ self.dropout = nn.Dropout(dropout)
15
+ self.fc = nn.Linear(hidden_lstm*2, num_classes)
16
+ assert pooling in ['cls','masked_mean','masked_max']
17
+ self.pooling = pooling
18
+ def pool(self, x, mask):
19
+ if self.pooling=='cls': return x[:,0,:]
20
+ mask = mask.unsqueeze(-1)
21
+ if self.pooling=='masked_mean':
22
+ s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
23
+ x=x.masked_fill(mask==0,-1e9); return x.max(1).values
24
+ def forward_after_bert(self, seq, mask):
25
+ x, _ = self.lstm(seq)
26
+ x = self.pool(x, mask)
27
+ return self.fc(self.dropout(x))
28
+
29
+ class Model1Baseline(nn.Module):
30
+ def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
31
+ super().__init__()
32
+ self.bert = AutoModel.from_pretrained(name)
33
+ self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
34
+ def forward(self, ids, mask):
35
+ out = self.bert(input_ids=ids, attention_mask=mask)
36
+ return self.head.forward_after_bert(out.last_hidden_state, mask)
37
+
38
+ class Model2CNNBiLSTM(nn.Module):
39
+ def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
40
+ super().__init__()
41
+ self.bert = AutoModel.from_pretrained(name)
42
+ H = self.bert.config.hidden_size
43
+ self.c1 = nn.Conv1d(H,128,3,padding=1)
44
+ self.c2 = nn.Conv1d(128,128,5,padding=2)
45
+ self.head = BaseHead(128, hidden, classes, dropout, pooling)
46
+ def forward(self, ids, mask):
47
+ out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
48
+ x = F.relu(self.c1(out.transpose(1,2)))
49
+ x = F.relu(self.c2(x)).transpose(1,2)
50
+ return self.head.forward_after_bert(x, mask)
51
+
52
+ class Model3PureLast4(nn.Module):
53
+ def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
54
+ super().__init__()
55
+ from transformers import AutoModel
56
+ import torch.nn.functional as F
57
+ self.bert = AutoModel.from_pretrained(name)
58
+ self.w = nn.Parameter(torch.ones(4))
59
+ H = self.bert.config.hidden_size
60
+ self.head = BaseHead(H, hidden, classes, dropout, pooling)
61
+ def forward(self, ids, mask):
62
+ out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
63
+ last4 = out.hidden_states[-4:]; w = F.softmax(self.w, dim=0)
64
+ seq = sum(w[i]*last4[i] for i in range(4))
65
+ return self.head.forward_after_bert(seq, mask)
66
+
67
+ def create_model_by_name(model_name):
68
+ if model_name == "Model1_Baseline": return Model1Baseline()
69
+ elif model_name == "Model2_CNN_BiLSTM": return Model2CNNBiLSTM()
70
+ elif model_name == "Model3_Pure_Last4Weighted": #in ["Model3_Pure_Last4Weighted","last4weighted_pure","last4_pure"]:
71
+ return Model3PureLast4()
72
+
73
+ else:
74
+ raise ValueError(f"Unknown model name: {model_name}")