Trainer being very slow to init training setting group_by_length to True

I observe a very long time before the training actually starts once Trainer.train being called.
It appears it comes from LengthGroupedSampler used when setting group_by_length to True.

Is there a way to use multiple workers to accelerate this process?

2 Likes

Hey, did you manage to solve this?

2 Likes

Did anyone solve this issue?

1 Like

I think it’s possible to mitigate this to some extent by pre-processing the dataset with tokenization, but it feels like half of it is just the Trainer’s specification…

from __future__ import annotations

import math
import os
import random
from dataclasses import dataclass
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple

import torch
from torch.utils.data import BatchSampler, Dataset, Sampler

from datasets import DatasetDict, load_dataset
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

# ---- Config (edit as needed) ----
MODEL_NAME = "gpt2"
TEXT_COLUMN = "text"
DATASET = ("wikitext", "wikitext-2-raw-v1")  # (name, subset) or local files
CACHE_DIR = os.path.join(os.getcwd(), ".cache_ds")
MAX_LENGTH = 1024
BATCH_SIZE_PER_DEVICE = 4
NUM_WORKERS = max(1, os.cpu_count() // 2)


# ---- Utilities ----
def _seq_length(example_ids: List[int], pad_id: int) -> int:
    # Why: robust for both padded and unpadded tokenization
    if pad_id in example_ids:
        # Count non-pad tokens
        return int(sum(1 for t in example_ids if t != pad_id))
    return len(example_ids)


def add_length_column(ds: Dataset, pad_id: int) -> Dataset:
    def _len_fn(batch):
        return {"length": [ _seq_length(ids, pad_id) for ids in batch["input_ids"] ]}

    # Map once and cache; batched for speed; num_proc for parallelism
    ds = ds.map(_len_fn, batched=True, num_proc=max(1, os.cpu_count() // 2))
    return ds


def build_buckets_from_quantiles(lengths: Sequence[int], num_buckets: int = 20) -> List[int]:
    # Why: stable bucket boundaries that avoid global sort at runtime
    if num_buckets < 2:
        return [max(lengths)]
    qs = [i / num_buckets for i in range(1, num_buckets)]
    sorted_lens = sorted(lengths)
    edges = [sorted_lens[math.floor(q * (len(sorted_lens) - 1))] for q in qs]
    return edges + [max(sorted_lens)]


def bucket_index(length: int, edges: Sequence[int]) -> int:
    lo, hi = 0, len(edges) - 1
    while lo < hi:
        mid = (lo + hi) // 2
        if length <= edges[mid]:
            hi = mid
        else:
            lo = mid + 1
    return lo


# ---- Samplers ----
class BucketBatchSampler(BatchSampler):
    """
    Epoch-wise shuffling inside buckets only; no global sort.
    DDP-aware if `process_rank`/`num_processes` are set (Trainer passes these).
    """
    def __init__(
        self,
        lengths: Sequence[int],
        batch_size: int,
        bucket_edges: Sequence[int],
        drop_last: bool = False,
        process_rank: int = 0,
        num_processes: int = 1,
        generator: Optional[torch.Generator] = None,
    ):
        self.lengths = lengths
        self.batch_size = batch_size
        self.bucket_edges = list(bucket_edges)
        self.drop_last = drop_last
        self.process_rank = process_rank
        self.num_processes = max(1, num_processes)
        self.generator = generator

        # Pre-partition indices into buckets once
        self._buckets: List[List[int]] = [[] for _ in range(len(self.bucket_edges))]
        for idx, L in enumerate(self.lengths):
            b = bucket_index(L, self.bucket_edges)
            self._buckets[b].append(idx)

    def __iter__(self) -> Iterator[List[int]]:
        g = self.generator
        if g is None:
            # Make per-epoch randomness deterministic under Trainer seed
            g = torch.Generator()
            g.manual_seed(torch.randint(0, 2**31 - 1, (1,)).item())

        # Shuffle within each bucket and then interleave buckets
        local_batches: List[List[int]] = []
        for bucket in self._buckets:
            # Per-bucket shuffle
            bucket_indices = bucket[:]  # copy
            if bucket_indices:
                perm = torch.randperm(len(bucket_indices), generator=g).tolist()
                bucket_indices = [bucket_indices[i] for i in perm]

            # Make batches
            for i in range(0, len(bucket_indices), self.batch_size):
                batch = bucket_indices[i : i + self.batch_size]
                if len(batch) == self.batch_size or (not self.drop_last and len(batch) > 0):
                    local_batches.append(batch)

        # Shuffle batches themselves for extra randomness
        if local_batches:
            perm = torch.randperm(len(local_batches), generator=g).tolist()
            local_batches = [local_batches[i] for i in perm]

        # DDP sharding across processes at batch granularity
        if self.num_processes > 1:
            local_batches = local_batches[self.process_rank :: self.num_processes]

        for b in local_batches:
            yield b

    def __len__(self) -> int:
        total = 0
        for bucket in self._buckets:
            n = len(bucket) // self.batch_size
            if not self.drop_last and (len(bucket) % self.batch_size):
                n += 1
            total += n
        if self.num_processes > 1:
            # Round up to include uneven splits
            return (total + self.num_processes - 1) // self.num_processes
        return total


# ---- Trainer subclass to use custom batch sampler ----
class BucketedTrainer(Trainer):
    def get_train_dataloader(self):
        if self.train_dataset is None:
            return super().get_train_dataloader()

        # Ensure torch-format fast path
        if hasattr(self.train_dataset, "with_format"):
            self.train_dataset = self.train_dataset.with_format("torch")

        lengths = self.train_dataset["length"]
        # Derive bucket edges once
        edges = build_buckets_from_quantiles(lengths, num_buckets=24)

        # Resolve DDP parameters from Trainer/Accelerate
        process_rank = self.args.process_index
        num_processes = self.args.world_size

        batch_sampler = BucketBatchSampler(
            lengths=lengths,
            batch_size=self.args.per_device_train_batch_size,
            bucket_edges=edges,
            drop_last=self.args.dataloader_drop_last,
            process_rank=process_rank,
            num_processes=num_processes,
            generator=None,
        )

        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_sampler=batch_sampler,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            persistent_workers=True,
            prefetch_factor=4 if self.args.dataloader_num_workers > 0 else None,
        )


# ---- Example: wire everything together ----
def main():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load dataset
    ds: DatasetDict
    if isinstance(DATASET, tuple):
        ds = load_dataset(DATASET[0], DATASET[1])
    else:
        ds = load_dataset("text", data_files=DATASET)

    # Tokenize
    def tok_fn(examples):
        return tokenizer(
            examples[TEXT_COLUMN],
            truncation=True,
            max_length=MAX_LENGTH,
            # Let DataCollator pad dynamically; faster + less memory
            padding=False,
        )

    ds = ds.map(tok_fn, batched=True, num_proc=max(1, os.cpu_count() // 2), remove_columns=[TEXT_COLUMN])

    # Add cached lengths once and save to disk for instant reuse
    ds = ds.map(lambda x: x, load_from_cache_file=True)  # no-op to ensure cache dirs
    ds["train"] = add_length_column(ds["train"], tokenizer.pad_token_id)
    if "validation" in ds:
        ds["validation"] = add_length_column(ds["validation"], tokenizer.pad_token_id)

    if not os.path.exists(CACHE_DIR):
        os.makedirs(CACHE_DIR, exist_ok=True)
    ds.save_to_disk(CACHE_DIR)

    # Collator
    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)  # align for Tensor Cores

    # Training args: turn OFF built-in group_by_length; rely on our sampler
    args = TrainingArguments(
        output_dir="./out",
        per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
        per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE,
        learning_rate=5e-5,
        num_train_epochs=1,
        evaluation_strategy="no",
        logging_steps=50,
        save_steps=0,
        seed=42,
        dataloader_num_workers=NUM_WORKERS,
        dataloader_pin_memory=True,
        dataloader_drop_last=False,
        group_by_length=False,  # crucial
        fp16=torch.cuda.is_available(),
        bf16=torch.cuda.is_bf16_supported() if hasattr(torch.cuda, "is_bf16_supported") else False,
        report_to=[],
    )

    trainer = BucketedTrainer(
        model_name_or_path=MODEL_NAME,
        args=args,
        train_dataset=ds["train"],
        eval_dataset=ds.get("validation"),
        tokenizer=tokenizer,
        data_collator=collator,
    )

    # Sanity: print quick batch to show it works without long init
    it = iter(trainer.get_train_dataloader())
    first_batch = next(it)
    print(
        "Batch shapes:",
        {k: tuple(v.shape) for k, v in first_batch.items() if isinstance(v, torch.Tensor)},
    )

    # Start training
    # trainer.train()  # uncomment to actually train


if __name__ == "__main__":
    main()

Script generated by TD Ai

1 Like