Jason-thingnario commited on
Commit
be89dda
·
1 Parent(s): 26b9f7f

upload DDPM inference script

Browse files
.gitattributes CHANGED
@@ -4,6 +4,7 @@
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
 
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gif filter=lfs diff=lfs merge=lfs -text
8
  *.gz filter=lfs diff=lfs merge=lfs -text
9
  *.h5 filter=lfs diff=lfs merge=lfs -text
10
  *.joblib filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - solar-radiation
5
+ - deep-learning
6
+ - nowcasting
7
+ - ddpm
8
+ - MCVD
9
+ ---
10
+
11
+ # DDPM Solar Radiation model
12
+ A deep learning model for solar radiation nowcasting using modified [MCVD](https://arxiv.org/pdf/2205.09853) model, a kind of DDPM model for video generation. The model predicts clearsky index and converts it to solar radiation for up to 6 or 36 time steps ahead.
13
+
14
+ ![Solar Prediction Example](docs/BT202504010900-ddpm.gif)
15
+
16
+ ## Overview
17
+
18
+ This repository contains two trained models (1hr & 6hr) for solar radiation forecasting:
19
+ - **1hr DDPM Model**: Predicts solar radiation up to 1 hour ahead (6 time steps)
20
+ - **6hr DDPM Model**: Predicts solar radiation up to 6 hours ahead (36 time steps).
21
+
22
+ The model uses multiple input sources:
23
+ - **Himawari satellite data**: Clearsky index calculated from Himawari satellite data
24
+ - **WRF Prediction**: Clearsky index from WRF's solar irradiation prediction
25
+ - **Topography**: Static topographical features
26
+
27
+ ## Installation
28
+
29
+ 1. Clone the repository & install Git LFS:
30
+ ```bash
31
+ git lfs install
32
+ git clone <repository-url>
33
+ cd Diffusion_SolRad
34
+ git lfs pull
35
+ git lfs ls-files # confirm whether models weights & sample data are downloaded
36
+ ```
37
+
38
+ 2. Install dependencies:
39
+ ```bash
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+ ## Requirements
44
+
45
+ - Python 3.x
46
+ - PyTorch 2.4.0
47
+ - NumPy 1.26.4
48
+ - einops 0.8.0
49
+
50
+ ## Usage
51
+
52
+ ### Basic Inference
53
+
54
+ Run solar radiation prediction using the pre-trained models:
55
+
56
+ ```bash
57
+ python inference.py --pred-hr [1hr/6hr] --pred-mode [DDPM/DDIM] --basetime 202504131100
58
+ ```
59
+
60
+ ### Command Line Arguments
61
+
62
+ - `pred-mode`: Choose between `DDPM` or `DDIM` sampling methods (default: `DDPM`)
63
+ - `pred-hr`: Choose between `1hr` or `6hr` prediction models (default: `1hr`)
64
+ - `--basetime`: Timestamp for input data in format YYYYMMDDHHMM (default: `202504131100`)
65
+
66
+ ### Example
67
+
68
+ ```bash
69
+ # DDIM sampling method for 1-hour prediction
70
+ python inference.py --pred-hr 1hr --pred-mode DDIM --basetime 202507151200
71
+ ```
72
+
73
+ ## Sample Data
74
+
75
+ The repository includes sample data files:
76
+ - `sample_202504131100.npz`
77
+ - `sample_202504161200.npz`
78
+ - `sample_202507151200.npz`
79
+
80
+ ## Model Weights
81
+
82
+ Pre-trained weights are available for both models:
83
+ - `model_weights/ft06_01hr/weights.ckpt`
84
+ - `model_weights/ft36_06hr/weights.ckpt`
85
+
86
+ ## License
87
+
88
+ This project is released under the MIT License.
docs/BT202504010900-ddpm.gif ADDED

Git LFS Details

  • SHA256: 6bee19570264f935772aad9c9f9f6a4624ee95a0e99277c2461efb72846ca42f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.97 MB
inference.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import List, Sequence
6
+ import sys
7
+ from datetime import datetime, timedelta
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from model_architect.UNet_DDPM import UNet_with_time, DDPM
14
+
15
+ @dataclass
16
+ class Config:
17
+ input_frame: int = 12
18
+ output_frame: int = 6
19
+ cond_nc: int = 5
20
+ time_emb_dim: int = 128
21
+ base_chs: int = 32
22
+ chs_mult: tuple = (1, 2, 4, 8, 8) ## different resolution
23
+ use_attn_list: tuple = (0, 0, 1, 1, 1) # 0 means no attention, 1 means use attention
24
+ n_res_blocks: int = 2
25
+ n_steps: int = 1000
26
+ dropout: float = 0.1
27
+
28
+ def data_loading(BASETIME, device):
29
+ data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz')
30
+
31
+ inputs = {}
32
+ for key in data_npz:
33
+ inputs[key] = torch.from_numpy(data_npz[key]).to(device)
34
+
35
+ return inputs
36
+
37
+ def arg_parse():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument(
40
+ '--pred-hr',
41
+ type=str,
42
+ default='1hr',
43
+ choices=[
44
+ '1hr',
45
+ '6hr'
46
+ ]
47
+ )
48
+ parser.add_argument(
49
+ '--pred-mode',
50
+ type=str,
51
+ default='DDPM',
52
+ choices=[
53
+ 'DDPM',
54
+ 'DDIM'
55
+ ]
56
+ )
57
+ parser.add_argument('--basetime', type=str, default='202504131100')
58
+ args = parser.parse_args()
59
+ return args
60
+
61
+ if __name__ == "__main__":
62
+ config = Config()
63
+ args = arg_parse()
64
+ pred_hr = args.pred_hr
65
+ pred_mode = args.pred_mode
66
+
67
+ BASETIME = args.basetime
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+
70
+ inputs = data_loading(BASETIME, device)
71
+ model_config = Config()
72
+ if pred_hr == '6hr':
73
+ model_config.input_frame = 72
74
+ model_config.output_frame = 36
75
+ print("Prediction mode:", pred_mode)
76
+ print("Prediction horizon:", pred_hr)
77
+
78
+ ## preporcess inputs for DDPM model
79
+ ## concat previous Himawari and topo as conditional input (B, 5, 512, 512)
80
+ ## WRF dim: (B, 36, 512, 512). 1hr: (B, 6, 512, 512), 6hr: (B, 36, 512, 512)
81
+ prev_himawari = inputs['Himawari'].squeeze(2)
82
+ topo = inputs['topo']
83
+ input_ = torch.cat([prev_himawari, topo], dim=1)
84
+ WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear')
85
+
86
+ clearsky = inputs['clearsky']
87
+ if pred_hr == '1hr':
88
+ WRF = WRF[:, :6]
89
+ clearsky = clearsky[:, :6]
90
+
91
+ backbone = UNet_with_time(model_config)
92
+ model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512))
93
+
94
+ ## load model weights
95
+ if pred_hr == '1hr':
96
+ ckpt_path = './model_weights/ft06_01hr/weights.ckpt'
97
+ elif pred_hr == '6hr':
98
+ ckpt_path = './model_weights/ft36_06hr/weights.ckpt'
99
+
100
+ ckpt = torch.load(ckpt_path, weights_only=True)
101
+ model.load_state_dict(ckpt['state_dict'])
102
+ model.eval()
103
+ model = model.to(device)
104
+
105
+ if pred_mode == 'DDPM':
106
+ pred_clr_idx = model.sample_ddpm(
107
+ input_,
108
+ input_cond=WRF,
109
+ verbose="text"
110
+ )
111
+ elif pred_mode == 'DDIM':
112
+ pred_clr_idx = model.sample_ddim(
113
+ input_,
114
+ input_cond=WRF,
115
+ ddim_steps=100,
116
+ verbose="text"
117
+ )
118
+
119
+ pred_clr_idx = (pred_clr_idx + 1.0) / 2.0
120
+ pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0)
121
+
122
+ ## transform clearsky index to solar radiation
123
+ pred_srad = pred_clr_idx * clearsky
124
+
125
+ ## save prediction
126
+ np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy())
127
+ print('Done')
model_architect/UNet_DDPM.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .layers import ResidualBlock, AttnBlock
6
+ from .utils import get_named_beta_schedule
7
+
8
+ def sinusoidal_embedding(n, d):
9
+ """
10
+ n: iteration steps,
11
+ d: time embedding dimension
12
+ """
13
+ # Returns the standard positional embedding
14
+ embedding = torch.tensor([[i / 10000 ** (2 * j / d) for j in range(d)] for i in range(n)])
15
+ sin_mask = torch.arange(0, n, 2)
16
+
17
+ embedding[sin_mask] = torch.sin(embedding[sin_mask])
18
+ embedding[1 - sin_mask] = torch.cos(embedding[sin_mask])
19
+
20
+ return embedding
21
+
22
+ def _make_te(dim_in, dim_out):
23
+ return nn.Sequential(
24
+ nn.Linear(dim_in, dim_out),
25
+ nn.SiLU(),
26
+ nn.Linear(dim_out, dim_out)
27
+ )
28
+
29
+ class UNet_with_time(nn.Module):
30
+ def __init__(self, config):
31
+ super().__init__()
32
+ self.config = config
33
+ input_frame = config.input_frame
34
+ output_frame = config.output_frame
35
+ n_steps = config.n_steps
36
+ time_emb_dim = config.time_emb_dim
37
+ cond_nc = config.cond_nc
38
+ chs_mult = config.chs_mult ## e.g. (1, 2, 4, 8)
39
+ n_res_blocks = config.n_res_blocks
40
+ base_chs = config.base_chs
41
+ ## e.g. (0, 0, 1, 1) -> 0 means no attention
42
+ use_attn_list = config.use_attn_list
43
+
44
+ layer_depth = len(chs_mult)
45
+ assert len(use_attn_list) == layer_depth, "length of use_attn_list should be the same as chs_mult"
46
+ assert input_frame >= output_frame, "input_frame should be larger than or equal to output_frame"
47
+
48
+ self.filter_list = [base_chs * m for m in chs_mult]
49
+
50
+ ## time embedding
51
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
52
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
53
+ self.time_embed.requires_grad_(False)
54
+ self.time_embed_fc = _make_te(time_emb_dim, time_emb_dim)
55
+ ## end of time embedding
56
+
57
+ ## input conv
58
+ self.input_layer = nn.PixelUnshuffle(downscale_factor=2)
59
+
60
+ ## downsampling
61
+ self.down_blocks = nn.ModuleList()
62
+ in_c = input_frame * 4 ## after pixel unshuffle
63
+ for i in range(layer_depth):
64
+ out_c = self.filter_list[i]
65
+
66
+ for _ in range(n_res_blocks):
67
+ self.down_blocks.append(
68
+ ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
69
+ )
70
+
71
+ if use_attn_list[i]:
72
+ self.down_blocks.append(AttnBlock(in_c, 4)) ## num_head=4
73
+
74
+ self.down_blocks.append(
75
+ ResidualBlock(in_c, out_c, cond_nc, time_emb_dim, down_flag=True, up_flag=False)
76
+ )
77
+ in_c = out_c
78
+ ## end of downsampling
79
+
80
+ ## middle
81
+ self.mid_block1 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
82
+ self.mid_attn = AttnBlock(in_c, 4)
83
+ self.mid_block2 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
84
+ ## end of middle
85
+
86
+ ## upsampling
87
+ self.up_blocks = nn.ModuleList()
88
+ self.filter_list = [input_frame * 4] + self.filter_list[:-1]
89
+ for i in reversed(range(layer_depth)): ## i = layer_depth-1, ..., 0
90
+ out_c = self.filter_list[i]
91
+
92
+ self.up_blocks.append(
93
+ ResidualBlock(in_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=True)
94
+ )
95
+ if use_attn_list[i]:
96
+ self.up_blocks.append(AttnBlock(out_c)) ## num_head=1
97
+
98
+ for _ in range(n_res_blocks):
99
+ self.up_blocks.append(
100
+ ResidualBlock(out_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
101
+ )
102
+
103
+ in_c = out_c
104
+
105
+ ## end of upsampling
106
+ self.out_up = nn.PixelShuffle(upscale_factor=2)
107
+ self.out_conv = nn.Conv2d(input_frame, output_frame, 3, padding=1)
108
+
109
+ def forward(self, x, t, cond):
110
+ """
111
+ x: (b, in_c, h, w), noisy input (concatenated with some data)
112
+ t: (b,), time step
113
+ cond: (b, cond_nc, h, w), conditional input
114
+ """
115
+ # time embedding
116
+ t_emb = self.time_embed(t) ## (b, time_emb_dim)
117
+ t_emb = self.time_embed_fc(t_emb) ## (b, time_emb_dim)
118
+
119
+ # input conv
120
+ x = self.input_layer(x)
121
+
122
+ # downsampling
123
+ skip_x = []
124
+ for ii, down_layer in enumerate(self.down_blocks):
125
+ if isinstance(down_layer, ResidualBlock):
126
+ x = down_layer(x, cond, t_emb)
127
+ skip_x.append(x)
128
+ elif isinstance(down_layer, AttnBlock):
129
+ x = down_layer(x)
130
+ else:
131
+ raise ValueError("Wrong layer type in down_blocks")
132
+
133
+ # middle
134
+ x = self.mid_block1(x, cond, t_emb)
135
+ x = self.mid_attn(x)
136
+ x = self.mid_block2(x, cond, t_emb)
137
+
138
+ # upsampling
139
+ for up_layer in self.up_blocks:
140
+ if isinstance(up_layer, ResidualBlock):
141
+ skip_feat = skip_x.pop()
142
+ x = torch.cat([x, skip_feat], dim=1) ## concat along channel dimension
143
+ x = up_layer(x, cond, t_emb)
144
+ elif isinstance(up_layer, AttnBlock):
145
+ x = up_layer(x)
146
+ else:
147
+ raise ValueError("Wrong layer type in up_blocks")
148
+
149
+ # output
150
+ x = self.out_up(x)
151
+ x = self.out_conv(x)
152
+
153
+ return x
154
+
155
+ class DDPM(nn.Module):
156
+ def __init__(self, backbone, output_shape, n_steps=1000, min_beta=1e-4, max_beta=0.02, device='cuda'):
157
+ """
158
+ output_shape: dim(C, H, W)
159
+ """
160
+ super().__init__()
161
+ self.device = device
162
+ self.backbone_model = backbone
163
+ self.output_shape = output_shape
164
+
165
+ self.n_steps = n_steps
166
+
167
+ ## linear betas
168
+ betas = get_named_beta_schedule("linear", n_steps, min_beta, max_beta)
169
+ alphas = 1.0 - betas
170
+ alpha_bars = torch.cumprod(alphas, dim=0)
171
+
172
+ self.register_buffer('betas', betas)
173
+ self.register_buffer('alphas', alphas)
174
+ self.register_buffer('alpha_bars', alpha_bars)
175
+
176
+ def forward(self, x, t, cond):
177
+ """
178
+ x: (b, in_c, h, w), noisy input (concatenated with some data)
179
+ cond: (b, cond_nc, h, w), conditional input
180
+ t: (b,), time step
181
+ """
182
+ return self.backbone_model(x, t, cond)
183
+
184
+ @torch.no_grad()
185
+ def add_noise(self, x0, t, eta=None):
186
+ """
187
+ x0: (b, c, h, w), original data
188
+ t: (b,), time step (0 <= t < n_steps)
189
+ """
190
+ b, c, h, w = x0.shape
191
+ if eta is None:
192
+ eta = torch.randn(b, c, h, w, device=x0.device)
193
+
194
+ alpha_bar = self.alpha_bars[t]
195
+ noisy_x = alpha_bar.sqrt().reshape(b, 1, 1, 1) * x0 + (1 - alpha_bar).sqrt().reshape(b, 1, 1, 1) * eta
196
+
197
+ return noisy_x
198
+
199
+ def denoise(self, xt, t, cond):
200
+ """
201
+ xt: (b, in_c, h, w), noisy input (concatenated with some data)
202
+ cond: (b, cond_nc, h, w), conditional input
203
+ t: (b,), time step (0 <= t < n_steps)
204
+ """
205
+ pred_noise = self(xt, t, cond)
206
+ return pred_noise
207
+
208
+ @torch.no_grad()
209
+ def _build_progress_iter(self, iterable, total, mode: str):
210
+ """
211
+ Internal helper to create a progress iterator based on verbose mode.
212
+ """
213
+ mode = (mode or "none").lower()
214
+ if mode == "tqdm":
215
+ try:
216
+ from tqdm import tqdm
217
+
218
+ return tqdm(iterable, total=total, desc="DDPM sampling", leave=False), mode
219
+ except Exception:
220
+ return iterable, "none"
221
+ return iterable, mode
222
+
223
+ @torch.no_grad()
224
+ def sample_ddpm(self, cond, input_cond=None, verbose: str = "none", store_intermediate: bool = False):
225
+ """
226
+ input_frame: (b, c, h, w) number of input frames (conditional input frames) for the diffusion model
227
+ cond: (b, cond_nc, h, w), conditional input
228
+ verbose: "none", "text", or "tqdm" for progress display
229
+ """
230
+ ## confirm that the model is in eval mode
231
+ self.backbone_model.eval()
232
+
233
+ B, C, H, W = cond.shape
234
+ ## get cond device
235
+ device = cond.device
236
+
237
+ x = torch.randn(B, *self.output_shape, device=device)
238
+
239
+ progress_iter_raw = reversed(range(self.n_steps))
240
+ progress_iter, mode = self._build_progress_iter(progress_iter_raw, self.n_steps, verbose)
241
+ use_text = mode == "text"
242
+
243
+ text_interval = max(1, self.n_steps // 10)
244
+
245
+ frames = []
246
+ for idx, t in enumerate(progress_iter):
247
+ time_tensor = (torch.ones(B, device=device) * t).long()
248
+ if input_cond is not None:
249
+ input_ = torch.cat((x, input_cond), dim=1)
250
+ else:
251
+ input_ = x
252
+
253
+ eta_theta = self.denoise(input_, time_tensor, cond)
254
+
255
+ alpha_t = self.alphas[t]
256
+ alpha_t_bar = self.alpha_bars[t]
257
+
258
+ a = 1 / alpha_t.sqrt()
259
+ b = ((1 - alpha_t) / (1 - alpha_t_bar).sqrt()) * eta_theta
260
+
261
+ x = a * (x - b)
262
+ if t > 0:
263
+ z = torch.randn(B, *self.output_shape, device=device)
264
+ beta_t = self.betas[t]
265
+ sigma_t = beta_t.sqrt()
266
+ x = x + sigma_t * z
267
+
268
+ ## store intermediate frames for visualization
269
+ if (idx % 50 == 0) or (t == 0):
270
+ out = x.clone()
271
+ out = ((out + 1) / 2).clamp(0, 1)
272
+ out = out.cpu().numpy()
273
+ frames.append(out)
274
+
275
+ if use_text and (idx + 1) % text_interval == 0:
276
+ print(f"DDPM sampling {idx + 1}/{self.n_steps}", flush=True)
277
+
278
+ if mode == "tqdm" and hasattr(progress_iter, "close"):
279
+ progress_iter.close()
280
+
281
+ if store_intermediate:
282
+ return x, frames
283
+ else:
284
+ return x
285
+
286
+ @torch.no_grad()
287
+ def sample_ddim(self, cond, input_cond=None, ddim_steps: int = 100, eta: float = 0.2, verbose: str = "none", store_intermediate: bool = False):
288
+ """
289
+ Deterministic/stochastic DDIM sampling.
290
+
291
+ cond: (b, cond_nc, h, w)
292
+ input_cond: optional conditional input concatenated with the predicted frames
293
+ ddim_steps: number of steps to sample (<= n_steps)
294
+ eta: 0 for deterministic DDIM, >0 adds noise controlled by eta
295
+ verbose: "none", "text", or "tqdm" for progress display
296
+ """
297
+ self.backbone_model.eval()
298
+
299
+ B, C, H, W = cond.shape
300
+ device = cond.device
301
+ ddim_steps = max(1, min(ddim_steps, self.n_steps))
302
+
303
+ # create evenly spaced timesteps
304
+ ddim_timesteps = torch.linspace(0, self.n_steps - 1, steps=ddim_steps, device=device).long()
305
+ ddim_timesteps = torch.unique(ddim_timesteps, sorted=True) # safety against duplicates
306
+ ddim_t_reverse = list(reversed(ddim_timesteps.tolist()))
307
+
308
+ x = torch.randn(B, *self.output_shape, device=device)
309
+
310
+ progress_iter_raw = enumerate(ddim_t_reverse)
311
+ progress_iter, mode = self._build_progress_iter(progress_iter_raw, len(ddim_t_reverse), verbose)
312
+ use_text = mode == "text"
313
+ text_interval = max(1, len(ddim_t_reverse) // 10)
314
+
315
+ frames = []
316
+ for idx, (iter_idx, t) in enumerate(progress_iter):
317
+ time_tensor = torch.full((B,), t, device=device, dtype=torch.long)
318
+ if input_cond is not None:
319
+ input_ = torch.cat((x, input_cond), dim=1)
320
+ else:
321
+ input_ = x
322
+
323
+ eps = self.denoise(input_, time_tensor, cond)
324
+
325
+ alpha_bar_t = self.alpha_bars[t]
326
+ sqrt_alpha_bar_t = alpha_bar_t.sqrt()
327
+ sqrt_one_minus_alpha_bar_t = (1 - alpha_bar_t).sqrt()
328
+
329
+ x0_pred = (x - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t
330
+
331
+ if iter_idx + 1 < len(ddim_t_reverse):
332
+ t_prev = ddim_t_reverse[iter_idx + 1]
333
+ alpha_bar_prev = self.alpha_bars[t_prev]
334
+ else:
335
+ alpha_bar_prev = torch.ones_like(alpha_bar_t, device=device)
336
+
337
+ sigma_t = 0.0
338
+ if eta > 0 and alpha_bar_prev < 1:
339
+ sigma_t = eta * torch.sqrt(
340
+ (1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev)
341
+ )
342
+
343
+ sigma_t = torch.as_tensor(sigma_t, device=device, dtype=x.dtype)
344
+ noise = torch.randn_like(x) if (eta > 0 and alpha_bar_prev < 1) else torch.zeros_like(x)
345
+
346
+ c_t = torch.sqrt(torch.clamp(1 - alpha_bar_prev - sigma_t ** 2, min=0.0))
347
+ x = (
348
+ alpha_bar_prev.sqrt() * x0_pred
349
+ + c_t * eps
350
+ + sigma_t * noise
351
+ )
352
+
353
+ ## store intermediate frames for visualization
354
+ if (idx % 25 == 0) or (t == 0):
355
+ out = x.clone()
356
+ out = ((out + 1) / 2).clamp(0, 1)
357
+ out = out.cpu().numpy()
358
+ frames.append(out)
359
+
360
+ if use_text and (idx + 1) % text_interval == 0:
361
+ print(f"DDIM sampling {idx + 1}/{len(ddim_t_reverse)}", flush=True)
362
+
363
+ if mode == "tqdm" and hasattr(progress_iter, "close"):
364
+ progress_iter.close()
365
+
366
+ if store_intermediate:
367
+ return x, frames
368
+ else:
369
+ return x
370
+
371
+ # Backward-compatible alias
372
+ sample = sample_ddpm
model_architect/layers.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .weight_init import default_init
7
+
8
+ class SPADE(nn.Module):
9
+ def __init__(self, norm_nc, cond_nc, spade_dim=128, param_free_norm_type='group'):
10
+ """
11
+ SPADE (Spatially Adaptive Normalization) layer.
12
+ norm_nc: number of channels of the normalized feature map
13
+ cond_nc: number of channels of the conditional map
14
+ """
15
+ super().__init__()
16
+
17
+ if param_free_norm_type == 'group':
18
+ num_groups = min(norm_nc // 4, 32)
19
+ while(norm_nc % num_groups != 0): # must find another value
20
+ num_groups -= 1
21
+ self.param_free_norm = nn.GroupNorm(num_groups=num_groups, num_channels=norm_nc, affine=False, eps=1e-6)
22
+ elif param_free_norm_type == 'instance':
23
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
24
+ elif param_free_norm_type == 'batch':
25
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
26
+ else:
27
+ raise ValueError('%s is not a recognized param-free norm type in SPADE'
28
+ % param_free_norm_type)
29
+
30
+ ks = 3
31
+ pw = ks // 2
32
+ self.mlp_shared = nn.Sequential(
33
+ nn.Conv2d(cond_nc, spade_dim, kernel_size=ks, padding=pw),
34
+ nn.ReLU()
35
+ )
36
+ self.mlp_gamma = nn.Conv2d(spade_dim, norm_nc, kernel_size=ks, padding=pw)
37
+ self.mlp_beta = nn.Conv2d(spade_dim, norm_nc, kernel_size=ks, padding=pw)
38
+
39
+ def forward(self, x, cond_map):
40
+ ## do param-free normalization (GroupNorm / InstanceNorm / BatchNorm)
41
+ normalized = self.param_free_norm(x)
42
+
43
+ # Part 2. produce scaling and bias conditioned on semantic map
44
+ cond_map = F.interpolate(cond_map, size=x.size()[2:], mode='nearest')
45
+ actv = self.mlp_shared(cond_map)
46
+ gamma = self.mlp_gamma(actv)
47
+ beta = self.mlp_beta(actv)
48
+
49
+ # apply scale and bias
50
+ out = normalized * (1 + gamma) + beta
51
+
52
+ return out
53
+
54
+ class ActNorm(nn.Module):
55
+ def __init__(self, emb_dim, out_dim):
56
+ super(ActNorm, self).__init__()
57
+
58
+ ## For Time embedding
59
+ chs = 2 * out_dim
60
+ self.fc = nn.Linear(emb_dim, chs)
61
+ self.fc.weight.data = default_init()(self.fc.weight.shape)
62
+ nn.init.zeros_(self.fc.bias)
63
+
64
+ self.activation = nn.SiLU()
65
+
66
+ def forward(self, x, t_emb):
67
+ """
68
+ x: dim(B, C, H, W) or dim(B, C*N, H, W) if 3D
69
+ t_emb: dim(B, emb_dim)
70
+ """
71
+ # ada-norm as in https://github.com/openai/guided-diffusion
72
+ emb = self.activation(t_emb)
73
+ emb_out = self.fc(emb)[:, :, None, None] # Linear projection
74
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
75
+
76
+ y = x * (1 + scale) + shift
77
+
78
+ return y
79
+
80
+ class Upsample_with_conv(nn.Module):
81
+ def __init__(self, in_c, out_c):
82
+ super().__init__()
83
+
84
+ self.up = nn.Upsample(scale_factor=2, mode="nearest")
85
+ self.conv = nn.Conv2d(in_c, out_c, 3, padding=1)
86
+
87
+ def forward(self, x):
88
+ y = self.up(x)
89
+ y = self.conv(y)
90
+
91
+ return y
92
+
93
+ class Downsample_with_conv(nn.Module):
94
+ def __init__(self, in_c, out_c):
95
+ super().__init__()
96
+ self.conv = nn.Conv2d(in_c, out_c, 3, stride=2, padding=1)
97
+
98
+ def forward(self, x):
99
+ y = self.conv(x)
100
+
101
+ return y
102
+
103
+
104
+ class ResidualBlock(nn.Module):
105
+ def __init__(
106
+ self,
107
+ in_c,
108
+ out_c,
109
+ cond_nc,
110
+ emb_dim,
111
+ spade_dim=128,
112
+ dropout=0.1,
113
+ param_free_norm_type='group',
114
+ up_flag=False,
115
+ down_flag=False
116
+ ):
117
+ super().__init__()
118
+ self.in_c = in_c
119
+ self.out_c = out_c
120
+ self.cond_nc = cond_nc
121
+ self.emb_dim = emb_dim
122
+ self.up_flag = up_flag
123
+ self.down_flag = down_flag
124
+
125
+ self.activation = nn.SiLU()
126
+
127
+ ## first
128
+ self.spade1 = SPADE(in_c, cond_nc, spade_dim, param_free_norm_type)
129
+ self.act_norm1 = ActNorm(emb_dim, in_c)
130
+ self.conv1 = nn.Conv2d(in_c, in_c, 3, padding=1)
131
+
132
+ ## downsampling or upsampling
133
+ if up_flag:
134
+ self.up_or_down_layer = Upsample_with_conv(in_c, out_c)
135
+ self.skip_layer = nn.Upsample(scale_factor=2, mode="nearest")
136
+ elif down_flag:
137
+ self.up_or_down_layer = Downsample_with_conv(in_c, out_c)
138
+ self.skip_layer = nn.AvgPool2d(2)
139
+ else:
140
+ self.conv_no_change = nn.Conv2d(in_c, out_c, 3, padding=1)
141
+
142
+ ## second
143
+ self.spade2 = SPADE(out_c, cond_nc, spade_dim, param_free_norm_type)
144
+ self.act_norm2 = ActNorm(emb_dim, out_c)
145
+ self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
146
+
147
+ self.dropout = nn.Dropout(dropout)
148
+ ## skip connection
149
+ if in_c != out_c:
150
+ self.conv1x1 = nn.Conv2d(in_c, out_c, 1)
151
+
152
+ def forward(self, x, cond, t_emb):
153
+ """
154
+ x: dim(B, C, H, W) or dim(B, C*N, H, W) if 3D
155
+ cond: dim(B, cond_nc, H_cond, W_cond)
156
+ t_emb: dim(B, emb_dim)
157
+ """
158
+ h = x
159
+ ## first
160
+ h = self.spade1(h, cond)
161
+ h = self.act_norm1(h, t_emb)
162
+ h = self.activation(h)
163
+ h = self.conv1(h)
164
+
165
+ ## up or down
166
+ if self.up_flag or self.down_flag:
167
+ x = self.skip_layer(x)
168
+ h = self.up_or_down_layer(h)
169
+ else:
170
+ h = self.conv_no_change(h)
171
+
172
+ ## second
173
+ h = self.spade2(h, cond)
174
+ h = self.act_norm2(h, t_emb)
175
+ h = self.activation(h)
176
+ h = self.dropout(h)
177
+ h = self.conv2(h)
178
+
179
+ ## skip connection
180
+ if self.in_c != self.out_c:
181
+ x = self.conv1x1(x)
182
+
183
+ return x + h
184
+
185
+ class AttnBlock(nn.Module):
186
+ def __init__(self, in_channel, n_head=1, norm_groups=32):
187
+ super().__init__()
188
+
189
+ self.n_head = n_head
190
+
191
+ self.norm = nn.GroupNorm(norm_groups, in_channel)
192
+ self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
193
+ self.output_layer = nn.Conv2d(in_channel, in_channel, 1)
194
+
195
+ def forward(self, x):
196
+ batch, channel, height, width = x.shape
197
+ n_head = self.n_head
198
+ head_dim = channel // n_head
199
+
200
+ norm = self.norm(x)
201
+ qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, -1)
202
+ query, key, value = qkv.chunk(3, dim=2) # b, n_head, head_dim, h*w
203
+
204
+ attn = torch.einsum(
205
+ "bndL, bndM -> bnLM", query, key
206
+ ).contiguous() / math.sqrt(head_dim)
207
+ attn = torch.softmax(attn, -1)
208
+ out = torch.einsum("bnLM, bndM -> bndL", attn, value).contiguous()
209
+ out = out.view(batch, channel, height, width)
210
+ out = self.output_layer(out)
211
+
212
+ return out + x
213
+
214
+ def CropNConcat(x1, x2):
215
+ row_diff = x2.shape[3] - x1.shape[3]
216
+ col_diff = x2.shape[2] - x1.shape[2]
217
+
218
+ x1 = F.pad(x1, [row_diff // 2, row_diff - row_diff // 2,
219
+ col_diff // 2, col_diff - col_diff // 2])
220
+
221
+ out = torch.cat([x1, x2], dim=1)
222
+
223
+ return out
model_architect/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ def get_named_beta_schedule(
6
+ schedule_name,
7
+ num_diffusion_timesteps,
8
+ min_beta=1e-4,
9
+ max_beta=0.02,
10
+ s=0.008,
11
+ ):
12
+ """
13
+ Get a pre-defined beta schedule for the given name.
14
+
15
+ The beta schedule library consists of beta schedules which remain similar
16
+ in the limit of num_diffusion_timesteps.
17
+ Beta schedules may be added, but should not be removed or changed once
18
+ they are committed to maintain backwards compatibility.
19
+ """
20
+ if schedule_name == "linear":
21
+ # Linear schedule from Ho et al, extended to work for any number of
22
+ # diffusion steps.
23
+ #scale = 1000 / num_diffusion_timesteps
24
+ scale = 1.0
25
+ beta_start = scale * min_beta
26
+ beta_end = scale * max_beta
27
+ return torch.linspace(
28
+ beta_start, beta_end, num_diffusion_timesteps,
29
+ )
30
+
31
+ elif schedule_name == "cosine":
32
+ return betas_for_alpha_bar(
33
+ num_diffusion_timesteps,
34
+ lambda t: math.cos((t + s) / (1 + s) * math.pi / 2) ** 2,
35
+ )
36
+ else:
37
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
38
+
39
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
40
+ """
41
+ Create a beta schedule that discretizes the given alpha_t_bar function,
42
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
43
+
44
+ :param num_diffusion_timesteps: the number of betas to produce.
45
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
46
+ produces the cumulative product of (1-beta) up to that
47
+ part of the diffusion process.
48
+ :param max_beta: the maximum beta to use; use values lower than 1 to
49
+ prevent singularities.
50
+ """
51
+ betas = []
52
+ for i in range(num_diffusion_timesteps):
53
+ t1 = i / num_diffusion_timesteps
54
+ t2 = (i + 1) / num_diffusion_timesteps
55
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
56
+ return torch.tensor(betas, dtype=torch.float32)
model_architect/weight_init.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def variance_scaling(scale, mode, distribution,
5
+ in_axis=1, out_axis=0,
6
+ dtype=torch.float32,
7
+ device='cpu'):
8
+ """Ported from JAX. """
9
+
10
+ def _compute_fans(shape, in_axis=1, out_axis=0):
11
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
12
+ fan_in = shape[in_axis] * receptive_field_size
13
+ fan_out = shape[out_axis] * receptive_field_size
14
+ return fan_in, fan_out
15
+
16
+ def init(shape, dtype=dtype, device=device):
17
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
18
+ if mode == "fan_in":
19
+ denominator = fan_in
20
+ elif mode == "fan_out":
21
+ denominator = fan_out
22
+ elif mode == "fan_avg":
23
+ denominator = (fan_in + fan_out) / 2
24
+ else:
25
+ raise ValueError(
26
+ "invalid mode for variance scaling initializer: {}".format(mode))
27
+ variance = scale / denominator
28
+ if distribution == "normal":
29
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
30
+ elif distribution == "uniform":
31
+ return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
32
+ else:
33
+ raise ValueError("invalid distribution for variance scaling initializer")
34
+
35
+ return init
36
+
37
+ def default_init(scale=1.):
38
+ """The same initialization used in DDPM."""
39
+ scale = 1e-10 if scale == 0 else scale
40
+ return variance_scaling(scale, 'fan_avg', 'uniform')
model_weights/ft06_01hr/weights.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d1664224001817cb35f970ac9ff06d7e6ea66b0152385e849c6d7d1bf6bd01f
3
+ size 231196876
model_weights/ft36_06hr/weights.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c7e10ccc11f79598d99bd77ed52a6c02c4ec0bbd567e0f423bafa9ff887622f
3
+ size 326923212
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy==1.26.4
2
+ torch==2.4.0
3
+ einops==0.8.0
sample_data/sample_202504131100.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6117356ac780a530645e192cc85d647b103c915b63433441d810dedc7cdd4ec1
3
+ size 33002900
sample_data/sample_202504161200.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7b9e6c7ed76f695f7c6f2f1a976f4c27128120e5c4328223809c27dc8feee52
3
+ size 33300209
sample_data/sample_202507151200.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e20eb69ea2bef6bb0074c376afa7e8398e7da8eb1edd9a1f11c343ffe711a299
3
+ size 33038261