book-buddy-question-generator / scripts /CustomQuestionGenerationDataset.py
zibaatak's picture
Upload 17 files
fdfb2b8
raw
history blame
1.9 kB
import logging
from logging_config import * # Import logging settings from logging_config.py
from torch.utils.data import Dataset
class CustomQuestionGenerationDataset(Dataset):
def __init__(self, tokenizer, data_frame, max_len_inp=512, max_len_out=96):
self.data = data_frame
self.max_len_input = max_len_inp
self.max_len_output = max_len_out
self.tokenizer = tokenizer
self.inputs = []
self.targets = []
self._build()
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
source_ids = self.inputs[index]["input_ids"].squeeze()
target_ids = self.targets[index]["input_ids"].squeeze()
src_mask = self.inputs[index]["attention_mask"].squeeze()
target_mask = self.targets[index]["attention_mask"].squeeze()
labels = target_ids.clone()
labels[labels == 0] = -100
return {
"source_ids": source_ids,
"source_mask": src_mask,
"target_ids": target_ids,
"target_mask": target_mask,
"labels": labels
}
def _build(self):
for _, row in self.data.iterrows():
passage, question = row["passage"], row["question"]
input_ = f"context: {passage}"
target = f"question: {str(question)}"
# Tokenize inputs
tokenized_inputs = self.tokenizer.batch_encode_plus(
[input_], max_length=self.max_len_input, padding='max_length',
return_tensors="pt" # PyTorch tensors
)
# Tokenize targets
tokenized_targets = self.tokenizer.batch_encode_plus(
[target], max_length=self.max_len_output, padding='max_length', return_tensors="pt"
)
self.inputs.append(tokenized_inputs)
self.targets.append(tokenized_targets)