Hardware:
CPU: Xeon® E5-2630 v2 but limited to 16GB as this is what the vast.ai instance has.
GPU: 4x A40 → Total of 180GB
OS
Linux
python
3.10
cuda
12.2
packages:
torch==2.3.1
transformers==4.41.2
peft==0.11.1
datasets==2.20.0
accelerate==0.31.0
evaluate==0.4.1
bitsandbytes==0.43.1
huggingface_hub==0.23.4
trl==0.9.4
Issue
Introduction
Hi!
I’m trying to fine-tune LLama3-8B on a summarization dataset of about 1500 instances. The dataset contains long documents, often over 8K tokens. I want to use FSDP + QLORA to try and finetune LLama3 8B. When following this guide I was very hopeful this was possible on my setup as I’m finetuning a 8B version instead of the 70B version.
I’m following these two guides as inspiration:
bitsandbytes Guide
Phil Schmid Guide
Phil Schmid’s guide mentions the following:
Expected Memory usage:
Full-finetuning with FSDP needs ~16X80GB GPUs
FSDP + LoRA needs ~8X80GB GPUs
FSDP + Q-Lora needs ~2x40GB GPUs
FSDP + Q-Lora + CPU offloading needs 4x24GB GPUs, with 22 GB/GPU and 127 GB CPU RAM with a sequence length of 3072 and a batch size of 1.
Note: To NOT CPU offloading you need to change the value of fsdp and remove offload. This only works on > 40GB GPUs since it requires more memory.
Accelerate config setup:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: false #Was true before
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Code
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
'meta-llama/Meta-Llama-3-8B',
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B)
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
r= 8,
lora_alpha=16,
lora_dropout=0.1,
target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
task_type= 'CAUSAL_LM',
bias= 'none',
)
model = get_peft_model(model, lora_config)
training_args = TrainingArguments(
output_dir = os.path.join('results', model_id, 'output'),
num_train_epochs = 40,
per_device_train_batch_size = 1,
per_device_eval_batch_size = 1,
gradient_accumulation_steps = True,
warmup_ratio = args.warmup_ratio,
weight_decay = args.weight_decay,
logging_dir = os.path.join('results', model_id, 'logs'),
remove_unused_columns = False,
load_best_model_at_end = True,
metric_for_best_model = True,
save_strategy= "epoch",
save_total_limit= 2,
evaluation_strategy = "epoch",
label_names=["labels"],
report_to = "wandb",
logging_strategy = "epoch",
run_name = model_id,
eval_accumulation_steps = 1,
hub_model_id = f"{model_id}",
gradient_checkpointing= True,
fp16= args.fp16,
bf16= args.bf16,
ddp_find_unused_parameters = True,
gradient_checkpointing_kwargs= {'use_reentrant': False},
)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
args = training_args,
train_dataset = dataset["train"],
eval_dataset = dataset["validation"],
max_seq_length = context_length_abstractive_model, #8192
callbacks = [EarlyStoppingCallback(early_stopping_patience = args.early_stopping_patience)],
peft_config = lora_config,
packing= True
)
trainer.train()
Start training
accelerate launch training.py --bf16
errors:
First is followed the guides exactly and set fsdp_cpu_ram_efficient_loading to true. But when i do this, sometimes the OS would run give a SIGKILL(9) error and stop the process:

This makes sense as Phil Schmid also recommends a pretty hefty CPU memory: 127 GB CPU RAM with a sequence length of 3072 for a batch size of 1.
But oddly enough, I can run the script currently with fsdp_cpu_ram_efficient_loading_ with either true or false and not receive the SIGKILL(9) error. However, in both situations I do get the following OOM error:
rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/Thesis/training.py", line 703, in <module>
[rank1]: trainer.train()
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 440, in train
[rank1]: output = super().train(*args, **kwargs)
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank1]: return inner_training_loop(
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
[rank1]: tr_loss_step = self.training_step(model, inputs)
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3250, in training_step
[rank1]: self.accelerator.backward(loss)
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2134, in backward
[rank1]: loss.backward(**kwargs)
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank1]: torch.autograd.backward(
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]: _engine_run_backward(
[rank1]: File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.13 GiB. GPU has a total capacity of 44.35 GiB of which 20.85 GiB is free. Process 787350 has 23.49 GiB memory in use. Of the allocated memory 18.22 GiB is allocated by PyTorch, and 4.84 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0617 09:10:40.805000 140644428781376 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 3244 closing signal SIGTERM
As you can see, it seems that during the backward pass, the model runs out of memory. I find this pretty odd as I (should/probably) have enough GPU memory to accomodate for the 8B FSDP and QLORA setup.
Possible limitations
CPU has too little ram. The offloading isn’t possible because we only have 16GB of CPU ram. But following Phil Schmid’s guide and not offloading to the CPU would suffice still, as we use 4 A40’s. This is even more odd when you think that I’m using an 8B version, instead of the 70B versions that are used in both guides.
Not using Flash Attention 2 could also be an issue, but as seen in Phil Schmid’s guide, SDPA can also be used.
Sequence length is too long, causing OOM. I tried setting the max_sequence_length to 512, but this didn’t have any impact on the OOM issue.
Caveat
When i first dove into the rabbithole of FSDP and QLORA I started out simple and just used the following code:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
model = AutoModelForCausalLM.from_pretrained(
'meta-llama/Meta-Llama-3-8B',
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map = 'auto'
use_cache=False if args.gradient_checkpointing else True,
)
I launched the code with:
python3 training.py
This didn’t result in an OOM error and I was able to train for 100 steps. This took quite long however and would become too expensive for me as the training would probably last over 200 hours… I could see that the GPU memory was utilized pretty well and all GPU’s were utilized up until 40GB or so. Because this took quite long, I wanted to use QLORA. But I couldn’t just use QLORA device_map =‘auto’ together. That’s why I resorted to FSDP in combination with QLORA.
I don’t really know why using QLORA in combination with FSDP would then result in the OOM again, making me even more confused.
If you have any ideas, please let me know as I’m getting a bit frustrated after being stuck on this for a few days!