|
|
import logging |
|
|
from logging_config import * |
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from pytorch_lightning.callbacks import EarlyStopping |
|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
import optim |
|
|
|
|
|
class T5Tuner(pl.LightningModule): |
|
|
def __init__(self, batchsize, t5model, t5tokenizer, early_stopping_patience=5, early_stopping_metric='val_loss'): |
|
|
super(T5Tuner, self).__init__() |
|
|
self.batch_size = batchsize |
|
|
self.model = t5model |
|
|
self.tokenizer = t5tokenizer |
|
|
self.early_stopping_patience = early_stopping_patience |
|
|
self.early_stopping_metric = early_stopping_metric |
|
|
|
|
|
def configure_callbacks(self): |
|
|
early_stopping_callback = EarlyStopping( |
|
|
monitor=self.early_stopping_metric, |
|
|
patience=self.early_stopping_patience, |
|
|
verbose=True, |
|
|
mode='min' |
|
|
) |
|
|
return [early_stopping_callback] |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, decoder_attention_mask=None, lm_labels=None): |
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
decoder_attention_mask=decoder_attention_mask, |
|
|
labels=lm_labels, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
outputs = self.forward( |
|
|
input_ids=batch["source_ids"], |
|
|
attention_mask=batch["source_mask"], |
|
|
decoder_attention_mask=batch['target_mask'], |
|
|
lm_labels=batch['labels'] |
|
|
) |
|
|
loss = outputs[0] |
|
|
self.log('train_loss', loss) |
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
outputs = self.forward( |
|
|
input_ids=batch["source_ids"], |
|
|
attention_mask=batch["source_mask"], |
|
|
decoder_attention_mask=batch['target_mask'], |
|
|
lm_labels=batch['labels'] |
|
|
) |
|
|
loss = outputs[0] |
|
|
self.log("val_loss", loss) |
|
|
return loss |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader(custom_train_dataset, batch_size=self.batch_size, num_workers=4) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader(custom_validation_dataset, batch_size=self.batch_size, num_workers=4) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8) |
|
|
return optimizer |
|
|
|