Commit
·
be89dda
1
Parent(s):
26b9f7f
upload DDPM inference script
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- README.md +88 -0
- docs/BT202504010900-ddpm.gif +3 -0
- inference.py +127 -0
- model_architect/UNet_DDPM.py +372 -0
- model_architect/layers.py +223 -0
- model_architect/utils.py +56 -0
- model_architect/weight_init.py +40 -0
- model_weights/ft06_01hr/weights.ckpt +3 -0
- model_weights/ft36_06hr/weights.ckpt +3 -0
- requirements.txt +3 -0
- sample_data/sample_202504131100.npz +3 -0
- sample_data/sample_202504161200.npz +3 -0
- sample_data/sample_202507151200.npz +3 -0
.gitattributes
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- solar-radiation
|
| 5 |
+
- deep-learning
|
| 6 |
+
- nowcasting
|
| 7 |
+
- ddpm
|
| 8 |
+
- MCVD
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# DDPM Solar Radiation model
|
| 12 |
+
A deep learning model for solar radiation nowcasting using modified [MCVD](https://arxiv.org/pdf/2205.09853) model, a kind of DDPM model for video generation. The model predicts clearsky index and converts it to solar radiation for up to 6 or 36 time steps ahead.
|
| 13 |
+
|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
## Overview
|
| 17 |
+
|
| 18 |
+
This repository contains two trained models (1hr & 6hr) for solar radiation forecasting:
|
| 19 |
+
- **1hr DDPM Model**: Predicts solar radiation up to 1 hour ahead (6 time steps)
|
| 20 |
+
- **6hr DDPM Model**: Predicts solar radiation up to 6 hours ahead (36 time steps).
|
| 21 |
+
|
| 22 |
+
The model uses multiple input sources:
|
| 23 |
+
- **Himawari satellite data**: Clearsky index calculated from Himawari satellite data
|
| 24 |
+
- **WRF Prediction**: Clearsky index from WRF's solar irradiation prediction
|
| 25 |
+
- **Topography**: Static topographical features
|
| 26 |
+
|
| 27 |
+
## Installation
|
| 28 |
+
|
| 29 |
+
1. Clone the repository & install Git LFS:
|
| 30 |
+
```bash
|
| 31 |
+
git lfs install
|
| 32 |
+
git clone <repository-url>
|
| 33 |
+
cd Diffusion_SolRad
|
| 34 |
+
git lfs pull
|
| 35 |
+
git lfs ls-files # confirm whether models weights & sample data are downloaded
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
2. Install dependencies:
|
| 39 |
+
```bash
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Requirements
|
| 44 |
+
|
| 45 |
+
- Python 3.x
|
| 46 |
+
- PyTorch 2.4.0
|
| 47 |
+
- NumPy 1.26.4
|
| 48 |
+
- einops 0.8.0
|
| 49 |
+
|
| 50 |
+
## Usage
|
| 51 |
+
|
| 52 |
+
### Basic Inference
|
| 53 |
+
|
| 54 |
+
Run solar radiation prediction using the pre-trained models:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
python inference.py --pred-hr [1hr/6hr] --pred-mode [DDPM/DDIM] --basetime 202504131100
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### Command Line Arguments
|
| 61 |
+
|
| 62 |
+
- `pred-mode`: Choose between `DDPM` or `DDIM` sampling methods (default: `DDPM`)
|
| 63 |
+
- `pred-hr`: Choose between `1hr` or `6hr` prediction models (default: `1hr`)
|
| 64 |
+
- `--basetime`: Timestamp for input data in format YYYYMMDDHHMM (default: `202504131100`)
|
| 65 |
+
|
| 66 |
+
### Example
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
# DDIM sampling method for 1-hour prediction
|
| 70 |
+
python inference.py --pred-hr 1hr --pred-mode DDIM --basetime 202507151200
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Sample Data
|
| 74 |
+
|
| 75 |
+
The repository includes sample data files:
|
| 76 |
+
- `sample_202504131100.npz`
|
| 77 |
+
- `sample_202504161200.npz`
|
| 78 |
+
- `sample_202507151200.npz`
|
| 79 |
+
|
| 80 |
+
## Model Weights
|
| 81 |
+
|
| 82 |
+
Pre-trained weights are available for both models:
|
| 83 |
+
- `model_weights/ft06_01hr/weights.ckpt`
|
| 84 |
+
- `model_weights/ft36_06hr/weights.ckpt`
|
| 85 |
+
|
| 86 |
+
## License
|
| 87 |
+
|
| 88 |
+
This project is released under the MIT License.
|
docs/BT202504010900-ddpm.gif
ADDED
|
Git LFS Details
|
inference.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import argparse
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Sequence
|
| 6 |
+
import sys
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from model_architect.UNet_DDPM import UNet_with_time, DDPM
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Config:
|
| 17 |
+
input_frame: int = 12
|
| 18 |
+
output_frame: int = 6
|
| 19 |
+
cond_nc: int = 5
|
| 20 |
+
time_emb_dim: int = 128
|
| 21 |
+
base_chs: int = 32
|
| 22 |
+
chs_mult: tuple = (1, 2, 4, 8, 8) ## different resolution
|
| 23 |
+
use_attn_list: tuple = (0, 0, 1, 1, 1) # 0 means no attention, 1 means use attention
|
| 24 |
+
n_res_blocks: int = 2
|
| 25 |
+
n_steps: int = 1000
|
| 26 |
+
dropout: float = 0.1
|
| 27 |
+
|
| 28 |
+
def data_loading(BASETIME, device):
|
| 29 |
+
data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz')
|
| 30 |
+
|
| 31 |
+
inputs = {}
|
| 32 |
+
for key in data_npz:
|
| 33 |
+
inputs[key] = torch.from_numpy(data_npz[key]).to(device)
|
| 34 |
+
|
| 35 |
+
return inputs
|
| 36 |
+
|
| 37 |
+
def arg_parse():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
'--pred-hr',
|
| 41 |
+
type=str,
|
| 42 |
+
default='1hr',
|
| 43 |
+
choices=[
|
| 44 |
+
'1hr',
|
| 45 |
+
'6hr'
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
'--pred-mode',
|
| 50 |
+
type=str,
|
| 51 |
+
default='DDPM',
|
| 52 |
+
choices=[
|
| 53 |
+
'DDPM',
|
| 54 |
+
'DDIM'
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument('--basetime', type=str, default='202504131100')
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
return args
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
config = Config()
|
| 63 |
+
args = arg_parse()
|
| 64 |
+
pred_hr = args.pred_hr
|
| 65 |
+
pred_mode = args.pred_mode
|
| 66 |
+
|
| 67 |
+
BASETIME = args.basetime
|
| 68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 69 |
+
|
| 70 |
+
inputs = data_loading(BASETIME, device)
|
| 71 |
+
model_config = Config()
|
| 72 |
+
if pred_hr == '6hr':
|
| 73 |
+
model_config.input_frame = 72
|
| 74 |
+
model_config.output_frame = 36
|
| 75 |
+
print("Prediction mode:", pred_mode)
|
| 76 |
+
print("Prediction horizon:", pred_hr)
|
| 77 |
+
|
| 78 |
+
## preporcess inputs for DDPM model
|
| 79 |
+
## concat previous Himawari and topo as conditional input (B, 5, 512, 512)
|
| 80 |
+
## WRF dim: (B, 36, 512, 512). 1hr: (B, 6, 512, 512), 6hr: (B, 36, 512, 512)
|
| 81 |
+
prev_himawari = inputs['Himawari'].squeeze(2)
|
| 82 |
+
topo = inputs['topo']
|
| 83 |
+
input_ = torch.cat([prev_himawari, topo], dim=1)
|
| 84 |
+
WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear')
|
| 85 |
+
|
| 86 |
+
clearsky = inputs['clearsky']
|
| 87 |
+
if pred_hr == '1hr':
|
| 88 |
+
WRF = WRF[:, :6]
|
| 89 |
+
clearsky = clearsky[:, :6]
|
| 90 |
+
|
| 91 |
+
backbone = UNet_with_time(model_config)
|
| 92 |
+
model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512))
|
| 93 |
+
|
| 94 |
+
## load model weights
|
| 95 |
+
if pred_hr == '1hr':
|
| 96 |
+
ckpt_path = './model_weights/ft06_01hr/weights.ckpt'
|
| 97 |
+
elif pred_hr == '6hr':
|
| 98 |
+
ckpt_path = './model_weights/ft36_06hr/weights.ckpt'
|
| 99 |
+
|
| 100 |
+
ckpt = torch.load(ckpt_path, weights_only=True)
|
| 101 |
+
model.load_state_dict(ckpt['state_dict'])
|
| 102 |
+
model.eval()
|
| 103 |
+
model = model.to(device)
|
| 104 |
+
|
| 105 |
+
if pred_mode == 'DDPM':
|
| 106 |
+
pred_clr_idx = model.sample_ddpm(
|
| 107 |
+
input_,
|
| 108 |
+
input_cond=WRF,
|
| 109 |
+
verbose="text"
|
| 110 |
+
)
|
| 111 |
+
elif pred_mode == 'DDIM':
|
| 112 |
+
pred_clr_idx = model.sample_ddim(
|
| 113 |
+
input_,
|
| 114 |
+
input_cond=WRF,
|
| 115 |
+
ddim_steps=100,
|
| 116 |
+
verbose="text"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
pred_clr_idx = (pred_clr_idx + 1.0) / 2.0
|
| 120 |
+
pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0)
|
| 121 |
+
|
| 122 |
+
## transform clearsky index to solar radiation
|
| 123 |
+
pred_srad = pred_clr_idx * clearsky
|
| 124 |
+
|
| 125 |
+
## save prediction
|
| 126 |
+
np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy())
|
| 127 |
+
print('Done')
|
model_architect/UNet_DDPM.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .layers import ResidualBlock, AttnBlock
|
| 6 |
+
from .utils import get_named_beta_schedule
|
| 7 |
+
|
| 8 |
+
def sinusoidal_embedding(n, d):
|
| 9 |
+
"""
|
| 10 |
+
n: iteration steps,
|
| 11 |
+
d: time embedding dimension
|
| 12 |
+
"""
|
| 13 |
+
# Returns the standard positional embedding
|
| 14 |
+
embedding = torch.tensor([[i / 10000 ** (2 * j / d) for j in range(d)] for i in range(n)])
|
| 15 |
+
sin_mask = torch.arange(0, n, 2)
|
| 16 |
+
|
| 17 |
+
embedding[sin_mask] = torch.sin(embedding[sin_mask])
|
| 18 |
+
embedding[1 - sin_mask] = torch.cos(embedding[sin_mask])
|
| 19 |
+
|
| 20 |
+
return embedding
|
| 21 |
+
|
| 22 |
+
def _make_te(dim_in, dim_out):
|
| 23 |
+
return nn.Sequential(
|
| 24 |
+
nn.Linear(dim_in, dim_out),
|
| 25 |
+
nn.SiLU(),
|
| 26 |
+
nn.Linear(dim_out, dim_out)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
class UNet_with_time(nn.Module):
|
| 30 |
+
def __init__(self, config):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.config = config
|
| 33 |
+
input_frame = config.input_frame
|
| 34 |
+
output_frame = config.output_frame
|
| 35 |
+
n_steps = config.n_steps
|
| 36 |
+
time_emb_dim = config.time_emb_dim
|
| 37 |
+
cond_nc = config.cond_nc
|
| 38 |
+
chs_mult = config.chs_mult ## e.g. (1, 2, 4, 8)
|
| 39 |
+
n_res_blocks = config.n_res_blocks
|
| 40 |
+
base_chs = config.base_chs
|
| 41 |
+
## e.g. (0, 0, 1, 1) -> 0 means no attention
|
| 42 |
+
use_attn_list = config.use_attn_list
|
| 43 |
+
|
| 44 |
+
layer_depth = len(chs_mult)
|
| 45 |
+
assert len(use_attn_list) == layer_depth, "length of use_attn_list should be the same as chs_mult"
|
| 46 |
+
assert input_frame >= output_frame, "input_frame should be larger than or equal to output_frame"
|
| 47 |
+
|
| 48 |
+
self.filter_list = [base_chs * m for m in chs_mult]
|
| 49 |
+
|
| 50 |
+
## time embedding
|
| 51 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 52 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 53 |
+
self.time_embed.requires_grad_(False)
|
| 54 |
+
self.time_embed_fc = _make_te(time_emb_dim, time_emb_dim)
|
| 55 |
+
## end of time embedding
|
| 56 |
+
|
| 57 |
+
## input conv
|
| 58 |
+
self.input_layer = nn.PixelUnshuffle(downscale_factor=2)
|
| 59 |
+
|
| 60 |
+
## downsampling
|
| 61 |
+
self.down_blocks = nn.ModuleList()
|
| 62 |
+
in_c = input_frame * 4 ## after pixel unshuffle
|
| 63 |
+
for i in range(layer_depth):
|
| 64 |
+
out_c = self.filter_list[i]
|
| 65 |
+
|
| 66 |
+
for _ in range(n_res_blocks):
|
| 67 |
+
self.down_blocks.append(
|
| 68 |
+
ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if use_attn_list[i]:
|
| 72 |
+
self.down_blocks.append(AttnBlock(in_c, 4)) ## num_head=4
|
| 73 |
+
|
| 74 |
+
self.down_blocks.append(
|
| 75 |
+
ResidualBlock(in_c, out_c, cond_nc, time_emb_dim, down_flag=True, up_flag=False)
|
| 76 |
+
)
|
| 77 |
+
in_c = out_c
|
| 78 |
+
## end of downsampling
|
| 79 |
+
|
| 80 |
+
## middle
|
| 81 |
+
self.mid_block1 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
|
| 82 |
+
self.mid_attn = AttnBlock(in_c, 4)
|
| 83 |
+
self.mid_block2 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
|
| 84 |
+
## end of middle
|
| 85 |
+
|
| 86 |
+
## upsampling
|
| 87 |
+
self.up_blocks = nn.ModuleList()
|
| 88 |
+
self.filter_list = [input_frame * 4] + self.filter_list[:-1]
|
| 89 |
+
for i in reversed(range(layer_depth)): ## i = layer_depth-1, ..., 0
|
| 90 |
+
out_c = self.filter_list[i]
|
| 91 |
+
|
| 92 |
+
self.up_blocks.append(
|
| 93 |
+
ResidualBlock(in_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=True)
|
| 94 |
+
)
|
| 95 |
+
if use_attn_list[i]:
|
| 96 |
+
self.up_blocks.append(AttnBlock(out_c)) ## num_head=1
|
| 97 |
+
|
| 98 |
+
for _ in range(n_res_blocks):
|
| 99 |
+
self.up_blocks.append(
|
| 100 |
+
ResidualBlock(out_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
in_c = out_c
|
| 104 |
+
|
| 105 |
+
## end of upsampling
|
| 106 |
+
self.out_up = nn.PixelShuffle(upscale_factor=2)
|
| 107 |
+
self.out_conv = nn.Conv2d(input_frame, output_frame, 3, padding=1)
|
| 108 |
+
|
| 109 |
+
def forward(self, x, t, cond):
|
| 110 |
+
"""
|
| 111 |
+
x: (b, in_c, h, w), noisy input (concatenated with some data)
|
| 112 |
+
t: (b,), time step
|
| 113 |
+
cond: (b, cond_nc, h, w), conditional input
|
| 114 |
+
"""
|
| 115 |
+
# time embedding
|
| 116 |
+
t_emb = self.time_embed(t) ## (b, time_emb_dim)
|
| 117 |
+
t_emb = self.time_embed_fc(t_emb) ## (b, time_emb_dim)
|
| 118 |
+
|
| 119 |
+
# input conv
|
| 120 |
+
x = self.input_layer(x)
|
| 121 |
+
|
| 122 |
+
# downsampling
|
| 123 |
+
skip_x = []
|
| 124 |
+
for ii, down_layer in enumerate(self.down_blocks):
|
| 125 |
+
if isinstance(down_layer, ResidualBlock):
|
| 126 |
+
x = down_layer(x, cond, t_emb)
|
| 127 |
+
skip_x.append(x)
|
| 128 |
+
elif isinstance(down_layer, AttnBlock):
|
| 129 |
+
x = down_layer(x)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError("Wrong layer type in down_blocks")
|
| 132 |
+
|
| 133 |
+
# middle
|
| 134 |
+
x = self.mid_block1(x, cond, t_emb)
|
| 135 |
+
x = self.mid_attn(x)
|
| 136 |
+
x = self.mid_block2(x, cond, t_emb)
|
| 137 |
+
|
| 138 |
+
# upsampling
|
| 139 |
+
for up_layer in self.up_blocks:
|
| 140 |
+
if isinstance(up_layer, ResidualBlock):
|
| 141 |
+
skip_feat = skip_x.pop()
|
| 142 |
+
x = torch.cat([x, skip_feat], dim=1) ## concat along channel dimension
|
| 143 |
+
x = up_layer(x, cond, t_emb)
|
| 144 |
+
elif isinstance(up_layer, AttnBlock):
|
| 145 |
+
x = up_layer(x)
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError("Wrong layer type in up_blocks")
|
| 148 |
+
|
| 149 |
+
# output
|
| 150 |
+
x = self.out_up(x)
|
| 151 |
+
x = self.out_conv(x)
|
| 152 |
+
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
class DDPM(nn.Module):
|
| 156 |
+
def __init__(self, backbone, output_shape, n_steps=1000, min_beta=1e-4, max_beta=0.02, device='cuda'):
|
| 157 |
+
"""
|
| 158 |
+
output_shape: dim(C, H, W)
|
| 159 |
+
"""
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.device = device
|
| 162 |
+
self.backbone_model = backbone
|
| 163 |
+
self.output_shape = output_shape
|
| 164 |
+
|
| 165 |
+
self.n_steps = n_steps
|
| 166 |
+
|
| 167 |
+
## linear betas
|
| 168 |
+
betas = get_named_beta_schedule("linear", n_steps, min_beta, max_beta)
|
| 169 |
+
alphas = 1.0 - betas
|
| 170 |
+
alpha_bars = torch.cumprod(alphas, dim=0)
|
| 171 |
+
|
| 172 |
+
self.register_buffer('betas', betas)
|
| 173 |
+
self.register_buffer('alphas', alphas)
|
| 174 |
+
self.register_buffer('alpha_bars', alpha_bars)
|
| 175 |
+
|
| 176 |
+
def forward(self, x, t, cond):
|
| 177 |
+
"""
|
| 178 |
+
x: (b, in_c, h, w), noisy input (concatenated with some data)
|
| 179 |
+
cond: (b, cond_nc, h, w), conditional input
|
| 180 |
+
t: (b,), time step
|
| 181 |
+
"""
|
| 182 |
+
return self.backbone_model(x, t, cond)
|
| 183 |
+
|
| 184 |
+
@torch.no_grad()
|
| 185 |
+
def add_noise(self, x0, t, eta=None):
|
| 186 |
+
"""
|
| 187 |
+
x0: (b, c, h, w), original data
|
| 188 |
+
t: (b,), time step (0 <= t < n_steps)
|
| 189 |
+
"""
|
| 190 |
+
b, c, h, w = x0.shape
|
| 191 |
+
if eta is None:
|
| 192 |
+
eta = torch.randn(b, c, h, w, device=x0.device)
|
| 193 |
+
|
| 194 |
+
alpha_bar = self.alpha_bars[t]
|
| 195 |
+
noisy_x = alpha_bar.sqrt().reshape(b, 1, 1, 1) * x0 + (1 - alpha_bar).sqrt().reshape(b, 1, 1, 1) * eta
|
| 196 |
+
|
| 197 |
+
return noisy_x
|
| 198 |
+
|
| 199 |
+
def denoise(self, xt, t, cond):
|
| 200 |
+
"""
|
| 201 |
+
xt: (b, in_c, h, w), noisy input (concatenated with some data)
|
| 202 |
+
cond: (b, cond_nc, h, w), conditional input
|
| 203 |
+
t: (b,), time step (0 <= t < n_steps)
|
| 204 |
+
"""
|
| 205 |
+
pred_noise = self(xt, t, cond)
|
| 206 |
+
return pred_noise
|
| 207 |
+
|
| 208 |
+
@torch.no_grad()
|
| 209 |
+
def _build_progress_iter(self, iterable, total, mode: str):
|
| 210 |
+
"""
|
| 211 |
+
Internal helper to create a progress iterator based on verbose mode.
|
| 212 |
+
"""
|
| 213 |
+
mode = (mode or "none").lower()
|
| 214 |
+
if mode == "tqdm":
|
| 215 |
+
try:
|
| 216 |
+
from tqdm import tqdm
|
| 217 |
+
|
| 218 |
+
return tqdm(iterable, total=total, desc="DDPM sampling", leave=False), mode
|
| 219 |
+
except Exception:
|
| 220 |
+
return iterable, "none"
|
| 221 |
+
return iterable, mode
|
| 222 |
+
|
| 223 |
+
@torch.no_grad()
|
| 224 |
+
def sample_ddpm(self, cond, input_cond=None, verbose: str = "none", store_intermediate: bool = False):
|
| 225 |
+
"""
|
| 226 |
+
input_frame: (b, c, h, w) number of input frames (conditional input frames) for the diffusion model
|
| 227 |
+
cond: (b, cond_nc, h, w), conditional input
|
| 228 |
+
verbose: "none", "text", or "tqdm" for progress display
|
| 229 |
+
"""
|
| 230 |
+
## confirm that the model is in eval mode
|
| 231 |
+
self.backbone_model.eval()
|
| 232 |
+
|
| 233 |
+
B, C, H, W = cond.shape
|
| 234 |
+
## get cond device
|
| 235 |
+
device = cond.device
|
| 236 |
+
|
| 237 |
+
x = torch.randn(B, *self.output_shape, device=device)
|
| 238 |
+
|
| 239 |
+
progress_iter_raw = reversed(range(self.n_steps))
|
| 240 |
+
progress_iter, mode = self._build_progress_iter(progress_iter_raw, self.n_steps, verbose)
|
| 241 |
+
use_text = mode == "text"
|
| 242 |
+
|
| 243 |
+
text_interval = max(1, self.n_steps // 10)
|
| 244 |
+
|
| 245 |
+
frames = []
|
| 246 |
+
for idx, t in enumerate(progress_iter):
|
| 247 |
+
time_tensor = (torch.ones(B, device=device) * t).long()
|
| 248 |
+
if input_cond is not None:
|
| 249 |
+
input_ = torch.cat((x, input_cond), dim=1)
|
| 250 |
+
else:
|
| 251 |
+
input_ = x
|
| 252 |
+
|
| 253 |
+
eta_theta = self.denoise(input_, time_tensor, cond)
|
| 254 |
+
|
| 255 |
+
alpha_t = self.alphas[t]
|
| 256 |
+
alpha_t_bar = self.alpha_bars[t]
|
| 257 |
+
|
| 258 |
+
a = 1 / alpha_t.sqrt()
|
| 259 |
+
b = ((1 - alpha_t) / (1 - alpha_t_bar).sqrt()) * eta_theta
|
| 260 |
+
|
| 261 |
+
x = a * (x - b)
|
| 262 |
+
if t > 0:
|
| 263 |
+
z = torch.randn(B, *self.output_shape, device=device)
|
| 264 |
+
beta_t = self.betas[t]
|
| 265 |
+
sigma_t = beta_t.sqrt()
|
| 266 |
+
x = x + sigma_t * z
|
| 267 |
+
|
| 268 |
+
## store intermediate frames for visualization
|
| 269 |
+
if (idx % 50 == 0) or (t == 0):
|
| 270 |
+
out = x.clone()
|
| 271 |
+
out = ((out + 1) / 2).clamp(0, 1)
|
| 272 |
+
out = out.cpu().numpy()
|
| 273 |
+
frames.append(out)
|
| 274 |
+
|
| 275 |
+
if use_text and (idx + 1) % text_interval == 0:
|
| 276 |
+
print(f"DDPM sampling {idx + 1}/{self.n_steps}", flush=True)
|
| 277 |
+
|
| 278 |
+
if mode == "tqdm" and hasattr(progress_iter, "close"):
|
| 279 |
+
progress_iter.close()
|
| 280 |
+
|
| 281 |
+
if store_intermediate:
|
| 282 |
+
return x, frames
|
| 283 |
+
else:
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
@torch.no_grad()
|
| 287 |
+
def sample_ddim(self, cond, input_cond=None, ddim_steps: int = 100, eta: float = 0.2, verbose: str = "none", store_intermediate: bool = False):
|
| 288 |
+
"""
|
| 289 |
+
Deterministic/stochastic DDIM sampling.
|
| 290 |
+
|
| 291 |
+
cond: (b, cond_nc, h, w)
|
| 292 |
+
input_cond: optional conditional input concatenated with the predicted frames
|
| 293 |
+
ddim_steps: number of steps to sample (<= n_steps)
|
| 294 |
+
eta: 0 for deterministic DDIM, >0 adds noise controlled by eta
|
| 295 |
+
verbose: "none", "text", or "tqdm" for progress display
|
| 296 |
+
"""
|
| 297 |
+
self.backbone_model.eval()
|
| 298 |
+
|
| 299 |
+
B, C, H, W = cond.shape
|
| 300 |
+
device = cond.device
|
| 301 |
+
ddim_steps = max(1, min(ddim_steps, self.n_steps))
|
| 302 |
+
|
| 303 |
+
# create evenly spaced timesteps
|
| 304 |
+
ddim_timesteps = torch.linspace(0, self.n_steps - 1, steps=ddim_steps, device=device).long()
|
| 305 |
+
ddim_timesteps = torch.unique(ddim_timesteps, sorted=True) # safety against duplicates
|
| 306 |
+
ddim_t_reverse = list(reversed(ddim_timesteps.tolist()))
|
| 307 |
+
|
| 308 |
+
x = torch.randn(B, *self.output_shape, device=device)
|
| 309 |
+
|
| 310 |
+
progress_iter_raw = enumerate(ddim_t_reverse)
|
| 311 |
+
progress_iter, mode = self._build_progress_iter(progress_iter_raw, len(ddim_t_reverse), verbose)
|
| 312 |
+
use_text = mode == "text"
|
| 313 |
+
text_interval = max(1, len(ddim_t_reverse) // 10)
|
| 314 |
+
|
| 315 |
+
frames = []
|
| 316 |
+
for idx, (iter_idx, t) in enumerate(progress_iter):
|
| 317 |
+
time_tensor = torch.full((B,), t, device=device, dtype=torch.long)
|
| 318 |
+
if input_cond is not None:
|
| 319 |
+
input_ = torch.cat((x, input_cond), dim=1)
|
| 320 |
+
else:
|
| 321 |
+
input_ = x
|
| 322 |
+
|
| 323 |
+
eps = self.denoise(input_, time_tensor, cond)
|
| 324 |
+
|
| 325 |
+
alpha_bar_t = self.alpha_bars[t]
|
| 326 |
+
sqrt_alpha_bar_t = alpha_bar_t.sqrt()
|
| 327 |
+
sqrt_one_minus_alpha_bar_t = (1 - alpha_bar_t).sqrt()
|
| 328 |
+
|
| 329 |
+
x0_pred = (x - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t
|
| 330 |
+
|
| 331 |
+
if iter_idx + 1 < len(ddim_t_reverse):
|
| 332 |
+
t_prev = ddim_t_reverse[iter_idx + 1]
|
| 333 |
+
alpha_bar_prev = self.alpha_bars[t_prev]
|
| 334 |
+
else:
|
| 335 |
+
alpha_bar_prev = torch.ones_like(alpha_bar_t, device=device)
|
| 336 |
+
|
| 337 |
+
sigma_t = 0.0
|
| 338 |
+
if eta > 0 and alpha_bar_prev < 1:
|
| 339 |
+
sigma_t = eta * torch.sqrt(
|
| 340 |
+
(1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev)
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
sigma_t = torch.as_tensor(sigma_t, device=device, dtype=x.dtype)
|
| 344 |
+
noise = torch.randn_like(x) if (eta > 0 and alpha_bar_prev < 1) else torch.zeros_like(x)
|
| 345 |
+
|
| 346 |
+
c_t = torch.sqrt(torch.clamp(1 - alpha_bar_prev - sigma_t ** 2, min=0.0))
|
| 347 |
+
x = (
|
| 348 |
+
alpha_bar_prev.sqrt() * x0_pred
|
| 349 |
+
+ c_t * eps
|
| 350 |
+
+ sigma_t * noise
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
## store intermediate frames for visualization
|
| 354 |
+
if (idx % 25 == 0) or (t == 0):
|
| 355 |
+
out = x.clone()
|
| 356 |
+
out = ((out + 1) / 2).clamp(0, 1)
|
| 357 |
+
out = out.cpu().numpy()
|
| 358 |
+
frames.append(out)
|
| 359 |
+
|
| 360 |
+
if use_text and (idx + 1) % text_interval == 0:
|
| 361 |
+
print(f"DDIM sampling {idx + 1}/{len(ddim_t_reverse)}", flush=True)
|
| 362 |
+
|
| 363 |
+
if mode == "tqdm" and hasattr(progress_iter, "close"):
|
| 364 |
+
progress_iter.close()
|
| 365 |
+
|
| 366 |
+
if store_intermediate:
|
| 367 |
+
return x, frames
|
| 368 |
+
else:
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
# Backward-compatible alias
|
| 372 |
+
sample = sample_ddpm
|
model_architect/layers.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .weight_init import default_init
|
| 7 |
+
|
| 8 |
+
class SPADE(nn.Module):
|
| 9 |
+
def __init__(self, norm_nc, cond_nc, spade_dim=128, param_free_norm_type='group'):
|
| 10 |
+
"""
|
| 11 |
+
SPADE (Spatially Adaptive Normalization) layer.
|
| 12 |
+
norm_nc: number of channels of the normalized feature map
|
| 13 |
+
cond_nc: number of channels of the conditional map
|
| 14 |
+
"""
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
if param_free_norm_type == 'group':
|
| 18 |
+
num_groups = min(norm_nc // 4, 32)
|
| 19 |
+
while(norm_nc % num_groups != 0): # must find another value
|
| 20 |
+
num_groups -= 1
|
| 21 |
+
self.param_free_norm = nn.GroupNorm(num_groups=num_groups, num_channels=norm_nc, affine=False, eps=1e-6)
|
| 22 |
+
elif param_free_norm_type == 'instance':
|
| 23 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
| 24 |
+
elif param_free_norm_type == 'batch':
|
| 25 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError('%s is not a recognized param-free norm type in SPADE'
|
| 28 |
+
% param_free_norm_type)
|
| 29 |
+
|
| 30 |
+
ks = 3
|
| 31 |
+
pw = ks // 2
|
| 32 |
+
self.mlp_shared = nn.Sequential(
|
| 33 |
+
nn.Conv2d(cond_nc, spade_dim, kernel_size=ks, padding=pw),
|
| 34 |
+
nn.ReLU()
|
| 35 |
+
)
|
| 36 |
+
self.mlp_gamma = nn.Conv2d(spade_dim, norm_nc, kernel_size=ks, padding=pw)
|
| 37 |
+
self.mlp_beta = nn.Conv2d(spade_dim, norm_nc, kernel_size=ks, padding=pw)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, cond_map):
|
| 40 |
+
## do param-free normalization (GroupNorm / InstanceNorm / BatchNorm)
|
| 41 |
+
normalized = self.param_free_norm(x)
|
| 42 |
+
|
| 43 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
| 44 |
+
cond_map = F.interpolate(cond_map, size=x.size()[2:], mode='nearest')
|
| 45 |
+
actv = self.mlp_shared(cond_map)
|
| 46 |
+
gamma = self.mlp_gamma(actv)
|
| 47 |
+
beta = self.mlp_beta(actv)
|
| 48 |
+
|
| 49 |
+
# apply scale and bias
|
| 50 |
+
out = normalized * (1 + gamma) + beta
|
| 51 |
+
|
| 52 |
+
return out
|
| 53 |
+
|
| 54 |
+
class ActNorm(nn.Module):
|
| 55 |
+
def __init__(self, emb_dim, out_dim):
|
| 56 |
+
super(ActNorm, self).__init__()
|
| 57 |
+
|
| 58 |
+
## For Time embedding
|
| 59 |
+
chs = 2 * out_dim
|
| 60 |
+
self.fc = nn.Linear(emb_dim, chs)
|
| 61 |
+
self.fc.weight.data = default_init()(self.fc.weight.shape)
|
| 62 |
+
nn.init.zeros_(self.fc.bias)
|
| 63 |
+
|
| 64 |
+
self.activation = nn.SiLU()
|
| 65 |
+
|
| 66 |
+
def forward(self, x, t_emb):
|
| 67 |
+
"""
|
| 68 |
+
x: dim(B, C, H, W) or dim(B, C*N, H, W) if 3D
|
| 69 |
+
t_emb: dim(B, emb_dim)
|
| 70 |
+
"""
|
| 71 |
+
# ada-norm as in https://github.com/openai/guided-diffusion
|
| 72 |
+
emb = self.activation(t_emb)
|
| 73 |
+
emb_out = self.fc(emb)[:, :, None, None] # Linear projection
|
| 74 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
| 75 |
+
|
| 76 |
+
y = x * (1 + scale) + shift
|
| 77 |
+
|
| 78 |
+
return y
|
| 79 |
+
|
| 80 |
+
class Upsample_with_conv(nn.Module):
|
| 81 |
+
def __init__(self, in_c, out_c):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.up = nn.Upsample(scale_factor=2, mode="nearest")
|
| 85 |
+
self.conv = nn.Conv2d(in_c, out_c, 3, padding=1)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
y = self.up(x)
|
| 89 |
+
y = self.conv(y)
|
| 90 |
+
|
| 91 |
+
return y
|
| 92 |
+
|
| 93 |
+
class Downsample_with_conv(nn.Module):
|
| 94 |
+
def __init__(self, in_c, out_c):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.conv = nn.Conv2d(in_c, out_c, 3, stride=2, padding=1)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
y = self.conv(x)
|
| 100 |
+
|
| 101 |
+
return y
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ResidualBlock(nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
in_c,
|
| 108 |
+
out_c,
|
| 109 |
+
cond_nc,
|
| 110 |
+
emb_dim,
|
| 111 |
+
spade_dim=128,
|
| 112 |
+
dropout=0.1,
|
| 113 |
+
param_free_norm_type='group',
|
| 114 |
+
up_flag=False,
|
| 115 |
+
down_flag=False
|
| 116 |
+
):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.in_c = in_c
|
| 119 |
+
self.out_c = out_c
|
| 120 |
+
self.cond_nc = cond_nc
|
| 121 |
+
self.emb_dim = emb_dim
|
| 122 |
+
self.up_flag = up_flag
|
| 123 |
+
self.down_flag = down_flag
|
| 124 |
+
|
| 125 |
+
self.activation = nn.SiLU()
|
| 126 |
+
|
| 127 |
+
## first
|
| 128 |
+
self.spade1 = SPADE(in_c, cond_nc, spade_dim, param_free_norm_type)
|
| 129 |
+
self.act_norm1 = ActNorm(emb_dim, in_c)
|
| 130 |
+
self.conv1 = nn.Conv2d(in_c, in_c, 3, padding=1)
|
| 131 |
+
|
| 132 |
+
## downsampling or upsampling
|
| 133 |
+
if up_flag:
|
| 134 |
+
self.up_or_down_layer = Upsample_with_conv(in_c, out_c)
|
| 135 |
+
self.skip_layer = nn.Upsample(scale_factor=2, mode="nearest")
|
| 136 |
+
elif down_flag:
|
| 137 |
+
self.up_or_down_layer = Downsample_with_conv(in_c, out_c)
|
| 138 |
+
self.skip_layer = nn.AvgPool2d(2)
|
| 139 |
+
else:
|
| 140 |
+
self.conv_no_change = nn.Conv2d(in_c, out_c, 3, padding=1)
|
| 141 |
+
|
| 142 |
+
## second
|
| 143 |
+
self.spade2 = SPADE(out_c, cond_nc, spade_dim, param_free_norm_type)
|
| 144 |
+
self.act_norm2 = ActNorm(emb_dim, out_c)
|
| 145 |
+
self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
|
| 146 |
+
|
| 147 |
+
self.dropout = nn.Dropout(dropout)
|
| 148 |
+
## skip connection
|
| 149 |
+
if in_c != out_c:
|
| 150 |
+
self.conv1x1 = nn.Conv2d(in_c, out_c, 1)
|
| 151 |
+
|
| 152 |
+
def forward(self, x, cond, t_emb):
|
| 153 |
+
"""
|
| 154 |
+
x: dim(B, C, H, W) or dim(B, C*N, H, W) if 3D
|
| 155 |
+
cond: dim(B, cond_nc, H_cond, W_cond)
|
| 156 |
+
t_emb: dim(B, emb_dim)
|
| 157 |
+
"""
|
| 158 |
+
h = x
|
| 159 |
+
## first
|
| 160 |
+
h = self.spade1(h, cond)
|
| 161 |
+
h = self.act_norm1(h, t_emb)
|
| 162 |
+
h = self.activation(h)
|
| 163 |
+
h = self.conv1(h)
|
| 164 |
+
|
| 165 |
+
## up or down
|
| 166 |
+
if self.up_flag or self.down_flag:
|
| 167 |
+
x = self.skip_layer(x)
|
| 168 |
+
h = self.up_or_down_layer(h)
|
| 169 |
+
else:
|
| 170 |
+
h = self.conv_no_change(h)
|
| 171 |
+
|
| 172 |
+
## second
|
| 173 |
+
h = self.spade2(h, cond)
|
| 174 |
+
h = self.act_norm2(h, t_emb)
|
| 175 |
+
h = self.activation(h)
|
| 176 |
+
h = self.dropout(h)
|
| 177 |
+
h = self.conv2(h)
|
| 178 |
+
|
| 179 |
+
## skip connection
|
| 180 |
+
if self.in_c != self.out_c:
|
| 181 |
+
x = self.conv1x1(x)
|
| 182 |
+
|
| 183 |
+
return x + h
|
| 184 |
+
|
| 185 |
+
class AttnBlock(nn.Module):
|
| 186 |
+
def __init__(self, in_channel, n_head=1, norm_groups=32):
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
self.n_head = n_head
|
| 190 |
+
|
| 191 |
+
self.norm = nn.GroupNorm(norm_groups, in_channel)
|
| 192 |
+
self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
|
| 193 |
+
self.output_layer = nn.Conv2d(in_channel, in_channel, 1)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
batch, channel, height, width = x.shape
|
| 197 |
+
n_head = self.n_head
|
| 198 |
+
head_dim = channel // n_head
|
| 199 |
+
|
| 200 |
+
norm = self.norm(x)
|
| 201 |
+
qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, -1)
|
| 202 |
+
query, key, value = qkv.chunk(3, dim=2) # b, n_head, head_dim, h*w
|
| 203 |
+
|
| 204 |
+
attn = torch.einsum(
|
| 205 |
+
"bndL, bndM -> bnLM", query, key
|
| 206 |
+
).contiguous() / math.sqrt(head_dim)
|
| 207 |
+
attn = torch.softmax(attn, -1)
|
| 208 |
+
out = torch.einsum("bnLM, bndM -> bndL", attn, value).contiguous()
|
| 209 |
+
out = out.view(batch, channel, height, width)
|
| 210 |
+
out = self.output_layer(out)
|
| 211 |
+
|
| 212 |
+
return out + x
|
| 213 |
+
|
| 214 |
+
def CropNConcat(x1, x2):
|
| 215 |
+
row_diff = x2.shape[3] - x1.shape[3]
|
| 216 |
+
col_diff = x2.shape[2] - x1.shape[2]
|
| 217 |
+
|
| 218 |
+
x1 = F.pad(x1, [row_diff // 2, row_diff - row_diff // 2,
|
| 219 |
+
col_diff // 2, col_diff - col_diff // 2])
|
| 220 |
+
|
| 221 |
+
out = torch.cat([x1, x2], dim=1)
|
| 222 |
+
|
| 223 |
+
return out
|
model_architect/utils.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 5 |
+
def get_named_beta_schedule(
|
| 6 |
+
schedule_name,
|
| 7 |
+
num_diffusion_timesteps,
|
| 8 |
+
min_beta=1e-4,
|
| 9 |
+
max_beta=0.02,
|
| 10 |
+
s=0.008,
|
| 11 |
+
):
|
| 12 |
+
"""
|
| 13 |
+
Get a pre-defined beta schedule for the given name.
|
| 14 |
+
|
| 15 |
+
The beta schedule library consists of beta schedules which remain similar
|
| 16 |
+
in the limit of num_diffusion_timesteps.
|
| 17 |
+
Beta schedules may be added, but should not be removed or changed once
|
| 18 |
+
they are committed to maintain backwards compatibility.
|
| 19 |
+
"""
|
| 20 |
+
if schedule_name == "linear":
|
| 21 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
| 22 |
+
# diffusion steps.
|
| 23 |
+
#scale = 1000 / num_diffusion_timesteps
|
| 24 |
+
scale = 1.0
|
| 25 |
+
beta_start = scale * min_beta
|
| 26 |
+
beta_end = scale * max_beta
|
| 27 |
+
return torch.linspace(
|
| 28 |
+
beta_start, beta_end, num_diffusion_timesteps,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
elif schedule_name == "cosine":
|
| 32 |
+
return betas_for_alpha_bar(
|
| 33 |
+
num_diffusion_timesteps,
|
| 34 |
+
lambda t: math.cos((t + s) / (1 + s) * math.pi / 2) ** 2,
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
| 38 |
+
|
| 39 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 40 |
+
"""
|
| 41 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 42 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 43 |
+
|
| 44 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 45 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 46 |
+
produces the cumulative product of (1-beta) up to that
|
| 47 |
+
part of the diffusion process.
|
| 48 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 49 |
+
prevent singularities.
|
| 50 |
+
"""
|
| 51 |
+
betas = []
|
| 52 |
+
for i in range(num_diffusion_timesteps):
|
| 53 |
+
t1 = i / num_diffusion_timesteps
|
| 54 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 55 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 56 |
+
return torch.tensor(betas, dtype=torch.float32)
|
model_architect/weight_init.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def variance_scaling(scale, mode, distribution,
|
| 5 |
+
in_axis=1, out_axis=0,
|
| 6 |
+
dtype=torch.float32,
|
| 7 |
+
device='cpu'):
|
| 8 |
+
"""Ported from JAX. """
|
| 9 |
+
|
| 10 |
+
def _compute_fans(shape, in_axis=1, out_axis=0):
|
| 11 |
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
| 12 |
+
fan_in = shape[in_axis] * receptive_field_size
|
| 13 |
+
fan_out = shape[out_axis] * receptive_field_size
|
| 14 |
+
return fan_in, fan_out
|
| 15 |
+
|
| 16 |
+
def init(shape, dtype=dtype, device=device):
|
| 17 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
| 18 |
+
if mode == "fan_in":
|
| 19 |
+
denominator = fan_in
|
| 20 |
+
elif mode == "fan_out":
|
| 21 |
+
denominator = fan_out
|
| 22 |
+
elif mode == "fan_avg":
|
| 23 |
+
denominator = (fan_in + fan_out) / 2
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"invalid mode for variance scaling initializer: {}".format(mode))
|
| 27 |
+
variance = scale / denominator
|
| 28 |
+
if distribution == "normal":
|
| 29 |
+
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
| 30 |
+
elif distribution == "uniform":
|
| 31 |
+
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("invalid distribution for variance scaling initializer")
|
| 34 |
+
|
| 35 |
+
return init
|
| 36 |
+
|
| 37 |
+
def default_init(scale=1.):
|
| 38 |
+
"""The same initialization used in DDPM."""
|
| 39 |
+
scale = 1e-10 if scale == 0 else scale
|
| 40 |
+
return variance_scaling(scale, 'fan_avg', 'uniform')
|
model_weights/ft06_01hr/weights.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d1664224001817cb35f970ac9ff06d7e6ea66b0152385e849c6d7d1bf6bd01f
|
| 3 |
+
size 231196876
|
model_weights/ft36_06hr/weights.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c7e10ccc11f79598d99bd77ed52a6c02c4ec0bbd567e0f423bafa9ff887622f
|
| 3 |
+
size 326923212
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch==2.4.0
|
| 3 |
+
einops==0.8.0
|
sample_data/sample_202504131100.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6117356ac780a530645e192cc85d647b103c915b63433441d810dedc7cdd4ec1
|
| 3 |
+
size 33002900
|
sample_data/sample_202504161200.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7b9e6c7ed76f695f7c6f2f1a976f4c27128120e5c4328223809c27dc8feee52
|
| 3 |
+
size 33300209
|
sample_data/sample_202507151200.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e20eb69ea2bef6bb0074c376afa7e8398e7da8eb1edd9a1f11c343ffe711a299
|
| 3 |
+
size 33038261
|