leoeric commited on
Commit
5db3c1f
·
1 Parent(s): 0194c79

Add dataset.py and fix OpenMP warning

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. dataset.py +929 -0
app.py CHANGED
@@ -12,6 +12,9 @@ import subprocess
12
  import pathlib
13
  from pathlib import Path
14
 
 
 
 
15
  # Try to import huggingface_hub for downloading checkpoints
16
  try:
17
  from huggingface_hub import hf_hub_download
 
12
  import pathlib
13
  from pathlib import Path
14
 
15
+ # Fix OpenMP warning
16
+ os.environ['OMP_NUM_THREADS'] = '1'
17
+
18
  # Try to import huggingface_hub for downloading checkpoints
19
  try:
20
  from huggingface_hub import hf_hub_download
dataset.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ import io
6
+ import os
7
+ import csv
8
+ import json
9
+ import random
10
+ import torch
11
+ import numpy as np
12
+ import math
13
+ import time
14
+ import contextlib
15
+ from typing import Optional, Union
16
+ from PIL import Image
17
+ from collections import defaultdict
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms
20
+ from torch.utils.data import default_collate, get_worker_info
21
+ import tarfile
22
+ import tqdm
23
+ import gc
24
+ import threading
25
+ import psutil
26
+ import tempfile
27
+ import decord
28
+ from decord import VideoReader
29
+ import concurrent.futures
30
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
31
+ from misc import print, xprint
32
+ from misc.condition_utils import get_camera_condition, get_point_condition, get_wind_condition
33
+
34
+ # Initialize multiprocessing manager
35
+ manager = torch.multiprocessing.Manager()
36
+
37
+ # ==== helpers ==== #
38
+
39
+ @contextlib.contextmanager
40
+ def ram_temp_file(data, suffix=".mp4"):
41
+ available_ram = psutil.virtual_memory().available
42
+ video_size = len(data)
43
+
44
+ # Use RAM if available, otherwise fall back to disk
45
+ if video_size < available_ram - (500 * 1024 * 1024):
46
+ temp_dir = "/dev/shm" # RAM disk
47
+ else:
48
+ temp_dir = None # Default system temp (disk)
49
+
50
+ with tempfile.NamedTemporaryFile(dir=temp_dir, suffix=suffix, delete=True) as temp_file:
51
+ temp_file.write(data)
52
+ temp_file.flush()
53
+ yield temp_file.name
54
+
55
+
56
+ def _nearest_multiple(x: float, base: int = 8) -> int:
57
+ """Round x to the nearest multiple of `base`."""
58
+ return int(round(x / base)) * base
59
+
60
+
61
+ def aspect_ratio_to_image_size(target_size, R, multiple=8):
62
+ if R is None:
63
+ return target_size, target_size
64
+ if isinstance(R, str):
65
+ rw, rh = map(int, R.split(':'))
66
+ R = rw / rh
67
+ area = target_size ** 2
68
+ out_h = _nearest_multiple(math.sqrt(area / R), multiple)
69
+ out_w = _nearest_multiple(math.sqrt(area * R), multiple)
70
+ return out_h, out_w
71
+
72
+
73
+ def read_tsv(filename):
74
+ # Open the TSV file for reading
75
+ with open(filename, 'r', newline='') as tsvfile:
76
+ reader = csv.reader(tsvfile, delimiter='\t')
77
+ rows = []
78
+ while True:
79
+ try:
80
+ r = next(reader)
81
+ rows.append(r)
82
+ except csv.Error as e:
83
+ print(f'{e}')
84
+ except StopIteration:
85
+ break
86
+ return rows
87
+
88
+
89
+ def sample_clip(
90
+ video_path: str,
91
+ num_frames: int = 8,
92
+ out_fps: Optional[float] = None, # ← pass an fps here
93
+ ):
94
+ vr = VideoReader(video_path)
95
+ src_fps = vr.get_avg_fps() # native fps
96
+ total = len(vr)
97
+
98
+ if out_fps is None or out_fps >= src_fps:
99
+ step = 1 # keep native rate or up-sample later
100
+ else:
101
+ target_duration = (num_frames - 1) / out_fps # duration in seconds
102
+ frame_span = target_duration * src_fps # frames needed for this duration
103
+ step = max(frame_span / (num_frames - 1), 1)
104
+
105
+ max_start = total - step * (num_frames - 1)
106
+ if max_start <= 1: # video too short for requested clip
107
+ indices = np.linspace(0, total - 1, num_frames, dtype=int)
108
+ return vr.get_batch(indices.tolist()), indices
109
+
110
+ max_start = int(np.floor(max_start - 1))
111
+ start = random.randint(0, max_start) if max_start > 0 else 0
112
+ idxs = [int(np.round(start + i * step)) for i in range(num_frames)]
113
+ return vr.get_batch(idxs), idxs
114
+
115
+
116
+ class InfiniteDataLoader(torch.utils.data.DataLoader):
117
+ def __init__(self, *args, **kwargs):
118
+ super().__init__(*args, **kwargs)
119
+ # Initialize an iterator over the dataset.
120
+ self.dataset_iterator = super().__iter__()
121
+
122
+ def __iter__(self):
123
+ return self
124
+
125
+ def __next__(self):
126
+ try:
127
+ batch = next(self.dataset_iterator)
128
+ except StopIteration:
129
+ # Dataset exhausted, use a new fresh iterator.
130
+ print('Another Loop over the dataset', flush=True)
131
+ self.dataset_iterator = super().__iter__()
132
+ batch = next(self.dataset_iterator)
133
+ return batch
134
+
135
+
136
+ class DataLoaderWrapper(InfiniteDataLoader):
137
+ def __iter__(self):
138
+ return IterWrapper(super().__iter__())
139
+
140
+
141
+ class IterWrapper:
142
+ def __init__(self, obj):
143
+ self.obj = obj
144
+
145
+ def __iter__(self):
146
+ return self
147
+
148
+ def __next__(self):
149
+ return self.next()
150
+
151
+ def next(self):
152
+ return next(self.obj)
153
+
154
+
155
+ # ==== Dataset Implementation, Load your own data ==== #
156
+
157
+ class ImageTarDataset(Dataset):
158
+ def __init__(self, dataset_tsv, image_size, temporal_size=None, rank=0, world_size=1,
159
+ use_image_bucket=False, multiple=8, no_flip=False, edit=False):
160
+ all_lines = []
161
+
162
+ # get all data lines
163
+ self.buckets = {}
164
+ self.weights = {}
165
+ self.image_buckets = defaultdict(lambda: 0)
166
+ self.image_buckets['1:1'] = 0 # default bucket
167
+
168
+ skipped = 0
169
+ for line in tqdm.tqdm(read_tsv(dataset_tsv)[1:]):
170
+ tsv_file = line[0]
171
+ bucket = line[1] if len(line) > 1 else 'mlx'
172
+ caption = line[2] if len(line) > 2 else 'caption'
173
+ weights = float(line[3] if len(line) > 3 else "1")
174
+ all_data = read_tsv(tsv_file)
175
+ all_maps = {all_data[0][i]: i for i in range(len(all_data[0]))}
176
+ self.weights[all_data[1][0]] = weights
177
+ for line in all_data[1:]:
178
+ try:
179
+ if 'width' in all_maps: # filter too small images
180
+ width, height = int(line[all_maps['width']]), int(line[all_maps['height']])
181
+ if width * height < (image_size * image_size) / 2: # if image is smaller than half size of the target size
182
+ skipped += 1; continue
183
+
184
+ if caption != 'folder': # input caption has higher priority
185
+ captions = caption.split('|')[0].split(':')
186
+ operation = caption.split('|')[1] if len(caption.split('|')) > 1 else "none"
187
+ caption_line = ([line[all_maps[c]] for c in captions], operation)
188
+ else:
189
+ caption_line = (line[all_maps['file']].split('/')[-2], "none") # use folder name as caption
190
+
191
+ items = {'tar': line[all_maps['tar']], 'file': line[all_maps['file']], 'caption': caption_line,
192
+ 'image_bucket': line[all_maps['image_bucket']] if 'image_bucket' in all_maps else "1:1"}
193
+
194
+ if "camera_file" in all_maps: # dl3dv data
195
+ items["camera_file"] = line[all_maps["camera_file"]]
196
+
197
+ if "force_caption" in all_maps: # force dataset
198
+ items["force_caption"] = line[all_maps["force_caption"]]
199
+ if "wind_speed" in all_maps: # wind force
200
+ items["wind_speed"] = line[all_maps["wind_speed"]]
201
+ items["wind_angle"] = line[all_maps["wind_angle"]]
202
+ elif "force" in all_maps: # point-wise
203
+ items["force"] = line[all_maps["force"]]
204
+ items["angle"] = line[all_maps["angle"]]
205
+ items["coordx"] = line[all_maps["coordx"]]
206
+ items["coordy"] = line[all_maps["coordy"]]
207
+
208
+ if edit:
209
+ if line[all_maps['visual_file']] != 'none': continue # TODO: for now, we only support one image, no visual clue
210
+ items['edit_instruction'] = line[all_maps['edit_instruction']]
211
+ items['edited_file'] = line[all_maps['edited_file']]
212
+ all_lines.append(items)
213
+
214
+ except Exception as e:
215
+ skipped += 1; continue
216
+
217
+ image_bucket = all_lines[-1]['image_bucket']
218
+ self.image_buckets[image_bucket] += 1
219
+ if all_lines[-1]['tar'] not in self.buckets:
220
+ self.buckets[all_lines[-1]['tar']] = bucket
221
+
222
+ if "force_caption" in all_lines[0]:
223
+ wind_forces = [l["wind_speed"] for l in all_lines] if "wind_speed" in all_lines[0] else [l["force"] for l in all_lines]
224
+ self.min_wind_force = min(wind_forces)
225
+ self.max_wind_force = max(wind_forces)
226
+
227
+ self.use_image_bucket = use_image_bucket
228
+ self.all_lines = all_lines[rank:][::world_size] # all lines is sorted by tar file
229
+ self.num_samples_per_rank = None
230
+ self.image_size = image_size
231
+ self.multiple = multiple
232
+ self.temporal_size = tuple(map(int, temporal_size.split(':'))) if isinstance(temporal_size, str) else None
233
+ self.edit_mode = edit
234
+
235
+ def center_crop_resize(img, ratio="1:1", target_size: int = 256, multiple: int = 8):
236
+ """
237
+ 1. Center crop `img` to the largest window with aspect ratio = ratio.
238
+ 2. Resize so HxW ≈ target_size² (each side a multiple of `multiple`).
239
+
240
+ Args
241
+ ----
242
+ img : PIL Image or torch tensor (CHW/HWC)
243
+ ratio : "3:2", (3,2), "1:1", etc.
244
+ target_size : reference side length (area = target_size²)
245
+ multiple : force each output side to be a multiple of this number
246
+ """
247
+ # --- parse ratio ----------------------------------------------------------
248
+ if isinstance(ratio, str):
249
+ rw, rh = map(int, ratio.split(':'))
250
+ else: # already a tuple/list
251
+ rw, rh = ratio
252
+ R = rw / rh # width / height
253
+
254
+ # --- crop to that aspect ratio -------------------------------------------
255
+ w, h = img.size if hasattr(img, "size") else (img.shape[-1], img.shape[-2])
256
+ if w / h > R: # image too wide → trim width
257
+ crop_h, crop_w = h, int(round(h * R))
258
+ else: # image too tall → trim height
259
+ crop_w, crop_h = w, int(round(w / R))
260
+ img = transforms.functional.center_crop(img, (crop_h, crop_w))
261
+
262
+ # --- compute output dimensions -------------------------------------------
263
+ area = target_size ** 2
264
+ out_h = _nearest_multiple(math.sqrt(area / R), multiple)
265
+ out_w = _nearest_multiple(math.sqrt(area * R), multiple)
266
+
267
+ # --- resize & return ------------------------------------------------------
268
+ return transforms.functional.resize(img, (out_h, out_w), antialias=True)
269
+
270
+ self.transforms = {}
271
+ self.size_bucket_maps = {}
272
+ self.bucket_size_maps = {}
273
+ for bucket in self.image_buckets:
274
+ trans = [transforms.Lambda(lambda img, r=bucket: center_crop_resize(img, ratio=r, target_size=image_size, multiple=multiple))]
275
+ if not no_flip:
276
+ trans.append(transforms.RandomHorizontalFlip())
277
+ trans.extend([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
278
+ self.transforms[bucket] = transforms.Compose(trans)
279
+
280
+ w, h = map(int, bucket.split(':'))
281
+ out_h, out_w = aspect_ratio_to_image_size(image_size, w / h, multiple=multiple)
282
+ self.size_bucket_maps[(out_h, out_w)] = bucket
283
+ self.bucket_size_maps[bucket] = (out_h, out_w)
284
+
285
+ self.transform = self.transforms['1:1'] # default transform
286
+ print(f"Rank0 -- Loading {len(self.all_lines)} lines of data | {skipped} lines are skipped due to size or error")
287
+
288
+ def __len__(self):
289
+ if self.num_samples_per_rank is not None:
290
+ return self.num_samples_per_rank
291
+ return len(self.all_lines)
292
+
293
+ def __getitem__(self, idx):
294
+ image_item = self.all_lines[idx]
295
+ tar_file = image_item['tar']
296
+ img_file = image_item['file']
297
+ img_bucket = image_item['image_bucket']
298
+ try:
299
+ with tarfile.open(tar_file, mode='r') as tar:
300
+ img = self._read_image(tar, img_file, img_bucket)
301
+ H0, W0 = img.size
302
+ scale = self.image_size / min(H0, W0)
303
+ state = np.array([scale, H0, W0])
304
+ except Exception as e:
305
+ print(f'Reading data error {e}')
306
+ sample = image_item.copy()
307
+ sample.update(image=img, state=state)
308
+ return sample
309
+
310
+ def _read_image(self, tar, img_file, img_bucket):
311
+
312
+ def _transform(img):
313
+ if not self.use_image_bucket:
314
+ return self.transform(img)
315
+ else:
316
+ return self.transforms[img_bucket](img)
317
+
318
+ x_shape = aspect_ratio_to_image_size(self.image_size, img_bucket, multiple=self.multiple)
319
+ if self.temporal_size is not None: # read video
320
+ num_frames, out_fps = self.temporal_size[0], self.temporal_size[1:]
321
+ if len(out_fps) == 1:
322
+ out_fps = out_fps[0]
323
+ else:
324
+ out_fps = random.choice(out_fps) # randomly choose one fps from the list
325
+ assert img_file.endswith('.mp4'), "Only support mp4 video for now"
326
+ try:
327
+ with tar.extractfile(img_file) as video_data:
328
+ with ram_temp_file(video_data.read()) as tmp_path:
329
+ frames, frame_inds = sample_clip(tmp_path, num_frames=num_frames, out_fps=out_fps)
330
+ frames = frames.asnumpy()
331
+ except Exception as e:
332
+ print(f'Reading data error {e} {img_file}')
333
+ frames = np.zeros((num_frames, x_shape[0], x_shape[1], 3), dtype=np.uint8)
334
+ return torch.stack([_transform(Image.fromarray(frame)) for frame in frames]), out_fps, frame_inds
335
+
336
+ try:
337
+ original_img = Image.open(tar.extractfile(img_file)).convert('RGB')
338
+ except Exception as e:
339
+ print(f'Reading data error {e} {img_file}')
340
+ original_img = Image.new('RGB', (x_shape[0], x_shape[1]), (0, 0, 0))
341
+ return _transform(original_img), 0, None
342
+
343
+ def collate_fn(self, batch):
344
+ batch = default_collate(batch)
345
+ return batch
346
+
347
+ def get_batch_modes(self, x):
348
+ x_aspect = self.size_bucket_maps.get(x.size()[-2:], "1:1")
349
+ video_mode = self.temporal_size is not None
350
+ return x_aspect, video_mode
351
+
352
+
353
+ class OnlineImageTarDataset(ImageTarDataset):
354
+ max_retry_n = 20
355
+ max_read = 4096
356
+ tar_keys_lock = manager.Lock() if manager is not None else None
357
+
358
+ def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs):
359
+ super().__init__(dataset_tsv, image_size, **kwargs)
360
+
361
+ self.tar_lists = defaultdict(lambda: [])
362
+ self.tar_image_buckets = defaultdict(lambda: defaultdict(lambda: 0))
363
+ for i, line in enumerate(self.all_lines):
364
+ tar_file = line['tar']
365
+ image_bucket = line['image_bucket']
366
+ self.tar_lists[tar_file] += [i]
367
+ self.tar_image_buckets[tar_file][image_bucket] += 1
368
+ self.reset_tar_keys = []
369
+ for key in self.tar_lists.keys():
370
+ repeat = int(self.weights.get(key, 1))
371
+ self.reset_tar_keys.extend([key] * repeat)
372
+ self.tar_keys = manager.list(self.reset_tar_keys) if manager is not None else list(self.reset_tar_keys)
373
+
374
+ # Use more workers for better prefetching, but limit to reasonable number
375
+ self.worker_executors = {}
376
+ self.worker_caches = {} # each entry: {active:{tar,key,cnt,inner_idx}, prefetch:{future,key}}
377
+ self.worker_caches_lock = threading.Lock() # Protect worker_caches access
378
+ self.shuffle_everything()
379
+ if self.use_image_bucket:
380
+ assert batch_size, "batch_size should be set when use_image_bucket is True"
381
+ self.batch_size = batch_size
382
+ if self.temporal_size is not None:
383
+ assert self.temporal_size[0] > 1, "temporal_size should be greater than 1 for video data"
384
+ self.max_read = 512
385
+
386
+ def cleanup_worker_cache(self, wid):
387
+ """Clean up worker cache entry and associated resources"""
388
+ with self.worker_caches_lock:
389
+ if wid in self.worker_caches:
390
+ cache_entry = self.worker_caches[wid]
391
+ # Cancel prefetch future if still running
392
+ if 'prefetch' in cache_entry and hasattr(cache_entry['prefetch'], 'cancel'):
393
+ cache_entry['prefetch'].cancel()
394
+
395
+ if cache_entry.get('tar') is not None:
396
+ tar = cache_entry['tar']
397
+ self._close_tar(tar)
398
+ cache_entry['tar'] = None
399
+ # Remove the entire cache entry
400
+ del self.worker_caches[wid]
401
+ gc.collect()
402
+
403
+ def _s3(self):
404
+ raise NotImplementedError("Please implement your own _s3() method to return a boto3 session/client")
405
+
406
+ def shuffle_everything(self):
407
+ for key in tqdm.tqdm(self.tar_keys):
408
+ random.shuffle(self.tar_lists[key])
409
+ random.shuffle(self.tar_keys)
410
+ print("shuffle everything done!")
411
+
412
+ def download_tar(self, prefetch=True, wid=None):
413
+ i = 0
414
+ file_stream = None
415
+ tar_file = None
416
+ download = f'prefetch {wid}' if prefetch else 'just download'
417
+ while True:
418
+ if i % self.max_retry_n == 0: # retry a different tar file
419
+ tar_file = self._get_next_key() # get the next tar file key
420
+ file_stream = None
421
+ try:
422
+ file_stream = io.BytesIO()
423
+ self._s3().download_fileobj(self.buckets[tar_file], tar_file, file_stream) # hard-coded
424
+ file_stream.seek(0)
425
+ tar = tarfile.open(fileobj=file_stream, mode='r')
426
+ # Store the file_stream reference so it can be closed later
427
+ tar._file_stream = file_stream
428
+ xprint(f'[INFO] {download} tar file: {tar_file}')
429
+ return tar, tar_file
430
+ except Exception as e:
431
+ xprint(f'[ERROR] {download} tar file {tar_file} failed: {e}')
432
+ i += 1
433
+ if file_stream:
434
+ file_stream.close()
435
+ file_stream = None
436
+ time.sleep(min(i * 0.1, 5)) # Exponential backoff with cap
437
+
438
+ def _get_next_key(self):
439
+ with self.tar_keys_lock:
440
+ if not self.tar_keys or len(self.tar_keys) == 0:
441
+ xprint(f'[WARN] all dataset exhausted... this should not happen usually')
442
+ self.tar_keys.extend(list(self.reset_tar_keys)) # reset
443
+ random.shuffle(self.tar_keys)
444
+ return self.tar_keys.pop(0) # remove and return the first key
445
+
446
+ def _start_prefetch(self, wid):
447
+ """Start prefetching the next tar file for the worker"""
448
+ # Create executor per worker process if it doesn't exist
449
+ if wid not in self.worker_executors:
450
+ self.worker_executors[wid] = ThreadPoolExecutor(max_workers=1)
451
+ future = self.worker_executors[wid].submit(self.download_tar, prefetch=True, wid=wid) # download tar file in a separate thread
452
+ self.worker_caches[wid]['prefetch'] = future
453
+
454
+ def _close_tar(self, tar):
455
+ # Properly close both tar and underlying file stream
456
+ if hasattr(tar, '_file_stream') and tar._file_stream:
457
+ tar._file_stream.close()
458
+ tar._file_stream = None
459
+ tar.close()
460
+ del tar
461
+ gc.collect()
462
+
463
+ def __getitem__(self, idx):
464
+ try:
465
+ wid = get_worker_info().id
466
+ except Exception as e:
467
+ wid = -1
468
+
469
+ # ─── first time this worker is used ─── #
470
+ if wid not in self.worker_caches:
471
+ tar, key = self.download_tar(prefetch=False) # download tar file
472
+ with self.worker_caches_lock:
473
+ self.worker_caches[wid] = dict(
474
+ active=dict(tar=tar, key=key, cnt=0, inner_idx=0), # active cache
475
+ )
476
+ self._start_prefetch(wid) # start prefetching the next tar file
477
+
478
+ cache = self.worker_caches[wid]
479
+ active = cache['active']
480
+ tar = active['tar']
481
+ key = active['key']
482
+ cnt = active['cnt']
483
+ inner_idx = active['inner_idx']
484
+
485
+ # handle image bucketting
486
+ if self.use_image_bucket:
487
+ if inner_idx % self.batch_size == 0:
488
+ # sample based on local tar file statistics in case some dataset only has one image bucket
489
+ tar_buckets = self.tar_image_buckets[key]
490
+ target_image_bucket = random.choices(
491
+ list(tar_buckets.keys()), weights=list(tar_buckets.values()), k=1)[0]
492
+ self.worker_caches[wid]['target_image_bucket'] = target_image_bucket
493
+
494
+ # scan the list to find the nearest target image bucket
495
+ target_image_bucket, t_cnt = self.worker_caches[wid]['target_image_bucket'], cnt
496
+ while self.all_lines[self.tar_lists[key][t_cnt]]['image_bucket'] != target_image_bucket:
497
+ t_cnt += 1
498
+ if t_cnt >= len(self.tar_lists[key]): t_cnt = 0
499
+ # sawp the image location
500
+ if cnt != t_cnt:
501
+ self.tar_lists[key][cnt], self.tar_lists[key][t_cnt] = self.tar_lists[key][t_cnt], self.tar_lists[key][cnt]
502
+
503
+ img_id = self.tar_lists[key][cnt]
504
+ image_item = self.all_lines[img_id]
505
+ sample = {key: image_item[key] for key in image_item}
506
+ image, fps, frame_inds = self._read_image(tar, image_item['file'], image_item['image_bucket'])
507
+ sample.update(image=image, fps=fps, local_idx=img_id, inner_idx=inner_idx)
508
+ if self.edit_mode:
509
+ image, fps, _ = self._read_image(tar, image_item['edited_file'], image_item['image_bucket'])
510
+ sample.update(edited_image=image, fps=fps, edit_instruction=image_item['edit_instruction'])
511
+
512
+ if "camera_file" in image_item: # dl3dv data
513
+ sample["condition"] = get_camera_condition(tar, image_item["camera_file"], width=image.shape[3], height=image.shape[2], factor=self.multiple, frame_inds=frame_inds)
514
+
515
+ if "force_caption" in image_item: # force dataset
516
+ if "wind_speed" in image_item: # wind force
517
+ sample["condition"] = get_wind_condition(image_item["wind_speed"], image_item["wind_angle"], min_force=self.min_wind_force, max_force=self.max_wind_force, num_frames=image.shape[1], width=image.shape[3], height=image.shape[2])
518
+ elif "force" in image_item: # point-wise
519
+ sample["condition"] = get_point_condition(image_item["force"], image_item["angle"], image_item["coordx"], image_item["coordy"], min_force=self.min_wind_force, max_force=self.max_wind_force, num_frames=image.shape[1], width=image.shape[3], height=image.shape[2])
520
+
521
+ # update cnt
522
+ cnt, inner_idx = cnt + 1, inner_idx + 1
523
+ if (cnt == len(self.tar_lists[key])) or (cnt == self.max_read):
524
+ # -- active tar finished, switch to prefetched tar -- #
525
+ self._close_tar(tar) # close the current tar file
526
+
527
+ try:
528
+ # Wait for prefetch with timeout
529
+ new_tar, new_key = cache['prefetch'].result() # 5 minute timeout
530
+ except Exception as e:
531
+ xprint(f'[WARN] Prefetch failed, downloading new tar synchronously: {e}')
532
+ new_tar, new_key = self.download_tar(prefetch=False)
533
+
534
+ cache['active'] = dict(tar=new_tar, key=new_key, cnt=0, inner_idx=inner_idx) # update active cache
535
+
536
+ # shuffle the image list
537
+ random.shuffle(self.tar_lists[key]) # shuffle the list
538
+ with self.tar_keys_lock:
539
+ self.tar_keys.append(key) # return the key to the list so other workers can use it
540
+
541
+ self._start_prefetch(wid) # start prefetching the next tar file
542
+ else:
543
+ cache['active']['cnt'] = cnt
544
+
545
+ # always update inner_idx (IMPORTANT)
546
+ cache['active']['inner_idx'] = inner_idx
547
+ return sample
548
+
549
+
550
+ class OnlineImageCaptionDataset(OnlineImageTarDataset):
551
+ def __getitem__(self, idx):
552
+ sample = super().__getitem__(idx)
553
+ captions, caption_op = sample['caption']
554
+ if caption_op == 'none':
555
+ sample['caption'] = captions[0] if isinstance(captions, list) else captions
556
+ elif ':' in caption_op:
557
+ sample['caption'] = random.choices(captions, weights=[float(a) for a in caption_op.split(':')])[0]
558
+ else:
559
+ raise NotImplementedError(f"Unknown caption operation: {caption_op}")
560
+ return sample
561
+
562
+ def collate_fn(self, batch):
563
+ batch = super().collate_fn(batch)
564
+ image = batch['image']
565
+ caption = batch['caption']
566
+ if self.edit_mode:
567
+ image = torch.cat([image, batch['edited_image']], dim=0)
568
+ caption.extend(batch['edit_instruction'])
569
+
570
+ meta = {key: batch[key] for key in batch if key not in
571
+ ['image', 'caption', 'edited_image', 'edit_instruction']}
572
+ return image, caption, meta
573
+
574
+
575
+ # ==== Dummy Dataset Implementation for Open Source Release ====
576
+
577
+ class DummyImageCaptionDataset(Dataset):
578
+ """
579
+ Dummy dataset that generates synthetic image-caption pairs for training/testing.
580
+ Supports mixed aspect ratios and batch-wise aspect ratio consistency.
581
+ """
582
+
583
+ def __init__(
584
+ self,
585
+ num_samples: int = 10000,
586
+ image_size: int = 256,
587
+ temporal_size: Optional[str] = None,
588
+ use_image_bucket: bool = False,
589
+ batch_size: Optional[int] = None,
590
+ multiple: int = 8,
591
+ no_flip: bool = False,
592
+ edit: bool = False
593
+ ):
594
+ """
595
+ Args:
596
+ num_samples: Number of samples in the dataset
597
+ image_size: Base image size for generation
598
+ temporal_size: Video size specification (e.g., "16:8" for frames:fps)
599
+ use_image_bucket: Whether to use aspect ratio bucketing
600
+ batch_size: Batch size for bucketing (required if use_image_bucket=True)
601
+ multiple: Multiple for dimension rounding
602
+ no_flip: Whether to disable horizontal flipping
603
+ edit: Whether this is an editing dataset
604
+ """
605
+ self.num_samples = num_samples
606
+ self.image_size = image_size
607
+ self.temporal_size = temporal_size
608
+ self.use_image_bucket = use_image_bucket
609
+ self.batch_size = batch_size
610
+ self.multiple = multiple
611
+ self.no_flip = no_flip
612
+ self.edit_mode = edit
613
+
614
+ # Parse video parameters
615
+ self.is_video = temporal_size is not None
616
+ if self.is_video:
617
+ frames, fps = map(int, temporal_size.split(':'))
618
+ self.num_frames = frames
619
+ self.fps = fps
620
+ else:
621
+ self.num_frames = 1
622
+ self.fps = None
623
+
624
+ # Aspect ratios for mixed aspect ratio training
625
+ self.aspect_ratios = [
626
+ "1:1", "2:3", "3:2", "16:9", "9:16",
627
+ "4:5", "5:4", "21:9", "9:21"
628
+ ] if use_image_bucket else ["1:1"]
629
+
630
+ # Generate image buckets for aspect ratios
631
+ self.image_buckets = {}
632
+ for i, ar in enumerate(self.aspect_ratios):
633
+ h, w = aspect_ratio_to_image_size(image_size, ar, multiple)
634
+ self.image_buckets[ar] = (h, w, ar)
635
+
636
+ # Sample captions for dummy data
637
+ self.sample_captions = [
638
+ "A beautiful landscape with mountains and trees",
639
+ "A cute cat sitting on a wooden table",
640
+ "A modern city skyline at sunset",
641
+ "A vintage car parked on a street",
642
+ "A delicious meal on a white plate",
643
+ "A person walking in a park",
644
+ "A colorful flower garden in bloom",
645
+ "A cozy living room with furniture",
646
+ "A stormy ocean with large waves",
647
+ "A peaceful forest path in autumn",
648
+ "A group of friends laughing together",
649
+ "A majestic eagle flying in the sky",
650
+ "A busy marketplace with vendors",
651
+ "A snow-covered mountain peak",
652
+ "A child playing with toys",
653
+ "A romantic candlelit dinner",
654
+ "A train traveling through countryside",
655
+ "A lighthouse on a rocky coast",
656
+ "A field of sunflowers under blue sky",
657
+ "A family having a picnic outdoors"
658
+ ]
659
+
660
+ # Create transform pipeline
661
+ def center_crop_resize(img, ratio="1:1", target_size: int = 256, multiple: int = 8):
662
+ """
663
+ 1. Center crop `img` to the largest window with aspect ratio = ratio.
664
+ 2. Resize so HxW ≈ target_size² (each side a multiple of `multiple`).
665
+
666
+ Args
667
+ ----
668
+ img : PIL Image or torch tensor (CHW/HWC)
669
+ ratio : "3:2", (3,2), "1:1", etc.
670
+ target_size : reference side length (area = target_size²)
671
+ multiple : force each output side to be a multiple of this number
672
+ """
673
+ # --- parse ratio ----------------------------------------------------------
674
+ if isinstance(ratio, str):
675
+ rw, rh = map(int, ratio.split(':'))
676
+ else: # already a tuple/list
677
+ rw, rh = ratio
678
+ R = rw / rh # width / height
679
+
680
+ # --- crop to that aspect ratio -------------------------------------------
681
+ w, h = img.size if hasattr(img, "size") else (img.shape[-1], img.shape[-2])
682
+ if w / h > R: # image too wide → trim width
683
+ crop_h, crop_w = h, int(round(h * R))
684
+ else: # image too tall → trim height
685
+ crop_w, crop_h = w, int(round(w / R))
686
+ img = transforms.functional.center_crop(img, (crop_h, crop_w))
687
+
688
+ # --- compute output dimensions -------------------------------------------
689
+ area = target_size ** 2
690
+ out_h = _nearest_multiple(math.sqrt(area / R), multiple)
691
+ out_w = _nearest_multiple(math.sqrt(area * R), multiple)
692
+
693
+ # --- resize & return ------------------------------------------------------
694
+ return transforms.functional.resize(img, (out_h, out_w), antialias=True)
695
+
696
+ self.transforms = {}
697
+ self.size_bucket_maps = {}
698
+ self.bucket_size_maps = {}
699
+ for bucket in self.image_buckets:
700
+ trans = [transforms.Lambda(lambda img, r=bucket: center_crop_resize(img, ratio=r, target_size=image_size, multiple=multiple))]
701
+ if not no_flip:
702
+ trans.append(transforms.RandomHorizontalFlip())
703
+ trans.extend([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
704
+ self.transforms[bucket] = transforms.Compose(trans)
705
+
706
+ w, h = map(int, bucket.split(':'))
707
+ out_h, out_w = aspect_ratio_to_image_size(image_size, w / h, multiple=multiple)
708
+ self.size_bucket_maps[(out_h, out_w)] = bucket
709
+ self.bucket_size_maps[bucket] = (out_h, out_w)
710
+
711
+ self.transform = self.transforms['1:1'] # default transform
712
+
713
+ def __len__(self) -> int:
714
+ return self.num_samples
715
+
716
+ def __getitem__(self, idx: int) -> dict:
717
+ """Get a single sample from the dataset."""
718
+ # Choose aspect ratio
719
+ if self.use_image_bucket:
720
+ bucket_name = random.choice(list(self.image_buckets.keys()))
721
+ h, w, aspect_ratio = self.image_buckets[bucket_name]
722
+ else:
723
+ h, w, aspect_ratio = self.image_size, self.image_size, "1:1"
724
+ bucket_name = aspect_ratio
725
+
726
+ # Generate dummy image
727
+ if self.is_video:
728
+ # Generate video tensor (T, C, H, W)
729
+ image = torch.randn(self.num_frames, 3, h, w)
730
+ # Normalize to [-1, 1] range
731
+ image = torch.tanh(image)
732
+ else:
733
+ # Generate RGB image
734
+ image = Image.new('RGB', (w, h), color=(
735
+ random.randint(50, 200),
736
+ random.randint(50, 200),
737
+ random.randint(50, 200)
738
+ ))
739
+
740
+ # Add some random patterns for variety
741
+ if random.random() > 0.5:
742
+ # Add gradient
743
+ pixels = []
744
+ for y in range(h):
745
+ for x in range(w):
746
+ r = int(255 * x / w)
747
+ g = int(255 * y / h)
748
+ b = int(255 * (x + y) / (w + h))
749
+ pixels.append((r, g, b))
750
+ image.putdata(pixels)
751
+
752
+ image = self.transform(image)
753
+
754
+ # Generate caption
755
+ caption = random.choice(self.sample_captions)
756
+
757
+ # Add some variation to captions
758
+ if random.random() > 0.7:
759
+ adjectives = ["beautiful", "stunning", "amazing", "incredible", "magnificent"]
760
+ caption = f"{random.choice(adjectives)} {caption.lower()}"
761
+
762
+ sample = {
763
+ 'image': image,
764
+ 'caption': caption,
765
+ 'image_bucket': bucket_name,
766
+ 'aspect_ratio': aspect_ratio,
767
+ 'idx': idx
768
+ }
769
+
770
+ # Add video-specific metadata
771
+ if self.is_video:
772
+ sample.update({
773
+ 'num_frames': self.num_frames,
774
+ 'fps': self.fps,
775
+ 'temporal_size': self.temporal_size
776
+ })
777
+
778
+ # Add editing data if needed
779
+ if self.edit_mode:
780
+ # Generate slightly modified image for editing tasks
781
+ edited_image = image + torch.randn_like(image) * 0.1
782
+ edited_image = torch.clamp(edited_image, -1, 1)
783
+ sample.update({
784
+ 'edited_image': edited_image,
785
+ 'edit_instruction': f"Edit this image to make it more {random.choice(['colorful', 'bright', 'artistic', 'realistic'])}"
786
+ })
787
+
788
+ return sample
789
+
790
+ def collate_fn(self, batch: list) -> tuple:
791
+ """Collate function for batching samples."""
792
+ # Group by aspect ratio if using image buckets
793
+ if self.use_image_bucket:
794
+ # Sort batch by image bucket for consistency
795
+ batch = sorted(batch, key=lambda x: x['image_bucket'])
796
+
797
+ # Standard collation
798
+ collated = {}
799
+ images = torch.stack([item['image'] for item in batch], dim=0)
800
+ captions = [item['caption'] for item in batch]
801
+
802
+ # Collect metadata
803
+ for key in ['image_bucket', 'aspect_ratio', 'idx']:
804
+ if key in batch[0]:
805
+ collated[key] = [item[key] for item in batch]
806
+
807
+ # Handle video metadata
808
+ if self.is_video:
809
+ for key in ['num_frames', 'fps', 'temporal_size']:
810
+ if key in batch[0]:
811
+ collated[key] = [item[key] for item in batch]
812
+
813
+ # Handle editing data
814
+ if self.edit_mode and 'edited_image' in batch[0]:
815
+ edited_images = torch.stack([item['edited_image'] for item in batch], dim=0)
816
+ collated['edited_image'] = edited_images
817
+ collated['edit_instruction'] = [item['edit_instruction'] for item in batch]
818
+
819
+ return images, captions, collated
820
+
821
+ def get_batch_modes(self, x):
822
+ x_aspect = self.size_bucket_maps.get(x.size()[-2:], "1:1")
823
+ video_mode = self.temporal_size is not None
824
+ return x_aspect, video_mode
825
+
826
+
827
+ class DummyDataLoaderWrapper:
828
+ """
829
+ Wrapper that mimics the DataLoaderWrapper functionality.
830
+ Provides infinite iteration over the dataset.
831
+ """
832
+
833
+ def __init__(self, dataset, batch_size=1, num_workers=0, **kwargs):
834
+ self.dataset = dataset
835
+ self.batch_size = batch_size
836
+ self.dataloader = DataLoader(
837
+ dataset,
838
+ batch_size=batch_size,
839
+ num_workers=num_workers,
840
+ collate_fn=dataset.collate_fn,
841
+ shuffle=True,
842
+ drop_last=True,
843
+ **kwargs
844
+ )
845
+ self.iterator = None
846
+ self.secondary_loader = None
847
+
848
+ def __iter__(self):
849
+ """Infinite iteration over the dataset."""
850
+ while True:
851
+ if self.iterator is None:
852
+ self.iterator = iter(self.dataloader)
853
+ try:
854
+ yield next(self.iterator)
855
+ except StopIteration:
856
+ self.iterator = iter(self.dataloader)
857
+ yield next(self.iterator)
858
+
859
+ def __len__(self):
860
+ return len(self.dataloader)
861
+
862
+
863
+ def create_dummy_dataloader(
864
+ dataset_name: str,
865
+ img_size: int,
866
+ vid_size: Optional[str] = None,
867
+ batch_size: int = 16,
868
+ use_mixed_aspect: bool = False,
869
+ multiple: int = 8,
870
+ num_samples: int = 10000,
871
+ infinite: bool = False
872
+ ) -> Union[DataLoader, DummyDataLoaderWrapper]:
873
+ """
874
+ Create a dummy dataloader that mimics the original functionality.
875
+
876
+ Args:
877
+ dataset_name: Name of the dataset (used for deterministic seeding)
878
+ img_size: Base image size
879
+ vid_size: Video specification (e.g., "16:8")
880
+ batch_size: Batch size
881
+ use_mixed_aspect: Whether to use mixed aspect ratio training
882
+ multiple: Multiple for dimension rounding
883
+ num_samples: Number of samples in the dataset
884
+ infinite: Whether to create infinite dataloader
885
+
886
+ Returns:
887
+ DataLoader or DummyDataLoaderWrapper
888
+ """
889
+ # Set seed based on dataset name for reproducibility
890
+ seed = hash(dataset_name) % (2**32 - 1)
891
+ random.seed(seed)
892
+ np.random.seed(seed)
893
+
894
+ # Create dataset
895
+ dataset = DummyImageCaptionDataset(
896
+ num_samples=num_samples,
897
+ image_size=img_size,
898
+ temporal_size=vid_size,
899
+ use_image_bucket=use_mixed_aspect,
900
+ batch_size=batch_size,
901
+ multiple=multiple,
902
+ edit='edit' in dataset_name.lower()
903
+ )
904
+
905
+ # Set dataset attributes expected by training code
906
+ dataset.total_num_samples = num_samples
907
+ dataset.num_samples_per_rank = num_samples
908
+
909
+ # Create dataloader
910
+ if infinite:
911
+ return DummyDataLoaderWrapper(
912
+ dataset,
913
+ batch_size=batch_size,
914
+ num_workers=2,
915
+ pin_memory=True,
916
+ drop_last=True,
917
+ persistent_workers=True
918
+ )
919
+ else:
920
+ return DataLoader(
921
+ dataset,
922
+ batch_size=batch_size,
923
+ num_workers=2,
924
+ pin_memory=True,
925
+ drop_last=True,
926
+ shuffle=True,
927
+ collate_fn=dataset.collate_fn,
928
+ persistent_workers=True
929
+ )