nroggendorff commited on
Commit
e1f7a3c
·
verified ·
1 Parent(s): e6ad1f5

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +209 -107
train.py CHANGED
@@ -1,121 +1,206 @@
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
3
  import datasets
4
  from datasets import Dataset
5
- from typing import cast
6
- import os
7
- import shutil
8
- import multiprocessing as mp
9
  from PIL import Image
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def load_model(model_name, device_id=0):
13
  bnb_config = BitsAndBytesConfig(
14
  load_in_4bit=True,
15
  bnb_4bit_compute_dtype=torch.bfloat16,
16
  bnb_4bit_quant_type="nf4",
17
- bnb_4bit_use_double_quant=False,
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
21
- processor.tokenizer.padding_side = "left"
 
 
 
 
22
 
23
  model = AutoModelForImageTextToText.from_pretrained(
24
  model_name,
25
  quantization_config=bnb_config,
26
  dtype=torch.bfloat16,
27
- device_map={"": device_id},
28
  attn_implementation="flash_attention_2",
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  return processor, model
32
 
33
 
34
- def caption_batch(batch, processor, model):
35
- images = batch["image"]
 
 
 
 
 
 
36
 
37
  pil_images = []
38
- for image in images:
39
- if isinstance(image, Image.Image):
40
- if image.mode != "RGB":
41
- image = image.convert("RGB")
42
- pil_images.append(image)
43
 
44
- msg = [
45
- {
46
- "role": "user",
47
- "content": [
48
- {"type": "image"},
49
- {
50
- "type": "text",
51
- "text": "Describe the image concisely, and skip mentioning that it's illustrated or from anime.",
52
- },
53
- ],
54
- }
55
- ]
56
-
57
- text = processor.apply_chat_template(
58
- msg, add_generation_prompt=True, tokenize=False
59
- )
60
- texts = [text] * len(pil_images)
61
 
 
62
  inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
63
 
64
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
65
 
66
- with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
 
 
 
 
67
  generated = model.generate(
68
  **inputs,
69
  max_new_tokens=128,
70
  do_sample=False,
 
71
  )
72
 
73
- decoded = processor.batch_decode(generated, skip_special_tokens=False)
74
-
75
- captions = []
76
- special_tokens = set(processor.tokenizer.all_special_tokens)
77
- for d in decoded:
78
- if "<|im_start|>assistant" in d:
79
- d = d.split("<|im_start|>assistant")[-1]
80
 
81
- for token in special_tokens:
82
- d = d.replace(token, "")
83
 
84
- d = d.strip()
85
- captions.append(d)
86
 
87
- return {
88
- "text": captions,
89
- }
90
-
91
-
92
- def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, output_file):
93
  try:
94
- torch.cuda.set_device(gpu_id)
95
-
96
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
97
- processor, model = load_model(model_name, gpu_id)
98
 
99
- print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
100
- loaded = datasets.load_dataset(input_dataset, split=f"train[{start}:{end}]")
 
 
 
101
 
102
- if isinstance(loaded, datasets.DatasetDict):
103
- shard = cast(Dataset, loaded["train"])
104
- else:
105
- shard = cast(Dataset, loaded)
106
 
107
- print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
108
  result = shard.map(
109
- lambda batch: caption_batch(batch, processor, model),
110
  batched=True,
111
- batch_size=batch_size,
112
- remove_columns=[col for col in shard.column_names if col != "image"],
 
113
  )
114
 
115
- print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
 
 
116
  result.save_to_disk(output_file)
117
 
118
- print(f"[GPU {gpu_id}] Done!", flush=True)
119
  return output_file
120
  except Exception as e:
121
  print(f"[GPU {gpu_id}] Error: {e}", flush=True)
@@ -123,72 +208,89 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
123
 
124
 
125
  def main():
126
- mp.set_start_method('spawn', force=True)
127
-
128
- input_dataset = "none-yet/anime-captions"
129
- output_dataset = "nroggendorff/anime-captions"
130
- model_name = "datalab-to/chandra"
131
- batch_size = 16
132
-
133
- print("Loading dataset info...")
134
- loaded = datasets.load_dataset(input_dataset, split="train")
135
 
136
- if isinstance(loaded, datasets.DatasetDict):
137
- ds = cast(Dataset, loaded["train"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  else:
139
- ds = cast(Dataset, loaded)
140
 
141
- num_gpus = torch.cuda.device_count()
 
142
  total_size = len(ds)
 
 
 
143
  shard_size = total_size // num_gpus
144
 
145
- print(f"Dataset size: {total_size}")
146
- print(f"Using {num_gpus} GPUs")
147
- print(f"Shard size: {shard_size}")
148
 
 
149
  processes = []
150
- temp_files = []
151
-
152
  for i in range(num_gpus):
153
  start = i * shard_size
154
  end = start + shard_size if i < num_gpus - 1 else total_size
155
- output_file = f"temp_shard_{i}"
156
- temp_files.append(output_file)
157
 
158
  p = mp.Process(
159
  target=process_shard,
160
- args=(i, start, end, model_name, batch_size, input_dataset, output_file),
 
161
  )
162
  p.start()
163
  processes.append(p)
164
 
 
165
  for p in processes:
166
  p.join()
167
  if p.exitcode != 0:
168
- print(f"\nProcess failed with exit code {p.exitcode}", flush=True)
169
- print("Terminating all processes...", flush=True)
170
- for proc in processes:
171
- if proc.is_alive():
172
- proc.terminate()
173
- for proc in processes:
174
- proc.join()
175
- raise RuntimeError(f"At least one process failed")
176
-
177
- print("\nAll processes completed. Loading and concatenating results...")
178
-
179
- shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
180
  final_ds = datasets.concatenate_datasets(shards)
181
 
182
- print(f"Final dataset size: {len(final_ds)}")
183
- print("Pushing to hub...")
184
- final_ds.push_to_hub(output_dataset, create_pr=False)
185
 
186
- print("Cleaning up temporary files...")
187
- for f in temp_files:
188
- if os.path.exists(f):
189
- shutil.rmtree(f)
 
 
190
 
191
- print("Done!")
192
 
193
 
194
  if __name__ == "__main__":
 
1
+ # caption_pipeline_fast.py
2
+ import os
3
+ import shutil
4
+ import io
5
+ import multiprocessing as mp
6
+ from typing import Tuple, Dict, Any
7
+
8
  import torch
9
  from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
10
  import datasets
11
  from datasets import Dataset
 
 
 
 
12
  from PIL import Image
13
 
14
+ # -------------------------
15
+ # CONFIG
16
+ # -------------------------
17
+ INPUT_DATASET = "none-yet/anime-captions" # original dataset id / path
18
+ PREPROCESSED_DIR = "preprocessed_ds" # temporary preprocessed dataset on disk
19
+ TEMP_SHARD_PREFIX = "temp_shard_" # per-GPU result dirs
20
+ OUTPUT_DATASET = "nroggendorff/anime-captions"
21
+ MODEL_NAME = "datalab-to/chandra"
22
+ BATCH_SIZE = 32 # try 32 or 64 depending on VRAM
23
+ PREPROCESS_NUM_PROC = max(1, mp.cpu_count() - 2)
24
+ DEVICE_BATCH_PREPIN = True # pin memory before to(device)
25
+ USE_BETTERTRANSFORMER = True # try BetterTransformer if installed
26
+ # -------------------------
27
+
28
+ def preprocess_example(example: Dict[str, Any]) -> Dict[str, Any]:
29
+ """
30
+ Convert image to RGB bytes and store the prompt string once per example.
31
+ This is run in main process (once).
32
+ """
33
+ img = example["image"]
34
+ if not isinstance(img, Image.Image):
35
+ # datasets Image feature may already give PIL or path - handle both
36
+ try:
37
+ img = Image.open(io.BytesIO(img)) # if raw bytes
38
+ except Exception:
39
+ # fall back to the feature handling
40
+ img = img.convert("RGB")
41
+ if img.mode != "RGB":
42
+ img = img.convert("RGB")
43
+
44
+ bio = io.BytesIO()
45
+ img.save(bio, format="PNG") # PNG keeps quality and is easy to decode later
46
+ example["image_bytes"] = bio.getvalue()
47
+
48
+ # keep the original image field for compatibility if you want
49
+ # but we'll use image_bytes in workers
50
+ return example
51
+
52
+
53
+ def prepare_and_save_dataset(input_dataset: str, processor_chat_prompt: str) -> None:
54
+ """
55
+ Loads dataset once, preprocesses images to bytes, writes a
56
+ new field 'image_bytes' and saves to PREPROCESSED_DIR.
57
+ """
58
+ print("[main] Loading dataset for preprocessing...")
59
+ loaded = datasets.load_dataset(input_dataset, split="train")
60
+ ds = loaded if not isinstance(loaded, datasets.DatasetDict) else loaded["train"]
61
+
62
+ # Remove any columns we don't need (keep image) to save space
63
+ # But keep other metadata if needed
64
+ cols_to_remove = [c for c in ds.column_names if c not in ("image",)]
65
+ if cols_to_remove:
66
+ ds = ds.remove_columns(cols_to_remove)
67
+
68
+ print(f"[main] Preprocessing images to bytes with {PREPROCESS_NUM_PROC} procs...")
69
+ ds = ds.map(preprocess_example, remove_columns=[], num_proc=PREPROCESS_NUM_PROC)
70
+
71
+ # store the constant chat template string in dataset (small redundancy) to avoid recomputing
72
+ print("[main] Storing prompt string per example (small overhead)...")
73
+ ds = ds.add_column("prompt", [processor_chat_prompt] * len(ds))
74
+
75
+ # save to disk for fast worker access (preprocessed once)
76
+ if os.path.exists(PREPROCESSED_DIR):
77
+ shutil.rmtree(PREPROCESSED_DIR)
78
+ print(f"[main] Saving preprocessed dataset to {PREPROCESSED_DIR} ...")
79
+ ds.save_to_disk(PREPROCESSED_DIR)
80
+ print("[main] Preprocessing complete.")
81
+
82
+
83
+ def load_model_for_gpu(model_name: str, gpu_id: int):
84
+ """
85
+ Load model + processor on the target GPU with 4-bit config (like your original)
86
+ """
87
+ torch.cuda.set_device(gpu_id)
88
 
 
89
  bnb_config = BitsAndBytesConfig(
90
  load_in_4bit=True,
91
  bnb_4bit_compute_dtype=torch.bfloat16,
92
  bnb_4bit_quant_type="nf4",
93
+ bnb_4bit_use_double_quant=True,
94
  )
95
 
96
  processor = AutoProcessor.from_pretrained(model_name)
97
+ # keep left padding as you had
98
+ try:
99
+ processor.tokenizer.padding_side = "left"
100
+ except Exception:
101
+ pass
102
 
103
  model = AutoModelForImageTextToText.from_pretrained(
104
  model_name,
105
  quantization_config=bnb_config,
106
  dtype=torch.bfloat16,
107
+ device_map={"": gpu_id},
108
  attn_implementation="flash_attention_2",
109
  )
110
 
111
+ # Try BetterTransformer if available
112
+ if USE_BETTERTRANSFORMER:
113
+ try:
114
+ from optimum.bettertransformer import BetterTransformer
115
+ model = BetterTransformer.transform(model)
116
+ print(f"[GPU {gpu_id}] Applied BetterTransformer.")
117
+ except Exception:
118
+ # not fatal
119
+ print(f"[GPU {gpu_id}] BetterTransformer unavailable or failed; continuing.")
120
+
121
+ model.eval()
122
  return processor, model
123
 
124
 
125
+ def caption_batch_from_bytes(batch: Dict[str, Any], processor, model) -> Dict[str, Any]:
126
+ """
127
+ Given a batch from the preprocessed dataset (contains 'image_bytes' and 'prompt'),
128
+ reconstruct PIL images, call processor, run generate, decode, and return texts.
129
+ """
130
+ image_bytes_list = batch["image_bytes"]
131
+ prompts = batch["prompt"]
132
+ assert len(image_bytes_list) == len(prompts)
133
 
134
  pil_images = []
135
+ for b in image_bytes_list:
136
+ img = Image.open(io.BytesIO(b))
137
+ if img.mode != "RGB":
138
+ img = img.convert("RGB")
139
+ pil_images.append(img)
140
 
141
+ # processor.apply_chat_template was already run on main, so prompts are ready strings
142
+ texts = list(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # Build inputs. This step will perform tokenizer + image feature extraction.
145
  inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
146
 
147
+ # Pin memory for faster host->device copy if enabled
148
+ if DEVICE_BATCH_PREPIN:
149
+ for k, v in inputs.items():
150
+ if torch.is_tensor(v):
151
+ inputs[k] = v.pin_memory()
152
 
153
+ # Move to device with non_blocking transfer (works with pinned memory)
154
+ device = model.device
155
+ inputs = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in inputs.items()}
156
+
157
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
158
  generated = model.generate(
159
  **inputs,
160
  max_new_tokens=128,
161
  do_sample=False,
162
+ num_beams=1,
163
  )
164
 
165
+ # decode skipping special tokens to avoid expensive post-processing
166
+ decoded = processor.batch_decode(generated, skip_special_tokens=True)
 
 
 
 
 
167
 
168
+ # clean and return
169
+ return {"text": [d.strip() for d in decoded]}
170
 
 
 
171
 
172
+ def process_shard(gpu_id: int, start: int, end: int, output_file: str):
173
+ """
174
+ Worker process: loads the preprocessed dataset shard, loads the model on the GPU,
175
+ runs batched generation and saves the results to disk.
176
+ """
 
177
  try:
 
 
178
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
179
+ processor, model = load_model_for_gpu(MODEL_NAME, gpu_id)
180
 
181
+ print(f"[GPU {gpu_id}] Loading preprocessed dataset from disk...", flush=True)
182
+ ds = datasets.load_from_disk(PREPROCESSED_DIR)
183
+ # slice with select for a true copy
184
+ indices = list(range(start, end))
185
+ shard = ds.select(indices)
186
 
187
+ print(f"[GPU {gpu_id}] Processing {len(shard)} examples (shard indices {start}:{end}) ...", flush=True)
 
 
 
188
 
189
+ # map with batched generator function (uses our caption_batch_from_bytes)
190
  result = shard.map(
191
+ lambda batch: caption_batch_from_bytes(batch, processor, model),
192
  batched=True,
193
+ batch_size=BATCH_SIZE,
194
+ remove_columns=[col for col in shard.column_names if col not in ("image_bytes", "prompt")],
195
+ num_proc=1, # model inference must run in the GPU process (no multiproc here)
196
  )
197
 
198
+ print(f"[GPU {gpu_id}] Saving results to {output_file} ...", flush=True)
199
+ if os.path.exists(output_file):
200
+ shutil.rmtree(output_file)
201
  result.save_to_disk(output_file)
202
 
203
+ print(f"[GPU {gpu_id}] Done.", flush=True)
204
  return output_file
205
  except Exception as e:
206
  print(f"[GPU {gpu_id}] Error: {e}", flush=True)
 
208
 
209
 
210
  def main():
211
+ mp.set_start_method("spawn", force=True)
 
 
 
 
 
 
 
 
212
 
213
+ # 1) Load processor temporarily to build the chat prompt once
214
+ print("[main] Loading processor to create chat prompt...")
215
+ tmp_proc = AutoProcessor.from_pretrained(MODEL_NAME)
216
+ chat_msg = [
217
+ {
218
+ "role": "user",
219
+ "content": [
220
+ {"type": "image"},
221
+ {
222
+ "type": "text",
223
+ "text": "Describe the image concisely, and skip mentioning that it's illustrated or from anime.",
224
+ },
225
+ ],
226
+ }
227
+ ]
228
+ # keep tokenize=False so we store the raw prompt and let processor tokenize in workers with padding semantics
229
+ prompt_str = tmp_proc.apply_chat_template(chat_msg, add_generation_prompt=True, tokenize=False)
230
+ del tmp_proc
231
+
232
+ # 2) Preprocess dataset once (images -> bytes, add prompt column)
233
+ if not os.path.exists(PREPROCESSED_DIR):
234
+ prepare_and_save_dataset(INPUT_DATASET, prompt_str)
235
  else:
236
+ print(f"[main] Preprocessed dataset found at {PREPROCESSED_DIR}, skipping preprocess.")
237
 
238
+ # 3) Load the preprocessed dataset to compute shard indices
239
+ ds = datasets.load_from_disk(PREPROCESSED_DIR)
240
  total_size = len(ds)
241
+ num_gpus = torch.cuda.device_count()
242
+ if num_gpus == 0:
243
+ raise RuntimeError("No GPUs found. This script requires GPUs.")
244
  shard_size = total_size // num_gpus
245
 
246
+ print(f"[main] Dataset size: {total_size}")
247
+ print(f"[main] Using {num_gpus} GPUs (shard size {shard_size})")
 
248
 
249
+ # 4) Spawn worker processes
250
  processes = []
251
+ temp_dirs = []
 
252
  for i in range(num_gpus):
253
  start = i * shard_size
254
  end = start + shard_size if i < num_gpus - 1 else total_size
255
+ out_dir = f"{TEMP_SHARD_PREFIX}{i}"
256
+ temp_dirs.append(out_dir)
257
 
258
  p = mp.Process(
259
  target=process_shard,
260
+ args=(i, start, end, out_dir),
261
+ daemon=False,
262
  )
263
  p.start()
264
  processes.append(p)
265
 
266
+ # 5) wait for processes
267
  for p in processes:
268
  p.join()
269
  if p.exitcode != 0:
270
+ print(f"[main] Process {p.pid} failed with exit code {p.exitcode}. Terminating others.", flush=True)
271
+ for q in processes:
272
+ if q.is_alive():
273
+ q.terminate()
274
+ for q in processes:
275
+ q.join()
276
+ raise RuntimeError("At least one GPU worker failed.")
277
+
278
+ print("[main] All workers finished. Concatenating shards...")
279
+
280
+ shards = [datasets.load_from_disk(d) for d in temp_dirs]
 
281
  final_ds = datasets.concatenate_datasets(shards)
282
 
283
+ print(f"[main] Final dataset size: {len(final_ds)}. Pushing to hub as {OUTPUT_DATASET} ...")
284
+ final_ds.push_to_hub(OUTPUT_DATASET, create_pr=False)
 
285
 
286
+ print("[main] Cleaning up temporary files...")
287
+ for d in temp_dirs:
288
+ if os.path.exists(d):
289
+ shutil.rmtree(d)
290
+ # optionally keep PREPROCESSED_DIR for re-runs; comment out removal if you want to keep it
291
+ # shutil.rmtree(PREPROCESSED_DIR)
292
 
293
+ print("[main] Done.")
294
 
295
 
296
  if __name__ == "__main__":