| from torch.utils.data import Dataset, DataLoader | |
| import torch | |
| class NewsDataset(Dataset): | |
| def __init__(self, titles, texts, labels=None): | |
| self.titles = titles | |
| self.texts = texts | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.titles) | |
| def __getitem__(self, idx): | |
| if self.labels is not None: | |
| return self.titles[idx], self.texts[idx], self.labels[idx] | |
| return self.titles[idx], self.texts[idx] | |
| def create_data_loader(titles, texts, labels=None, batch_size=32, shuffle=False, num_workers=6): | |
| dataset = NewsDataset(titles, texts, labels) | |
| return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, persistent_workers=True) | |