zibaatak's picture
Upload 17 files
fdfb2b8
raw
history blame
2.47 kB
import logging
from logging_config import * # Import logging settings from logging_config.py
import torch
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import EarlyStopping
from transformers import AutoTokenizer, T5ForConditionalGeneration
import optim
class T5Tuner(pl.LightningModule):
def __init__(self, batchsize, t5model, t5tokenizer, early_stopping_patience=5, early_stopping_metric='val_loss'):
super(T5Tuner, self).__init__()
self.batch_size = batchsize
self.model = t5model
self.tokenizer = t5tokenizer
self.early_stopping_patience = early_stopping_patience
self.early_stopping_metric = early_stopping_metric
def configure_callbacks(self):
early_stopping_callback = EarlyStopping(
monitor=self.early_stopping_metric,
patience=self.early_stopping_patience,
verbose=True,
mode='min' # 'min' because we want to minimize the validation loss
)
return [early_stopping_callback]
def forward(self, input_ids, attention_mask=None, decoder_attention_mask=None, lm_labels=None):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
return outputs
def training_step(self, batch, batch_idx):
outputs = self.forward(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
decoder_attention_mask=batch['target_mask'],
lm_labels=batch['labels']
)
loss = outputs[0]
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
outputs = self.forward(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
decoder_attention_mask=batch['target_mask'],
lm_labels=batch['labels']
)
loss = outputs[0]
self.log("val_loss", loss)
return loss
def train_dataloader(self):
return DataLoader(custom_train_dataset, batch_size=self.batch_size, num_workers=4)
def val_dataloader(self):
return DataLoader(custom_validation_dataset, batch_size=self.batch_size, num_workers=4)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8)
return optimizer