nroggendorff commited on
Commit
d87825e
·
verified ·
1 Parent(s): e0a0c1d

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +1335 -0
train.py ADDED
@@ -0,0 +1,1335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Fine-tuning script for Stable Diffusion XL for text2image."""
17
+
18
+ import argparse
19
+ import functools
20
+ import gc
21
+ import logging
22
+ import math
23
+ import os
24
+ import random
25
+ import shutil
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import accelerate
30
+ import datasets
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn.functional as F
34
+ import torch.utils.checkpoint
35
+ import transformers
36
+ from accelerate import Accelerator
37
+ from accelerate.logging import get_logger
38
+ from accelerate.utils import ProjectConfiguration, set_seed
39
+ from datasets import concatenate_datasets, load_dataset
40
+ from huggingface_hub import create_repo, upload_folder
41
+ from packaging import version
42
+ from torchvision import transforms
43
+ from torchvision.transforms.functional import crop
44
+ from tqdm.auto import tqdm
45
+ from transformers import AutoTokenizer, PretrainedConfig
46
+
47
+ import diffusers
48
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
49
+ from diffusers.optimization import get_scheduler
50
+ from diffusers.training_utils import EMAModel, compute_snr
51
+ from diffusers.utils import check_min_version, is_wandb_available
52
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53
+ from diffusers.utils.import_utils import is_xformers_available
54
+ from diffusers.utils.torch_utils import is_compiled_module
55
+
56
+
57
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
+ check_min_version("0.28.0.dev0")
59
+
60
+ logger = get_logger(__name__)
61
+
62
+
63
+ DATASET_NAME_MAPPING = {
64
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
65
+ }
66
+
67
+
68
+ def save_model_card(
69
+ repo_id: str,
70
+ images: list = None,
71
+ validation_prompt: str = None,
72
+ base_model: str = None,
73
+ dataset_name: str = None,
74
+ repo_folder: str = None,
75
+ vae_path: str = None,
76
+ ):
77
+ img_str = ""
78
+ if images is not None:
79
+ for i, image in enumerate(images):
80
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
81
+ img_str += f"![img_{i}](./image_{i}.png)\n"
82
+
83
+ model_description = f"""
84
+ # Text-to-image finetuning - {repo_id}
85
+
86
+ This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
87
+ {img_str}
88
+
89
+ Special VAE used for training: {vae_path}.
90
+ """
91
+
92
+ model_card = load_or_create_model_card(
93
+ repo_id_or_path=repo_id,
94
+ from_training=True,
95
+ license="creativeml-openrail-m",
96
+ base_model=base_model,
97
+ model_description=model_description,
98
+ inference=True,
99
+ )
100
+
101
+ tags = [
102
+ "stable-diffusion-xl",
103
+ "stable-diffusion-xl-diffusers",
104
+ "text-to-image",
105
+ "diffusers-training",
106
+ "diffusers",
107
+ ]
108
+ model_card = populate_model_card(model_card, tags=tags)
109
+
110
+ model_card.save(os.path.join(repo_folder, "README.md"))
111
+
112
+
113
+ def import_model_class_from_model_name_or_path(
114
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
115
+ ):
116
+ text_encoder_config = PretrainedConfig.from_pretrained(
117
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
118
+ )
119
+ model_class = text_encoder_config.architectures[0]
120
+
121
+ if model_class == "CLIPTextModel":
122
+ from transformers import CLIPTextModel
123
+
124
+ return CLIPTextModel
125
+ elif model_class == "CLIPTextModelWithProjection":
126
+ from transformers import CLIPTextModelWithProjection
127
+
128
+ return CLIPTextModelWithProjection
129
+ else:
130
+ raise ValueError(f"{model_class} is not supported.")
131
+
132
+
133
+ def parse_args(input_args=None):
134
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
135
+ parser.add_argument(
136
+ "--pretrained_model_name_or_path",
137
+ type=str,
138
+ default=None,
139
+ required=True,
140
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
141
+ )
142
+ parser.add_argument(
143
+ "--pretrained_vae_model_name_or_path",
144
+ type=str,
145
+ default=None,
146
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
147
+ )
148
+ parser.add_argument(
149
+ "--revision",
150
+ type=str,
151
+ default=None,
152
+ required=False,
153
+ help="Revision of pretrained model identifier from huggingface.co/models.",
154
+ )
155
+ parser.add_argument(
156
+ "--variant",
157
+ type=str,
158
+ default=None,
159
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
160
+ )
161
+ parser.add_argument(
162
+ "--dataset_name",
163
+ type=str,
164
+ default=None,
165
+ help=(
166
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
167
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
168
+ " or to a folder containing files that 🤗 Datasets can understand."
169
+ ),
170
+ )
171
+ parser.add_argument(
172
+ "--dataset_config_name",
173
+ type=str,
174
+ default=None,
175
+ help="The config of the Dataset, leave as None if there's only one config.",
176
+ )
177
+ parser.add_argument(
178
+ "--train_data_dir",
179
+ type=str,
180
+ default=None,
181
+ help=(
182
+ "A folder containing the training data. Folder contents must follow the structure described in"
183
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
184
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
185
+ ),
186
+ )
187
+ parser.add_argument(
188
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
189
+ )
190
+ parser.add_argument(
191
+ "--caption_column",
192
+ type=str,
193
+ default="text",
194
+ help="The column of the dataset containing a caption or a list of captions.",
195
+ )
196
+ parser.add_argument(
197
+ "--validation_prompt",
198
+ type=str,
199
+ default=None,
200
+ help="A prompt that is used during validation to verify that the model is learning.",
201
+ )
202
+ parser.add_argument(
203
+ "--num_validation_images",
204
+ type=int,
205
+ default=4,
206
+ help="Number of images that should be generated during validation with `validation_prompt`.",
207
+ )
208
+ parser.add_argument(
209
+ "--validation_epochs",
210
+ type=int,
211
+ default=1,
212
+ help=(
213
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
214
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
215
+ ),
216
+ )
217
+ parser.add_argument(
218
+ "--max_train_samples",
219
+ type=int,
220
+ default=None,
221
+ help=(
222
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
223
+ "value if set."
224
+ ),
225
+ )
226
+ parser.add_argument(
227
+ "--proportion_empty_prompts",
228
+ type=float,
229
+ default=0,
230
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
231
+ )
232
+ parser.add_argument(
233
+ "--output_dir",
234
+ type=str,
235
+ default="sdxl-model-finetuned",
236
+ help="The output directory where the model predictions and checkpoints will be written.",
237
+ )
238
+ parser.add_argument(
239
+ "--cache_dir",
240
+ type=str,
241
+ default=None,
242
+ help="The directory where the downloaded models and datasets will be stored.",
243
+ )
244
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
245
+ parser.add_argument(
246
+ "--resolution",
247
+ type=int,
248
+ default=1024,
249
+ help=(
250
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
251
+ " resolution"
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--center_crop",
256
+ default=False,
257
+ action="store_true",
258
+ help=(
259
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
260
+ " cropped. The images will be resized to the resolution first before cropping."
261
+ ),
262
+ )
263
+ parser.add_argument(
264
+ "--random_flip",
265
+ action="store_true",
266
+ help="whether to randomly flip images horizontally",
267
+ )
268
+ parser.add_argument(
269
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
270
+ )
271
+ parser.add_argument("--num_train_epochs", type=int, default=100)
272
+ parser.add_argument(
273
+ "--max_train_steps",
274
+ type=int,
275
+ default=None,
276
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
277
+ )
278
+ parser.add_argument(
279
+ "--checkpointing_steps",
280
+ type=int,
281
+ default=500,
282
+ help=(
283
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
284
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
285
+ " training using `--resume_from_checkpoint`."
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--checkpoints_total_limit",
290
+ type=int,
291
+ default=None,
292
+ help=("Max number of checkpoints to store."),
293
+ )
294
+ parser.add_argument(
295
+ "--resume_from_checkpoint",
296
+ type=str,
297
+ default=None,
298
+ help=(
299
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
300
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
301
+ ),
302
+ )
303
+ parser.add_argument(
304
+ "--gradient_accumulation_steps",
305
+ type=int,
306
+ default=1,
307
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
308
+ )
309
+ parser.add_argument(
310
+ "--gradient_checkpointing",
311
+ action="store_true",
312
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
313
+ )
314
+ parser.add_argument(
315
+ "--learning_rate",
316
+ type=float,
317
+ default=1e-4,
318
+ help="Initial learning rate (after the potential warmup period) to use.",
319
+ )
320
+ parser.add_argument(
321
+ "--scale_lr",
322
+ action="store_true",
323
+ default=False,
324
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
325
+ )
326
+ parser.add_argument(
327
+ "--lr_scheduler",
328
+ type=str,
329
+ default="constant",
330
+ help=(
331
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
332
+ ' "constant", "constant_with_warmup"]'
333
+ ),
334
+ )
335
+ parser.add_argument(
336
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
337
+ )
338
+ parser.add_argument(
339
+ "--timestep_bias_strategy",
340
+ type=str,
341
+ default="none",
342
+ choices=["earlier", "later", "range", "none"],
343
+ help=(
344
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
345
+ " Choices: ['earlier', 'later', 'range', 'none']."
346
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
347
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--timestep_bias_multiplier",
352
+ type=float,
353
+ default=1.0,
354
+ help=(
355
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
356
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
357
+ ),
358
+ )
359
+ parser.add_argument(
360
+ "--timestep_bias_begin",
361
+ type=int,
362
+ default=0,
363
+ help=(
364
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
365
+ " Defaults to zero, which equates to having no specific bias."
366
+ ),
367
+ )
368
+ parser.add_argument(
369
+ "--timestep_bias_end",
370
+ type=int,
371
+ default=1000,
372
+ help=(
373
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
374
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
375
+ ),
376
+ )
377
+ parser.add_argument(
378
+ "--timestep_bias_portion",
379
+ type=float,
380
+ default=0.25,
381
+ help=(
382
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
383
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
384
+ " whether the biased portions are in the earlier or later timesteps."
385
+ ),
386
+ )
387
+ parser.add_argument(
388
+ "--snr_gamma",
389
+ type=float,
390
+ default=None,
391
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
392
+ "More details here: https://arxiv.org/abs/2303.09556.",
393
+ )
394
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
395
+ parser.add_argument(
396
+ "--allow_tf32",
397
+ action="store_true",
398
+ help=(
399
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
400
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
401
+ ),
402
+ )
403
+ parser.add_argument(
404
+ "--dataloader_num_workers",
405
+ type=int,
406
+ default=0,
407
+ help=(
408
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
413
+ )
414
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
415
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
416
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
417
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
418
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
419
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
420
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
421
+ parser.add_argument(
422
+ "--prediction_type",
423
+ type=str,
424
+ default=None,
425
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
426
+ )
427
+ parser.add_argument(
428
+ "--hub_model_id",
429
+ type=str,
430
+ default=None,
431
+ help="The name of the repository to keep in sync with the local `output_dir`.",
432
+ )
433
+ parser.add_argument(
434
+ "--logging_dir",
435
+ type=str,
436
+ default="logs",
437
+ help=(
438
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
439
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
440
+ ),
441
+ )
442
+ parser.add_argument(
443
+ "--report_to",
444
+ type=str,
445
+ default="tensorboard",
446
+ help=(
447
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
448
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
449
+ ),
450
+ )
451
+ parser.add_argument(
452
+ "--mixed_precision",
453
+ type=str,
454
+ default=None,
455
+ choices=["no", "fp16", "bf16"],
456
+ help=(
457
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
458
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
459
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
460
+ ),
461
+ )
462
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
463
+ parser.add_argument(
464
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
465
+ )
466
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
467
+
468
+ if input_args is not None:
469
+ args = parser.parse_args(input_args)
470
+ else:
471
+ args = parser.parse_args()
472
+
473
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
474
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
475
+ args.local_rank = env_local_rank
476
+
477
+ # Sanity checks
478
+ if args.dataset_name is None and args.train_data_dir is None:
479
+ raise ValueError("Need either a dataset name or a training folder.")
480
+
481
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
482
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
483
+
484
+ return args
485
+
486
+
487
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
488
+ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
489
+ prompt_embeds_list = []
490
+ prompt_batch = batch[caption_column]
491
+
492
+ captions = []
493
+ for caption in prompt_batch:
494
+ if random.random() < proportion_empty_prompts:
495
+ captions.append("")
496
+ elif isinstance(caption, str):
497
+ captions.append(caption)
498
+ elif isinstance(caption, (list, np.ndarray)):
499
+ # take a random caption if there are multiple
500
+ captions.append(random.choice(caption) if is_train else caption[0])
501
+
502
+ with torch.no_grad():
503
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
504
+ text_inputs = tokenizer(
505
+ captions,
506
+ padding="max_length",
507
+ max_length=tokenizer.model_max_length,
508
+ truncation=True,
509
+ return_tensors="pt",
510
+ )
511
+ text_input_ids = text_inputs.input_ids
512
+ prompt_embeds = text_encoder(
513
+ text_input_ids.to(text_encoder.device),
514
+ output_hidden_states=True,
515
+ return_dict=False,
516
+ )
517
+
518
+ # We are only ALWAYS interested in the pooled output of the final text encoder
519
+ pooled_prompt_embeds = prompt_embeds[0]
520
+ prompt_embeds = prompt_embeds[-1][-2]
521
+ bs_embed, seq_len, _ = prompt_embeds.shape
522
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
523
+ prompt_embeds_list.append(prompt_embeds)
524
+
525
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
526
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
527
+ return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
528
+
529
+
530
+ def compute_vae_encodings(batch, vae):
531
+ images = batch.pop("pixel_values")
532
+ pixel_values = torch.stack(list(images))
533
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
534
+ pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
535
+
536
+ with torch.no_grad():
537
+ model_input = vae.encode(pixel_values).latent_dist.sample()
538
+ model_input = model_input * vae.config.scaling_factor
539
+ return {"model_input": model_input.cpu()}
540
+
541
+
542
+ def generate_timestep_weights(args, num_timesteps):
543
+ weights = torch.ones(num_timesteps)
544
+
545
+ # Determine the indices to bias
546
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
547
+
548
+ if args.timestep_bias_strategy == "later":
549
+ bias_indices = slice(-num_to_bias, None)
550
+ elif args.timestep_bias_strategy == "earlier":
551
+ bias_indices = slice(0, num_to_bias)
552
+ elif args.timestep_bias_strategy == "range":
553
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
554
+ range_begin = args.timestep_bias_begin
555
+ range_end = args.timestep_bias_end
556
+ if range_begin < 0:
557
+ raise ValueError(
558
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
559
+ )
560
+ if range_end > num_timesteps:
561
+ raise ValueError(
562
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
563
+ )
564
+ bias_indices = slice(range_begin, range_end)
565
+ else: # 'none' or any other string
566
+ return weights
567
+ if args.timestep_bias_multiplier <= 0:
568
+ return ValueError(
569
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
570
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
571
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
572
+ )
573
+
574
+ # Apply the bias
575
+ weights[bias_indices] *= args.timestep_bias_multiplier
576
+
577
+ # Normalize
578
+ weights /= weights.sum()
579
+
580
+ return weights
581
+
582
+
583
+ def main(args):
584
+ if args.report_to == "wandb" and args.hub_token is not None:
585
+ raise ValueError(
586
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
587
+ " Please use `huggingface-cli login` to authenticate with the Hub."
588
+ )
589
+
590
+ logging_dir = Path(args.output_dir, args.logging_dir)
591
+
592
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
593
+
594
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
595
+ # due to pytorch#99272, MPS does not yet support bfloat16.
596
+ raise ValueError(
597
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
598
+ )
599
+
600
+ accelerator = Accelerator(
601
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
602
+ mixed_precision=args.mixed_precision,
603
+ log_with=args.report_to,
604
+ project_config=accelerator_project_config,
605
+ )
606
+
607
+ # Disable AMP for MPS.
608
+ if torch.backends.mps.is_available():
609
+ accelerator.native_amp = False
610
+
611
+ if args.report_to == "wandb":
612
+ if not is_wandb_available():
613
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
614
+ import wandb
615
+
616
+ # Make one log on every process with the configuration for debugging.
617
+ logging.basicConfig(
618
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
619
+ datefmt="%m/%d/%Y %H:%M:%S",
620
+ level=logging.INFO,
621
+ )
622
+ logger.info(accelerator.state, main_process_only=False)
623
+ if accelerator.is_local_main_process:
624
+ datasets.utils.logging.set_verbosity_warning()
625
+ transformers.utils.logging.set_verbosity_warning()
626
+ diffusers.utils.logging.set_verbosity_info()
627
+ else:
628
+ datasets.utils.logging.set_verbosity_error()
629
+ transformers.utils.logging.set_verbosity_error()
630
+ diffusers.utils.logging.set_verbosity_error()
631
+
632
+ # If passed along, set the training seed now.
633
+ if args.seed is not None:
634
+ set_seed(args.seed)
635
+
636
+ # Handle the repository creation
637
+ if accelerator.is_main_process:
638
+ if args.output_dir is not None:
639
+ os.makedirs(args.output_dir, exist_ok=True)
640
+
641
+ if args.push_to_hub:
642
+ repo_id = create_repo(
643
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
644
+ ).repo_id
645
+
646
+ # Load the tokenizers
647
+ tokenizer_one = AutoTokenizer.from_pretrained(
648
+ args.pretrained_model_name_or_path,
649
+ subfolder="tokenizer",
650
+ revision=args.revision,
651
+ use_fast=False,
652
+ )
653
+ tokenizer_two = AutoTokenizer.from_pretrained(
654
+ args.pretrained_model_name_or_path,
655
+ subfolder="tokenizer_2",
656
+ revision=args.revision,
657
+ use_fast=False,
658
+ )
659
+
660
+ # import correct text encoder classes
661
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
662
+ args.pretrained_model_name_or_path, args.revision
663
+ )
664
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
665
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
666
+ )
667
+
668
+ # Load scheduler and models
669
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
670
+ # Check for terminal SNR in combination with SNR Gamma
671
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
672
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
673
+ )
674
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
675
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
676
+ )
677
+ vae_path = (
678
+ args.pretrained_model_name_or_path
679
+ if args.pretrained_vae_model_name_or_path is None
680
+ else args.pretrained_vae_model_name_or_path
681
+ )
682
+ vae = AutoencoderKL.from_pretrained(
683
+ vae_path,
684
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
685
+ revision=args.revision,
686
+ variant=args.variant,
687
+ )
688
+ unet = UNet2DConditionModel.from_pretrained(
689
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
690
+ )
691
+
692
+ # Freeze vae and text encoders.
693
+ vae.requires_grad_(False)
694
+ text_encoder_one.requires_grad_(False)
695
+ text_encoder_two.requires_grad_(False)
696
+ # Set unet as trainable.
697
+ unet.train()
698
+
699
+ # For mixed precision training we cast all non-trainable weights to half-precision
700
+ # as these weights are only used for inference, keeping weights in full precision is not required.
701
+ weight_dtype = torch.float32
702
+ if accelerator.mixed_precision == "fp16":
703
+ weight_dtype = torch.float16
704
+ elif accelerator.mixed_precision == "bf16":
705
+ weight_dtype = torch.bfloat16
706
+
707
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
708
+ # The VAE is in float32 to avoid NaN losses.
709
+ vae.to(accelerator.device, dtype=torch.float32)
710
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
711
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
712
+
713
+ # Create EMA for the unet.
714
+ if args.use_ema:
715
+ ema_unet = UNet2DConditionModel.from_pretrained(
716
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
717
+ )
718
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
719
+
720
+ if args.enable_xformers_memory_efficient_attention:
721
+ if is_xformers_available():
722
+ import xformers
723
+
724
+ xformers_version = version.parse(xformers.__version__)
725
+ if xformers_version == version.parse("0.0.16"):
726
+ logger.warning(
727
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
728
+ )
729
+ unet.enable_xformers_memory_efficient_attention()
730
+ else:
731
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
732
+
733
+ # `accelerate` 0.16.0 will have better support for customized saving
734
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
735
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
736
+ def save_model_hook(models, weights, output_dir):
737
+ if accelerator.is_main_process:
738
+ if args.use_ema:
739
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
740
+
741
+ for i, model in enumerate(models):
742
+ model.save_pretrained(os.path.join(output_dir, "unet"))
743
+
744
+ # make sure to pop weight so that corresponding model is not saved again
745
+ weights.pop()
746
+
747
+ def load_model_hook(models, input_dir):
748
+ if args.use_ema:
749
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
750
+ ema_unet.load_state_dict(load_model.state_dict())
751
+ ema_unet.to(accelerator.device)
752
+ del load_model
753
+
754
+ for _ in range(len(models)):
755
+ # pop models so that they are not loaded again
756
+ model = models.pop()
757
+
758
+ # load diffusers style into model
759
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
760
+ model.register_to_config(**load_model.config)
761
+
762
+ model.load_state_dict(load_model.state_dict())
763
+ del load_model
764
+
765
+ accelerator.register_save_state_pre_hook(save_model_hook)
766
+ accelerator.register_load_state_pre_hook(load_model_hook)
767
+
768
+ if args.gradient_checkpointing:
769
+ unet.enable_gradient_checkpointing()
770
+
771
+ # Enable TF32 for faster training on Ampere GPUs,
772
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
773
+ if args.allow_tf32:
774
+ torch.backends.cuda.matmul.allow_tf32 = True
775
+
776
+ if args.scale_lr:
777
+ args.learning_rate = (
778
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
779
+ )
780
+
781
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
782
+ if args.use_8bit_adam:
783
+ try:
784
+ import bitsandbytes as bnb
785
+ except ImportError:
786
+ raise ImportError(
787
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
788
+ )
789
+
790
+ optimizer_class = bnb.optim.AdamW8bit
791
+ else:
792
+ optimizer_class = torch.optim.AdamW
793
+
794
+ # Optimizer creation
795
+ params_to_optimize = unet.parameters()
796
+ optimizer = optimizer_class(
797
+ params_to_optimize,
798
+ lr=args.learning_rate,
799
+ betas=(args.adam_beta1, args.adam_beta2),
800
+ weight_decay=args.adam_weight_decay,
801
+ eps=args.adam_epsilon,
802
+ )
803
+
804
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
805
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
806
+
807
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
808
+ # download the dataset.
809
+ if args.dataset_name is not None:
810
+ # Downloading and loading a dataset from the hub.
811
+ dataset = load_dataset(
812
+ args.dataset_name,
813
+ args.dataset_config_name,
814
+ cache_dir=args.cache_dir,
815
+ )
816
+ else:
817
+ data_files = {}
818
+ if args.train_data_dir is not None:
819
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
820
+ dataset = load_dataset(
821
+ "imagefolder",
822
+ data_files=data_files,
823
+ cache_dir=args.cache_dir,
824
+ )
825
+ # See more about loading custom images at
826
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
827
+
828
+ # Preprocessing the datasets.
829
+ # We need to tokenize inputs and targets.
830
+ column_names = dataset["train"].column_names
831
+
832
+ # 6. Get the column names for input/target.
833
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
834
+ if args.image_column is None:
835
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
836
+ else:
837
+ image_column = args.image_column
838
+ if image_column not in column_names:
839
+ raise ValueError(
840
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
841
+ )
842
+ if args.caption_column is None:
843
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
844
+ else:
845
+ caption_column = args.caption_column
846
+ if caption_column not in column_names:
847
+ raise ValueError(
848
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
849
+ )
850
+
851
+ # Preprocessing the datasets.
852
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
853
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
854
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
855
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
856
+
857
+ def preprocess_train(examples):
858
+ images = [image.convert("RGB") for image in examples[image_column]]
859
+ # image aug
860
+ original_sizes = []
861
+ all_images = []
862
+ crop_top_lefts = []
863
+ for image in images:
864
+ original_sizes.append((image.height, image.width))
865
+ image = train_resize(image)
866
+ if args.random_flip and random.random() < 0.5:
867
+ # flip
868
+ image = train_flip(image)
869
+ if args.center_crop:
870
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
871
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
872
+ image = train_crop(image)
873
+ else:
874
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
875
+ image = crop(image, y1, x1, h, w)
876
+ crop_top_left = (y1, x1)
877
+ crop_top_lefts.append(crop_top_left)
878
+ image = train_transforms(image)
879
+ all_images.append(image)
880
+
881
+ examples["original_sizes"] = original_sizes
882
+ examples["crop_top_lefts"] = crop_top_lefts
883
+ examples["pixel_values"] = all_images
884
+ return examples
885
+
886
+ with accelerator.main_process_first():
887
+ if args.max_train_samples is not None:
888
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
889
+ # Set the training transforms
890
+ train_dataset = dataset["train"].with_transform(preprocess_train)
891
+
892
+ # Let's first compute all the embeddings so that we can free up the text encoders
893
+ # from memory. We will pre-compute the VAE encodings too.
894
+ text_encoders = [text_encoder_one, text_encoder_two]
895
+ tokenizers = [tokenizer_one, tokenizer_two]
896
+ compute_embeddings_fn = functools.partial(
897
+ encode_prompt,
898
+ text_encoders=text_encoders,
899
+ tokenizers=tokenizers,
900
+ proportion_empty_prompts=args.proportion_empty_prompts,
901
+ caption_column=args.caption_column,
902
+ )
903
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
904
+ with accelerator.main_process_first():
905
+ from datasets.fingerprint import Hasher
906
+
907
+ # fingerprint used by the cache for the other processes to load the result
908
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
909
+ new_fingerprint = Hasher.hash(args)
910
+ new_fingerprint_for_vae = Hasher.hash(vae_path)
911
+ train_dataset_with_embeddings = train_dataset.map(
912
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
913
+ )
914
+ train_dataset_with_vae = train_dataset.map(
915
+ compute_vae_encodings_fn,
916
+ batched=True,
917
+ batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
918
+ new_fingerprint=new_fingerprint_for_vae,
919
+ )
920
+ precomputed_dataset = concatenate_datasets(
921
+ [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
922
+ )
923
+ precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
924
+
925
+ del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
926
+ del text_encoders, tokenizers, vae
927
+ gc.collect()
928
+ torch.cuda.empty_cache()
929
+
930
+ def collate_fn(examples):
931
+ model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
932
+ original_sizes = [example["original_sizes"] for example in examples]
933
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
934
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
935
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
936
+
937
+ return {
938
+ "model_input": model_input,
939
+ "prompt_embeds": prompt_embeds,
940
+ "pooled_prompt_embeds": pooled_prompt_embeds,
941
+ "original_sizes": original_sizes,
942
+ "crop_top_lefts": crop_top_lefts,
943
+ }
944
+
945
+ # DataLoaders creation:
946
+ train_dataloader = torch.utils.data.DataLoader(
947
+ precomputed_dataset,
948
+ shuffle=True,
949
+ collate_fn=collate_fn,
950
+ batch_size=args.train_batch_size,
951
+ num_workers=args.dataloader_num_workers,
952
+ )
953
+
954
+ # Scheduler and math around the number of training steps.
955
+ overrode_max_train_steps = False
956
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
957
+ if args.max_train_steps is None:
958
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
959
+ overrode_max_train_steps = True
960
+
961
+ lr_scheduler = get_scheduler(
962
+ args.lr_scheduler,
963
+ optimizer=optimizer,
964
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
965
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
966
+ )
967
+
968
+ # Prepare everything with our `accelerator`.
969
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
970
+ unet, optimizer, train_dataloader, lr_scheduler
971
+ )
972
+
973
+ if args.use_ema:
974
+ ema_unet.to(accelerator.device)
975
+
976
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
977
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
978
+ if overrode_max_train_steps:
979
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
980
+ # Afterwards we recalculate our number of training epochs
981
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
982
+
983
+ # We need to initialize the trackers we use, and also store our configuration.
984
+ # The trackers initializes automatically on the main process.
985
+ if accelerator.is_main_process:
986
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
987
+
988
+ # Function for unwrapping if torch.compile() was used in accelerate.
989
+ def unwrap_model(model):
990
+ model = accelerator.unwrap_model(model)
991
+ model = model._orig_mod if is_compiled_module(model) else model
992
+ return model
993
+
994
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
995
+ autocast_ctx = nullcontext()
996
+ else:
997
+ autocast_ctx = torch.autocast(accelerator.device.type)
998
+
999
+ # Train!
1000
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1001
+
1002
+ logger.info("***** Running training *****")
1003
+ logger.info(f" Num examples = {len(precomputed_dataset)}")
1004
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1005
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1006
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1007
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1008
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1009
+ global_step = 0
1010
+ first_epoch = 0
1011
+
1012
+ # Potentially load in the weights and states from a previous save
1013
+ if args.resume_from_checkpoint:
1014
+ if args.resume_from_checkpoint != "latest":
1015
+ path = os.path.basename(args.resume_from_checkpoint)
1016
+ else:
1017
+ # Get the most recent checkpoint
1018
+ dirs = os.listdir(args.output_dir)
1019
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1020
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1021
+ path = dirs[-1] if len(dirs) > 0 else None
1022
+
1023
+ if path is None:
1024
+ accelerator.print(
1025
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1026
+ )
1027
+ args.resume_from_checkpoint = None
1028
+ initial_global_step = 0
1029
+ else:
1030
+ accelerator.print(f"Resuming from checkpoint {path}")
1031
+ accelerator.load_state(os.path.join(args.output_dir, path))
1032
+ global_step = int(path.split("-")[1])
1033
+
1034
+ initial_global_step = global_step
1035
+ first_epoch = global_step // num_update_steps_per_epoch
1036
+
1037
+ else:
1038
+ initial_global_step = 0
1039
+
1040
+ progress_bar = tqdm(
1041
+ range(0, args.max_train_steps),
1042
+ initial=initial_global_step,
1043
+ desc="Steps",
1044
+ # Only show the progress bar once on each machine.
1045
+ disable=not accelerator.is_local_main_process,
1046
+ )
1047
+
1048
+ for epoch in range(first_epoch, args.num_train_epochs):
1049
+ train_loss = 0.0
1050
+ for step, batch in enumerate(train_dataloader):
1051
+ with accelerator.accumulate(unet):
1052
+ # Sample noise that we'll add to the latents
1053
+ model_input = batch["model_input"].to(accelerator.device)
1054
+ noise = torch.randn_like(model_input)
1055
+ if args.noise_offset:
1056
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
1057
+ noise += args.noise_offset * torch.randn(
1058
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
1059
+ )
1060
+
1061
+ bsz = model_input.shape[0]
1062
+ if args.timestep_bias_strategy == "none":
1063
+ # Sample a random timestep for each image without bias.
1064
+ timesteps = torch.randint(
1065
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1066
+ )
1067
+ else:
1068
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
1069
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
1070
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
1071
+ model_input.device
1072
+ )
1073
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
1074
+
1075
+ # Add noise to the model input according to the noise magnitude at each timestep
1076
+ # (this is the forward diffusion process)
1077
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1078
+
1079
+ # time ids
1080
+ def compute_time_ids(original_size, crops_coords_top_left):
1081
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1082
+ target_size = (args.resolution, args.resolution)
1083
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1084
+ add_time_ids = torch.tensor([add_time_ids])
1085
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1086
+ return add_time_ids
1087
+
1088
+ add_time_ids = torch.cat(
1089
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1090
+ )
1091
+
1092
+ # Predict the noise residual
1093
+ unet_added_conditions = {"time_ids": add_time_ids}
1094
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
1095
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
1096
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1097
+ model_pred = unet(
1098
+ noisy_model_input,
1099
+ timesteps,
1100
+ prompt_embeds,
1101
+ added_cond_kwargs=unet_added_conditions,
1102
+ return_dict=False,
1103
+ )[0]
1104
+
1105
+ # Get the target for loss depending on the prediction type
1106
+ if args.prediction_type is not None:
1107
+ # set prediction_type of scheduler if defined
1108
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
1109
+
1110
+ if noise_scheduler.config.prediction_type == "epsilon":
1111
+ target = noise
1112
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1113
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1114
+ elif noise_scheduler.config.prediction_type == "sample":
1115
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
1116
+ target = model_input
1117
+ # We will have to subtract the noise residual from the prediction to get the target sample.
1118
+ model_pred = model_pred - noise
1119
+ else:
1120
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1121
+
1122
+ if args.snr_gamma is None:
1123
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1124
+ else:
1125
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1126
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1127
+ # This is discussed in Section 4.2 of the same paper.
1128
+ snr = compute_snr(noise_scheduler, timesteps)
1129
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1130
+ dim=1
1131
+ )[0]
1132
+ if noise_scheduler.config.prediction_type == "epsilon":
1133
+ mse_loss_weights = mse_loss_weights / snr
1134
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1135
+ mse_loss_weights = mse_loss_weights / (snr + 1)
1136
+
1137
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1138
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1139
+ loss = loss.mean()
1140
+
1141
+ # Gather the losses across all processes for logging (if we use distributed training).
1142
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1143
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
1144
+
1145
+ # Backpropagate
1146
+ accelerator.backward(loss)
1147
+ if accelerator.sync_gradients:
1148
+ params_to_clip = unet.parameters()
1149
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1150
+ optimizer.step()
1151
+ lr_scheduler.step()
1152
+ optimizer.zero_grad()
1153
+
1154
+ # Checks if the accelerator has performed an optimization step behind the scenes
1155
+ if accelerator.sync_gradients:
1156
+ if args.use_ema:
1157
+ ema_unet.step(unet.parameters())
1158
+ progress_bar.update(1)
1159
+ global_step += 1
1160
+ accelerator.log({"train_loss": train_loss}, step=global_step)
1161
+ train_loss = 0.0
1162
+
1163
+ if accelerator.is_main_process:
1164
+ if global_step % args.checkpointing_steps == 0:
1165
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1166
+ if args.checkpoints_total_limit is not None:
1167
+ checkpoints = os.listdir(args.output_dir)
1168
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1169
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1170
+
1171
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1172
+ if len(checkpoints) >= args.checkpoints_total_limit:
1173
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1174
+ removing_checkpoints = checkpoints[0:num_to_remove]
1175
+
1176
+ logger.info(
1177
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1178
+ )
1179
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1180
+
1181
+ for removing_checkpoint in removing_checkpoints:
1182
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1183
+ shutil.rmtree(removing_checkpoint)
1184
+
1185
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1186
+ accelerator.save_state(save_path)
1187
+ logger.info(f"Saved state to {save_path}")
1188
+
1189
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1190
+ progress_bar.set_postfix(**logs)
1191
+
1192
+ if global_step >= args.max_train_steps:
1193
+ break
1194
+
1195
+ if accelerator.is_main_process:
1196
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1197
+ logger.info(
1198
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1199
+ f" {args.validation_prompt}."
1200
+ )
1201
+ if args.use_ema:
1202
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1203
+ ema_unet.store(unet.parameters())
1204
+ ema_unet.copy_to(unet.parameters())
1205
+
1206
+ # create pipeline
1207
+ vae = AutoencoderKL.from_pretrained(
1208
+ vae_path,
1209
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1210
+ revision=args.revision,
1211
+ variant=args.variant,
1212
+ )
1213
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1214
+ args.pretrained_model_name_or_path,
1215
+ vae=vae,
1216
+ unet=accelerator.unwrap_model(unet),
1217
+ revision=args.revision,
1218
+ variant=args.variant,
1219
+ torch_dtype=weight_dtype,
1220
+ )
1221
+ if args.prediction_type is not None:
1222
+ scheduler_args = {"prediction_type": args.prediction_type}
1223
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1224
+
1225
+ pipeline = pipeline.to(accelerator.device)
1226
+ pipeline.set_progress_bar_config(disable=True)
1227
+
1228
+ # run inference
1229
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1230
+ pipeline_args = {"prompt": args.validation_prompt}
1231
+
1232
+ with autocast_ctx:
1233
+ images = [
1234
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
1235
+ for _ in range(args.num_validation_images)
1236
+ ]
1237
+
1238
+ for tracker in accelerator.trackers:
1239
+ if tracker.name == "tensorboard":
1240
+ np_images = np.stack([np.asarray(img) for img in images])
1241
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1242
+ if tracker.name == "wandb":
1243
+ tracker.log(
1244
+ {
1245
+ "validation": [
1246
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1247
+ for i, image in enumerate(images)
1248
+ ]
1249
+ }
1250
+ )
1251
+
1252
+ del pipeline
1253
+ torch.cuda.empty_cache()
1254
+
1255
+ if args.use_ema:
1256
+ # Switch back to the original UNet parameters.
1257
+ ema_unet.restore(unet.parameters())
1258
+
1259
+ accelerator.wait_for_everyone()
1260
+ if accelerator.is_main_process:
1261
+ unet = unwrap_model(unet)
1262
+ if args.use_ema:
1263
+ ema_unet.copy_to(unet.parameters())
1264
+
1265
+ # Serialize pipeline.
1266
+ vae = AutoencoderKL.from_pretrained(
1267
+ vae_path,
1268
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1269
+ revision=args.revision,
1270
+ variant=args.variant,
1271
+ torch_dtype=weight_dtype,
1272
+ )
1273
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1274
+ args.pretrained_model_name_or_path,
1275
+ unet=unet,
1276
+ vae=vae,
1277
+ revision=args.revision,
1278
+ variant=args.variant,
1279
+ torch_dtype=weight_dtype,
1280
+ )
1281
+ if args.prediction_type is not None:
1282
+ scheduler_args = {"prediction_type": args.prediction_type}
1283
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1284
+ pipeline.save_pretrained(args.output_dir)
1285
+
1286
+ # run inference
1287
+ images = []
1288
+ if args.validation_prompt and args.num_validation_images > 0:
1289
+ pipeline = pipeline.to(accelerator.device)
1290
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1291
+
1292
+ with autocast_ctx:
1293
+ images = [
1294
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1295
+ for _ in range(args.num_validation_images)
1296
+ ]
1297
+
1298
+ for tracker in accelerator.trackers:
1299
+ if tracker.name == "tensorboard":
1300
+ np_images = np.stack([np.asarray(img) for img in images])
1301
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1302
+ if tracker.name == "wandb":
1303
+ tracker.log(
1304
+ {
1305
+ "test": [
1306
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1307
+ for i, image in enumerate(images)
1308
+ ]
1309
+ }
1310
+ )
1311
+
1312
+ if args.push_to_hub:
1313
+ save_model_card(
1314
+ repo_id=repo_id,
1315
+ images=images,
1316
+ validation_prompt=args.validation_prompt,
1317
+ base_model=args.pretrained_model_name_or_path,
1318
+ dataset_name=args.dataset_name,
1319
+ repo_folder=args.output_dir,
1320
+ vae_path=args.pretrained_vae_model_name_or_path,
1321
+ )
1322
+ upload_folder(
1323
+ repo_id=repo_id,
1324
+ folder_path=args.output_dir,
1325
+ commit_message="End of training",
1326
+ ignore_patterns=["step_*", "epoch_*"],
1327
+ )
1328
+
1329
+ accelerator.end_training()
1330
+
1331
+
1332
+ if __name__ == "__main__":
1333
+ args = parse_args()
1334
+ main(args)
1335
+ raise RuntimeError("The script is finished.")