Klayand commited on
Commit
54dd802
·
1 Parent(s): ef33138

first version of CoRe^2

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
readme.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The Official Implementation of our Arxiv 2025 paper:
2
+
3
+ > **[CoRe^2: _Collect, Reflect and Refine_ to Generate Better and Faster](https://arxiv.org/abs/2503.09662)** <br>
4
+
5
+ Authors:
6
+
7
+ >**<em>Shitong Shao, Zikai Zhou, Dian Xie, Yuetong Fang, Tian Ye, Lichen Bai</em> and <em>Zeke Xie*</em>** <br>
8
+ > xLeaf Lab, HKUST (GZ) <br>
9
+ > *: Corresponding author
10
+
11
+ ## New
12
+
13
+ - [x] Release the inference code of SD3.5 and SDXL.
14
+
15
+ - [ ] Release the inference code of FLUX.
16
+
17
+ - [ ] Release the inference code of LlamaGen.
18
+
19
+ - [ ] Release the implementation of the Collect phase.
20
+
21
+ - [ ] Release the implementation of the Reflect phase.
22
+
23
+
24
+ ## Overview
25
+
26
+ This guide provides instructions on how to use the CoRe^2.
27
+
28
+ Here we provide the inference code which supports different models like ***Stable Diffusion XL, Stable Diffusion 3.5 Large.***
29
+
30
+ ## Requirements
31
+
32
+ - `python version == 3.8`
33
+ - `pytorch with cuda version`
34
+ - `diffusers`
35
+ - `PIL`
36
+ - `bitsandbytes`
37
+ - `numpy`
38
+ - `timm`
39
+ - `argparse`
40
+ - `einops`
41
+
42
+ ## Installation🚀️
43
+
44
+ Make sure you have successfully built `python` environment and installed `pytorch` with cuda version. Before running the script, ensure you have all the required packages installed. You can install them using:
45
+
46
+ ```bash
47
+ pip install diffusers, PIL, numpy, timm, argparse, einops
48
+ ```
49
+
50
+ ## Usage👀️
51
+
52
+ To use the CoRe^2 pipeline, you need to run the `sample_img.py` script with appropriate command-line arguments. Below are the available options:
53
+
54
+ ### Command-Line Arguments
55
+
56
+ - `--pipeline`: Select the model pipeline (`sdxl`, `sd35`). Default is `sdxl`.
57
+ - `--prompt`: The textual prompt based on which the image will be generated. Default is "Mickey Mouse painting by Frank Frazetta."
58
+ - `--inference-step`: Number of inference steps for the diffusion process. Default is 50.
59
+ - `--cfg`: Classifier-free guidance scale. Default is 5.5.
60
+ - `--pretrained-path`: Path to the pretrained model weights. Default is a specified path in the script.
61
+ - `--size`: The size (height and width) of the generated image. Default is 1024.
62
+ - `--method`: Select the inference method (`standard`, `core`, `zigzag`, `z-core`)
63
+
64
+ ### Running the Script
65
+
66
+ Run the script from the command line by navigating to the directory containing `sample_img.py` and executing:
67
+
68
+ ```
69
+ python sample_img.py --pipeline sdxl --prompt "A banana on the left of an apple." --size 1024
70
+ ```
71
+
72
+ This command will generate an image based on the prompt using the Stable Diffusion XL model with an image size of 1024x1024 pixels.
73
+
74
+ ### Output🎉️
75
+
76
+ The script will save one image:
77
+
78
+ ## Pre-trained Weights Download❤️
79
+
80
+ We provide the pre-trained CoRe^2 weights of Stable Diffusion XL, and Stable Diffusion 3.5 Large with https://drive.google.com/drive/folders/1alJco6X3cFw4oHTD9SifvS7apc3AwG8I?usp=drive_link
81
+
82
+
83
+
sample_img.py CHANGED
@@ -103,14 +103,14 @@ if __name__ == '__main__':
103
  replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
104
  lora_true(refine_model, lora_idx=0)
105
 
106
- checkpoint = torch.load('./weights/sd35_ckpt_v9.pth', map_location='cpu')
107
  refine_model.load_state_dict(checkpoint)
108
  elif args.model == 'sdxl':
109
  refine_model = PromptSDXLNet()
110
  replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50)
111
  lora_true(refine_model, lora_idx=0)
112
 
113
- checkpoint = torch.load('./weights/sdxl_ckpt_v9.pth', map_location='cpu')
114
  refine_model.load_state_dict(checkpoint)
115
 
116
  print("Load Lora Success")
 
103
  replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
104
  lora_true(refine_model, lora_idx=0)
105
 
106
+ checkpoint = torch.load('./weights/sd35_noise_model.pth', map_location='cpu')
107
  refine_model.load_state_dict(checkpoint)
108
  elif args.model == 'sdxl':
109
  refine_model = PromptSDXLNet()
110
  replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50)
111
  lora_true(refine_model, lora_idx=0)
112
 
113
+ checkpoint = torch.load('./weights/sdxl_noise_model.pth', map_location='cpu')
114
  refine_model.load_state_dict(checkpoint)
115
 
116
  print("Load Lora Success")
weights/sd35_noise_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f99a9b437fba4da9c3fb87516c6285bd9bac07f1969a4ba4d631734412edaf2
3
+ size 2881450254
weights/sdxl_noise_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad620bd3a604908abfe8178e05f34a83434db246cd63f151755f26de14c5f241
3
+ size 2034660755