first version of CoRe^2
Browse files- .gitattributes +35 -0
- readme.md +83 -0
- sample_img.py +2 -2
- weights/sd35_noise_model.pth +3 -0
- weights/sdxl_noise_model.pth +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
readme.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The Official Implementation of our Arxiv 2025 paper:
|
| 2 |
+
|
| 3 |
+
> **[CoRe^2: _Collect, Reflect and Refine_ to Generate Better and Faster](https://arxiv.org/abs/2503.09662)** <br>
|
| 4 |
+
|
| 5 |
+
Authors:
|
| 6 |
+
|
| 7 |
+
>**<em>Shitong Shao, Zikai Zhou, Dian Xie, Yuetong Fang, Tian Ye, Lichen Bai</em> and <em>Zeke Xie*</em>** <br>
|
| 8 |
+
> xLeaf Lab, HKUST (GZ) <br>
|
| 9 |
+
> *: Corresponding author
|
| 10 |
+
|
| 11 |
+
## New
|
| 12 |
+
|
| 13 |
+
- [x] Release the inference code of SD3.5 and SDXL.
|
| 14 |
+
|
| 15 |
+
- [ ] Release the inference code of FLUX.
|
| 16 |
+
|
| 17 |
+
- [ ] Release the inference code of LlamaGen.
|
| 18 |
+
|
| 19 |
+
- [ ] Release the implementation of the Collect phase.
|
| 20 |
+
|
| 21 |
+
- [ ] Release the implementation of the Reflect phase.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## Overview
|
| 25 |
+
|
| 26 |
+
This guide provides instructions on how to use the CoRe^2.
|
| 27 |
+
|
| 28 |
+
Here we provide the inference code which supports different models like ***Stable Diffusion XL, Stable Diffusion 3.5 Large.***
|
| 29 |
+
|
| 30 |
+
## Requirements
|
| 31 |
+
|
| 32 |
+
- `python version == 3.8`
|
| 33 |
+
- `pytorch with cuda version`
|
| 34 |
+
- `diffusers`
|
| 35 |
+
- `PIL`
|
| 36 |
+
- `bitsandbytes`
|
| 37 |
+
- `numpy`
|
| 38 |
+
- `timm`
|
| 39 |
+
- `argparse`
|
| 40 |
+
- `einops`
|
| 41 |
+
|
| 42 |
+
## Installation🚀️
|
| 43 |
+
|
| 44 |
+
Make sure you have successfully built `python` environment and installed `pytorch` with cuda version. Before running the script, ensure you have all the required packages installed. You can install them using:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
pip install diffusers, PIL, numpy, timm, argparse, einops
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Usage👀️
|
| 51 |
+
|
| 52 |
+
To use the CoRe^2 pipeline, you need to run the `sample_img.py` script with appropriate command-line arguments. Below are the available options:
|
| 53 |
+
|
| 54 |
+
### Command-Line Arguments
|
| 55 |
+
|
| 56 |
+
- `--pipeline`: Select the model pipeline (`sdxl`, `sd35`). Default is `sdxl`.
|
| 57 |
+
- `--prompt`: The textual prompt based on which the image will be generated. Default is "Mickey Mouse painting by Frank Frazetta."
|
| 58 |
+
- `--inference-step`: Number of inference steps for the diffusion process. Default is 50.
|
| 59 |
+
- `--cfg`: Classifier-free guidance scale. Default is 5.5.
|
| 60 |
+
- `--pretrained-path`: Path to the pretrained model weights. Default is a specified path in the script.
|
| 61 |
+
- `--size`: The size (height and width) of the generated image. Default is 1024.
|
| 62 |
+
- `--method`: Select the inference method (`standard`, `core`, `zigzag`, `z-core`)
|
| 63 |
+
|
| 64 |
+
### Running the Script
|
| 65 |
+
|
| 66 |
+
Run the script from the command line by navigating to the directory containing `sample_img.py` and executing:
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
python sample_img.py --pipeline sdxl --prompt "A banana on the left of an apple." --size 1024
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
This command will generate an image based on the prompt using the Stable Diffusion XL model with an image size of 1024x1024 pixels.
|
| 73 |
+
|
| 74 |
+
### Output🎉️
|
| 75 |
+
|
| 76 |
+
The script will save one image:
|
| 77 |
+
|
| 78 |
+
## Pre-trained Weights Download❤️
|
| 79 |
+
|
| 80 |
+
We provide the pre-trained CoRe^2 weights of Stable Diffusion XL, and Stable Diffusion 3.5 Large with https://drive.google.com/drive/folders/1alJco6X3cFw4oHTD9SifvS7apc3AwG8I?usp=drive_link
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
sample_img.py
CHANGED
|
@@ -103,14 +103,14 @@ if __name__ == '__main__':
|
|
| 103 |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
|
| 104 |
lora_true(refine_model, lora_idx=0)
|
| 105 |
|
| 106 |
-
checkpoint = torch.load('./weights/
|
| 107 |
refine_model.load_state_dict(checkpoint)
|
| 108 |
elif args.model == 'sdxl':
|
| 109 |
refine_model = PromptSDXLNet()
|
| 110 |
replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50)
|
| 111 |
lora_true(refine_model, lora_idx=0)
|
| 112 |
|
| 113 |
-
checkpoint = torch.load('./weights/
|
| 114 |
refine_model.load_state_dict(checkpoint)
|
| 115 |
|
| 116 |
print("Load Lora Success")
|
|
|
|
| 103 |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
|
| 104 |
lora_true(refine_model, lora_idx=0)
|
| 105 |
|
| 106 |
+
checkpoint = torch.load('./weights/sd35_noise_model.pth', map_location='cpu')
|
| 107 |
refine_model.load_state_dict(checkpoint)
|
| 108 |
elif args.model == 'sdxl':
|
| 109 |
refine_model = PromptSDXLNet()
|
| 110 |
replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50)
|
| 111 |
lora_true(refine_model, lora_idx=0)
|
| 112 |
|
| 113 |
+
checkpoint = torch.load('./weights/sdxl_noise_model.pth', map_location='cpu')
|
| 114 |
refine_model.load_state_dict(checkpoint)
|
| 115 |
|
| 116 |
print("Load Lora Success")
|
weights/sd35_noise_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f99a9b437fba4da9c3fb87516c6285bd9bac07f1969a4ba4d631734412edaf2
|
| 3 |
+
size 2881450254
|
weights/sdxl_noise_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad620bd3a604908abfe8178e05f34a83434db246cd63f151755f26de14c5f241
|
| 3 |
+
size 2034660755
|