from __future__ import annotations
import math
import os
import random
from dataclasses import dataclass
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import torch
from torch.utils.data import BatchSampler, Dataset, Sampler
from datasets import DatasetDict, load_dataset
from transformers import (
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
# ---- Config (edit as needed) ----
MODEL_NAME = "gpt2"
TEXT_COLUMN = "text"
DATASET = ("wikitext", "wikitext-2-raw-v1") # (name, subset) or local files
CACHE_DIR = os.path.join(os.getcwd(), ".cache_ds")
MAX_LENGTH = 1024
BATCH_SIZE_PER_DEVICE = 4
NUM_WORKERS = max(1, os.cpu_count() // 2)
# ---- Utilities ----
def _seq_length(example_ids: List[int], pad_id: int) -> int:
# Why: robust for both padded and unpadded tokenization
if pad_id in example_ids:
# Count non-pad tokens
return int(sum(1 for t in example_ids if t != pad_id))
return len(example_ids)
def add_length_column(ds: Dataset, pad_id: int) -> Dataset:
def _len_fn(batch):
return {"length": [ _seq_length(ids, pad_id) for ids in batch["input_ids"] ]}
# Map once and cache; batched for speed; num_proc for parallelism
ds = ds.map(_len_fn, batched=True, num_proc=max(1, os.cpu_count() // 2))
return ds
def build_buckets_from_quantiles(lengths: Sequence[int], num_buckets: int = 20) -> List[int]:
# Why: stable bucket boundaries that avoid global sort at runtime
if num_buckets < 2:
return [max(lengths)]
qs = [i / num_buckets for i in range(1, num_buckets)]
sorted_lens = sorted(lengths)
edges = [sorted_lens[math.floor(q * (len(sorted_lens) - 1))] for q in qs]
return edges + [max(sorted_lens)]
def bucket_index(length: int, edges: Sequence[int]) -> int:
lo, hi = 0, len(edges) - 1
while lo < hi:
mid = (lo + hi) // 2
if length <= edges[mid]:
hi = mid
else:
lo = mid + 1
return lo
# ---- Samplers ----
class BucketBatchSampler(BatchSampler):
"""
Epoch-wise shuffling inside buckets only; no global sort.
DDP-aware if `process_rank`/`num_processes` are set (Trainer passes these).
"""
def __init__(
self,
lengths: Sequence[int],
batch_size: int,
bucket_edges: Sequence[int],
drop_last: bool = False,
process_rank: int = 0,
num_processes: int = 1,
generator: Optional[torch.Generator] = None,
):
self.lengths = lengths
self.batch_size = batch_size
self.bucket_edges = list(bucket_edges)
self.drop_last = drop_last
self.process_rank = process_rank
self.num_processes = max(1, num_processes)
self.generator = generator
# Pre-partition indices into buckets once
self._buckets: List[List[int]] = [[] for _ in range(len(self.bucket_edges))]
for idx, L in enumerate(self.lengths):
b = bucket_index(L, self.bucket_edges)
self._buckets[b].append(idx)
def __iter__(self) -> Iterator[List[int]]:
g = self.generator
if g is None:
# Make per-epoch randomness deterministic under Trainer seed
g = torch.Generator()
g.manual_seed(torch.randint(0, 2**31 - 1, (1,)).item())
# Shuffle within each bucket and then interleave buckets
local_batches: List[List[int]] = []
for bucket in self._buckets:
# Per-bucket shuffle
bucket_indices = bucket[:] # copy
if bucket_indices:
perm = torch.randperm(len(bucket_indices), generator=g).tolist()
bucket_indices = [bucket_indices[i] for i in perm]
# Make batches
for i in range(0, len(bucket_indices), self.batch_size):
batch = bucket_indices[i : i + self.batch_size]
if len(batch) == self.batch_size or (not self.drop_last and len(batch) > 0):
local_batches.append(batch)
# Shuffle batches themselves for extra randomness
if local_batches:
perm = torch.randperm(len(local_batches), generator=g).tolist()
local_batches = [local_batches[i] for i in perm]
# DDP sharding across processes at batch granularity
if self.num_processes > 1:
local_batches = local_batches[self.process_rank :: self.num_processes]
for b in local_batches:
yield b
def __len__(self) -> int:
total = 0
for bucket in self._buckets:
n = len(bucket) // self.batch_size
if not self.drop_last and (len(bucket) % self.batch_size):
n += 1
total += n
if self.num_processes > 1:
# Round up to include uneven splits
return (total + self.num_processes - 1) // self.num_processes
return total
# ---- Trainer subclass to use custom batch sampler ----
class BucketedTrainer(Trainer):
def get_train_dataloader(self):
if self.train_dataset is None:
return super().get_train_dataloader()
# Ensure torch-format fast path
if hasattr(self.train_dataset, "with_format"):
self.train_dataset = self.train_dataset.with_format("torch")
lengths = self.train_dataset["length"]
# Derive bucket edges once
edges = build_buckets_from_quantiles(lengths, num_buckets=24)
# Resolve DDP parameters from Trainer/Accelerate
process_rank = self.args.process_index
num_processes = self.args.world_size
batch_sampler = BucketBatchSampler(
lengths=lengths,
batch_size=self.args.per_device_train_batch_size,
bucket_edges=edges,
drop_last=self.args.dataloader_drop_last,
process_rank=process_rank,
num_processes=num_processes,
generator=None,
)
return torch.utils.data.DataLoader(
self.train_dataset,
batch_sampler=batch_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
persistent_workers=True,
prefetch_factor=4 if self.args.dataloader_num_workers > 0 else None,
)
# ---- Example: wire everything together ----
def main():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
ds: DatasetDict
if isinstance(DATASET, tuple):
ds = load_dataset(DATASET[0], DATASET[1])
else:
ds = load_dataset("text", data_files=DATASET)
# Tokenize
def tok_fn(examples):
return tokenizer(
examples[TEXT_COLUMN],
truncation=True,
max_length=MAX_LENGTH,
# Let DataCollator pad dynamically; faster + less memory
padding=False,
)
ds = ds.map(tok_fn, batched=True, num_proc=max(1, os.cpu_count() // 2), remove_columns=[TEXT_COLUMN])
# Add cached lengths once and save to disk for instant reuse
ds = ds.map(lambda x: x, load_from_cache_file=True) # no-op to ensure cache dirs
ds["train"] = add_length_column(ds["train"], tokenizer.pad_token_id)
if "validation" in ds:
ds["validation"] = add_length_column(ds["validation"], tokenizer.pad_token_id)
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR, exist_ok=True)
ds.save_to_disk(CACHE_DIR)
# Collator
collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # align for Tensor Cores
# Training args: turn OFF built-in group_by_length; rely on our sampler
args = TrainingArguments(
output_dir="./out",
per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE,
learning_rate=5e-5,
num_train_epochs=1,
evaluation_strategy="no",
logging_steps=50,
save_steps=0,
seed=42,
dataloader_num_workers=NUM_WORKERS,
dataloader_pin_memory=True,
dataloader_drop_last=False,
group_by_length=False, # crucial
fp16=torch.cuda.is_available(),
bf16=torch.cuda.is_bf16_supported() if hasattr(torch.cuda, "is_bf16_supported") else False,
report_to=[],
)
trainer = BucketedTrainer(
model_name_or_path=MODEL_NAME,
args=args,
train_dataset=ds["train"],
eval_dataset=ds.get("validation"),
tokenizer=tokenizer,
data_collator=collator,
)
# Sanity: print quick batch to show it works without long init
it = iter(trainer.get_train_dataloader())
first_batch = next(it)
print(
"Batch shapes:",
{k: tuple(v.shape) for k, v in first_batch.items() if isinstance(v, torch.Tensor)},
)
# Start training
# trainer.train() # uncomment to actually train
if __name__ == "__main__":
main()
Script generated by TD Ai