SFT with vLLM Downstream Evaluation: A VRAM-Efficient Pipeline (arm64)
This project implements a Supervised Fine-Tuning (SFT) script designed to maximize training efficiency while enabling high-throughput downstream evaluation.
The Problem: In limited VRAM environments (or when maximizing context length), the training process consumes nearly all available GPU memory. This leaves no room to run generation-based evaluation, often forcing users to rely solely on loss metrics or perform slow, single-batch inference.
The Solution: This script implements a "Swap-and-Eval" strategy:
- Offload: Pauses training and moves the model/optimizer states to system RAM (CPU).
- Serve: Spins up a temporary vLLM server via Docker on the freed GPU.
- Evaluate: Performs high-speed batch inference for accurate downstream metrics.
- Resume: Kills the container, reloads the model to VRAM, and continues training.
!!!!! Attention: This Setup is for arm64.
Can be used to train on GH200 (Grace Hopper) Adjustments needed for amd64/x86_64. (Docker Command for vllm see comments) Also bitsandbytes installation should be simpler, pip install bitsandbytes might be enough.
Also optimized for single GPU, never tested on multi GPU setup. At most have to adjust "--tensor-parallel-size", "1" for vllm spin up.
https://cloud.lambda.ai Does have GH200 (96GB VRAM) for $1.49/h
I did not put to much effort into cleaning the Code or explaining every Part of it. I am happy to Answer Questions if something is unclear.
Example for x86_64 vllm Docker command
docker_cmd = [
"sudo", "docker", "run", "-d",
"--ipc", "host",
"-p", "8000:8000",
"-v", "/home/ubuntu/.cache/huggingface:/root/.cache/huggingface/",
"-v", f"{checkpoint_path}:{container_model_path}", # Mount checkpoint directory
"--gpus", "all",
"-e", "HF_HOME=/root/.cache/huggingface/",
"-e", "TRANSFORMERS_OFFLINE=0",
"-e", "HF_DATASET_OFFLINE=1",
"-e", "NCCL_P2P_LEVEL=NVL",
"-e", "NCCL_SHM_DISABLE=1",
"-e", "NCCL_SOCKET_IFNAME=eth0",
#"-e", "CUDA_LAUNCH_BLOCKING=1",
"-e", "TORCH_CUDA_ARCH_LIST=10.0",
"-e", "HF_TOKEN=XXX",
"--platform", "linux/amd64",
"vllm/vllm-openai",
"--model",
container_model_path,
"--trust-remote-code",
"--host", "0.0.0.0",
"--port", "8000",
"--max-model-len", f"{max_seq_length}",
"--tensor-parallel-size", "1",
"--gpu_memory_utilization", "0.60",
"--tokenizer", container_model_path, # Use mounted path for tokenizer
]
Dataset Structure
expecting items like:
{
"id":XX,
"prompt":XX,
"res": XX, // Assistant Response can be adjusted via --res_key res on programm start
}
Single Turn. Obviously the Script can simply be adjusted to match other Dataset Structures or Models.
Dataset preparation
Adjust for non Qwen Models or different Dataset Structure
def formatting_prompts_func(batch):
output_texts = []
for i in range(len(batch['prompt'])):
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": batch['prompt'][i]},
]
if res_key in batch and batch[res_key][i] is not None and batch[res_key][i] != '':
messages.append({"role": "assistant", "content": batch[res_key][i]})
text = tokenizer.apply_chat_template(messages, tokenize=False)
output_texts.append(text)
return output_texts
collator = DataCollatorForCompletionOnlyLM('<|im_start|>assistant\n', tokenizer=tokenizer)
Same goes for generate_with_vllm ASSISTANT PREFIX and IM_END have to be adjusted for different Models to properly split.
def generate_with_vllm(self, eval_dataloader, checkpoint_dir):
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="XXX", max_retries=3)
tokenizer = self.processing_class # alias
ASSISTANT_PREFIX = "<|im_start|>assistant\n"
IM_END = "<|im_end|>"
assistant_prefix_ids = tokenizer.encode(ASSISTANT_PREFIX, add_special_tokens=False)
def find_subseq(seq, subseq):
L, M = len(seq), len(subseq)
for i in range(L - M + 1):
if seq[i:i + M] == subseq:
return i
return -1
prompts, references = [], []
for batch in eval_dataloader:
input_ids = batch["input_ids"]
labels = batch["labels"]
attn = batch.get("attention_mask", None)
B = input_ids.size(0)
for i in range(B):
ids = input_ids[i].tolist()
if attn is not None:
ids = ids[:int(attn[i].sum().item())] # trim pads
# prompt = everything up to and including "<|im_start|>assistant\n"
j = find_subseq(ids, assistant_prefix_ids)
assert j != -1, "assistant prefix not found in eval row"
prompt_ids = ids[: j + len(assistant_prefix_ids)]
prompt_text = tokenizer.decode(prompt_ids, skip_special_tokens=False)
prompts.append(prompt_text)
# reference = only label tokens (non -100)
label_ids = [tid for tid, lab in zip(input_ids[i].tolist(), labels[i].tolist()) if lab != -100]
ref_text = tokenizer.decode(label_ids, skip_special_tokens=False)
references.append(ref_text)
# Run vLLM (batch the prompts; stop at <|im_end|>)
predictions = []
batch_size = 5000
for s in range(0, len(prompts), batch_size):
batch = prompts[s:s + batch_size]
try:
resp = client.completions.create(
model="/model",
prompt=batch,
max_tokens=8000, # Adjust if necessary
temperature=0.0,
stop=[IM_END],
timeout=18000,
)
predictions.extend([c.text for c in resp.choices])
except Exception as e:
print(f"VLLM error at batch {s}: {e}")
predictions.extend([""] * len(batch))
return predictions, references
requirements.txt
torch
transformers[torch]
datasets
accelerate
hf_transfer
huggingface_hub
tqdm
trl==0.14.0
openai
prepare
python3 -m venv venv; source venv/bin/activate; pip install -r requirements.txt; pip uninstall torch; pip install torch --index-url https://download.pytorch.org/whl/cu126;
pip install bitsandbytes; git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git; cd bitsandbytes; git checkout 7aec4a88465440b6466a526fe9bbb30930a04ba4; python setup.py install ;cmake -DCOMPUTE_BACKEND=cuda -S .; make CUDA_VERSION=126; cp /home/ubuntu/bitsandbytes/bitsandbytes/libbitsandbytes_cuda128.so /home/ubuntu/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda126.so
# NOTE: THE LAST COMMAND CAN DIFFER !
# Check the Locations, when in doubt try running sft.py -> will get message some_path/xxx.so is missing
# Now you know where to place the file, check the the bitsandytes installation and copy the cudaXX.so to desired location
e.g:
cp /home/ubuntu/bitsandbytes/bitsandbytes/libbitsandbytes_cuda128.so /home/ubuntu/venv/lib/python3.10/site-packages/UNKNOWN-0.49.1.dev0-py3.10-linux-aarch64.egg/bitsandbytes/libbitsandbytes_cuda126.so
OR
cp /home/ubuntu/bitsandbytes/bitsandbytes/libbitsandbytes_cuda128.so /home/ubuntu/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda126.so
eval
train_and_evaluate(args.model, args.dataset,eval_positions.eval_timestamps , args.res_key)
Eval Function has to be passed which expects
def eval_timestamps(predictions: List[str], references: List[str]) -> dict:
return {"final":X}
On different Naming check metric_for_best_model on TrainingArguments to make sure they match.
usage
python sft.py --model Qwen/Qwen3-4B-Instruct-2507 --dataset local_path/or_repo_id --res_key res
sft.py (Full Example)
import argparse
import gc
import json
import logging
import os
import subprocess
import time
import numpy as np
import openai
import torch
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import create_repo
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,AutoConfig
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer, SFTConfig
from datetime import datetime
import eval_positions as eval_positions
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.INFO)
os.environ['HF_TOKEN'] = REPLACE_ME # Adjust
repo_owner = REPLACE_ME # Adjust
max_seq_length = REPLACE_ME
# Custom SFTTrainer with VLLM evaluation
class SFTTrainerWithVLLMEval(SFTTrainer):
def __init__(self, *args, score_func=None, **kwargs):
super().__init__(*args, **kwargs)
self.score_func = score_func
if score_func is None:
raise ValueError("score_func must be provided for evaluation.")
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
eval_dataloader = self.get_eval_dataloader(eval_dataset)
# Step 1: Free GPU memory
optimizer_state = self.optimizer.state_dict() if self.optimizer else None
lr_scheduler_state = self.lr_scheduler.state_dict() if self.lr_scheduler else None
self.model.to('cpu')
torch.cuda.empty_cache()
if hasattr(self, "optimizer"):
del self.optimizer
self.optimizer = None
if hasattr(self, "lr_scheduler"):
del self.lr_scheduler
self.lr_scheduler = None
gc.collect()
time.sleep(20)
#self._save(self.args.output_dir,self.state)
# Step 2: Save current checkpoint temporarily
checkpoint_dir = os.path.join(f"temp_checkpoint_{self.state.global_step or 0}")
self.model.save_pretrained(checkpoint_dir)
self.processing_class.save_pretrained(checkpoint_dir)
# Step 3: Start VLLM server with the checkpoint and perform evaluation
process = self.start_vllm_server(checkpoint_dir)
try:
predictions, references = self.generate_with_vllm(eval_dataloader, checkpoint_dir)
finally:
self.stop_vllm_server(process)
print('sleep 20')
time.sleep(30)
print('load model to gpu')
# maybe have to reload from checkpoint instead
self.model.to('cuda')
self.create_optimizer_and_scheduler(num_training_steps=self.state.max_steps)
if optimizer_state:
self.optimizer.load_state_dict(optimizer_state)
if lr_scheduler_state:
self.lr_scheduler.load_state_dict(lr_scheduler_state)
torch.cuda.empty_cache()
gc.collect()
# Step 5: Compute metrics
print(f'{datetime.now().strftime("%H:%M:%S")}: calc score')
field_averages = self.score_func(predictions, references)
print(f'{datetime.now().strftime("%H:%M:%S")}: done calc score')
# Save eval results
output_dir = os.path.join(self.args.output_dir, 'eval_steps')
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, f'eval_{self.state.global_step or 0}.json'), "w", encoding='utf-8') as f:
eval_step = {
"predictions": predictions,
"reference": references,
}
json.dump(eval_step, f, indent=4)
updated_field_averages = {f"eval_{key}": value for key, value in field_averages.items()}
self.log(updated_field_averages)
print(updated_field_averages)
return updated_field_averages
def start_vllm_server(self, checkpoint_dir):
"""Start the VLLM server using Docker with the checkpoint path mounted."""
# Absolute path for Docker volume mounting
checkpoint_path = os.path.abspath(checkpoint_dir)
container_model_path = "/model" # Path inside the container where checkpoint is mounted
# Docker run command with volume mapping for checkpoint
docker_cmd = [
"sudo", "docker", "run", "-d",
"--ipc", "host",
"-p", "8000:8000",
"-v", "/home/ubuntu/.cache/huggingface:/root/.cache/huggingface/",
"-v", f"{checkpoint_path}:{container_model_path}", # Mount checkpoint directory
"--gpus", "all",
"-e", "HF_HOME=/root/.cache/huggingface/",
"-e", "TRANSFORMERS_OFFLINE=0",
"-e", "HF_DATASET_OFFLINE=1",
"-e", "NCCL_P2P_LEVEL=NVL",
"-e", "NCCL_SHM_DISABLE=1",
"-e", "NCCL_SOCKET_IFNAME=eth0",
"-e", "CUDA_LAUNCH_BLOCKING=1",
"-e", "TORCH_USE_CUDA_DSA=1",
"-e", "HF_TOKEN=XXX",
#"--platform", "linux/amd64", # Non Arm Platform
"--platform", "linux/arm64", # Arm Platform
#"vllm/vllm-openai:v0.10.2",
"ghcr.io/lambdalabsml/vllm-builder:latest", ## Brauche hier anderes Image in anderem Fall
#"--model", # Syntax on non arm
container_model_path, # Use mounted path as model
"--trust-remote-code",
"--host", "0.0.0.0",
"--port", "8000",
"--max-model-len", f"{max_seq_length}",
"--tensor-parallel-size", "1",
"--gpu_memory_utilization", "0.60",
"--tokenizer", container_model_path, # Use mounted path for tokenizer
]
# Start the Docker container
process = subprocess.Popen(docker_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Wait briefly and get container ID
time.sleep(5)
container_id = process.stdout.read().decode().strip()
if not container_id:
raise RuntimeError(f"Failed to start VLLM server: {process.stderr.read().decode()}")
# Wait until the server is ready
client = openai.OpenAI(base_url="http://localhost:8000/v1", api_key="XXX")
max_attempts = 120
for _ in range(max_attempts):
try:
client.models.list()
print("VLLM server started successfully.")
return container_id
except Exception:
time.sleep(2)
raise RuntimeError("VLLM server failed to start within timeout.")
def stop_vllm_server(self, container_id):
"""Stop and remove the VLLM Docker container."""
try:
subprocess.run(["sudo","docker", "stop", container_id], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
subprocess.run(["sudo","docker", "rm", container_id], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("VLLM server stopped and container removed.")
except subprocess.CalledProcessError as e:
print(f"Failed to stop VLLM server: {e}")
def generate_with_vllm(self, eval_dataloader, checkpoint_dir):
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="XXX", max_retries=3)
tokenizer = self.processing_class # alias
ASSISTANT_PREFIX = "<|im_start|>assistant\n"
IM_END = "<|im_end|>"
assistant_prefix_ids = tokenizer.encode(ASSISTANT_PREFIX, add_special_tokens=False)
def find_subseq(seq, subseq):
L, M = len(seq), len(subseq)
for i in range(L - M + 1):
if seq[i:i + M] == subseq:
return i
return -1
prompts, references = [], []
for batch in eval_dataloader:
input_ids = batch["input_ids"]
labels = batch["labels"]
attn = batch.get("attention_mask", None)
B = input_ids.size(0)
for i in range(B):
ids = input_ids[i].tolist()
if attn is not None:
ids = ids[:int(attn[i].sum().item())] # trim pads
# prompt = everything up to and including "<|im_start|>assistant\n"
j = find_subseq(ids, assistant_prefix_ids)
assert j != -1, "assistant prefix not found in eval row"
prompt_ids = ids[: j + len(assistant_prefix_ids)]
prompt_text = tokenizer.decode(prompt_ids, skip_special_tokens=False)
prompts.append(prompt_text)
# reference = only label tokens (non -100)
label_ids = [tid for tid, lab in zip(input_ids[i].tolist(), labels[i].tolist()) if lab != -100]
ref_text = tokenizer.decode(label_ids, skip_special_tokens=False)
references.append(ref_text)
# Run vLLM (batch the prompts; stop at <|im_end|>)
predictions = []
batch_size = 5000
for s in range(0, len(prompts), batch_size):
batch = prompts[s:s + batch_size]
try:
resp = client.completions.create(
model="/model",
prompt=batch,
max_tokens=8000, # Adjust if necessary
temperature=0.0,
stop=[IM_END],
timeout=18000,
)
predictions.extend([c.text for c in resp.choices])
except Exception as e:
print(f"VLLM error at batch {s}: {e}")
predictions.extend([""] * len(batch))
return predictions, references
def adjust_steps_ez(train_dataset_size,effective_batch_size,percentage):
steps_per_epoch = train_dataset_size // effective_batch_size
steps = max(1, steps_per_epoch * percentage // 100)
return steps
def train_and_evaluate(model_name, dataset_path, score_func, res_key):
# Simplify model naming
new_model_name = f'{dataset_path.split("/")[-1].split(".")[0]}_{model_name.split("/")[-1]}'
os.makedirs(new_model_name, exist_ok=True)
logging.info(f'NEW MODEL NAME IS {new_model_name}')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
def model_init():
# Load the base configuration
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# Apply YaRN scaling
# Remove if not needed
config.rope_scaling = {
"rope_type": "yarn",
"factor": 4.0,
"original_max_position_embeddings": 32768
}
# Set the new max position embeddings (32768 * 4)
config.max_position_embeddings = 131072
config.use_cache = False # because of gradient checkpointing
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype="auto",
trust_remote_code=True,
).to('cuda')
model.gradient_checkpointing_enable()
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model
def get_dataset(dataset_path):
seed = 3123
# Set seed for reproducibility (Numpy)
np.random.seed(seed)
# 1. Check if path exists locally
if os.path.exists(dataset_path):
print(f"Loading local file: {dataset_path}")
with open(dataset_path, "r", encoding="utf-8") as f:
data = json.loads(f.read())
# Check if the expected keys exist in the JSON
if isinstance(data, dict) and "train" in data and "test" in data:
# If pre-split in JSON, convert to Dataset directly
dataset = DatasetDict({
"train": Dataset.from_list(data["train"]),
"test": Dataset.from_list(data["test"]),
})
else:
# Flatten and shuffle
full_pool = data if isinstance(data, list) else list(data.values())
np.random.shuffle(full_pool)
# Calculate split index (0.9 for train)
split_idx = int(len(full_pool) * 0.9)
train_data = full_pool[:split_idx]
test_data = full_pool[split_idx:]
dataset = DatasetDict({
"train": Dataset.from_list(train_data),
"test": Dataset.from_list(test_data),
})
# 2. If not local, assume Hugging Face Repo ID
else:
print(f"Loading from Hugging Face Hub: {dataset_path}")
# Load dataset (downloads config if necessary)
raw_dataset = load_dataset(dataset_path) # token from env
# Check structure and standardize to DatasetDict with train/test
if isinstance(raw_dataset, DatasetDict):
if "test" not in raw_dataset:
# If there is a validation set but no test, map validation to test
# OR split train if only train exists.
if "validation" in raw_dataset:
dataset = DatasetDict({
"train": raw_dataset["train"],
"test": raw_dataset["validation"]
})
elif "train" in raw_dataset:
# Split train 90/10 to match your manual logic
dataset = raw_dataset["train"].train_test_split(test_size=0.1, seed=seed)
else:
# Fallback for weird structures
dataset = raw_dataset
else:
dataset = raw_dataset
else:
# If raw_dataset is a single Dataset object (no splits yet)
dataset = raw_dataset.train_test_split(test_size=0.1, seed=seed)
# Final Access
train_set = dataset["train"]
test_set = dataset["test"]
return dataset,train_set,test_set
dataset,train_set,test_set = get_dataset(dataset_path)
# Formatting function
# Has to be adjusted for other Models
def formatting_prompts_func(batch):
output_texts = []
for i in range(len(batch['prompt'])):
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": batch['prompt'][i]},
]
if res_key in batch and batch[res_key][i] is not None and batch[res_key][i] != '':
messages.append({"role": "assistant", "content": batch[res_key][i]})
text = tokenizer.apply_chat_template(messages, tokenize=False)
output_texts.append(text)
return output_texts
collator = DataCollatorForCompletionOnlyLM('<|im_start|>assistant\n', tokenizer=tokenizer)
bs = 1
ga = 4 # 16
epochs = 1
eval_percentage = 10
# Static training arguments
train_args = TrainingArguments(
output_dir=new_model_name,
learning_rate=5e-5,
num_train_epochs=epochs,
per_device_train_batch_size=bs,
gradient_accumulation_steps=ga,
per_device_eval_batch_size=16,
eval_steps=adjust_steps_ez(len(train_set),bs*ga,eval_percentage),
save_steps=adjust_steps_ez(len(train_set),bs*ga,eval_percentage),
logging_steps=5,
warmup_ratio=0.1,
optim="adamw_8bit",
eval_strategy='steps',
save_strategy='steps',
log_level="debug",
lr_scheduler_type="cosine",
hub_strategy="all_checkpoints",
metric_for_best_model="eval_final",
greater_is_better=True,
load_best_model_at_end=True,
save_total_limit=None,
gradient_checkpointing=True,
bf16=True,
)
# Initialize trainer
trainer = SFTTrainerWithVLLMEval(
model=model_init(),
args=SFTConfig(**train_args.to_dict(), max_seq_length=max_seq_length),
train_dataset=train_set,
eval_dataset=test_set,
processing_class=tokenizer,
formatting_func=formatting_prompts_func,
data_collator=collator,
score_func=score_func,
)
trainer.train()
# Push to repo
repo_name = f'{repo_owner}/{new_model_name}'
try:
create_repo(repo_name, private=True)
except Exception as e:
logging.info(f"Repo push failed: {e}")
# Cleanup
del trainer
del dataset
del train_set
del test_set
torch.cuda.empty_cache()
gc.collect()
def main():
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda if torch.cuda.is_available() else "No CUDA")
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU detected")
if not torch.cuda.is_available():
print("NO CUDA EXIT!")
exit()
parser = argparse.ArgumentParser(description="Train a Qwen Model on the Markdown set")
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--dataset", type=str, required=True, help="Dataset")
parser.add_argument("--res_key", type=str, required=True, help="Result key to extract prediction")
args = parser.parse_args()
train_and_evaluate(args.model, args.dataset,eval_positions.eval_timestamps , args.res_key)
if __name__ == "__main__":
main()