Spaces:
Sleeping
Sleeping
feat: citydreamer inference (bugs to be fixed).
Browse files- app.py +58 -38
- citydreamer/__init__.py +0 -0
- citydreamer/extensions/__init__.py +8 -0
- citydreamer/extensions/extrude_tensor/__init__.py +40 -0
- citydreamer/extensions/extrude_tensor/bindings.cpp +39 -0
- citydreamer/extensions/extrude_tensor/extrude_tensor_ext.cu +67 -0
- citydreamer/extensions/extrude_tensor/setup.py +26 -0
- citydreamer/extensions/extrude_tensor/test.py +124 -0
- citydreamer/extensions/grid_encoder/__init__.py +193 -0
- citydreamer/extensions/grid_encoder/bindings.cpp +40 -0
- citydreamer/extensions/grid_encoder/grid_encoder_ext.cu +605 -0
- citydreamer/extensions/grid_encoder/setup.py +39 -0
- citydreamer/extensions/voxlib/__init__.py +5 -0
- citydreamer/extensions/voxlib/ray_voxel_intersection.cu +351 -0
- citydreamer/extensions/voxlib/setup.py +25 -0
- citydreamer/extensions/voxlib/voxlib.cpp +21 -0
- citydreamer/extensions/voxlib/voxlib_common.h +83 -0
- citydreamer/inference.py +537 -0
- citydreamer/model.py +1264 -0
- requirements.txt +4 -1
app.py
CHANGED
|
@@ -4,80 +4,90 @@
|
|
| 4 |
# @Author: Haozhe Xie
|
| 5 |
# @Date: 2024-03-02 16:30:00
|
| 6 |
# @Last Modified by: Haozhe Xie
|
| 7 |
-
# @Last Modified at: 2024-03-03
|
| 8 |
# @Email: [email protected]
|
| 9 |
|
|
|
|
| 10 |
import logging
|
|
|
|
| 11 |
import os
|
| 12 |
-
import torch
|
| 13 |
-
import gradio as gr
|
| 14 |
-
import subprocess
|
| 15 |
-
import urllib.request
|
| 16 |
import ssl
|
|
|
|
| 17 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
|
| 20 |
ssl._create_default_https_context = ssl._create_unverified_context
|
| 21 |
-
|
| 22 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
|
| 23 |
# Import CityDreamer modules
|
| 24 |
-
|
| 25 |
-
# import citydreamer.inference
|
| 26 |
|
| 27 |
|
| 28 |
def setup_runtime_env():
|
| 29 |
-
subprocess.
|
|
|
|
|
|
|
| 30 |
ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
|
| 31 |
for e in os.listdir(ext_dir):
|
| 32 |
-
if not os.path.isdir(e):
|
| 33 |
continue
|
| 34 |
-
subprocess.call(["pip", "install", "."], workdir=os.path.join(ext_dir, e))
|
| 35 |
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
)
|
| 43 |
-
if not os.path.exists("CityDreamer-Bgnd.pth"):
|
| 44 |
urllib.request.urlretrieve(
|
| 45 |
-
"https://huggingface.co/hzxie/city-dreamer/resolve/main
|
| 46 |
-
|
| 47 |
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
bgm = citydreamer.model.GanCraftGenerator(bgm_ckpt["cfg"])
|
| 52 |
-
fgm = citydreamer.model.GanCraftGenerator(fgm_ckpt["cfg"])
|
| 53 |
if torch.cuda.is_available():
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def get_generated_city(radius, altitude, azimuth):
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
def main(debug):
|
| 65 |
title = "CityDreamer Demo 🏙️"
|
| 66 |
with open("README.md", "r") as f:
|
| 67 |
markdown = f.read()
|
| 68 |
-
desc = markdown[markdown.rfind("---") + 3:]
|
| 69 |
with open("ARTICLE.md", "r") as f:
|
| 70 |
arti = f.read()
|
| 71 |
|
| 72 |
app = gr.Interface(
|
| 73 |
get_generated_city,
|
| 74 |
[
|
| 75 |
-
gr.Slider(
|
| 76 |
-
|
| 77 |
-
),
|
| 78 |
-
gr.Slider(
|
| 79 |
-
256, 512, value=384, step=5, label="Camera Altitude (m)"
|
| 80 |
-
),
|
| 81 |
gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
|
| 82 |
],
|
| 83 |
[gr.Image(type="numpy", label="Generated City")],
|
|
@@ -94,9 +104,19 @@ if __name__ == "__main__":
|
|
| 94 |
logging.basicConfig(
|
| 95 |
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
|
| 96 |
)
|
| 97 |
-
logging.info("
|
| 98 |
# setup_runtime_env()
|
|
|
|
| 99 |
logging.info("Downloading pretrained models...")
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
logging.info("Starting the main application...")
|
| 102 |
main(os.getenv("DEBUG") == "1")
|
|
|
|
| 4 |
# @Author: Haozhe Xie
|
| 5 |
# @Date: 2024-03-02 16:30:00
|
| 6 |
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-03-03 12:02:23
|
| 8 |
# @Email: [email protected]
|
| 9 |
|
| 10 |
+
import gradio as gr
|
| 11 |
import logging
|
| 12 |
+
import numpy as np
|
| 13 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
import ssl
|
| 15 |
+
import subprocess
|
| 16 |
import sys
|
| 17 |
+
import torch
|
| 18 |
+
import urllib.request
|
| 19 |
+
|
| 20 |
+
from PIL import Image
|
| 21 |
|
| 22 |
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
|
| 23 |
ssl._create_default_https_context = ssl._create_unverified_context
|
|
|
|
|
|
|
| 24 |
# Import CityDreamer modules
|
| 25 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def setup_runtime_env():
|
| 29 |
+
logging.info("CUDA version is %s" % subprocess.check_output(["nvcc", "--version"]))
|
| 30 |
+
logging.info("GCC version is %s" % subprocess.check_output(["g++", "--version"]))
|
| 31 |
+
# Compile CUDA extensions
|
| 32 |
ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
|
| 33 |
for e in os.listdir(ext_dir):
|
| 34 |
+
if not os.path.isdir(os.path.join(ext_dir, e)):
|
| 35 |
continue
|
|
|
|
| 36 |
|
| 37 |
+
subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
|
| 38 |
|
| 39 |
+
|
| 40 |
+
def get_models(file_name):
|
| 41 |
+
import citydreamer.model
|
| 42 |
+
|
| 43 |
+
if not os.path.exists(file_name):
|
|
|
|
|
|
|
| 44 |
urllib.request.urlretrieve(
|
| 45 |
+
"https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
|
| 46 |
+
file_name,
|
| 47 |
)
|
| 48 |
|
| 49 |
+
ckpt = torch.load(file_name)
|
| 50 |
+
model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
|
|
|
|
|
|
|
| 51 |
if torch.cuda.is_available():
|
| 52 |
+
model = torch.nn.DataParallel(model).cuda().eval()
|
| 53 |
+
|
| 54 |
+
return model
|
| 55 |
|
| 56 |
+
|
| 57 |
+
def get_city_layout():
|
| 58 |
+
hf = np.array(Image.open("assets/NYC-HghtFld.png"))
|
| 59 |
+
seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P"))
|
| 60 |
+
return hf, seg
|
| 61 |
|
| 62 |
|
| 63 |
def get_generated_city(radius, altitude, azimuth):
|
| 64 |
+
# The import must be done after CUDA extension compilation
|
| 65 |
+
import citydreamer.inference
|
| 66 |
+
|
| 67 |
+
return citydreamer.inference.generate_city(
|
| 68 |
+
get_generated_city.fgm,
|
| 69 |
+
get_generated_city.bgm,
|
| 70 |
+
get_generated_city.hf,
|
| 71 |
+
get_generated_city.seg,
|
| 72 |
+
radius,
|
| 73 |
+
altitude,
|
| 74 |
+
azimuth,
|
| 75 |
+
)
|
| 76 |
|
| 77 |
|
| 78 |
def main(debug):
|
| 79 |
title = "CityDreamer Demo 🏙️"
|
| 80 |
with open("README.md", "r") as f:
|
| 81 |
markdown = f.read()
|
| 82 |
+
desc = markdown[markdown.rfind("---") + 3 :]
|
| 83 |
with open("ARTICLE.md", "r") as f:
|
| 84 |
arti = f.read()
|
| 85 |
|
| 86 |
app = gr.Interface(
|
| 87 |
get_generated_city,
|
| 88 |
[
|
| 89 |
+
gr.Slider(128, 512, value=320, step=5, label="Camera Radius (m)"),
|
| 90 |
+
gr.Slider(256, 512, value=384, step=5, label="Camera Altitude (m)"),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
|
| 92 |
],
|
| 93 |
[gr.Image(type="numpy", label="Generated City")],
|
|
|
|
| 104 |
logging.basicConfig(
|
| 105 |
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
|
| 106 |
)
|
| 107 |
+
logging.info("Compiling CUDA extensions...")
|
| 108 |
# setup_runtime_env()
|
| 109 |
+
|
| 110 |
logging.info("Downloading pretrained models...")
|
| 111 |
+
fgm = get_models("CityDreamer-Fgnd.pth")
|
| 112 |
+
bgm = get_models("CityDreamer-Bgnd.pth")
|
| 113 |
+
get_generated_city.fgm = fgm
|
| 114 |
+
get_generated_city.bgm = bgm
|
| 115 |
+
|
| 116 |
+
logging.info("Loading New York city layout to RAM...")
|
| 117 |
+
hf, seg = get_city_layout()
|
| 118 |
+
get_generated_city.hf = hf
|
| 119 |
+
get_generated_city.seg = seg
|
| 120 |
+
|
| 121 |
logging.info("Starting the main application...")
|
| 122 |
main(os.getenv("DEBUG") == "1")
|
citydreamer/__init__.py
ADDED
|
File without changes
|
citydreamer/extensions/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: __init__.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2023-03-24 20:23:53
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-03-24 20:23:55
|
| 8 |
+
# @Email: [email protected]
|
citydreamer/extensions/extrude_tensor/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: __init__.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2023-03-24 20:24:38
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-06-16 09:55:58
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import extrude_tensor_ext
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TensorExtruder(torch.nn.Module):
|
| 16 |
+
def __init__(self, max_height=256):
|
| 17 |
+
super(TensorExtruder, self).__init__()
|
| 18 |
+
self.max_height = max_height
|
| 19 |
+
|
| 20 |
+
def forward(self, seg_map, height_field):
|
| 21 |
+
assert torch.max(height_field) < self.max_height, "Max Value %d" % torch.max(
|
| 22 |
+
height_field
|
| 23 |
+
)
|
| 24 |
+
return ExtrudeTensorFunction.apply(seg_map, height_field, self.max_height)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ExtrudeTensorFunction(torch.autograd.Function):
|
| 28 |
+
@staticmethod
|
| 29 |
+
def forward(ctx, seg_map, height_field, max_height):
|
| 30 |
+
# seg_map.shape: (B, C, H, W)
|
| 31 |
+
# height_field.shape: (B, C, H, W)
|
| 32 |
+
return extrude_tensor_ext.forward(seg_map, height_field, max_height)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def backward(ctx, grad_volume):
|
| 36 |
+
# grad_volume.shape: (B, C, H, W, D)
|
| 37 |
+
# Combine the gradients along the Z-axis.
|
| 38 |
+
grad_seg_map = torch.sum(grad_volume, dim=4)
|
| 39 |
+
grad_height_field = grad_seg_map
|
| 40 |
+
return grad_seg_map, grad_height_field
|
citydreamer/extensions/extrude_tensor/bindings.cpp
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: extrude_tensor_ext_cuda.cpp
|
| 3 |
+
* @Author: Haozhe Xie
|
| 4 |
+
* @Date: 2023-03-26 11:06:13
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2023-03-26 16:28:20
|
| 7 |
+
* @Email: [email protected]
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 11 |
+
#include <torch/extension.h>
|
| 12 |
+
|
| 13 |
+
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
| 14 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
| 15 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 16 |
+
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
| 17 |
+
#define CHECK_INPUT(x) \
|
| 18 |
+
CHECK_CUDA(x); \
|
| 19 |
+
CHECK_CONTIGUOUS(x)
|
| 20 |
+
|
| 21 |
+
torch::Tensor extrude_tensor_ext_cuda_forward(torch::Tensor seg_map,
|
| 22 |
+
torch::Tensor height_field,
|
| 23 |
+
int max_height,
|
| 24 |
+
cudaStream_t stream);
|
| 25 |
+
|
| 26 |
+
torch::Tensor extrude_tensor_ext_forward(torch::Tensor seg_map,
|
| 27 |
+
torch::Tensor height_field,
|
| 28 |
+
int max_height) {
|
| 29 |
+
CHECK_INPUT(seg_map);
|
| 30 |
+
CHECK_INPUT(height_field);
|
| 31 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 32 |
+
return extrude_tensor_ext_cuda_forward(seg_map, height_field, max_height,
|
| 33 |
+
stream);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 37 |
+
m.def("forward", &extrude_tensor_ext_forward,
|
| 38 |
+
"Extrude Tensor Ext. Forward (CUDA)");
|
| 39 |
+
}
|
citydreamer/extensions/extrude_tensor/extrude_tensor_ext.cu
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: extrude_tensor_ext.cu
|
| 3 |
+
* @Author: Haozhe Xie
|
| 4 |
+
* @Date: 2023-03-26 11:06:18
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2023-05-03 14:55:01
|
| 7 |
+
* @Email: [email protected]
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <cmath>
|
| 11 |
+
#include <cstdio>
|
| 12 |
+
#include <cstdlib>
|
| 13 |
+
#include <torch/extension.h>
|
| 14 |
+
|
| 15 |
+
#define CUDA_NUM_THREADS 512
|
| 16 |
+
|
| 17 |
+
// Computer the number of threads needed in GPU
|
| 18 |
+
inline int get_n_threads(int n) {
|
| 19 |
+
const int pow_2 = std::log(static_cast<float>(n)) / std::log(2.0);
|
| 20 |
+
return max(min(1 << pow_2, CUDA_NUM_THREADS), 1);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
__global__ void extrude_tensor_ext_cuda_kernel(
|
| 24 |
+
int height, int width, int depth, const int *__restrict__ seg_map,
|
| 25 |
+
const int *__restrict__ height_field, int *__restrict__ volume) {
|
| 26 |
+
int batch_index = blockIdx.x;
|
| 27 |
+
int index = threadIdx.x;
|
| 28 |
+
int stride = blockDim.x;
|
| 29 |
+
|
| 30 |
+
seg_map += batch_index * height * width;
|
| 31 |
+
height_field += batch_index * height * width;
|
| 32 |
+
volume += batch_index * height * width * depth;
|
| 33 |
+
for (int i = index; i < height; i += stride) {
|
| 34 |
+
int offset_2d_r = i * width, offset_3d_r = i * width * depth;
|
| 35 |
+
for (int j = 0; j < width; ++j) {
|
| 36 |
+
int offset_2d_c = offset_2d_r + j, offset_3d_c = offset_3d_r + j * depth;
|
| 37 |
+
int seg = seg_map[offset_2d_c];
|
| 38 |
+
int hf = height_field[offset_2d_c];
|
| 39 |
+
for (int k = 0; k < hf + 1; ++k) {
|
| 40 |
+
volume[offset_3d_c + k] = seg;
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
torch::Tensor extrude_tensor_ext_cuda_forward(torch::Tensor seg_map,
|
| 47 |
+
torch::Tensor height_field,
|
| 48 |
+
int max_height,
|
| 49 |
+
cudaStream_t stream) {
|
| 50 |
+
int batch_size = seg_map.size(0);
|
| 51 |
+
int height = seg_map.size(2);
|
| 52 |
+
int width = seg_map.size(3);
|
| 53 |
+
torch::Tensor volume = torch::zeros({batch_size, height, width, max_height},
|
| 54 |
+
torch::CUDA(torch::kInt32));
|
| 55 |
+
|
| 56 |
+
extrude_tensor_ext_cuda_kernel<<<
|
| 57 |
+
batch_size, int(CUDA_NUM_THREADS / CUDA_NUM_THREADS), 0, stream>>>(
|
| 58 |
+
height, width, max_height, seg_map.data_ptr<int>(),
|
| 59 |
+
height_field.data_ptr<int>(), volume.data_ptr<int>());
|
| 60 |
+
|
| 61 |
+
cudaError_t err = cudaGetLastError();
|
| 62 |
+
if (err != cudaSuccess) {
|
| 63 |
+
printf("Error in extrude_tensor_ext_cuda_forward: %s\n",
|
| 64 |
+
cudaGetErrorString(err));
|
| 65 |
+
}
|
| 66 |
+
return volume;
|
| 67 |
+
}
|
citydreamer/extensions/extrude_tensor/setup.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: setup.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2023-03-24 20:35:43
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-04-29 10:47:30
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
|
| 10 |
+
from setuptools import setup
|
| 11 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 12 |
+
|
| 13 |
+
setup(
|
| 14 |
+
name="extrude_tensor",
|
| 15 |
+
version="1.0.0",
|
| 16 |
+
ext_modules=[
|
| 17 |
+
CUDAExtension(
|
| 18 |
+
"extrude_tensor_ext",
|
| 19 |
+
[
|
| 20 |
+
"bindings.cpp",
|
| 21 |
+
"extrude_tensor_ext.cu",
|
| 22 |
+
],
|
| 23 |
+
),
|
| 24 |
+
],
|
| 25 |
+
cmdclass={"build_ext": BuildExtension},
|
| 26 |
+
)
|
citydreamer/extensions/extrude_tensor/test.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: test.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2023-03-26 19:23:26
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-04-15 10:47:53
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
|
| 10 |
+
# Mayavi off screen rendering
|
| 11 |
+
# Ref: https://github.com/enthought/mayavi/issues/477#issuecomment-477653210
|
| 12 |
+
from xvfbwrapper import Xvfb
|
| 13 |
+
|
| 14 |
+
vdisplay = Xvfb(width=1920, height=1080)
|
| 15 |
+
vdisplay.start()
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import mayavi.mlab
|
| 19 |
+
import numpy as np
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import torch
|
| 23 |
+
import unittest
|
| 24 |
+
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from torch.autograd import gradcheck
|
| 27 |
+
|
| 28 |
+
sys.path.append(
|
| 29 |
+
os.path.abspath(
|
| 30 |
+
os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
from extensions.extrude_tensor import ExtrudeTensorFunction
|
| 34 |
+
|
| 35 |
+
# Disable the warning message for PIL decompression bomb
|
| 36 |
+
# Ref: https://stackoverflow.com/questions/25705773/image-cropping-tool-python
|
| 37 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ExtrudeTensorTestCase(unittest.TestCase):
|
| 41 |
+
@unittest.skip("The CUDA extension is compiled with int types by default.")
|
| 42 |
+
def test_extrude_tensor_grad(self):
|
| 43 |
+
# To run this test, make sure that the int types are replaced by double types in CUDA
|
| 44 |
+
SIZE = 16
|
| 45 |
+
seg_map = (
|
| 46 |
+
torch.randint(low=1, high=7, size=(SIZE, SIZE))
|
| 47 |
+
.double()
|
| 48 |
+
.unsqueeze(dim=0)
|
| 49 |
+
.unsqueeze(dim=0)
|
| 50 |
+
)
|
| 51 |
+
height_field = (
|
| 52 |
+
torch.randint(low=0, high=255, size=(SIZE, SIZE))
|
| 53 |
+
.double()
|
| 54 |
+
.unsqueeze(dim=0)
|
| 55 |
+
.unsqueeze(dim=0)
|
| 56 |
+
)
|
| 57 |
+
logging.debug("SegMap Size: %s" % (seg_map.size(),))
|
| 58 |
+
logging.debug("HeightField Size: %s" % (height_field.size(),))
|
| 59 |
+
seg_map.requires_grad = True
|
| 60 |
+
height_field.requires_grad = True
|
| 61 |
+
logging.info(
|
| 62 |
+
"Gradient Check: %s" % "OK"
|
| 63 |
+
if gradcheck(
|
| 64 |
+
ExtrudeTensorFunction.apply, [seg_map.cuda(), height_field.cuda(), 256]
|
| 65 |
+
)
|
| 66 |
+
else "Failed"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def test_extrude_tensor_gen(self):
|
| 70 |
+
MAX_HEIGHT = 256
|
| 71 |
+
proj_home_dir = os.path.join(
|
| 72 |
+
os.path.dirname(__file__), os.path.pardir, os.path.pardir
|
| 73 |
+
)
|
| 74 |
+
osm_data_dir = os.path.join(proj_home_dir, "data", "osm")
|
| 75 |
+
osm_name = "US-NewYork"
|
| 76 |
+
seg_map = Image.open(os.path.join(osm_data_dir, osm_name, "seg.png")).convert(
|
| 77 |
+
"P"
|
| 78 |
+
)
|
| 79 |
+
height_field = Image.open(os.path.join(osm_data_dir, osm_name, "hf.png"))
|
| 80 |
+
# Crop the maps
|
| 81 |
+
seg_map = np.array(seg_map)[3840:4096, 3840:4096]
|
| 82 |
+
height_field = np.array(height_field)[3840:4096, 3840:4096]
|
| 83 |
+
# Convert to tensors
|
| 84 |
+
seg_map_tnsr = (
|
| 85 |
+
torch.from_numpy(seg_map).unsqueeze(dim=0).unsqueeze(dim=0).int().cuda()
|
| 86 |
+
)
|
| 87 |
+
height_field_tnsr = (
|
| 88 |
+
torch.from_numpy(height_field)
|
| 89 |
+
.unsqueeze(dim=0)
|
| 90 |
+
.unsqueeze(dim=0)
|
| 91 |
+
.int()
|
| 92 |
+
.cuda()
|
| 93 |
+
)
|
| 94 |
+
volume = ExtrudeTensorFunction.apply(
|
| 95 |
+
seg_map_tnsr, height_field_tnsr, MAX_HEIGHT
|
| 96 |
+
)
|
| 97 |
+
# 3D Visualization
|
| 98 |
+
vol = volume.squeeze().cpu().numpy().astype(np.uint8)
|
| 99 |
+
|
| 100 |
+
x, y, z = np.where(vol != 0)
|
| 101 |
+
n_pts = len(x)
|
| 102 |
+
colors = np.zeros((n_pts, 4), dtype=np.uint8)
|
| 103 |
+
# fmt: off
|
| 104 |
+
colors[vol[x, y, z] == 1] = [96, 0, 0, 255] # highway -> red
|
| 105 |
+
colors[vol[x, y, z] == 2] = [96, 96, 0, 255] # building -> yellow
|
| 106 |
+
colors[vol[x, y, z] == 3] = [0, 96, 0, 255] # green lands -> green
|
| 107 |
+
colors[vol[x, y, z] == 4] = [0, 96, 96, 255] # construction -> cyan
|
| 108 |
+
colors[vol[x, y, z] == 5] = [0, 0, 96, 255] # water -> blue
|
| 109 |
+
colors[vol[x, y, z] == 6] = [128, 128, 128, 255] # ground -> gray
|
| 110 |
+
# fmt: on
|
| 111 |
+
mayavi.mlab.options.offscreen = True
|
| 112 |
+
mayavi.mlab.figure(size=(1600, 900), bgcolor=(1, 1, 1))
|
| 113 |
+
pts = mayavi.mlab.points3d(x, y, z, mode="cube", scale_factor=1)
|
| 114 |
+
pts.glyph.scale_mode = "scale_by_vector"
|
| 115 |
+
pts.mlab_source.dataset.point_data.scalars = colors
|
| 116 |
+
mayavi.mlab.savefig(os.path.join(proj_home_dir, "logs", "%s-3d.jpg" % osm_name))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
logging.basicConfig(
|
| 121 |
+
format="[%(levelname)s] %(asctime)s %(message)s",
|
| 122 |
+
level=logging.INFO,
|
| 123 |
+
)
|
| 124 |
+
unittest.main()
|
citydreamer/extensions/grid_encoder/__init__.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: __init__.py
|
| 4 |
+
# @Author: Jiaxiang Tang (@ashawkey)
|
| 5 |
+
# @Date: 2023-04-15 10:39:28
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-04-15 13:08:46
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
# @Ref: https://github.com/ashawkey/torch-ngp
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
import grid_encoder_ext
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GridEncoderFunction(torch.autograd.Function):
|
| 19 |
+
@staticmethod
|
| 20 |
+
def forward(
|
| 21 |
+
ctx,
|
| 22 |
+
inputs,
|
| 23 |
+
embeddings,
|
| 24 |
+
offsets,
|
| 25 |
+
per_level_scale,
|
| 26 |
+
base_resolution,
|
| 27 |
+
calc_grad_inputs=False,
|
| 28 |
+
gridtype=0,
|
| 29 |
+
align_corners=False,
|
| 30 |
+
):
|
| 31 |
+
# inputs: [B, D], float in [0, 1]
|
| 32 |
+
# embeddings: [sO, C], float
|
| 33 |
+
# offsets: [L + 1], int
|
| 34 |
+
# RETURN: [B, F], float
|
| 35 |
+
inputs = inputs.contiguous()
|
| 36 |
+
# batch size, coord dim
|
| 37 |
+
B, D = inputs.shape
|
| 38 |
+
# level
|
| 39 |
+
L = offsets.shape[0] - 1
|
| 40 |
+
# embedding dim for each level
|
| 41 |
+
C = embeddings.shape[1]
|
| 42 |
+
# resolution multiplier at each level, apply log2 for later CUDA exp2f
|
| 43 |
+
S = math.log2(per_level_scale)
|
| 44 |
+
# base resolution
|
| 45 |
+
H = base_resolution
|
| 46 |
+
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
| 47 |
+
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
| 48 |
+
|
| 49 |
+
if calc_grad_inputs:
|
| 50 |
+
dy_dx = torch.empty(
|
| 51 |
+
B, L * D * C, device=inputs.device, dtype=embeddings.dtype
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
dy_dx = torch.empty(
|
| 55 |
+
1, device=inputs.device, dtype=embeddings.dtype
|
| 56 |
+
) # placeholder... TODO: a better way?
|
| 57 |
+
|
| 58 |
+
grid_encoder_ext.forward(
|
| 59 |
+
inputs,
|
| 60 |
+
embeddings,
|
| 61 |
+
offsets,
|
| 62 |
+
outputs,
|
| 63 |
+
B,
|
| 64 |
+
D,
|
| 65 |
+
C,
|
| 66 |
+
L,
|
| 67 |
+
S,
|
| 68 |
+
H,
|
| 69 |
+
calc_grad_inputs,
|
| 70 |
+
dy_dx,
|
| 71 |
+
gridtype,
|
| 72 |
+
align_corners,
|
| 73 |
+
)
|
| 74 |
+
# permute back to [B, L * C]
|
| 75 |
+
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
| 76 |
+
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
| 77 |
+
ctx.dims = [B, D, C, L, S, H, gridtype]
|
| 78 |
+
ctx.calc_grad_inputs = calc_grad_inputs
|
| 79 |
+
ctx.align_corners = align_corners
|
| 80 |
+
|
| 81 |
+
return outputs
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def backward(ctx, grad):
|
| 85 |
+
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
| 86 |
+
B, D, C, L, S, H, gridtype = ctx.dims
|
| 87 |
+
calc_grad_inputs = ctx.calc_grad_inputs
|
| 88 |
+
align_corners = ctx.align_corners
|
| 89 |
+
|
| 90 |
+
# grad: [B, L * C] --> [L, B, C]
|
| 91 |
+
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
| 92 |
+
grad_embeddings = torch.zeros_like(embeddings)
|
| 93 |
+
|
| 94 |
+
if calc_grad_inputs:
|
| 95 |
+
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
| 96 |
+
else:
|
| 97 |
+
grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
|
| 98 |
+
|
| 99 |
+
grid_encoder_ext.backward(
|
| 100 |
+
grad,
|
| 101 |
+
inputs,
|
| 102 |
+
embeddings,
|
| 103 |
+
offsets,
|
| 104 |
+
grad_embeddings,
|
| 105 |
+
B,
|
| 106 |
+
D,
|
| 107 |
+
C,
|
| 108 |
+
L,
|
| 109 |
+
S,
|
| 110 |
+
H,
|
| 111 |
+
calc_grad_inputs,
|
| 112 |
+
dy_dx,
|
| 113 |
+
grad_inputs,
|
| 114 |
+
gridtype,
|
| 115 |
+
align_corners,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if calc_grad_inputs:
|
| 119 |
+
grad_inputs = grad_inputs.to(inputs.dtype)
|
| 120 |
+
return grad_inputs, grad_embeddings, None, None, None, None, None, None
|
| 121 |
+
else:
|
| 122 |
+
return None, grad_embeddings, None, None, None, None, None, None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class GridEncoder(torch.nn.Module):
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
in_channels,
|
| 129 |
+
n_levels,
|
| 130 |
+
lvl_channels,
|
| 131 |
+
desired_resolution,
|
| 132 |
+
per_level_scale=2,
|
| 133 |
+
base_resolution=16,
|
| 134 |
+
log2_hashmap_size=19,
|
| 135 |
+
gridtype="hash",
|
| 136 |
+
align_corners=False,
|
| 137 |
+
):
|
| 138 |
+
super(GridEncoder, self).__init__()
|
| 139 |
+
self.in_channels = in_channels
|
| 140 |
+
self.n_levels = n_levels # num levels, each level multiply resolution by 2
|
| 141 |
+
self.lvl_channels = lvl_channels # encode channels per level
|
| 142 |
+
self.per_level_scale = 2 ** (
|
| 143 |
+
math.log2(desired_resolution / base_resolution) / (n_levels - 1)
|
| 144 |
+
)
|
| 145 |
+
self.log2_hashmap_size = log2_hashmap_size
|
| 146 |
+
self.base_resolution = base_resolution
|
| 147 |
+
self.output_dim = n_levels * lvl_channels
|
| 148 |
+
self.gridtype = gridtype
|
| 149 |
+
self.gridtype_id = 0 if gridtype == "hash" else 1
|
| 150 |
+
self.align_corners = align_corners
|
| 151 |
+
|
| 152 |
+
# allocate parameters
|
| 153 |
+
offsets = []
|
| 154 |
+
offset = 0
|
| 155 |
+
self.max_params = 2**log2_hashmap_size
|
| 156 |
+
for i in range(n_levels):
|
| 157 |
+
resolution = int(math.ceil(base_resolution * per_level_scale**i))
|
| 158 |
+
params_in_level = min(
|
| 159 |
+
self.max_params,
|
| 160 |
+
(resolution if align_corners else resolution + 1) ** in_channels,
|
| 161 |
+
) # limit max number
|
| 162 |
+
params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible
|
| 163 |
+
offsets.append(offset)
|
| 164 |
+
offset += params_in_level
|
| 165 |
+
|
| 166 |
+
offsets.append(offset)
|
| 167 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
| 168 |
+
self.register_buffer("offsets", offsets)
|
| 169 |
+
|
| 170 |
+
self.n_params = offsets[-1] * lvl_channels
|
| 171 |
+
self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels))
|
| 172 |
+
self._init_weights()
|
| 173 |
+
|
| 174 |
+
def _init_weights(self):
|
| 175 |
+
self.embeddings.data.uniform_(-1e-4, 1e-4)
|
| 176 |
+
|
| 177 |
+
def forward(self, inputs, bound=1):
|
| 178 |
+
# inputs: [..., in_channels], normalized real world positions in [-bound, bound]
|
| 179 |
+
# return: [..., n_levels * lvl_channels]
|
| 180 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
| 181 |
+
prefix_shape = list(inputs.shape[:-1])
|
| 182 |
+
inputs = inputs.view(-1, self.in_channels)
|
| 183 |
+
outputs = GridEncoderFunction.apply(
|
| 184 |
+
inputs,
|
| 185 |
+
self.embeddings,
|
| 186 |
+
self.offsets,
|
| 187 |
+
self.per_level_scale,
|
| 188 |
+
self.base_resolution,
|
| 189 |
+
inputs.requires_grad,
|
| 190 |
+
self.gridtype_id,
|
| 191 |
+
self.align_corners,
|
| 192 |
+
)
|
| 193 |
+
return outputs.view(prefix_shape + [self.output_dim])
|
citydreamer/extensions/grid_encoder/bindings.cpp
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: grid_encoder_ext_cuda.cpp
|
| 3 |
+
* @Author: Jiaxiang Tang (@ashawkey)
|
| 4 |
+
* @Date: 2023-04-15 10:39:17
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2023-04-15 11:01:32
|
| 7 |
+
* @Email: [email protected]
|
| 8 |
+
* @Ref: https://github.com/ashawkey/torch-ngp
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <stdint.h>
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
#include <torch/torch.h>
|
| 14 |
+
|
| 15 |
+
// inputs: [B, D], float, in [0, 1]
|
| 16 |
+
// embeddings: [sO, C], float
|
| 17 |
+
// offsets: [L + 1], uint32_t
|
| 18 |
+
// outputs: [B, L * C], float
|
| 19 |
+
// H: base resolution
|
| 20 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
|
| 21 |
+
const at::Tensor offsets, at::Tensor outputs,
|
| 22 |
+
const uint32_t B, const uint32_t D, const uint32_t C,
|
| 23 |
+
const uint32_t L, const float S, const uint32_t H,
|
| 24 |
+
const bool calc_grad_inputs, at::Tensor dy_dx,
|
| 25 |
+
const uint32_t gridtype, const bool align_corners);
|
| 26 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
|
| 27 |
+
const at::Tensor embeddings, const at::Tensor offsets,
|
| 28 |
+
at::Tensor grad_embeddings, const uint32_t B,
|
| 29 |
+
const uint32_t D, const uint32_t C, const uint32_t L,
|
| 30 |
+
const float S, const uint32_t H,
|
| 31 |
+
const bool calc_grad_inputs, const at::Tensor dy_dx,
|
| 32 |
+
at::Tensor grad_inputs, const uint32_t gridtype,
|
| 33 |
+
const bool align_corners);
|
| 34 |
+
|
| 35 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 36 |
+
m.def("forward", &grid_encode_forward,
|
| 37 |
+
"grid_encode_forward (CUDA)");
|
| 38 |
+
m.def("backward", &grid_encode_backward,
|
| 39 |
+
"grid_encode_backward (CUDA)");
|
| 40 |
+
}
|
citydreamer/extensions/grid_encoder/grid_encoder_ext.cu
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: grid_encoder_ext.cu
|
| 3 |
+
* @Author: Jiaxiang Tang (@ashawkey)
|
| 4 |
+
* @Date: 2023-04-15 10:43:16
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2023-04-29 11:47:54
|
| 7 |
+
* @Email: [email protected]
|
| 8 |
+
* @Ref: https://github.com/ashawkey/torch-ngp
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <cuda.h>
|
| 12 |
+
#include <cuda_fp16.h>
|
| 13 |
+
#include <cuda_runtime.h>
|
| 14 |
+
|
| 15 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 16 |
+
#include <torch/torch.h>
|
| 17 |
+
|
| 18 |
+
#include <algorithm>
|
| 19 |
+
#include <stdexcept>
|
| 20 |
+
|
| 21 |
+
#include <cstdio>
|
| 22 |
+
#include <stdint.h>
|
| 23 |
+
|
| 24 |
+
#define CHECK_CUDA(x) \
|
| 25 |
+
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 26 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 27 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
| 28 |
+
#define CHECK_IS_INT(x) \
|
| 29 |
+
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
|
| 30 |
+
#x " must be an int tensor")
|
| 31 |
+
#define CHECK_IS_FLOATING(x) \
|
| 32 |
+
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \
|
| 33 |
+
x.scalar_type() == at::ScalarType::Half || \
|
| 34 |
+
x.scalar_type() == at::ScalarType::Double, \
|
| 35 |
+
#x " must be a floating tensor")
|
| 36 |
+
|
| 37 |
+
// just for compatability of half precision in
|
| 38 |
+
// AT_DISPATCH_FLOATING_TYPES_AND_HALF...
|
| 39 |
+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
| 40 |
+
// requires CUDA >= 10 and ARCH >= 70
|
| 41 |
+
// this is very slow compared to float or __half2, and never used.
|
| 42 |
+
// return atomicAdd(reinterpret_cast<__half*>(address), val);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename T>
|
| 46 |
+
static inline __host__ __device__ T div_round_up(T val, T divisor) {
|
| 47 |
+
return (val + divisor - 1) / divisor;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <uint32_t D>
|
| 51 |
+
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
| 52 |
+
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
|
| 53 |
+
|
| 54 |
+
// While 1 is technically not a good prime for hashing (or a prime at all), it
|
| 55 |
+
// helps memory coherence and is sufficient for our use case of obtaining a
|
| 56 |
+
// uniformly colliding index from high-dimensional coordinates.
|
| 57 |
+
constexpr uint32_t primes[7] = {1, 2654435761, 805459861, 3674653429,
|
| 58 |
+
2097192037, 1434869437, 2165219737};
|
| 59 |
+
|
| 60 |
+
uint32_t result = 0;
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (uint32_t i = 0; i < D; ++i) {
|
| 63 |
+
result ^= pos_grid[i] * primes[i];
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return result;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <uint32_t D, uint32_t C>
|
| 70 |
+
__device__ uint32_t get_grid_index(const uint32_t gridtype,
|
| 71 |
+
const bool align_corners, const uint32_t ch,
|
| 72 |
+
const uint32_t hashmap_size,
|
| 73 |
+
const uint32_t resolution,
|
| 74 |
+
const uint32_t pos_grid[D]) {
|
| 75 |
+
uint32_t stride = 1;
|
| 76 |
+
uint32_t index = 0;
|
| 77 |
+
|
| 78 |
+
#pragma unroll
|
| 79 |
+
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
| 80 |
+
index += pos_grid[d] * stride;
|
| 81 |
+
stride *= align_corners ? resolution : (resolution + 1);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// NOTE: for NeRF, the hash is in fact not necessary. Check
|
| 85 |
+
// https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1 ==
|
| 86 |
+
// tiled
|
| 87 |
+
if (gridtype == 0 && stride > hashmap_size) {
|
| 88 |
+
index = fast_hash<D>(pos_grid);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return (index % hashmap_size) * C + ch;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
| 95 |
+
__global__ void
|
| 96 |
+
kernel_grid(const float *__restrict__ inputs, const scalar_t *__restrict__ grid,
|
| 97 |
+
const int *__restrict__ offsets, scalar_t *__restrict__ outputs,
|
| 98 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
| 99 |
+
const bool calc_grad_inputs, scalar_t *__restrict__ dy_dx,
|
| 100 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 101 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
| 102 |
+
|
| 103 |
+
if (b >= B)
|
| 104 |
+
return;
|
| 105 |
+
|
| 106 |
+
const uint32_t level = blockIdx.y;
|
| 107 |
+
|
| 108 |
+
// locate
|
| 109 |
+
grid += (uint32_t)offsets[level] * C;
|
| 110 |
+
inputs += b * D;
|
| 111 |
+
outputs += level * B * C + b * C;
|
| 112 |
+
|
| 113 |
+
// check input range (should be in [0, 1])
|
| 114 |
+
bool flag_oob = false;
|
| 115 |
+
#pragma unroll
|
| 116 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 117 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
| 118 |
+
flag_oob = true;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
// if input out of bound, just set output to 0
|
| 122 |
+
if (flag_oob) {
|
| 123 |
+
#pragma unroll
|
| 124 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 125 |
+
outputs[ch] = 0;
|
| 126 |
+
}
|
| 127 |
+
if (calc_grad_inputs) {
|
| 128 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
| 129 |
+
#pragma unroll
|
| 130 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 131 |
+
#pragma unroll
|
| 132 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 133 |
+
dy_dx[d * C + ch] = 0;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
return;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
| 141 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
| 142 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
| 143 |
+
|
| 144 |
+
// calculate coordinate
|
| 145 |
+
float pos[D];
|
| 146 |
+
uint32_t pos_grid[D];
|
| 147 |
+
|
| 148 |
+
#pragma unroll
|
| 149 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 150 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
| 151 |
+
pos_grid[d] = floorf(pos[d]);
|
| 152 |
+
pos[d] -= (float)pos_grid[d];
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1],
|
| 156 |
+
// pos_grid[0], pos_grid[1]);
|
| 157 |
+
|
| 158 |
+
// interpolate
|
| 159 |
+
scalar_t results[C] = {0}; // temp results in register
|
| 160 |
+
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
| 163 |
+
float w = 1;
|
| 164 |
+
uint32_t pos_grid_local[D];
|
| 165 |
+
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 168 |
+
if ((idx & (1 << d)) == 0) {
|
| 169 |
+
w *= 1 - pos[d];
|
| 170 |
+
pos_grid_local[d] = pos_grid[d];
|
| 171 |
+
} else {
|
| 172 |
+
w *= pos[d];
|
| 173 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
uint32_t index = get_grid_index<D, C>(
|
| 178 |
+
gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
| 179 |
+
|
| 180 |
+
// writing to register (fast)
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 183 |
+
results[ch] += w * grid[index + ch];
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx,
|
| 187 |
+
// index, w, grid[index]);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// writing to global memory (slow)
|
| 191 |
+
#pragma unroll
|
| 192 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 193 |
+
outputs[ch] = results[ch];
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// prepare dy_dx for calc_grad_inputs
|
| 197 |
+
// differentiable (soft) indexing:
|
| 198 |
+
// https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
| 199 |
+
if (calc_grad_inputs) {
|
| 200 |
+
|
| 201 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
| 202 |
+
|
| 203 |
+
#pragma unroll
|
| 204 |
+
for (uint32_t gd = 0; gd < D; gd++) {
|
| 205 |
+
|
| 206 |
+
scalar_t results_grad[C] = {0};
|
| 207 |
+
|
| 208 |
+
#pragma unroll
|
| 209 |
+
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
| 210 |
+
float w = scale;
|
| 211 |
+
uint32_t pos_grid_local[D];
|
| 212 |
+
|
| 213 |
+
#pragma unroll
|
| 214 |
+
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
| 215 |
+
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
| 216 |
+
|
| 217 |
+
if ((idx & (1 << nd)) == 0) {
|
| 218 |
+
w *= 1 - pos[d];
|
| 219 |
+
pos_grid_local[d] = pos_grid[d];
|
| 220 |
+
} else {
|
| 221 |
+
w *= pos[d];
|
| 222 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
pos_grid_local[gd] = pos_grid[gd];
|
| 227 |
+
uint32_t index_left =
|
| 228 |
+
get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
|
| 229 |
+
resolution, pos_grid_local);
|
| 230 |
+
pos_grid_local[gd] = pos_grid[gd] + 1;
|
| 231 |
+
uint32_t index_right =
|
| 232 |
+
get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
|
| 233 |
+
resolution, pos_grid_local);
|
| 234 |
+
|
| 235 |
+
#pragma unroll
|
| 236 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 237 |
+
results_grad[ch] +=
|
| 238 |
+
w * (grid[index_right + ch] - grid[index_left + ch]);
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
#pragma unroll
|
| 243 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 244 |
+
dy_dx[gd * C + ch] = results_grad[ch];
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
| 251 |
+
__global__ void kernel_grid_backward(
|
| 252 |
+
const scalar_t *__restrict__ grad, const float *__restrict__ inputs,
|
| 253 |
+
const scalar_t *__restrict__ grid, const int *__restrict__ offsets,
|
| 254 |
+
scalar_t *__restrict__ grad_grid, const uint32_t B, const uint32_t L,
|
| 255 |
+
const float S, const uint32_t H, const uint32_t gridtype,
|
| 256 |
+
const bool align_corners) {
|
| 257 |
+
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
| 258 |
+
if (b >= B)
|
| 259 |
+
return;
|
| 260 |
+
|
| 261 |
+
const uint32_t level = blockIdx.y;
|
| 262 |
+
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
| 263 |
+
|
| 264 |
+
// locate
|
| 265 |
+
grad_grid += offsets[level] * C;
|
| 266 |
+
inputs += b * D;
|
| 267 |
+
grad += level * B * C + b * C + ch; // L, B, C
|
| 268 |
+
|
| 269 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
| 270 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
| 271 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
| 272 |
+
|
| 273 |
+
// check input range (should be in [0, 1])
|
| 274 |
+
#pragma unroll
|
| 275 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 276 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
| 277 |
+
return; // grad is init as 0, so we simply return.
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// calculate coordinate
|
| 282 |
+
float pos[D];
|
| 283 |
+
uint32_t pos_grid[D];
|
| 284 |
+
|
| 285 |
+
#pragma unroll
|
| 286 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 287 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
| 288 |
+
pos_grid[d] = floorf(pos[d]);
|
| 289 |
+
pos[d] -= (float)pos_grid[d];
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
| 293 |
+
#pragma unroll
|
| 294 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
| 295 |
+
grad_cur[c] = grad[c];
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// interpolate
|
| 299 |
+
#pragma unroll
|
| 300 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
| 301 |
+
float w = 1;
|
| 302 |
+
uint32_t pos_grid_local[D];
|
| 303 |
+
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 306 |
+
if ((idx & (1 << d)) == 0) {
|
| 307 |
+
w *= 1 - pos[d];
|
| 308 |
+
pos_grid_local[d] = pos_grid[d];
|
| 309 |
+
} else {
|
| 310 |
+
w *= pos[d];
|
| 311 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
uint32_t index = get_grid_index<D, C>(
|
| 316 |
+
gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
| 317 |
+
|
| 318 |
+
// atomicAdd for __half is slow (especially for large values), so we use
|
| 319 |
+
// __half2 if N_C % 2 == 0
|
| 320 |
+
// TODO: use float which is better than __half, if N_C % 2 != 0
|
| 321 |
+
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
| 322 |
+
#pragma unroll
|
| 323 |
+
for (uint32_t c = 0; c < N_C; c += 2) {
|
| 324 |
+
// process two __half at once (by interpreting as a __half2)
|
| 325 |
+
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
| 326 |
+
atomicAdd((__half2 *)&grad_grid[index + c], v);
|
| 327 |
+
}
|
| 328 |
+
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
| 329 |
+
} else {
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
| 332 |
+
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
| 339 |
+
__global__ void kernel_input_backward(const scalar_t *__restrict__ grad,
|
| 340 |
+
const scalar_t *__restrict__ dy_dx,
|
| 341 |
+
scalar_t *__restrict__ grad_inputs,
|
| 342 |
+
uint32_t B, uint32_t L) {
|
| 343 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
| 344 |
+
if (t >= B * D)
|
| 345 |
+
return;
|
| 346 |
+
|
| 347 |
+
const uint32_t b = t / D;
|
| 348 |
+
const uint32_t d = t - b * D;
|
| 349 |
+
|
| 350 |
+
dy_dx += b * L * D * C;
|
| 351 |
+
|
| 352 |
+
scalar_t result = 0;
|
| 353 |
+
|
| 354 |
+
#pragma unroll
|
| 355 |
+
for (int l = 0; l < L; l++) {
|
| 356 |
+
#pragma unroll
|
| 357 |
+
for (int ch = 0; ch < C; ch++) {
|
| 358 |
+
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
grad_inputs[t] = result;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
template <typename scalar_t, uint32_t D>
|
| 366 |
+
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings,
|
| 367 |
+
const int *offsets, scalar_t *outputs,
|
| 368 |
+
const uint32_t B, const uint32_t C, const uint32_t L,
|
| 369 |
+
const float S, const uint32_t H,
|
| 370 |
+
const bool calc_grad_inputs, scalar_t *dy_dx,
|
| 371 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 372 |
+
static constexpr uint32_t N_THREAD = 512;
|
| 373 |
+
const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1};
|
| 374 |
+
switch (C) {
|
| 375 |
+
case 1:
|
| 376 |
+
kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(
|
| 377 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 378 |
+
dy_dx, gridtype, align_corners);
|
| 379 |
+
break;
|
| 380 |
+
case 2:
|
| 381 |
+
kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 382 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 383 |
+
dy_dx, gridtype, align_corners);
|
| 384 |
+
break;
|
| 385 |
+
case 4:
|
| 386 |
+
kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(
|
| 387 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 388 |
+
dy_dx, gridtype, align_corners);
|
| 389 |
+
break;
|
| 390 |
+
case 8:
|
| 391 |
+
kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(
|
| 392 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 393 |
+
dy_dx, gridtype, align_corners);
|
| 394 |
+
break;
|
| 395 |
+
default:
|
| 396 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// inputs: [B, D], float, in [0, 1]
|
| 401 |
+
// embeddings: [sO, C], float
|
| 402 |
+
// offsets: [L + 1], uint32_t
|
| 403 |
+
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit
|
| 404 |
+
// into cache at a time.) H: base resolution dy_dx: [B, L * D * C]
|
| 405 |
+
template <typename scalar_t>
|
| 406 |
+
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings,
|
| 407 |
+
const int *offsets, scalar_t *outputs,
|
| 408 |
+
const uint32_t B, const uint32_t D,
|
| 409 |
+
const uint32_t C, const uint32_t L, const float S,
|
| 410 |
+
const uint32_t H, const bool calc_grad_inputs,
|
| 411 |
+
scalar_t *dy_dx, const uint32_t gridtype,
|
| 412 |
+
const bool align_corners) {
|
| 413 |
+
switch (D) {
|
| 414 |
+
case 2:
|
| 415 |
+
kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C,
|
| 416 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 417 |
+
align_corners);
|
| 418 |
+
break;
|
| 419 |
+
case 3:
|
| 420 |
+
kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C,
|
| 421 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 422 |
+
align_corners);
|
| 423 |
+
break;
|
| 424 |
+
case 4:
|
| 425 |
+
kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C,
|
| 426 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 427 |
+
align_corners);
|
| 428 |
+
break;
|
| 429 |
+
case 5:
|
| 430 |
+
kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C,
|
| 431 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 432 |
+
align_corners);
|
| 433 |
+
break;
|
| 434 |
+
default:
|
| 435 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
template <typename scalar_t, uint32_t D>
|
| 440 |
+
void kernel_grid_backward_wrapper(
|
| 441 |
+
const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
|
| 442 |
+
const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
|
| 443 |
+
const uint32_t C, const uint32_t L, const float S, const uint32_t H,
|
| 444 |
+
const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs,
|
| 445 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 446 |
+
static constexpr uint32_t N_THREAD = 256;
|
| 447 |
+
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
| 448 |
+
const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1};
|
| 449 |
+
switch (C) {
|
| 450 |
+
case 1:
|
| 451 |
+
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(
|
| 452 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 453 |
+
gridtype, align_corners);
|
| 454 |
+
if (calc_grad_inputs)
|
| 455 |
+
kernel_input_backward<scalar_t, D, 1>
|
| 456 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 457 |
+
grad_inputs, B, L);
|
| 458 |
+
break;
|
| 459 |
+
case 2:
|
| 460 |
+
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 461 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 462 |
+
gridtype, align_corners);
|
| 463 |
+
if (calc_grad_inputs)
|
| 464 |
+
kernel_input_backward<scalar_t, D, 2>
|
| 465 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 466 |
+
grad_inputs, B, L);
|
| 467 |
+
break;
|
| 468 |
+
case 4:
|
| 469 |
+
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 470 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 471 |
+
gridtype, align_corners);
|
| 472 |
+
if (calc_grad_inputs)
|
| 473 |
+
kernel_input_backward<scalar_t, D, 4>
|
| 474 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 475 |
+
grad_inputs, B, L);
|
| 476 |
+
break;
|
| 477 |
+
case 8:
|
| 478 |
+
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 479 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 480 |
+
gridtype, align_corners);
|
| 481 |
+
if (calc_grad_inputs)
|
| 482 |
+
kernel_input_backward<scalar_t, D, 8>
|
| 483 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 484 |
+
grad_inputs, B, L);
|
| 485 |
+
break;
|
| 486 |
+
default:
|
| 487 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
// grad: [L, B, C], float
|
| 492 |
+
// inputs: [B, D], float, in [0, 1]
|
| 493 |
+
// embeddings: [sO, C], float
|
| 494 |
+
// offsets: [L + 1], uint32_t
|
| 495 |
+
// grad_embeddings: [sO, C]
|
| 496 |
+
// H: base resolution
|
| 497 |
+
template <typename scalar_t>
|
| 498 |
+
void grid_encode_backward_cuda(
|
| 499 |
+
const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
|
| 500 |
+
const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
|
| 501 |
+
const uint32_t D, const uint32_t C, const uint32_t L, const float S,
|
| 502 |
+
const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx,
|
| 503 |
+
scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
| 504 |
+
switch (D) {
|
| 505 |
+
case 2:
|
| 506 |
+
kernel_grid_backward_wrapper<scalar_t, 2>(
|
| 507 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 508 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 509 |
+
break;
|
| 510 |
+
case 3:
|
| 511 |
+
kernel_grid_backward_wrapper<scalar_t, 3>(
|
| 512 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 513 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 514 |
+
break;
|
| 515 |
+
case 4:
|
| 516 |
+
kernel_grid_backward_wrapper<scalar_t, 4>(
|
| 517 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 518 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 519 |
+
break;
|
| 520 |
+
case 5:
|
| 521 |
+
kernel_grid_backward_wrapper<scalar_t, 5>(
|
| 522 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 523 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 524 |
+
break;
|
| 525 |
+
default:
|
| 526 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 527 |
+
}
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
|
| 531 |
+
const at::Tensor offsets, at::Tensor outputs,
|
| 532 |
+
const uint32_t B, const uint32_t D, const uint32_t C,
|
| 533 |
+
const uint32_t L, const float S, const uint32_t H,
|
| 534 |
+
const bool calc_grad_inputs, at::Tensor dy_dx,
|
| 535 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 536 |
+
CHECK_CUDA(inputs);
|
| 537 |
+
CHECK_CUDA(embeddings);
|
| 538 |
+
CHECK_CUDA(offsets);
|
| 539 |
+
CHECK_CUDA(outputs);
|
| 540 |
+
CHECK_CUDA(dy_dx);
|
| 541 |
+
|
| 542 |
+
CHECK_CONTIGUOUS(inputs);
|
| 543 |
+
CHECK_CONTIGUOUS(embeddings);
|
| 544 |
+
CHECK_CONTIGUOUS(offsets);
|
| 545 |
+
CHECK_CONTIGUOUS(outputs);
|
| 546 |
+
CHECK_CONTIGUOUS(dy_dx);
|
| 547 |
+
|
| 548 |
+
CHECK_IS_FLOATING(inputs);
|
| 549 |
+
CHECK_IS_FLOATING(embeddings);
|
| 550 |
+
CHECK_IS_INT(offsets);
|
| 551 |
+
CHECK_IS_FLOATING(outputs);
|
| 552 |
+
CHECK_IS_FLOATING(dy_dx);
|
| 553 |
+
|
| 554 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 555 |
+
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
| 556 |
+
grid_encode_forward_cuda<scalar_t>(
|
| 557 |
+
inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(),
|
| 558 |
+
offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L,
|
| 559 |
+
S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype,
|
| 560 |
+
align_corners);
|
| 561 |
+
}));
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
|
| 565 |
+
const at::Tensor embeddings, const at::Tensor offsets,
|
| 566 |
+
at::Tensor grad_embeddings, const uint32_t B,
|
| 567 |
+
const uint32_t D, const uint32_t C, const uint32_t L,
|
| 568 |
+
const float S, const uint32_t H,
|
| 569 |
+
const bool calc_grad_inputs, const at::Tensor dy_dx,
|
| 570 |
+
at::Tensor grad_inputs, const uint32_t gridtype,
|
| 571 |
+
const bool align_corners) {
|
| 572 |
+
CHECK_CUDA(grad);
|
| 573 |
+
CHECK_CUDA(inputs);
|
| 574 |
+
CHECK_CUDA(embeddings);
|
| 575 |
+
CHECK_CUDA(offsets);
|
| 576 |
+
CHECK_CUDA(grad_embeddings);
|
| 577 |
+
CHECK_CUDA(dy_dx);
|
| 578 |
+
CHECK_CUDA(grad_inputs);
|
| 579 |
+
|
| 580 |
+
CHECK_CONTIGUOUS(grad);
|
| 581 |
+
CHECK_CONTIGUOUS(inputs);
|
| 582 |
+
CHECK_CONTIGUOUS(embeddings);
|
| 583 |
+
CHECK_CONTIGUOUS(offsets);
|
| 584 |
+
CHECK_CONTIGUOUS(grad_embeddings);
|
| 585 |
+
CHECK_CONTIGUOUS(dy_dx);
|
| 586 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
| 587 |
+
|
| 588 |
+
CHECK_IS_FLOATING(grad);
|
| 589 |
+
CHECK_IS_FLOATING(inputs);
|
| 590 |
+
CHECK_IS_FLOATING(embeddings);
|
| 591 |
+
CHECK_IS_INT(offsets);
|
| 592 |
+
CHECK_IS_FLOATING(grad_embeddings);
|
| 593 |
+
CHECK_IS_FLOATING(dy_dx);
|
| 594 |
+
CHECK_IS_FLOATING(grad_inputs);
|
| 595 |
+
|
| 596 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 597 |
+
grad.scalar_type(), "grid_encode_backward", ([&] {
|
| 598 |
+
grid_encode_backward_cuda<scalar_t>(
|
| 599 |
+
grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(),
|
| 600 |
+
embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(),
|
| 601 |
+
grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H,
|
| 602 |
+
calc_grad_inputs, dy_dx.data_ptr<scalar_t>(),
|
| 603 |
+
grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
|
| 604 |
+
}));
|
| 605 |
+
}
|
citydreamer/extensions/grid_encoder/setup.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: setup.py
|
| 4 |
+
# @Author: Jiaxiang Tang (@ashawkey)
|
| 5 |
+
# @Date: 2023-04-15 10:33:32
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-04-29 10:47:10
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
# @Ref: https://github.com/ashawkey/torch-ngp
|
| 10 |
+
|
| 11 |
+
from setuptools import setup
|
| 12 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 13 |
+
|
| 14 |
+
setup(
|
| 15 |
+
name="grid_encoder",
|
| 16 |
+
version="1.0.0",
|
| 17 |
+
ext_modules=[
|
| 18 |
+
CUDAExtension(
|
| 19 |
+
name="grid_encoder_ext",
|
| 20 |
+
sources=[
|
| 21 |
+
"grid_encoder_ext.cu",
|
| 22 |
+
"bindings.cpp",
|
| 23 |
+
],
|
| 24 |
+
extra_compile_args={
|
| 25 |
+
"cxx": ["-O3", "-std=c++14"],
|
| 26 |
+
"nvcc": [
|
| 27 |
+
"-O3",
|
| 28 |
+
"-std=c++14",
|
| 29 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
| 30 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
| 31 |
+
"-U__CUDA_NO_HALF2_OPERATORS__",
|
| 32 |
+
],
|
| 33 |
+
},
|
| 34 |
+
),
|
| 35 |
+
],
|
| 36 |
+
cmdclass={
|
| 37 |
+
"build_ext": BuildExtension,
|
| 38 |
+
},
|
| 39 |
+
)
|
citydreamer/extensions/voxlib/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
# To view a copy of this license, check out LICENSE.md
|
| 5 |
+
from voxlib import ray_voxel_intersection_perspective
|
citydreamer/extensions/voxlib/ray_voxel_intersection.cu
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, check out LICENSE.md
|
| 5 |
+
//
|
| 6 |
+
// The ray marching algorithm used in this file is a variety of modified
|
| 7 |
+
// Bresenham method:
|
| 8 |
+
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf
|
| 9 |
+
// Search for "voxel traversal algorithm" for related information
|
| 10 |
+
|
| 11 |
+
#include <torch/types.h>
|
| 12 |
+
|
| 13 |
+
#include <ATen/ATen.h>
|
| 14 |
+
#include <ATen/AccumulateType.h>
|
| 15 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 16 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 17 |
+
|
| 18 |
+
#include <cuda.h>
|
| 19 |
+
#include <cuda_runtime.h>
|
| 20 |
+
#include <curand.h>
|
| 21 |
+
#include <curand_kernel.h>
|
| 22 |
+
#include <time.h>
|
| 23 |
+
|
| 24 |
+
//#include <pybind11/numpy.h>
|
| 25 |
+
#include <pybind11/pybind11.h>
|
| 26 |
+
#include <pybind11/stl.h>
|
| 27 |
+
#include <vector>
|
| 28 |
+
|
| 29 |
+
#include "voxlib_common.h"
|
| 30 |
+
|
| 31 |
+
struct RVIP_Params {
|
| 32 |
+
int voxel_dims[3];
|
| 33 |
+
int voxel_strides[3];
|
| 34 |
+
int max_samples;
|
| 35 |
+
int img_dims[2];
|
| 36 |
+
// Camera parameters
|
| 37 |
+
float cam_ori[3];
|
| 38 |
+
float cam_fwd[3];
|
| 39 |
+
float cam_side[3];
|
| 40 |
+
float cam_up[3];
|
| 41 |
+
float cam_c[2];
|
| 42 |
+
float cam_f;
|
| 43 |
+
// unsigned long seed;
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
/*
|
| 47 |
+
out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples,
|
| 48 |
+
1] out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples,
|
| 49 |
+
1] out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1,
|
| 50 |
+
3] Image coordinates refer to the center of the pixel [0, 0, 0] at voxel
|
| 51 |
+
coordinate is at the corner of the corner block (instead of at the center)
|
| 52 |
+
*/
|
| 53 |
+
template <int TILE_DIM>
|
| 54 |
+
static __global__ void ray_voxel_intersection_perspective_kernel(
|
| 55 |
+
int32_t *__restrict__ out_voxel_id, float *__restrict__ out_depth,
|
| 56 |
+
float *__restrict__ out_raydirs, const int32_t *__restrict__ in_voxel,
|
| 57 |
+
const RVIP_Params p) {
|
| 58 |
+
|
| 59 |
+
int img_coords[2];
|
| 60 |
+
img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x;
|
| 61 |
+
img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y;
|
| 62 |
+
if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
|
| 63 |
+
return;
|
| 64 |
+
}
|
| 65 |
+
int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];
|
| 66 |
+
|
| 67 |
+
// Calculate ray origin and direction
|
| 68 |
+
float rayori[3], raydir[3];
|
| 69 |
+
rayori[0] = p.cam_ori[0];
|
| 70 |
+
rayori[1] = p.cam_ori[1];
|
| 71 |
+
rayori[2] = p.cam_ori[2];
|
| 72 |
+
|
| 73 |
+
// Camera intrinsics
|
| 74 |
+
float ndc_imcoords[2];
|
| 75 |
+
ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
|
| 76 |
+
ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];
|
| 77 |
+
|
| 78 |
+
raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] +
|
| 79 |
+
p.cam_fwd[0] * p.cam_f;
|
| 80 |
+
raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] +
|
| 81 |
+
p.cam_fwd[1] * p.cam_f;
|
| 82 |
+
raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] +
|
| 83 |
+
p.cam_fwd[2] * p.cam_f;
|
| 84 |
+
normalize<float, 3>(raydir);
|
| 85 |
+
|
| 86 |
+
// Save out_raydirs
|
| 87 |
+
out_raydirs[pix_index * 3] = raydir[0];
|
| 88 |
+
out_raydirs[pix_index * 3 + 1] = raydir[1];
|
| 89 |
+
out_raydirs[pix_index * 3 + 2] = raydir[2];
|
| 90 |
+
|
| 91 |
+
float axis_t[3];
|
| 92 |
+
int axis_int[3];
|
| 93 |
+
// int axis_intbound[3];
|
| 94 |
+
|
| 95 |
+
// Current voxel
|
| 96 |
+
axis_int[0] = floorf(rayori[0]);
|
| 97 |
+
axis_int[1] = floorf(rayori[1]);
|
| 98 |
+
axis_int[2] = floorf(rayori[2]);
|
| 99 |
+
|
| 100 |
+
#pragma unroll
|
| 101 |
+
for (int i = 0; i < 3; i++) {
|
| 102 |
+
if (raydir[i] > 0) {
|
| 103 |
+
// Initial t value
|
| 104 |
+
// Handle boundary case where rayori[i] is a whole number. Always round Up
|
| 105 |
+
// for the next block
|
| 106 |
+
// axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) /
|
| 107 |
+
// raydir[i];
|
| 108 |
+
axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i];
|
| 109 |
+
} else if (raydir[i] < 0) {
|
| 110 |
+
axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
|
| 111 |
+
} else {
|
| 112 |
+
axis_t[i] = HUGE_VALF;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Fused raymarching and sampling
|
| 117 |
+
bool quit = false;
|
| 118 |
+
for (int cur_plane = 0; cur_plane < p.max_samples;
|
| 119 |
+
cur_plane++) { // Last cycle is for calculating p2
|
| 120 |
+
float t = nanf("0");
|
| 121 |
+
float t2 = nanf("0");
|
| 122 |
+
int32_t blk_id = 0;
|
| 123 |
+
// Find the next intersection
|
| 124 |
+
while (!quit) {
|
| 125 |
+
// Find the next smallest t
|
| 126 |
+
float tnow;
|
| 127 |
+
/*
|
| 128 |
+
#pragma unroll
|
| 129 |
+
for (int i=0; i<3; i++) {
|
| 130 |
+
if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
|
| 131 |
+
// Update current t
|
| 132 |
+
tnow = axis_t[i];
|
| 133 |
+
// Update t candidates
|
| 134 |
+
if (raydir[i] > 0) {
|
| 135 |
+
axis_int[i] += 1;
|
| 136 |
+
if (axis_int[i] >= p.voxel_dims[i]) {
|
| 137 |
+
quit = true;
|
| 138 |
+
}
|
| 139 |
+
axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
|
| 140 |
+
} else {
|
| 141 |
+
axis_int[i] -= 1;
|
| 142 |
+
if (axis_int[i] < 0) {
|
| 143 |
+
quit = true;
|
| 144 |
+
}
|
| 145 |
+
axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
|
| 146 |
+
}
|
| 147 |
+
break; // Avoid advancing multiple steps as axis_t is updated
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
*/
|
| 151 |
+
// Hand unroll
|
| 152 |
+
if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
|
| 153 |
+
// Update current t
|
| 154 |
+
tnow = axis_t[0];
|
| 155 |
+
// Update t candidates
|
| 156 |
+
if (raydir[0] > 0) {
|
| 157 |
+
axis_int[0] += 1;
|
| 158 |
+
if (axis_int[0] >= p.voxel_dims[0]) {
|
| 159 |
+
quit = true;
|
| 160 |
+
}
|
| 161 |
+
axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0];
|
| 162 |
+
} else {
|
| 163 |
+
axis_int[0] -= 1;
|
| 164 |
+
if (axis_int[0] < 0) {
|
| 165 |
+
quit = true;
|
| 166 |
+
}
|
| 167 |
+
axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
|
| 168 |
+
}
|
| 169 |
+
} else if (axis_t[1] <= axis_t[2]) {
|
| 170 |
+
tnow = axis_t[1];
|
| 171 |
+
if (raydir[1] > 0) {
|
| 172 |
+
axis_int[1] += 1;
|
| 173 |
+
if (axis_int[1] >= p.voxel_dims[1]) {
|
| 174 |
+
quit = true;
|
| 175 |
+
}
|
| 176 |
+
axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1];
|
| 177 |
+
} else {
|
| 178 |
+
axis_int[1] -= 1;
|
| 179 |
+
if (axis_int[1] < 0) {
|
| 180 |
+
quit = true;
|
| 181 |
+
}
|
| 182 |
+
axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
|
| 183 |
+
}
|
| 184 |
+
} else {
|
| 185 |
+
tnow = axis_t[2];
|
| 186 |
+
if (raydir[2] > 0) {
|
| 187 |
+
axis_int[2] += 1;
|
| 188 |
+
if (axis_int[2] >= p.voxel_dims[2]) {
|
| 189 |
+
quit = true;
|
| 190 |
+
}
|
| 191 |
+
axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2];
|
| 192 |
+
} else {
|
| 193 |
+
axis_int[2] -= 1;
|
| 194 |
+
if (axis_int[2] < 0) {
|
| 195 |
+
quit = true;
|
| 196 |
+
}
|
| 197 |
+
axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
if (quit) {
|
| 202 |
+
break;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Skip empty space
|
| 206 |
+
// Could there be deadlock if the ray direction is away from the world?
|
| 207 |
+
if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] ||
|
| 208 |
+
axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] ||
|
| 209 |
+
axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
|
| 210 |
+
continue;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
// Test intersection using voxel grid
|
| 214 |
+
blk_id = in_voxel[axis_int[0] * p.voxel_strides[0] +
|
| 215 |
+
axis_int[1] * p.voxel_strides[1] +
|
| 216 |
+
axis_int[2] * p.voxel_strides[2]];
|
| 217 |
+
if (blk_id == 0) {
|
| 218 |
+
continue;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// Now that there is an intersection
|
| 222 |
+
t = tnow;
|
| 223 |
+
// Calculate t2
|
| 224 |
+
/*
|
| 225 |
+
#pragma unroll
|
| 226 |
+
for (int i=0; i<3; i++) {
|
| 227 |
+
if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
|
| 228 |
+
t2 = axis_t[i];
|
| 229 |
+
break;
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
*/
|
| 233 |
+
// Hand unroll
|
| 234 |
+
if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
|
| 235 |
+
t2 = axis_t[0];
|
| 236 |
+
} else if (axis_t[1] <= axis_t[2]) {
|
| 237 |
+
t2 = axis_t[1];
|
| 238 |
+
} else {
|
| 239 |
+
t2 = axis_t[2];
|
| 240 |
+
}
|
| 241 |
+
break;
|
| 242 |
+
} // while !quit (ray marching loop)
|
| 243 |
+
|
| 244 |
+
out_depth[pix_index * p.max_samples + cur_plane] = t;
|
| 245 |
+
out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples +
|
| 246 |
+
pix_index * p.max_samples + cur_plane] = t2;
|
| 247 |
+
out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id;
|
| 248 |
+
} // cur_plane
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
/*
|
| 252 |
+
out:
|
| 253 |
+
out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1],
|
| 254 |
+
max_samples, 1] out_depth: torch CUDA float [2, img_dims[0], img_dims[1],
|
| 255 |
+
max_samples, 1] out_raydirs: torch CUDA float [ img_dims[0], img_dims[1],
|
| 256 |
+
1, 3] in: in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512] cam_ori:
|
| 257 |
+
torch float [3] cam_dir: torch float [3] cam_up: torch
|
| 258 |
+
float [3] cam_f: float cam_c: int [2]
|
| 259 |
+
img_dims: int [2]
|
| 260 |
+
max_samples: int
|
| 261 |
+
*/
|
| 262 |
+
std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
|
| 263 |
+
const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
|
| 264 |
+
const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
|
| 265 |
+
const std::vector<float> &cam_c, const std::vector<int> &img_dims,
|
| 266 |
+
int max_samples) {
|
| 267 |
+
CHECK_CUDA(in_voxel);
|
| 268 |
+
|
| 269 |
+
int curDevice = -1;
|
| 270 |
+
cudaGetDevice(&curDevice);
|
| 271 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 272 |
+
torch::Device device = in_voxel.device();
|
| 273 |
+
|
| 274 |
+
// assert(in_voxel.dtype() == torch::kU8);
|
| 275 |
+
assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
|
| 276 |
+
assert(in_voxel.dim() == 3);
|
| 277 |
+
assert(cam_ori.dtype() == torch::kFloat32);
|
| 278 |
+
assert(cam_ori.numel() == 3);
|
| 279 |
+
assert(cam_dir.dtype() == torch::kFloat32);
|
| 280 |
+
assert(cam_dir.numel() == 3);
|
| 281 |
+
assert(cam_up.dtype() == torch::kFloat32);
|
| 282 |
+
assert(cam_up.numel() == 3);
|
| 283 |
+
assert(img_dims.size() == 2);
|
| 284 |
+
|
| 285 |
+
RVIP_Params p;
|
| 286 |
+
|
| 287 |
+
// Calculate camera rays
|
| 288 |
+
const torch::Tensor cam_ori_c = cam_ori.cpu();
|
| 289 |
+
const torch::Tensor cam_dir_c = cam_dir.cpu();
|
| 290 |
+
const torch::Tensor cam_up_c = cam_up.cpu();
|
| 291 |
+
|
| 292 |
+
// Get the coordinate frame of camera space in world space
|
| 293 |
+
normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
|
| 294 |
+
cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
|
| 295 |
+
normalize<float, 3>(p.cam_side);
|
| 296 |
+
cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
|
| 297 |
+
normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are
|
| 298 |
+
// normalized. But just in case...
|
| 299 |
+
|
| 300 |
+
copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());
|
| 301 |
+
|
| 302 |
+
p.cam_f = cam_f;
|
| 303 |
+
p.cam_c[0] = cam_c[0];
|
| 304 |
+
p.cam_c[1] = cam_c[1];
|
| 305 |
+
p.max_samples = max_samples;
|
| 306 |
+
// printf("[Renderer] max_dist: %ld\n", max_dist);
|
| 307 |
+
|
| 308 |
+
p.voxel_dims[0] = in_voxel.size(0);
|
| 309 |
+
p.voxel_dims[1] = in_voxel.size(1);
|
| 310 |
+
p.voxel_dims[2] = in_voxel.size(2);
|
| 311 |
+
p.voxel_strides[0] = in_voxel.stride(0);
|
| 312 |
+
p.voxel_strides[1] = in_voxel.stride(1);
|
| 313 |
+
p.voxel_strides[2] = in_voxel.stride(2);
|
| 314 |
+
|
| 315 |
+
// printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0],
|
| 316 |
+
// p.voxel_dims[1], p.voxel_dims[2]);
|
| 317 |
+
|
| 318 |
+
p.img_dims[0] = img_dims[0];
|
| 319 |
+
p.img_dims[1] = img_dims[1];
|
| 320 |
+
|
| 321 |
+
// Create output tensors
|
| 322 |
+
// For Minecraft Seg Mask
|
| 323 |
+
torch::Tensor out_voxel_id =
|
| 324 |
+
torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1},
|
| 325 |
+
torch::TensorOptions().dtype(torch::kInt32).device(device));
|
| 326 |
+
|
| 327 |
+
torch::Tensor out_depth;
|
| 328 |
+
// Produce two sets of localcoords, one for entry point, the other one for
|
| 329 |
+
// exit point. They share the same corner_ids.
|
| 330 |
+
out_depth = torch::empty(
|
| 331 |
+
{2, p.img_dims[0], p.img_dims[1], p.max_samples, 1},
|
| 332 |
+
torch::TensorOptions().dtype(torch::kFloat32).device(device));
|
| 333 |
+
|
| 334 |
+
torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3},
|
| 335 |
+
torch::TensorOptions()
|
| 336 |
+
.dtype(torch::kFloat32)
|
| 337 |
+
.device(device)
|
| 338 |
+
.requires_grad(false));
|
| 339 |
+
|
| 340 |
+
const int TILE_DIM = 8;
|
| 341 |
+
dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM,
|
| 342 |
+
(p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1);
|
| 343 |
+
dim3 dimBlock(TILE_DIM, TILE_DIM, 1);
|
| 344 |
+
|
| 345 |
+
ray_voxel_intersection_perspective_kernel<TILE_DIM>
|
| 346 |
+
<<<dimGrid, dimBlock, 0, stream>>>(
|
| 347 |
+
out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(),
|
| 348 |
+
out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p);
|
| 349 |
+
|
| 350 |
+
return {out_voxel_id, out_depth, out_raydirs};
|
| 351 |
+
}
|
citydreamer/extensions/voxlib/setup.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
# To view a copy of this license, check out LICENSE.md
|
| 5 |
+
from setuptools import setup
|
| 6 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 7 |
+
|
| 8 |
+
cxx_args = ["-fopenmp"]
|
| 9 |
+
nvcc_args = []
|
| 10 |
+
|
| 11 |
+
setup(
|
| 12 |
+
name="voxrender",
|
| 13 |
+
version="1.0.0",
|
| 14 |
+
ext_modules=[
|
| 15 |
+
CUDAExtension(
|
| 16 |
+
"voxlib",
|
| 17 |
+
[
|
| 18 |
+
"voxlib.cpp",
|
| 19 |
+
"ray_voxel_intersection.cu",
|
| 20 |
+
],
|
| 21 |
+
extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args},
|
| 22 |
+
)
|
| 23 |
+
],
|
| 24 |
+
cmdclass={"build_ext": BuildExtension},
|
| 25 |
+
)
|
citydreamer/extensions/voxlib/voxlib.cpp
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, check out LICENSE.md
|
| 5 |
+
#include <pybind11/pybind11.h>
|
| 6 |
+
#include <pybind11/stl.h>
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
// Fast voxel traversal along rays
|
| 11 |
+
std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
|
| 12 |
+
const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
|
| 13 |
+
const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
|
| 14 |
+
const std::vector<float> &cam_c, const std::vector<int> &img_dims,
|
| 15 |
+
int max_samples);
|
| 16 |
+
|
| 17 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 18 |
+
m.def("ray_voxel_intersection_perspective",
|
| 19 |
+
&ray_voxel_intersection_perspective_cuda,
|
| 20 |
+
"Ray-voxel intersections given perspective camera parameters (CUDA)");
|
| 21 |
+
}
|
citydreamer/extensions/voxlib/voxlib_common.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, check out LICENSE.md
|
| 5 |
+
#ifndef VOXLIB_COMMON_H
|
| 6 |
+
#define VOXLIB_COMMON_H
|
| 7 |
+
|
| 8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
| 9 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 10 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 11 |
+
#define CHECK_INPUT(x) \
|
| 12 |
+
CHECK_CUDA(x); \
|
| 13 |
+
CHECK_CONTIGUOUS(x)
|
| 14 |
+
#define CHECK_CPU(x) \
|
| 15 |
+
TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
|
| 16 |
+
|
| 17 |
+
#include <cuda.h>
|
| 18 |
+
#include <cuda_runtime.h>
|
| 19 |
+
// CUDA vector math functions
|
| 20 |
+
__host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
| 21 |
+
int c = a / b;
|
| 22 |
+
|
| 23 |
+
if (c * b > a) {
|
| 24 |
+
c--;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
return c;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <typename scalar_t>
|
| 31 |
+
__host__ __forceinline__ void cross(scalar_t *r, const scalar_t *a,
|
| 32 |
+
const scalar_t *b) {
|
| 33 |
+
r[0] = a[1] * b[2] - a[2] * b[1];
|
| 34 |
+
r[1] = a[2] * b[0] - a[0] * b[2];
|
| 35 |
+
r[2] = a[0] * b[1] - a[1] * b[0];
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
__device__ __host__ __forceinline__ float dot(const float *a, const float *b) {
|
| 39 |
+
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
template <typename scalar_t, int ndim>
|
| 43 |
+
__device__ __host__ __forceinline__ void copyarr(scalar_t *r,
|
| 44 |
+
const scalar_t *a) {
|
| 45 |
+
#pragma unroll
|
| 46 |
+
for (int i = 0; i < ndim; i++) {
|
| 47 |
+
r[i] = a[i];
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// TODO: use rsqrt to speed up
|
| 52 |
+
// inplace version
|
| 53 |
+
template <typename scalar_t, int ndim>
|
| 54 |
+
__device__ __host__ __forceinline__ void normalize(scalar_t *a) {
|
| 55 |
+
scalar_t vec_len = 0.0f;
|
| 56 |
+
#pragma unroll
|
| 57 |
+
for (int i = 0; i < ndim; i++) {
|
| 58 |
+
vec_len += a[i] * a[i];
|
| 59 |
+
}
|
| 60 |
+
vec_len = sqrtf(vec_len);
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (int i = 0; i < ndim; i++) {
|
| 63 |
+
a[i] /= vec_len;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// normalize + copy
|
| 68 |
+
template <typename scalar_t, int ndim>
|
| 69 |
+
__device__ __host__ __forceinline__ void normalize(scalar_t *r,
|
| 70 |
+
const scalar_t *a) {
|
| 71 |
+
scalar_t vec_len = 0.0f;
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (int i = 0; i < ndim; i++) {
|
| 74 |
+
vec_len += a[i] * a[i];
|
| 75 |
+
}
|
| 76 |
+
vec_len = sqrtf(vec_len);
|
| 77 |
+
#pragma unroll
|
| 78 |
+
for (int i = 0; i < ndim; i++) {
|
| 79 |
+
r[i] = a[i] / vec_len;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#endif // VOXLIB_COMMON_H
|
citydreamer/inference.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: inference.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2024-03-02 16:30:00
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-03-03 12:10:18
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import cv2
|
| 12 |
+
import logging
|
| 13 |
+
import math
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision
|
| 17 |
+
|
| 18 |
+
import citydreamer.extensions.extrude_tensor
|
| 19 |
+
import citydreamer.extensions.voxlib
|
| 20 |
+
|
| 21 |
+
# Global constants
|
| 22 |
+
HEIGHTS = {
|
| 23 |
+
"ROAD": 4,
|
| 24 |
+
"GREEN_LANDS": 8,
|
| 25 |
+
"CONSTRUCTION": 10,
|
| 26 |
+
"COAST_ZONES": 0,
|
| 27 |
+
"ROOF": 1,
|
| 28 |
+
}
|
| 29 |
+
CLASSES = {
|
| 30 |
+
"NULL": 0,
|
| 31 |
+
"ROAD": 1,
|
| 32 |
+
"BLD_FACADE": 2,
|
| 33 |
+
"GREEN_LANDS": 3,
|
| 34 |
+
"CONSTRUCTION": 4,
|
| 35 |
+
"COAST_ZONES": 5,
|
| 36 |
+
"OTHERS": 6,
|
| 37 |
+
"BLD_ROOF": 7,
|
| 38 |
+
}
|
| 39 |
+
# NOTE: ID > 10 are reserved for building instances.
|
| 40 |
+
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
|
| 41 |
+
CONSTANTS = {
|
| 42 |
+
"BLD_INS_LABEL_MIN": 10,
|
| 43 |
+
"LAYOUT_N_CLASSES": 7,
|
| 44 |
+
"LAYOUT_VOL_SIZE": 1536,
|
| 45 |
+
"BUILDING_VOL_SIZE": 672,
|
| 46 |
+
"EXTENDED_VOL_SIZE": 2880,
|
| 47 |
+
"LAYOUT_MAX_HEIGHT": 640,
|
| 48 |
+
"GES_VFOV": 20,
|
| 49 |
+
"GES_IMAGE_HEIGHT": 540,
|
| 50 |
+
"GES_IMAGE_WIDTH": 960,
|
| 51 |
+
"IMAGE_PADDING": 8,
|
| 52 |
+
"N_VOXEL_INTERSECT_SAMPLES": 6,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
|
| 57 |
+
cam_pos = get_orbit_camera_position(radius, altitude, azimuth)
|
| 58 |
+
seg, building_stats = get_instance_seg_map(seg)
|
| 59 |
+
# Generate latent codes
|
| 60 |
+
logging.info("Generating latent codes ...")
|
| 61 |
+
bg_z, building_zs = get_latent_codes(
|
| 62 |
+
building_stats,
|
| 63 |
+
bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 64 |
+
bgm.output_device,
|
| 65 |
+
)
|
| 66 |
+
# Random choose the center of the patch
|
| 67 |
+
cy = (
|
| 68 |
+
np.random.randint(seg.shape[0] - CONSTANTS["EXTENDED_VOL_SIZE"])
|
| 69 |
+
+ CONSTANTS["EXTENDED_VOL_SIZE"] // 2
|
| 70 |
+
)
|
| 71 |
+
cx = (
|
| 72 |
+
np.random.randint(seg.shape[1] - CONSTANTS["EXTENDED_VOL_SIZE"])
|
| 73 |
+
+ CONSTANTS["EXTENDED_VOL_SIZE"] // 2
|
| 74 |
+
)
|
| 75 |
+
# Generate local image patch of the height field and seg map
|
| 76 |
+
part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
|
| 77 |
+
# Generate local image patch of the height field and seg map
|
| 78 |
+
part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
|
| 79 |
+
# print(part_hf.shape) # (2880, 2880)
|
| 80 |
+
# print(part_seg.shape) # (2880, 2880)
|
| 81 |
+
# Recalculate the building positions based on the current patch
|
| 82 |
+
_building_stats = get_part_building_stats(part_seg, building_stats, cx, cy)
|
| 83 |
+
# Generate the concatenated height field and seg. map tensor
|
| 84 |
+
hf_seg = get_hf_seg_tensor(part_hf, part_seg, bgm.output_device)
|
| 85 |
+
# print(hf_seg.size()) # torch.Size([1, 8, 2880, 2880])
|
| 86 |
+
# Build seg_volume
|
| 87 |
+
logging.info("Generating seg volume ...")
|
| 88 |
+
seg_volume = get_seg_volume(part_hf, part_seg)
|
| 89 |
+
logging.info("Rendering City Image ...")
|
| 90 |
+
img = render(
|
| 91 |
+
(CONSTANTS["GES_IMAGE_HEIGHT"] // 5, CONSTANTS["GES_IMAGE_WIDTH"] // 5),
|
| 92 |
+
seg_volume,
|
| 93 |
+
hf_seg,
|
| 94 |
+
cam_pos,
|
| 95 |
+
bgm,
|
| 96 |
+
fgm,
|
| 97 |
+
_building_stats,
|
| 98 |
+
bg_z,
|
| 99 |
+
building_zs,
|
| 100 |
+
)
|
| 101 |
+
return ((img.cpu().numpy().squeeze().transpose((1, 2, 0)) / 2 + 0.5) * 255).astype(
|
| 102 |
+
np.uint8
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_orbit_camera_position(radius, altitude, azimuth):
|
| 107 |
+
cx = CONSTANTS["LAYOUT_VOL_SIZE"] // 2
|
| 108 |
+
cy = cx
|
| 109 |
+
theta = np.deg2rad(azimuth)
|
| 110 |
+
cam_x = cx + radius * math.cos(theta)
|
| 111 |
+
cam_y = cy + radius * math.sin(theta)
|
| 112 |
+
return {"x": cam_x, "y": cam_y, "z": altitude}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_instance_seg_map(seg_map):
|
| 116 |
+
# Mapping constructions to buildings
|
| 117 |
+
seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLD_FACADE"]
|
| 118 |
+
# Use connected components to get building instances
|
| 119 |
+
_, labels, stats, _ = cv2.connectedComponentsWithStats(
|
| 120 |
+
(seg_map == CLASSES["BLD_FACADE"]).astype(np.uint8), connectivity=4
|
| 121 |
+
)
|
| 122 |
+
# Remove non-building instance masks
|
| 123 |
+
labels[seg_map != CLASSES["BLD_FACADE"]] = 0
|
| 124 |
+
# Building instance mask
|
| 125 |
+
building_mask = labels != 0
|
| 126 |
+
# Make building instance IDs are even numbers and start from 10
|
| 127 |
+
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
|
| 128 |
+
labels = (labels + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2
|
| 129 |
+
|
| 130 |
+
seg_map[seg_map == CLASSES["BLD_FACADE"]] = 0
|
| 131 |
+
seg_map = seg_map * (1 - building_mask) + labels * building_mask
|
| 132 |
+
assert np.max(labels) < 2147483648
|
| 133 |
+
return seg_map.astype(np.int32), stats[:, :4]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_latent_codes(building_stats, bg_style_dim, output_device):
|
| 137 |
+
bg_z = _get_z(output_device, bg_style_dim)
|
| 138 |
+
building_zs = {
|
| 139 |
+
(i + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2: _get_z(output_device)
|
| 140 |
+
for i in range(len(building_stats))
|
| 141 |
+
}
|
| 142 |
+
return bg_z, building_zs
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _get_z(device, z_dim=256):
|
| 146 |
+
if z_dim is None:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
return torch.randn(1, z_dim, dtype=torch.float32, device=device)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_part_hf_seg(hf, seg, cx, cy, patch_size):
|
| 153 |
+
part_hf = _get_image_patch(hf, cx, cy, patch_size)
|
| 154 |
+
part_seg = _get_image_patch(seg, cx, cy, patch_size)
|
| 155 |
+
assert part_hf.shape == (
|
| 156 |
+
patch_size,
|
| 157 |
+
patch_size,
|
| 158 |
+
), part_hf.shape
|
| 159 |
+
assert part_hf.shape == part_seg.shape, part_seg.shape
|
| 160 |
+
return part_hf, part_seg
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _get_image_patch(image, cx, cy, patch_size):
|
| 164 |
+
sx = cx - patch_size // 2
|
| 165 |
+
sy = cy - patch_size // 2
|
| 166 |
+
ex = sx + patch_size
|
| 167 |
+
ey = sy + patch_size
|
| 168 |
+
return image[sy:ey, sx:ex]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_part_building_stats(part_seg, building_stats, cx, cy):
|
| 172 |
+
_buildings = np.unique(part_seg[part_seg > CONSTANTS["BLD_INS_LABEL_MIN"]])
|
| 173 |
+
_building_stats = {}
|
| 174 |
+
for b in _buildings:
|
| 175 |
+
_b = b // 2 - CONSTANTS["BLD_INS_LABEL_MIN"]
|
| 176 |
+
_building_stats[b] = [
|
| 177 |
+
building_stats[_b, 1] - cy + building_stats[_b, 3] / 2,
|
| 178 |
+
building_stats[_b, 0] - cx + building_stats[_b, 2] / 2,
|
| 179 |
+
]
|
| 180 |
+
return _building_stats
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_hf_seg_tensor(part_hf, part_seg, output_device):
|
| 184 |
+
part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device)
|
| 185 |
+
part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device)
|
| 186 |
+
part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"]
|
| 187 |
+
part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"])
|
| 188 |
+
return torch.cat([part_hf, part_seg], dim=1)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _masks_to_onehots(masks, n_class, ignored_classes=[]):
|
| 192 |
+
b, h, w = masks.shape
|
| 193 |
+
n_class_actual = n_class - len(ignored_classes)
|
| 194 |
+
one_hot_masks = torch.zeros(
|
| 195 |
+
(b, n_class_actual, h, w), dtype=torch.float32, device=masks.device
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
n_class_cnt = 0
|
| 199 |
+
for i in range(n_class):
|
| 200 |
+
if i not in ignored_classes:
|
| 201 |
+
one_hot_masks[:, n_class_cnt] = masks == i
|
| 202 |
+
n_class_cnt += 1
|
| 203 |
+
return one_hot_masks
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_seg_volume(part_hf, part_seg):
|
| 207 |
+
tensor_extruder = citydreamer.extensions.extrude_tensor.TensorExtruder(
|
| 208 |
+
CONSTANTS["LAYOUT_MAX_HEIGHT"]
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if part_hf.shape == (
|
| 212 |
+
CONSTANTS["EXTENDED_VOL_SIZE"],
|
| 213 |
+
CONSTANTS["EXTENDED_VOL_SIZE"],
|
| 214 |
+
):
|
| 215 |
+
part_hf = part_hf[
|
| 216 |
+
CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
|
| 217 |
+
CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
|
| 218 |
+
]
|
| 219 |
+
# print(part_hf.shape) # torch.Size([1, 8, 1536, 1536])
|
| 220 |
+
part_seg = part_seg[
|
| 221 |
+
CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
|
| 222 |
+
CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
|
| 223 |
+
]
|
| 224 |
+
# print(part_seg.shape) # torch.Size([1, 8, 1536, 1536])
|
| 225 |
+
|
| 226 |
+
assert part_hf.shape == (
|
| 227 |
+
CONSTANTS["LAYOUT_VOL_SIZE"],
|
| 228 |
+
CONSTANTS["LAYOUT_VOL_SIZE"],
|
| 229 |
+
)
|
| 230 |
+
assert part_hf.shape == part_seg.shape, part_seg.shape
|
| 231 |
+
|
| 232 |
+
seg_volume = tensor_extruder(
|
| 233 |
+
torch.from_numpy(part_seg[None, None, ...]).cuda(),
|
| 234 |
+
torch.from_numpy(part_hf[None, None, ...]).cuda(),
|
| 235 |
+
).squeeze()
|
| 236 |
+
logging.debug("The shape of SegVolume: %s" % (seg_volume.size(),))
|
| 237 |
+
# Change the top-level voxel of the "Building Facade" to "Building Roof"
|
| 238 |
+
roof_seg_map = part_seg.copy()
|
| 239 |
+
non_roof_msk = part_seg <= CONSTANTS["BLD_INS_LABEL_MIN"]
|
| 240 |
+
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
|
| 241 |
+
roof_seg_map = roof_seg_map - 1
|
| 242 |
+
roof_seg_map[non_roof_msk] = 0
|
| 243 |
+
for rh in range(1, HEIGHTS["ROOF"] + 1):
|
| 244 |
+
seg_volume = seg_volume.scatter_(
|
| 245 |
+
dim=2,
|
| 246 |
+
index=torch.from_numpy(part_hf[..., None] + rh).long().cuda(),
|
| 247 |
+
src=torch.from_numpy(roof_seg_map[..., None]).cuda(),
|
| 248 |
+
)
|
| 249 |
+
# print(seg_volume.size()) # torch.Size([1536, 1536, 640])
|
| 250 |
+
return seg_volume
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def get_voxel_intersection_perspective(seg_volume, camera_location):
|
| 254 |
+
CAMERA_FOCAL = (
|
| 255 |
+
CONSTANTS["GES_IMAGE_HEIGHT"] / 2 / np.tan(np.deg2rad(CONSTANTS["GES_VFOV"]))
|
| 256 |
+
)
|
| 257 |
+
# print(seg_volume.size()) # torch.Size([1536, 1536, 640])
|
| 258 |
+
camera_target = {
|
| 259 |
+
"x": seg_volume.size(1) // 2 - 1,
|
| 260 |
+
"y": seg_volume.size(0) // 2 - 1,
|
| 261 |
+
}
|
| 262 |
+
cam_origin = torch.tensor(
|
| 263 |
+
[
|
| 264 |
+
camera_location["y"],
|
| 265 |
+
camera_location["x"],
|
| 266 |
+
camera_location["z"],
|
| 267 |
+
],
|
| 268 |
+
dtype=torch.float32,
|
| 269 |
+
device=seg_volume.device,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
(
|
| 273 |
+
voxel_id,
|
| 274 |
+
depth2,
|
| 275 |
+
raydirs,
|
| 276 |
+
) = citydreamer.extensions.voxlib.ray_voxel_intersection_perspective(
|
| 277 |
+
seg_volume,
|
| 278 |
+
cam_origin,
|
| 279 |
+
torch.tensor(
|
| 280 |
+
[
|
| 281 |
+
camera_target["y"] - camera_location["y"],
|
| 282 |
+
camera_target["x"] - camera_location["x"],
|
| 283 |
+
-camera_location["z"],
|
| 284 |
+
],
|
| 285 |
+
dtype=torch.float32,
|
| 286 |
+
device=seg_volume.device,
|
| 287 |
+
),
|
| 288 |
+
torch.tensor([0, 0, 1], dtype=torch.float32),
|
| 289 |
+
CAMERA_FOCAL * 2.06,
|
| 290 |
+
[
|
| 291 |
+
(CONSTANTS["GES_IMAGE_HEIGHT"] - 1) / 2.0,
|
| 292 |
+
(CONSTANTS["GES_IMAGE_WIDTH"] - 1) / 2.0,
|
| 293 |
+
],
|
| 294 |
+
[CONSTANTS["GES_IMAGE_HEIGHT"], CONSTANTS["GES_IMAGE_WIDTH"]],
|
| 295 |
+
CONSTANTS["N_VOXEL_INTERSECT_SAMPLES"],
|
| 296 |
+
)
|
| 297 |
+
return (
|
| 298 |
+
voxel_id.unsqueeze(dim=0),
|
| 299 |
+
depth2.permute(1, 2, 0, 3, 4).unsqueeze(dim=0),
|
| 300 |
+
raydirs.unsqueeze(dim=0),
|
| 301 |
+
cam_origin.unsqueeze(dim=0),
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _get_pad_img_bbox(sx, ex, sy, ey):
|
| 306 |
+
psx = sx - CONSTANTS["IMAGE_PADDING"] if sx != 0 else 0
|
| 307 |
+
psy = sy - CONSTANTS["IMAGE_PADDING"] if sy != 0 else 0
|
| 308 |
+
pex = (
|
| 309 |
+
ex + CONSTANTS["IMAGE_PADDING"]
|
| 310 |
+
if ex != CONSTANTS["GES_IMAGE_WIDTH"]
|
| 311 |
+
else CONSTANTS["GES_IMAGE_WIDTH"]
|
| 312 |
+
)
|
| 313 |
+
pey = (
|
| 314 |
+
ey + CONSTANTS["IMAGE_PADDING"]
|
| 315 |
+
if ey != CONSTANTS["GES_IMAGE_HEIGHT"]
|
| 316 |
+
else CONSTANTS["GES_IMAGE_HEIGHT"]
|
| 317 |
+
)
|
| 318 |
+
return psx, pex, psy, pey
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _get_img_without_pad(img, sx, ex, sy, ey, psx, pex, psy, pey):
|
| 322 |
+
if CONSTANTS["IMAGE_PADDING"] == 0:
|
| 323 |
+
return img
|
| 324 |
+
|
| 325 |
+
return img[
|
| 326 |
+
:,
|
| 327 |
+
:,
|
| 328 |
+
sy - psy : ey - pey if ey != pey else ey,
|
| 329 |
+
sx - psx : ex - pex if ex != pex else ex,
|
| 330 |
+
]
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def render_bg(
|
| 334 |
+
patch_size, gancraft_bg, hf_seg, voxel_id, depth2, raydirs, cam_origin, z
|
| 335 |
+
):
|
| 336 |
+
assert hf_seg.size(2) == CONSTANTS["EXTENDED_VOL_SIZE"]
|
| 337 |
+
assert hf_seg.size(3) == CONSTANTS["EXTENDED_VOL_SIZE"]
|
| 338 |
+
hf_seg = hf_seg[
|
| 339 |
+
:,
|
| 340 |
+
:,
|
| 341 |
+
CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
|
| 342 |
+
CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
|
| 343 |
+
]
|
| 344 |
+
assert hf_seg.size(2) == CONSTANTS["LAYOUT_VOL_SIZE"]
|
| 345 |
+
assert hf_seg.size(3) == CONSTANTS["LAYOUT_VOL_SIZE"]
|
| 346 |
+
|
| 347 |
+
blurrer = torchvision.transforms.GaussianBlur(kernel_size=3, sigma=(2, 2))
|
| 348 |
+
_voxel_id = copy.deepcopy(voxel_id)
|
| 349 |
+
_voxel_id[voxel_id >= CONSTANTS["BLD_INS_LABEL_MIN"]] = CLASSES["BLD_FACADE"]
|
| 350 |
+
assert (_voxel_id < CONSTANTS["LAYOUT_N_CLASSES"]).all()
|
| 351 |
+
bg_img = torch.zeros(
|
| 352 |
+
1,
|
| 353 |
+
3,
|
| 354 |
+
CONSTANTS["GES_IMAGE_HEIGHT"],
|
| 355 |
+
CONSTANTS["GES_IMAGE_WIDTH"],
|
| 356 |
+
dtype=torch.float32,
|
| 357 |
+
device=gancraft_bg.output_device,
|
| 358 |
+
)
|
| 359 |
+
# Render background patches by patch to avoid OOM
|
| 360 |
+
for i in range(CONSTANTS["GES_IMAGE_HEIGHT"] // patch_size[0]):
|
| 361 |
+
for j in range(CONSTANTS["GES_IMAGE_WIDTH"] // patch_size[1]):
|
| 362 |
+
sy, sx = i * patch_size[0], j * patch_size[1]
|
| 363 |
+
ey, ex = sy + patch_size[0], sx + patch_size[1]
|
| 364 |
+
psx, pex, psy, pey = _get_pad_img_bbox(sx, ex, sy, ey)
|
| 365 |
+
output_bg = gancraft_bg(
|
| 366 |
+
hf_seg=hf_seg,
|
| 367 |
+
voxel_id=_voxel_id[:, psy:pey, psx:pex],
|
| 368 |
+
depth2=depth2[:, psy:pey, psx:pex],
|
| 369 |
+
raydirs=raydirs[:, psy:pey, psx:pex],
|
| 370 |
+
cam_origin=cam_origin,
|
| 371 |
+
building_stats=None,
|
| 372 |
+
z=z,
|
| 373 |
+
deterministic=True,
|
| 374 |
+
)
|
| 375 |
+
# Make road blurry
|
| 376 |
+
road_mask = (
|
| 377 |
+
(_voxel_id[:, None, psy:pey, psx:pex, 0, 0] == CLASSES["ROAD"])
|
| 378 |
+
.repeat(1, 3, 1, 1)
|
| 379 |
+
.float()
|
| 380 |
+
)
|
| 381 |
+
output_bg = blurrer(output_bg) * road_mask + output_bg * (1 - road_mask)
|
| 382 |
+
bg_img[:, :, sy:ey, sx:ex] = _get_img_without_pad(
|
| 383 |
+
output_bg, sx, ex, sy, ey, psx, pex, psy, pey
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
return bg_img
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def render_fg(
|
| 390 |
+
patch_size,
|
| 391 |
+
gancraft_fg,
|
| 392 |
+
building_id,
|
| 393 |
+
hf_seg,
|
| 394 |
+
voxel_id,
|
| 395 |
+
depth2,
|
| 396 |
+
raydirs,
|
| 397 |
+
cam_origin,
|
| 398 |
+
building_stats,
|
| 399 |
+
building_z,
|
| 400 |
+
):
|
| 401 |
+
_voxel_id = copy.deepcopy(voxel_id)
|
| 402 |
+
_curr_bld = torch.tensor([building_id, building_id - 1], device=voxel_id.device)
|
| 403 |
+
_voxel_id[~torch.isin(_voxel_id, _curr_bld)] = 0
|
| 404 |
+
_voxel_id[voxel_id == building_id] = CLASSES["BLD_FACADE"]
|
| 405 |
+
_voxel_id[voxel_id == building_id - 1] = CLASSES["BLD_ROOF"]
|
| 406 |
+
|
| 407 |
+
# assert (_voxel_id < CONSTANTS["LAYOUT_N_CLASSES"]).all()
|
| 408 |
+
_hf_seg = copy.deepcopy(hf_seg)
|
| 409 |
+
_hf_seg[hf_seg != building_id] = 0
|
| 410 |
+
_hf_seg[hf_seg == building_id] = CLASSES["BLD_FACADE"]
|
| 411 |
+
_raydirs = copy.deepcopy(raydirs)
|
| 412 |
+
_raydirs[_voxel_id[..., 0, 0] == 0] = 0
|
| 413 |
+
|
| 414 |
+
# Crop the "hf_seg" image using the center of the target building as the reference
|
| 415 |
+
cx = CONSTANTS["EXTENDED_VOL_SIZE"] // 2 - int(building_stats[1])
|
| 416 |
+
cy = CONSTANTS["EXTENDED_VOL_SIZE"] // 2 - int(building_stats[0])
|
| 417 |
+
sx = cx - CONSTANTS["BUILDING_VOL_SIZE"] // 2
|
| 418 |
+
ex = cx + CONSTANTS["BUILDING_VOL_SIZE"] // 2
|
| 419 |
+
sy = cy - CONSTANTS["BUILDING_VOL_SIZE"] // 2
|
| 420 |
+
ey = cy + CONSTANTS["BUILDING_VOL_SIZE"] // 2
|
| 421 |
+
_hf_seg = hf_seg[:, :, sy:ey, sx:ex]
|
| 422 |
+
|
| 423 |
+
fg_img = torch.zeros(
|
| 424 |
+
1,
|
| 425 |
+
3,
|
| 426 |
+
CONSTANTS["GES_IMAGE_HEIGHT"],
|
| 427 |
+
CONSTANTS["GES_IMAGE_WIDTH"],
|
| 428 |
+
dtype=torch.float32,
|
| 429 |
+
device=gancraft_fg.output_device,
|
| 430 |
+
)
|
| 431 |
+
fg_mask = torch.zeros(
|
| 432 |
+
1,
|
| 433 |
+
1,
|
| 434 |
+
CONSTANTS["GES_IMAGE_HEIGHT"],
|
| 435 |
+
CONSTANTS["GES_IMAGE_WIDTH"],
|
| 436 |
+
dtype=torch.float32,
|
| 437 |
+
device=gancraft_fg.output_device,
|
| 438 |
+
)
|
| 439 |
+
# Prevent some buildings are out of bound.
|
| 440 |
+
# THIS SHOULD NEVER HAPPEN AGAIN.
|
| 441 |
+
# if (
|
| 442 |
+
# _hf_seg.size(2) != CONSTANTS["BUILDING_VOL_SIZE"]
|
| 443 |
+
# or _hf_seg.size(3) != CONSTANTS["BUILDING_VOL_SIZE"]
|
| 444 |
+
# ):
|
| 445 |
+
# return fg_img, fg_mask
|
| 446 |
+
|
| 447 |
+
# Render foreground patches by patch to avoid OOM
|
| 448 |
+
for i in range(CONSTANTS["GES_IMAGE_HEIGHT"] // patch_size[0]):
|
| 449 |
+
for j in range(CONSTANTS["GES_IMAGE_WIDTH"] // patch_size[1]):
|
| 450 |
+
sy, sx = i * patch_size[0], j * patch_size[1]
|
| 451 |
+
ey, ex = sy + patch_size[0], sx + patch_size[1]
|
| 452 |
+
psx, pex, psy, pey = _get_pad_img_bbox(sx, ex, sy, ey)
|
| 453 |
+
|
| 454 |
+
if torch.count_nonzero(_raydirs[:, sy:ey, sx:ex]) > 0:
|
| 455 |
+
output_fg = gancraft_fg(
|
| 456 |
+
_hf_seg,
|
| 457 |
+
_voxel_id[:, psy:pey, psx:pex],
|
| 458 |
+
depth2[:, psy:pey, psx:pex],
|
| 459 |
+
_raydirs[:, psy:pey, psx:pex],
|
| 460 |
+
cam_origin,
|
| 461 |
+
building_stats=torch.from_numpy(np.array(building_stats)).unsqueeze(
|
| 462 |
+
dim=0
|
| 463 |
+
),
|
| 464 |
+
z=building_z,
|
| 465 |
+
deterministic=True,
|
| 466 |
+
)
|
| 467 |
+
facade_mask = (
|
| 468 |
+
voxel_id[:, sy:ey, sx:ex, 0, 0] == building_id
|
| 469 |
+
).unsqueeze(dim=1)
|
| 470 |
+
roof_mask = (
|
| 471 |
+
voxel_id[:, sy:ey, sx:ex, 0, 0] == building_id - 1
|
| 472 |
+
).unsqueeze(dim=1)
|
| 473 |
+
facade_img = facade_mask * _get_img_without_pad(
|
| 474 |
+
output_fg, sx, ex, sy, ey, psx, pex, psy, pey
|
| 475 |
+
)
|
| 476 |
+
# Make roof blurry
|
| 477 |
+
# output_fg = F.interpolate(
|
| 478 |
+
# F.interpolate(output_fg * 0.8, scale_factor=0.75),
|
| 479 |
+
# scale_factor=4 / 3,
|
| 480 |
+
# ),
|
| 481 |
+
roof_img = roof_mask * _get_img_without_pad(
|
| 482 |
+
output_fg,
|
| 483 |
+
sx,
|
| 484 |
+
ex,
|
| 485 |
+
sy,
|
| 486 |
+
ey,
|
| 487 |
+
psx,
|
| 488 |
+
pex,
|
| 489 |
+
psy,
|
| 490 |
+
pey,
|
| 491 |
+
)
|
| 492 |
+
fg_mask[:, :, sy:ey, sx:ex] = torch.logical_or(facade_mask, roof_mask)
|
| 493 |
+
fg_img[:, :, sy:ey, sx:ex] = (
|
| 494 |
+
facade_img * facade_mask + roof_img * roof_mask
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
return fg_img, fg_mask
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def render(
|
| 501 |
+
patch_size,
|
| 502 |
+
seg_volume,
|
| 503 |
+
hf_seg,
|
| 504 |
+
cam_pos,
|
| 505 |
+
gancraft_bg,
|
| 506 |
+
gancraft_fg,
|
| 507 |
+
building_stats,
|
| 508 |
+
bg_z,
|
| 509 |
+
building_zs,
|
| 510 |
+
):
|
| 511 |
+
voxel_id, depth2, raydirs, cam_origin = get_voxel_intersection_perspective(
|
| 512 |
+
seg_volume, cam_pos
|
| 513 |
+
)
|
| 514 |
+
buildings = torch.unique(voxel_id[voxel_id > CONSTANTS["BLD_INS_LABEL_MIN"]])
|
| 515 |
+
# Remove odd numbers from the list because they are reserved by roofs.
|
| 516 |
+
buildings = buildings[buildings % 2 == 0]
|
| 517 |
+
with torch.no_grad():
|
| 518 |
+
bg_img = render_bg(
|
| 519 |
+
patch_size, gancraft_bg, hf_seg, voxel_id, depth2, raydirs, cam_origin, bg_z
|
| 520 |
+
)
|
| 521 |
+
for b in buildings:
|
| 522 |
+
assert b % 2 == 0, "Building Instance ID MUST be an even number."
|
| 523 |
+
fg_img, fg_mask = render_fg(
|
| 524 |
+
patch_size,
|
| 525 |
+
gancraft_fg,
|
| 526 |
+
b.item(),
|
| 527 |
+
hf_seg,
|
| 528 |
+
voxel_id,
|
| 529 |
+
depth2,
|
| 530 |
+
raydirs,
|
| 531 |
+
cam_origin,
|
| 532 |
+
building_stats[b.item()],
|
| 533 |
+
building_zs[b.item()],
|
| 534 |
+
)
|
| 535 |
+
bg_img = bg_img * (1 - fg_mask) + fg_img * fg_mask
|
| 536 |
+
|
| 537 |
+
return bg_img
|
citydreamer/model.py
ADDED
|
@@ -0,0 +1,1264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: gancraft.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2023-04-12 19:53:21
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-03-03 11:15:36
|
| 8 |
+
# @Email: [email protected]
|
| 9 |
+
# @Ref: https://github.com/FrozenBurning/SceneDreamer
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
import citydreamer.extensions.grid_encoder
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GanCraftGenerator(torch.nn.Module):
|
| 19 |
+
def __init__(self, cfg):
|
| 20 |
+
super(GanCraftGenerator, self).__init__()
|
| 21 |
+
self.cfg = cfg
|
| 22 |
+
self.render_net = RenderMLP(cfg)
|
| 23 |
+
self.denoiser = RenderCNN(cfg)
|
| 24 |
+
if cfg.NETWORK.GANCRAFT.ENCODER == "GLOBAL":
|
| 25 |
+
self.encoder = GlobalEncoder(cfg)
|
| 26 |
+
elif cfg.NETWORK.GANCRAFT.ENCODER == "LOCAL":
|
| 27 |
+
self.encoder = LocalEncoder(cfg)
|
| 28 |
+
else:
|
| 29 |
+
self.encoder = None
|
| 30 |
+
|
| 31 |
+
if (
|
| 32 |
+
not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
|
| 33 |
+
and not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 34 |
+
):
|
| 35 |
+
raise ValueError(
|
| 36 |
+
"Either POS_EMD_INCUDE_CORDS or POS_EMD_INCUDE_FEATURES should be True."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if cfg.NETWORK.GANCRAFT.POS_EMD == "HASH_GRID":
|
| 40 |
+
grid_encoder_in_dim = 3 if cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS else 0
|
| 41 |
+
if (
|
| 42 |
+
cfg.NETWORK.GANCRAFT.ENCODER in ["GLOBAL", "LOCAL"]
|
| 43 |
+
and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 44 |
+
):
|
| 45 |
+
grid_encoder_in_dim += cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM
|
| 46 |
+
|
| 47 |
+
self.pos_encoder = citydreamer.extensions.grid_encoder.GridEncoder(
|
| 48 |
+
in_channels=grid_encoder_in_dim,
|
| 49 |
+
n_levels=cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS,
|
| 50 |
+
lvl_channels=cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM,
|
| 51 |
+
desired_resolution=cfg.NETWORK.GANCRAFT.HASH_GRID_RESOLUTION,
|
| 52 |
+
)
|
| 53 |
+
elif cfg.NETWORK.GANCRAFT.POS_EMD == "SIN_COS":
|
| 54 |
+
self.pos_encoder = SinCosEncoder(cfg)
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
hf_seg,
|
| 59 |
+
voxel_id,
|
| 60 |
+
depth2,
|
| 61 |
+
raydirs,
|
| 62 |
+
cam_origin,
|
| 63 |
+
building_stats=None,
|
| 64 |
+
z=None,
|
| 65 |
+
deterministic=False,
|
| 66 |
+
):
|
| 67 |
+
r"""GANcraft Generator forward.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
hf_seg (N x (1 + M) x H' x W' tensor) : height field + seg map, where M is the number of classes.
|
| 71 |
+
voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
|
| 72 |
+
depth2 (N x H x W x 2 x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
|
| 73 |
+
intersection.
|
| 74 |
+
raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
|
| 75 |
+
cam_origin (N x 3 tensor): Camera origins.
|
| 76 |
+
building_stats (N x 5 tensor): The dy, dx, h, w, ID of the target building. (Only used in building mode)
|
| 77 |
+
z (N x STYLE_DIM tensor): The style vector.
|
| 78 |
+
deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
|
| 79 |
+
Returns:
|
| 80 |
+
fake_images (N x 3 x H x W tensor): fake images
|
| 81 |
+
"""
|
| 82 |
+
bs, device = hf_seg.size(0), hf_seg.device
|
| 83 |
+
if z is None and self.cfg.NETWORK.GANCRAFT.STYLE_DIM is not None:
|
| 84 |
+
z = torch.randn(
|
| 85 |
+
bs,
|
| 86 |
+
self.cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 87 |
+
dtype=torch.float32,
|
| 88 |
+
device=device,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
features = None
|
| 92 |
+
if self.encoder is not None:
|
| 93 |
+
features = self.encoder(hf_seg)
|
| 94 |
+
|
| 95 |
+
net_out = self._forward_perpix(
|
| 96 |
+
features,
|
| 97 |
+
voxel_id,
|
| 98 |
+
depth2,
|
| 99 |
+
raydirs,
|
| 100 |
+
cam_origin,
|
| 101 |
+
z,
|
| 102 |
+
building_stats,
|
| 103 |
+
deterministic,
|
| 104 |
+
)
|
| 105 |
+
fake_images = self._forward_global(net_out, z)
|
| 106 |
+
return fake_images
|
| 107 |
+
|
| 108 |
+
def _forward_perpix(
|
| 109 |
+
self,
|
| 110 |
+
features,
|
| 111 |
+
voxel_id,
|
| 112 |
+
depth2,
|
| 113 |
+
raydirs,
|
| 114 |
+
cam_origin,
|
| 115 |
+
z,
|
| 116 |
+
building_stats=None,
|
| 117 |
+
deterministic=False,
|
| 118 |
+
):
|
| 119 |
+
r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
features (N x C1 tensor): Local features determined by the current pixel.
|
| 123 |
+
voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels
|
| 124 |
+
depth2 (N x H x W x 2 x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
|
| 125 |
+
raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
|
| 126 |
+
cam_origin (N x 3 tensor): Camera origins.
|
| 127 |
+
z (N x C3 tensor): Intermediate style vectors.
|
| 128 |
+
building_stats (N x 4 tensor): The dy, dx, h, w of the target building. (Only used in building mode)
|
| 129 |
+
deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
|
| 130 |
+
"""
|
| 131 |
+
# Generate sky_mask; PE transform on ray direction.
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
# sky_only_mask: when True, ray hits nothing but sky
|
| 134 |
+
sky_only_mask = voxel_id[:, :, :, [0], :] == 0
|
| 135 |
+
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
normalized_cord, new_dists, new_idx = self._get_sampled_coordinates(
|
| 138 |
+
self.cfg.NETWORK.GANCRAFT.N_SAMPLE_POINTS_PER_RAY,
|
| 139 |
+
depth2,
|
| 140 |
+
raydirs,
|
| 141 |
+
cam_origin,
|
| 142 |
+
building_stats,
|
| 143 |
+
deterministic,
|
| 144 |
+
)
|
| 145 |
+
# Generate per-sample segmentation label
|
| 146 |
+
seg_map_bev = torch.gather(voxel_id, -2, new_idx)
|
| 147 |
+
# print(seg_map_bev.size()) # torch.Size([N, H, W, n_samples + 1, 1])
|
| 148 |
+
# In Building Mode, the one more channel is used for building roofs
|
| 149 |
+
n_classes = (
|
| 150 |
+
self.cfg.NETWORK.GANCRAFT.N_CLASSES + 1
|
| 151 |
+
if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE
|
| 152 |
+
else self.cfg.NETWORK.GANCRAFT.N_CLASSES
|
| 153 |
+
)
|
| 154 |
+
seg_map_bev_onehot = torch.zeros(
|
| 155 |
+
[
|
| 156 |
+
seg_map_bev.size(0),
|
| 157 |
+
seg_map_bev.size(1),
|
| 158 |
+
seg_map_bev.size(2),
|
| 159 |
+
seg_map_bev.size(3),
|
| 160 |
+
n_classes,
|
| 161 |
+
],
|
| 162 |
+
dtype=torch.float,
|
| 163 |
+
device=voxel_id.device,
|
| 164 |
+
)
|
| 165 |
+
# print(seg_map_bev_onehot.size()) # torch.Size([N, H, W, n_samples + 1, 1])
|
| 166 |
+
seg_map_bev_onehot.scatter_(-1, seg_map_bev.long(), 1.0)
|
| 167 |
+
|
| 168 |
+
net_out_s, net_out_c = self._forward_perpix_sub(
|
| 169 |
+
features, normalized_cord, z, seg_map_bev_onehot
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Blending
|
| 173 |
+
weights = self._volum_rendering_relu(
|
| 174 |
+
net_out_s, new_dists * self.cfg.NETWORK.GANCRAFT.DIST_SCALE, dim=-2
|
| 175 |
+
)
|
| 176 |
+
# If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
|
| 177 |
+
weights = weights * torch.logical_not(sky_only_mask).float()
|
| 178 |
+
# print(weights.size()) # torch.Size([N, H, W, n_samples + 1, 1])
|
| 179 |
+
|
| 180 |
+
rgbs = torch.clamp(net_out_c, -1, 1) + 1
|
| 181 |
+
net_out = torch.sum(weights * rgbs, dim=-2, keepdim=True)
|
| 182 |
+
net_out = net_out.squeeze(-2)
|
| 183 |
+
net_out = net_out - 1
|
| 184 |
+
return net_out
|
| 185 |
+
|
| 186 |
+
def _get_sampled_coordinates(
|
| 187 |
+
self,
|
| 188 |
+
n_samples,
|
| 189 |
+
depth2,
|
| 190 |
+
raydirs,
|
| 191 |
+
cam_origin,
|
| 192 |
+
building_stats=None,
|
| 193 |
+
deterministic=False,
|
| 194 |
+
):
|
| 195 |
+
# Random sample points along the ray
|
| 196 |
+
rand_depth, new_dists, new_idx = self._sample_depth_batched(
|
| 197 |
+
depth2,
|
| 198 |
+
n_samples + 1,
|
| 199 |
+
deterministic=deterministic,
|
| 200 |
+
use_box_boundaries=False,
|
| 201 |
+
sample_depth=3,
|
| 202 |
+
)
|
| 203 |
+
nan_mask = torch.isnan(rand_depth)
|
| 204 |
+
inf_mask = torch.isinf(rand_depth)
|
| 205 |
+
rand_depth[nan_mask | inf_mask] = 0.0
|
| 206 |
+
world_coord = raydirs * rand_depth + cam_origin[:, None, None, None, :]
|
| 207 |
+
# assert worldcoord2.shape[-1] == 3
|
| 208 |
+
if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE:
|
| 209 |
+
assert building_stats is not None
|
| 210 |
+
# Make the building object-centric
|
| 211 |
+
building_stats = building_stats[:, None, None, None, :].repeat(
|
| 212 |
+
1, world_coord.size(1), world_coord.size(2), world_coord.size(3), 1
|
| 213 |
+
)
|
| 214 |
+
world_coord[..., 0] -= (
|
| 215 |
+
building_stats[..., 0] + self.cfg.NETWORK.GANCRAFT.CENTER_OFFSET
|
| 216 |
+
)
|
| 217 |
+
world_coord[..., 1] -= (
|
| 218 |
+
building_stats[..., 1] + self.cfg.NETWORK.GANCRAFT.CENTER_OFFSET
|
| 219 |
+
)
|
| 220 |
+
# TODO: Fix non-building rays
|
| 221 |
+
zero_rd_mask = raydirs.repeat(1, 1, 1, n_samples, 1)
|
| 222 |
+
world_coord[zero_rd_mask == 0] = 0
|
| 223 |
+
|
| 224 |
+
normalized_cord = self._get_normalized_coordinates(world_coord)
|
| 225 |
+
return normalized_cord, new_dists, new_idx
|
| 226 |
+
|
| 227 |
+
def _get_normalized_coordinates(self, world_coord):
|
| 228 |
+
delimeter = torch.tensor(
|
| 229 |
+
self.cfg.NETWORK.GANCRAFT.NORMALIZE_DELIMETER, device=world_coord.device
|
| 230 |
+
)
|
| 231 |
+
normalized_cord = world_coord / delimeter * 2 - 1
|
| 232 |
+
# TODO: Temporary fix
|
| 233 |
+
normalized_cord[normalized_cord > 1] = 1
|
| 234 |
+
normalized_cord[normalized_cord < -1] = -1
|
| 235 |
+
# assert (normalized_cord <= 1).all()
|
| 236 |
+
# assert (normalized_cord >= -1).all()
|
| 237 |
+
# print(delimeter, torch.min(normalized_cord), torch.max(normalized_cord))
|
| 238 |
+
# print(normalized_cord.size()) # torch.Size([1, 192, 192, 24, 3])
|
| 239 |
+
return normalized_cord
|
| 240 |
+
|
| 241 |
+
def _sample_depth_batched(
|
| 242 |
+
self,
|
| 243 |
+
depth2,
|
| 244 |
+
n_samples,
|
| 245 |
+
deterministic=False,
|
| 246 |
+
use_box_boundaries=True,
|
| 247 |
+
sample_depth=3,
|
| 248 |
+
):
|
| 249 |
+
r"""Make best effort to sample points within the same distance for every ray.
|
| 250 |
+
Exception: When there is not enough voxel.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
depth2 (N x H x W x 2 x M x 1 tensor):
|
| 254 |
+
- N: Batch.
|
| 255 |
+
- H, W: Height, Width.
|
| 256 |
+
- 2: Entrance / exit depth for each intersected box.
|
| 257 |
+
- M: Number of intersected boxes along the ray.
|
| 258 |
+
- 1: One extra dim for consistent tensor dims.
|
| 259 |
+
depth2 can include NaNs.
|
| 260 |
+
deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
|
| 261 |
+
use_box_boundaries (bool): Whether to add the entrance / exit points into the sample.
|
| 262 |
+
sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels.
|
| 263 |
+
"""
|
| 264 |
+
bs = depth2.size(0)
|
| 265 |
+
dim0 = depth2.size(1)
|
| 266 |
+
dim1 = depth2.size(2)
|
| 267 |
+
dists = depth2[:, :, :, 1] - depth2[:, :, :, 0]
|
| 268 |
+
dists[torch.isnan(dists)] = 0
|
| 269 |
+
# print(dists.size()) # torch.Size([N, H, W, M, 1])
|
| 270 |
+
accu_depth = torch.cumsum(dists, dim=-2)
|
| 271 |
+
# print(accu_depth.size()) # torch.Size([N, H, W, M, 1])
|
| 272 |
+
total_depth = accu_depth[..., [-1], :]
|
| 273 |
+
# print(total_depth.size()) # torch.Size([N, H, W, 1, 1])
|
| 274 |
+
total_depth = torch.clamp(total_depth, None, sample_depth)
|
| 275 |
+
|
| 276 |
+
# Ignore out of range box boundaries. Fill with random samples.
|
| 277 |
+
if use_box_boundaries:
|
| 278 |
+
boundary_samples = accu_depth.clone().detach()
|
| 279 |
+
boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth
|
| 280 |
+
bad_mask = (accu_depth > sample_depth) | (dists == 0)
|
| 281 |
+
boundary_samples[bad_mask] = boundary_samples_filler[bad_mask]
|
| 282 |
+
|
| 283 |
+
rand_shape = [bs, dim0, dim1, n_samples, 1]
|
| 284 |
+
if deterministic:
|
| 285 |
+
rand_samples = torch.empty(
|
| 286 |
+
rand_shape, dtype=total_depth.dtype, device=total_depth.device
|
| 287 |
+
)
|
| 288 |
+
rand_samples[..., :, 0] = torch.linspace(0, 1, n_samples + 2)[1:-1]
|
| 289 |
+
else:
|
| 290 |
+
rand_samples = torch.rand(
|
| 291 |
+
rand_shape, dtype=total_depth.dtype, device=total_depth.device
|
| 292 |
+
)
|
| 293 |
+
# Stratified sampling as in NeRF
|
| 294 |
+
rand_samples = rand_samples / n_samples
|
| 295 |
+
rand_samples[..., :, 0] += torch.linspace(
|
| 296 |
+
0, 1, n_samples + 1, device=rand_samples.device
|
| 297 |
+
)[:-1]
|
| 298 |
+
|
| 299 |
+
rand_samples = rand_samples * total_depth
|
| 300 |
+
# print(rand_samples.size()) # torch.Size([N, H, W, n_samples, 1])
|
| 301 |
+
|
| 302 |
+
# Can also include boundaries
|
| 303 |
+
if use_box_boundaries:
|
| 304 |
+
rand_samples = torch.cat(
|
| 305 |
+
[
|
| 306 |
+
rand_samples,
|
| 307 |
+
boundary_samples,
|
| 308 |
+
torch.zeros(
|
| 309 |
+
[bs, dim0, dim1, 1, 1],
|
| 310 |
+
dtype=total_depth.dtype,
|
| 311 |
+
device=total_depth.device,
|
| 312 |
+
),
|
| 313 |
+
],
|
| 314 |
+
dim=-2,
|
| 315 |
+
)
|
| 316 |
+
rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False)
|
| 317 |
+
|
| 318 |
+
midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2
|
| 319 |
+
# print(midpoints.size()) # torch.Size([N, H, W, n_samples, 1])
|
| 320 |
+
new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :]
|
| 321 |
+
|
| 322 |
+
# Scatter the random samples back
|
| 323 |
+
# print(midpoints.unsqueeze(-3).size()) # torch.Size([N, H, W, 1, n_samples, 1])
|
| 324 |
+
# print(accu_depth.unsqueeze(-2).size()) # torch.Size([N, H, W, M, 1, 1])
|
| 325 |
+
idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3)
|
| 326 |
+
# print(idx.shape, idx.max(), idx.min()) # torch.Size([N, H, W, n_samples, 1]) max 5, min 0
|
| 327 |
+
|
| 328 |
+
depth_deltas = (
|
| 329 |
+
depth2[:, :, :, 0, 1:, :] - depth2[:, :, :, 1, :-1, :]
|
| 330 |
+
) # There might be NaNs!
|
| 331 |
+
# print(depth_deltas.size()) # torch.Size([N, H, W, M, M - 1, 1])
|
| 332 |
+
depth_deltas = torch.cumsum(depth_deltas, dim=-2)
|
| 333 |
+
depth_deltas = torch.cat(
|
| 334 |
+
[depth2[:, :, :, 0, [0], :], depth_deltas + depth2[:, :, :, 0, [0], :]],
|
| 335 |
+
dim=-2,
|
| 336 |
+
)
|
| 337 |
+
heads = torch.gather(depth_deltas, -2, idx)
|
| 338 |
+
# print(heads.size()) # torch.Size([N, H, W, M, 1])
|
| 339 |
+
# print(torch.any(torch.isnan(heads)))
|
| 340 |
+
rand_depth = heads + midpoints
|
| 341 |
+
# print(rand_depth.size()) # torch.Size([N, H, W, M, n_samples, 1])
|
| 342 |
+
return rand_depth, new_dists, idx
|
| 343 |
+
|
| 344 |
+
def _volum_rendering_relu(self, sigma, dists, dim=2):
|
| 345 |
+
free_energy = F.relu(sigma) * dists
|
| 346 |
+
a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here
|
| 347 |
+
b = torch.exp(
|
| 348 |
+
-self._cumsum_exclusive(free_energy, dim=dim)
|
| 349 |
+
) # probability of everything is empty up to now
|
| 350 |
+
return a * b # probability of the ray hits something here
|
| 351 |
+
|
| 352 |
+
def _cumsum_exclusive(self, tensor, dim):
|
| 353 |
+
cumsum = torch.cumsum(tensor, dim)
|
| 354 |
+
cumsum = torch.roll(cumsum, 1, dim)
|
| 355 |
+
cumsum.index_fill_(
|
| 356 |
+
dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0
|
| 357 |
+
)
|
| 358 |
+
return cumsum
|
| 359 |
+
|
| 360 |
+
def _forward_perpix_sub(self, features, normalized_cord, z, seg_map_bev_onehot):
|
| 361 |
+
r"""Forwarding the MLP.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
features (N x C1 x ...? tensor): Local features determined by the current pixel.
|
| 365 |
+
normalized_coord (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
|
| 366 |
+
z (N x C3 tensor): Intermediate style vectors.
|
| 367 |
+
seg_map_bev_onehot (N x H x W x L x C4): One-hot segmentation maps.
|
| 368 |
+
Returns:
|
| 369 |
+
net_out_s (N x H x W x L x 1 tensor): Opacities.
|
| 370 |
+
net_out_c (N x H x W x L x C5 tensor): Color embeddings.
|
| 371 |
+
"""
|
| 372 |
+
feature_in = torch.empty(
|
| 373 |
+
normalized_cord.size(0),
|
| 374 |
+
normalized_cord.size(1),
|
| 375 |
+
normalized_cord.size(2),
|
| 376 |
+
normalized_cord.size(3),
|
| 377 |
+
0,
|
| 378 |
+
device=normalized_cord.device,
|
| 379 |
+
)
|
| 380 |
+
if self.cfg.NETWORK.GANCRAFT.ENCODER == "GLOBAL":
|
| 381 |
+
# print(features.size()) # torch.Size([N, ENCODER_OUT_DIM])
|
| 382 |
+
feature_in = features[:, None, None, None, :].repeat(
|
| 383 |
+
1,
|
| 384 |
+
normalized_cord.size(1),
|
| 385 |
+
normalized_cord.size(2),
|
| 386 |
+
normalized_cord.size(3),
|
| 387 |
+
1,
|
| 388 |
+
)
|
| 389 |
+
elif self.cfg.NETWORK.GANCRAFT.ENCODER == "LOCAL":
|
| 390 |
+
# print(features.size()) # torch.Size([N, ENCODER_OUT_DIM - 1, H, W])
|
| 391 |
+
# print(world_coord.size()) # torch.Size([N, H, W, L, 3])
|
| 392 |
+
# NOTE: grid specifies the sampling pixel locations normalized by the input spatial
|
| 393 |
+
# dimensions. Therefore, it should have most values in the range of [-1, 1].
|
| 394 |
+
grid = normalized_cord.permute(0, 3, 1, 2, 4).reshape(
|
| 395 |
+
-1, normalized_cord.size(1), normalized_cord.size(2), 3
|
| 396 |
+
)
|
| 397 |
+
# print(grid.size()) # torch.Size([N * L, H, W, 3])
|
| 398 |
+
feature_in = F.grid_sample(
|
| 399 |
+
features.repeat(grid.size(0), 1, 1, 1),
|
| 400 |
+
grid[..., [1, 0]],
|
| 401 |
+
align_corners=False,
|
| 402 |
+
)
|
| 403 |
+
# print(feature_in.size()) # torch.Size([N * L, ENCODER_OUT_DIM - 1, H, W])
|
| 404 |
+
feature_in = feature_in.reshape(
|
| 405 |
+
normalized_cord.size(0),
|
| 406 |
+
normalized_cord.size(3),
|
| 407 |
+
feature_in.size(1),
|
| 408 |
+
feature_in.size(2),
|
| 409 |
+
feature_in.size(3),
|
| 410 |
+
).permute(0, 3, 4, 1, 2)
|
| 411 |
+
# print(feature_in.size()) # torch.Size([N, H, W, L, ENCODER_OUT_DIM - 1])
|
| 412 |
+
feature_in = torch.cat([feature_in, normalized_cord[..., [2]]], dim=-1)
|
| 413 |
+
# print(feature_in.size()) # torch.Size([N, H, W, L, ENCODER_OUT_DIM])
|
| 414 |
+
|
| 415 |
+
if self.cfg.NETWORK.GANCRAFT.POS_EMD in ["HASH_GRID", "SIN_COS"]:
|
| 416 |
+
if (
|
| 417 |
+
self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
|
| 418 |
+
and self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 419 |
+
):
|
| 420 |
+
feature_in = self.pos_encoder(
|
| 421 |
+
torch.cat([normalized_cord, feature_in], dim=-1)
|
| 422 |
+
)
|
| 423 |
+
elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
|
| 424 |
+
feature_in = torch.cat(
|
| 425 |
+
[self.pos_encoder(normalized_cord), feature_in], dim=-1
|
| 426 |
+
)
|
| 427 |
+
elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
|
| 428 |
+
# Ignore normalized_cord here to make it decoupled with coordinates
|
| 429 |
+
feature_in = torch.cat([self.pos_encoder(feature_in)], dim=-1)
|
| 430 |
+
else:
|
| 431 |
+
if (
|
| 432 |
+
self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
|
| 433 |
+
and self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 434 |
+
):
|
| 435 |
+
feature_in = torch.cat([normalized_cord, feature_in], dim=-1)
|
| 436 |
+
elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
|
| 437 |
+
feature_in = normalized_cord
|
| 438 |
+
elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
|
| 439 |
+
feature_in = feature_in
|
| 440 |
+
|
| 441 |
+
net_out_s, net_out_c = self.render_net(feature_in, z, seg_map_bev_onehot)
|
| 442 |
+
return net_out_s, net_out_c
|
| 443 |
+
|
| 444 |
+
def _forward_global(self, net_out, z):
|
| 445 |
+
r"""Forward the CNN
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
net_out (N x C5 x H x W tensor): Intermediate feature maps.
|
| 449 |
+
z (N x C3 tensor): Intermediate style vectors.
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
fake_images (N x 3 x H x W tensor): Output image.
|
| 453 |
+
"""
|
| 454 |
+
fake_images = net_out.permute(0, 3, 1, 2).contiguous()
|
| 455 |
+
if self.denoiser is not None:
|
| 456 |
+
fake_images = self.denoiser(fake_images, z)
|
| 457 |
+
fake_images = torch.tanh(fake_images)
|
| 458 |
+
|
| 459 |
+
return fake_images
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class GlobalEncoder(torch.nn.Module):
|
| 463 |
+
def __init__(self, cfg):
|
| 464 |
+
super(GlobalEncoder, self).__init__()
|
| 465 |
+
n_classes = cfg.NETWORK.GANCRAFT.N_CLASSES
|
| 466 |
+
self.hf_conv = torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)
|
| 467 |
+
self.seg_conv = torch.nn.Conv2d(
|
| 468 |
+
n_classes,
|
| 469 |
+
8,
|
| 470 |
+
kernel_size=3,
|
| 471 |
+
stride=2,
|
| 472 |
+
padding=1,
|
| 473 |
+
)
|
| 474 |
+
conv_blocks = []
|
| 475 |
+
cur_hidden_channels = 16
|
| 476 |
+
for _ in range(1, cfg.NETWORK.GANCRAFT.GLOBAL_ENCODER_N_BLOCKS):
|
| 477 |
+
conv_blocks.append(
|
| 478 |
+
SRTConvBlock(in_channels=cur_hidden_channels, out_channels=None)
|
| 479 |
+
)
|
| 480 |
+
cur_hidden_channels *= 2
|
| 481 |
+
|
| 482 |
+
self.conv_blocks = torch.nn.Sequential(*conv_blocks)
|
| 483 |
+
self.fc1 = torch.nn.Linear(cur_hidden_channels, 16)
|
| 484 |
+
self.fc2 = torch.nn.Linear(16, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM)
|
| 485 |
+
self.act = torch.nn.LeakyReLU(0.2)
|
| 486 |
+
|
| 487 |
+
def forward(self, hf_seg):
|
| 488 |
+
hf = self.act(self.hf_conv(hf_seg[:, [0]]))
|
| 489 |
+
seg = self.act(self.seg_conv(hf_seg[:, 1:]))
|
| 490 |
+
out = torch.cat([hf, seg], dim=1)
|
| 491 |
+
for layer in self.conv_blocks:
|
| 492 |
+
out = self.act(layer(out))
|
| 493 |
+
|
| 494 |
+
out = out.permute(0, 2, 3, 1)
|
| 495 |
+
out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1)
|
| 496 |
+
cond = self.act(self.fc1(out))
|
| 497 |
+
cond = torch.tanh(self.fc2(cond))
|
| 498 |
+
return cond
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class LocalEncoder(torch.nn.Module):
|
| 502 |
+
def __init__(self, cfg):
|
| 503 |
+
super(LocalEncoder, self).__init__()
|
| 504 |
+
n_classes = cfg.NETWORK.GANCRAFT.N_CLASSES
|
| 505 |
+
self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3)
|
| 506 |
+
self.seg_conv = torch.nn.Conv2d(
|
| 507 |
+
n_classes, 32, kernel_size=7, stride=2, padding=3
|
| 508 |
+
)
|
| 509 |
+
if cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "BATCH_NORM":
|
| 510 |
+
self.bn1 = torch.nn.BatchNorm2d(64)
|
| 511 |
+
elif cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "GROUP_NORM":
|
| 512 |
+
self.bn1 = torch.nn.GroupNorm(32, 64)
|
| 513 |
+
else:
|
| 514 |
+
raise ValueError(
|
| 515 |
+
"Unknown normalization: %s" % cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM
|
| 516 |
+
)
|
| 517 |
+
self.conv2 = ResConvBlock(64, 128, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
|
| 518 |
+
self.conv3 = ResConvBlock(128, 256, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
|
| 519 |
+
self.conv4 = ResConvBlock(256, 512, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
|
| 520 |
+
self.dconv5 = torch.nn.ConvTranspose2d(
|
| 521 |
+
512, 128, kernel_size=4, stride=2, padding=1
|
| 522 |
+
)
|
| 523 |
+
self.dconv6 = torch.nn.ConvTranspose2d(
|
| 524 |
+
128, 32, kernel_size=4, stride=2, padding=1
|
| 525 |
+
)
|
| 526 |
+
self.dconv7 = torch.nn.Conv2d(
|
| 527 |
+
32, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1, kernel_size=1
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
def forward(self, hf_seg):
|
| 531 |
+
hf = self.hf_conv(hf_seg[:, [0]])
|
| 532 |
+
seg = self.seg_conv(hf_seg[:, 1:])
|
| 533 |
+
out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True)
|
| 534 |
+
# print(out.size()) # torch.Size([N, 64, H/2, W/2])
|
| 535 |
+
out = F.avg_pool2d(self.conv2(out), 2, stride=2)
|
| 536 |
+
# print(out.size()) # torch.Size([N, 128, H/4, W/4])
|
| 537 |
+
out = self.conv3(out)
|
| 538 |
+
# print(out.size()) # torch.Size([N, 256, H/4, W/4])
|
| 539 |
+
out = self.conv4(out)
|
| 540 |
+
# print(out.size()) # torch.Size([N, 512, H/4, W/4])
|
| 541 |
+
out = self.dconv5(out)
|
| 542 |
+
# print(out.size()) # torch.Size([N, 128, H/2, W/2])
|
| 543 |
+
out = self.dconv6(out)
|
| 544 |
+
# print(out.size()) # torch.Size([N, 32, H, W])
|
| 545 |
+
out = self.dconv7(out)
|
| 546 |
+
# print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W])
|
| 547 |
+
return torch.tanh(out)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class SinCosEncoder(torch.nn.Module):
|
| 551 |
+
def __init__(self, cfg):
|
| 552 |
+
super(SinCosEncoder, self).__init__()
|
| 553 |
+
self.freq_bands = 2.0 ** torch.linspace(
|
| 554 |
+
0,
|
| 555 |
+
cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS - 1,
|
| 556 |
+
steps=cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def forward(self, features):
|
| 560 |
+
cord_sin = torch.cat(
|
| 561 |
+
[torch.sin(features * fb) for fb in self.freq_bands], dim=-1
|
| 562 |
+
)
|
| 563 |
+
cord_cos = torch.cat(
|
| 564 |
+
[torch.cos(features * fb) for fb in self.freq_bands], dim=-1
|
| 565 |
+
)
|
| 566 |
+
return torch.cat([cord_sin, cord_cos], dim=-1)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class RenderMLP(torch.nn.Module):
|
| 570 |
+
r"""MLP with affine modulation."""
|
| 571 |
+
|
| 572 |
+
def __init__(self, cfg):
|
| 573 |
+
super(RenderMLP, self).__init__()
|
| 574 |
+
in_dim = 0
|
| 575 |
+
f_dim = (
|
| 576 |
+
cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM
|
| 577 |
+
if cfg.NETWORK.GANCRAFT.ENCODER in ["GLOBAL", "LOCAL"]
|
| 578 |
+
else 0
|
| 579 |
+
)
|
| 580 |
+
if cfg.NETWORK.GANCRAFT.POS_EMD == "HASH_GRID":
|
| 581 |
+
in_dim = (
|
| 582 |
+
cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS
|
| 583 |
+
* cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM
|
| 584 |
+
)
|
| 585 |
+
in_dim += (
|
| 586 |
+
f_dim
|
| 587 |
+
if cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
|
| 588 |
+
and not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 589 |
+
else 0
|
| 590 |
+
)
|
| 591 |
+
elif cfg.NETWORK.GANCRAFT.POS_EMD == "SIN_COS":
|
| 592 |
+
if (
|
| 593 |
+
cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
|
| 594 |
+
and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 595 |
+
):
|
| 596 |
+
in_dim = (3 + f_dim) * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2
|
| 597 |
+
elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
|
| 598 |
+
in_dim = 3 * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2 + f_dim
|
| 599 |
+
elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
|
| 600 |
+
in_dim = f_dim * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2
|
| 601 |
+
else:
|
| 602 |
+
if (
|
| 603 |
+
cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
|
| 604 |
+
and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
|
| 605 |
+
):
|
| 606 |
+
in_dim = 3 + f_dim
|
| 607 |
+
elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
|
| 608 |
+
in_dim = 3
|
| 609 |
+
elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
|
| 610 |
+
in_dim = f_dim
|
| 611 |
+
|
| 612 |
+
self.fc_m_a = torch.nn.Linear(
|
| 613 |
+
cfg.NETWORK.GANCRAFT.N_CLASSES + 1
|
| 614 |
+
if cfg.NETWORK.GANCRAFT.BUILDING_MODE
|
| 615 |
+
else cfg.NETWORK.GANCRAFT.N_CLASSES,
|
| 616 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 617 |
+
bias=False,
|
| 618 |
+
)
|
| 619 |
+
self.fc_1 = torch.nn.Linear(
|
| 620 |
+
in_dim,
|
| 621 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 622 |
+
)
|
| 623 |
+
self.fc_2 = (
|
| 624 |
+
ModLinear(
|
| 625 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 626 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 627 |
+
cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 628 |
+
bias=False,
|
| 629 |
+
mod_bias=True,
|
| 630 |
+
output_mode=True,
|
| 631 |
+
)
|
| 632 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 633 |
+
else torch.nn.Linear(
|
| 634 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 635 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 636 |
+
)
|
| 637 |
+
)
|
| 638 |
+
self.fc_3 = (
|
| 639 |
+
ModLinear(
|
| 640 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 641 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 642 |
+
cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 643 |
+
bias=False,
|
| 644 |
+
mod_bias=True,
|
| 645 |
+
output_mode=True,
|
| 646 |
+
)
|
| 647 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 648 |
+
else torch.nn.Linear(
|
| 649 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 650 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 651 |
+
)
|
| 652 |
+
)
|
| 653 |
+
self.fc_4 = (
|
| 654 |
+
ModLinear(
|
| 655 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 656 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 657 |
+
cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 658 |
+
bias=False,
|
| 659 |
+
mod_bias=True,
|
| 660 |
+
output_mode=True,
|
| 661 |
+
)
|
| 662 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 663 |
+
else torch.nn.Linear(
|
| 664 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 665 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 666 |
+
)
|
| 667 |
+
)
|
| 668 |
+
self.fc_sigma = (
|
| 669 |
+
torch.nn.Linear(
|
| 670 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 671 |
+
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA,
|
| 672 |
+
)
|
| 673 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 674 |
+
else torch.nn.Linear(
|
| 675 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 676 |
+
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA,
|
| 677 |
+
)
|
| 678 |
+
)
|
| 679 |
+
self.fc_5 = (
|
| 680 |
+
ModLinear(
|
| 681 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 682 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 683 |
+
cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 684 |
+
bias=False,
|
| 685 |
+
mod_bias=True,
|
| 686 |
+
output_mode=True,
|
| 687 |
+
)
|
| 688 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 689 |
+
else torch.nn.Linear(
|
| 690 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 691 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 692 |
+
)
|
| 693 |
+
)
|
| 694 |
+
self.fc_6 = (
|
| 695 |
+
ModLinear(
|
| 696 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 697 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 698 |
+
cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 699 |
+
bias=False,
|
| 700 |
+
mod_bias=True,
|
| 701 |
+
output_mode=True,
|
| 702 |
+
)
|
| 703 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 704 |
+
else torch.nn.Linear(
|
| 705 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 706 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 707 |
+
)
|
| 708 |
+
)
|
| 709 |
+
self.fc_out_c = (
|
| 710 |
+
torch.nn.Linear(
|
| 711 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 712 |
+
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR,
|
| 713 |
+
)
|
| 714 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
|
| 715 |
+
else torch.nn.Linear(
|
| 716 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 717 |
+
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR,
|
| 718 |
+
)
|
| 719 |
+
)
|
| 720 |
+
self.act = torch.nn.LeakyReLU(negative_slope=0.2)
|
| 721 |
+
|
| 722 |
+
def forward(self, x, z, m):
|
| 723 |
+
r"""Forward network
|
| 724 |
+
|
| 725 |
+
Args:
|
| 726 |
+
x (N x H x W x M x in_channels tensor): Projected features.
|
| 727 |
+
z (N x cfg.NETWORK.GANCRAFT.STYLE_DIM tensor): Style codes.
|
| 728 |
+
m (N x H x W x M x mask_channels tensor): One-hot segmentation maps.
|
| 729 |
+
"""
|
| 730 |
+
# b, h, w, n, _ = x.size()
|
| 731 |
+
if z is not None:
|
| 732 |
+
z = z[:, None, None, None, :]
|
| 733 |
+
f = self.fc_1(x)
|
| 734 |
+
f = f + self.fc_m_a(m)
|
| 735 |
+
# Common MLP
|
| 736 |
+
f = self.act(f)
|
| 737 |
+
f = self.act(self.fc_2(f, z)) if z is not None else self.act(self.fc_2(f))
|
| 738 |
+
f = self.act(self.fc_3(f, z)) if z is not None else self.act(self.fc_3(f))
|
| 739 |
+
f = self.act(self.fc_4(f, z)) if z is not None else self.act(self.fc_4(f))
|
| 740 |
+
# Sigma MLP
|
| 741 |
+
sigma = self.fc_sigma(f) if z is not None else self.act(self.fc_sigma(f))
|
| 742 |
+
# Color MLP
|
| 743 |
+
f = self.act(self.fc_5(f, z)) if z is not None else self.act(self.fc_5(f))
|
| 744 |
+
f = self.act(self.fc_6(f, z)) if z is not None else self.act(self.fc_6(f))
|
| 745 |
+
c = self.fc_out_c(f)
|
| 746 |
+
return sigma, c
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class RenderCNN(torch.nn.Module):
|
| 750 |
+
r"""CNN converting intermediate feature map to final image."""
|
| 751 |
+
|
| 752 |
+
def __init__(self, cfg):
|
| 753 |
+
super(RenderCNN, self).__init__()
|
| 754 |
+
if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None:
|
| 755 |
+
self.fc_z_cond = torch.nn.Linear(
|
| 756 |
+
cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
| 757 |
+
2 * 2 * cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 758 |
+
)
|
| 759 |
+
self.conv1 = torch.nn.Conv2d(
|
| 760 |
+
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR,
|
| 761 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 762 |
+
1,
|
| 763 |
+
stride=1,
|
| 764 |
+
padding=0,
|
| 765 |
+
)
|
| 766 |
+
self.conv2a = torch.nn.Conv2d(
|
| 767 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 768 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 769 |
+
3,
|
| 770 |
+
stride=1,
|
| 771 |
+
padding=1,
|
| 772 |
+
)
|
| 773 |
+
self.conv2b = torch.nn.Conv2d(
|
| 774 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 775 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 776 |
+
3,
|
| 777 |
+
stride=1,
|
| 778 |
+
padding=1,
|
| 779 |
+
bias=False,
|
| 780 |
+
)
|
| 781 |
+
self.conv3a = torch.nn.Conv2d(
|
| 782 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 783 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 784 |
+
3,
|
| 785 |
+
stride=1,
|
| 786 |
+
padding=1,
|
| 787 |
+
)
|
| 788 |
+
self.conv3b = torch.nn.Conv2d(
|
| 789 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 790 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 791 |
+
3,
|
| 792 |
+
stride=1,
|
| 793 |
+
padding=1,
|
| 794 |
+
bias=False,
|
| 795 |
+
)
|
| 796 |
+
self.conv4a = torch.nn.Conv2d(
|
| 797 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 798 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 799 |
+
1,
|
| 800 |
+
stride=1,
|
| 801 |
+
padding=0,
|
| 802 |
+
)
|
| 803 |
+
self.conv4b = torch.nn.Conv2d(
|
| 804 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 805 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
|
| 806 |
+
1,
|
| 807 |
+
stride=1,
|
| 808 |
+
padding=0,
|
| 809 |
+
)
|
| 810 |
+
self.conv4 = torch.nn.Conv2d(
|
| 811 |
+
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, 1, stride=1, padding=0
|
| 812 |
+
)
|
| 813 |
+
self.act = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 814 |
+
|
| 815 |
+
def modulate(self, x, w, b):
|
| 816 |
+
w = w[..., None, None]
|
| 817 |
+
b = b[..., None, None]
|
| 818 |
+
return x * (w + 1) + b
|
| 819 |
+
|
| 820 |
+
def forward(self, x, z):
|
| 821 |
+
r"""Forward network.
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
x (N x in_channels x H x W tensor): Intermediate feature map
|
| 825 |
+
z (N x style_dim tensor): Style codes.
|
| 826 |
+
"""
|
| 827 |
+
if z is not None:
|
| 828 |
+
z = self.fc_z_cond(z)
|
| 829 |
+
adapt = torch.chunk(z, 2 * 2, dim=-1)
|
| 830 |
+
|
| 831 |
+
y = self.act(self.conv1(x))
|
| 832 |
+
y = y + self.conv2b(self.act(self.conv2a(y)))
|
| 833 |
+
if z is not None:
|
| 834 |
+
y = self.act(self.modulate(y, adapt[0], adapt[1]))
|
| 835 |
+
else:
|
| 836 |
+
y = self.act(y)
|
| 837 |
+
|
| 838 |
+
y = y + self.conv3b(self.act(self.conv3a(y)))
|
| 839 |
+
if z is not None:
|
| 840 |
+
y = self.act(self.modulate(y, adapt[2], adapt[3]))
|
| 841 |
+
else:
|
| 842 |
+
y = self.act(y)
|
| 843 |
+
|
| 844 |
+
y = y + self.conv4b(self.act(self.conv4a(y)))
|
| 845 |
+
y = self.act(y)
|
| 846 |
+
y = self.conv4(y)
|
| 847 |
+
|
| 848 |
+
return y
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
class SRTConvBlock(torch.nn.Module):
|
| 852 |
+
def __init__(self, in_channels, hidden_channels=None, out_channels=None):
|
| 853 |
+
super(SRTConvBlock, self).__init__()
|
| 854 |
+
if hidden_channels is None:
|
| 855 |
+
hidden_channels = in_channels
|
| 856 |
+
if out_channels is None:
|
| 857 |
+
out_channels = 2 * hidden_channels
|
| 858 |
+
|
| 859 |
+
self.layers = torch.nn.Sequential(
|
| 860 |
+
torch.nn.Conv2d(
|
| 861 |
+
in_channels,
|
| 862 |
+
hidden_channels,
|
| 863 |
+
stride=1,
|
| 864 |
+
kernel_size=3,
|
| 865 |
+
padding=1,
|
| 866 |
+
bias=False,
|
| 867 |
+
),
|
| 868 |
+
torch.nn.ReLU(),
|
| 869 |
+
torch.nn.Conv2d(
|
| 870 |
+
hidden_channels,
|
| 871 |
+
out_channels,
|
| 872 |
+
stride=2,
|
| 873 |
+
kernel_size=3,
|
| 874 |
+
padding=1,
|
| 875 |
+
bias=False,
|
| 876 |
+
),
|
| 877 |
+
torch.nn.ReLU(),
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
def forward(self, x):
|
| 881 |
+
return self.layers(x)
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
class ResConvBlock(torch.nn.Module):
|
| 885 |
+
def __init__(self, in_channels, out_channels, norm, bias=False):
|
| 886 |
+
super(ResConvBlock, self).__init__()
|
| 887 |
+
# conv3x3(in_planes, int(out_planes / 2))
|
| 888 |
+
self.conv1 = torch.nn.Conv2d(
|
| 889 |
+
in_channels,
|
| 890 |
+
out_channels // 2,
|
| 891 |
+
kernel_size=3,
|
| 892 |
+
stride=1,
|
| 893 |
+
padding=1,
|
| 894 |
+
bias=bias,
|
| 895 |
+
)
|
| 896 |
+
# conv3x3(int(out_planes / 2), int(out_planes / 4))
|
| 897 |
+
self.conv2 = torch.nn.Conv2d(
|
| 898 |
+
out_channels // 2,
|
| 899 |
+
out_channels // 4,
|
| 900 |
+
kernel_size=3,
|
| 901 |
+
stride=1,
|
| 902 |
+
padding=1,
|
| 903 |
+
bias=bias,
|
| 904 |
+
)
|
| 905 |
+
# conv3x3(int(out_planes / 4), int(out_planes / 4))
|
| 906 |
+
self.conv3 = torch.nn.Conv2d(
|
| 907 |
+
out_channels // 4,
|
| 908 |
+
out_channels // 4,
|
| 909 |
+
kernel_size=3,
|
| 910 |
+
stride=1,
|
| 911 |
+
padding=1,
|
| 912 |
+
bias=bias,
|
| 913 |
+
)
|
| 914 |
+
if norm == "BATCH_NORM":
|
| 915 |
+
self.bn1 = torch.nn.BatchNorm2d(in_channels)
|
| 916 |
+
self.bn2 = torch.nn.BatchNorm2d(out_channels // 2)
|
| 917 |
+
self.bn3 = torch.nn.BatchNorm2d(out_channels // 4)
|
| 918 |
+
self.bn4 = torch.nn.BatchNorm2d(in_channels)
|
| 919 |
+
elif norm == "GROUP_NORM":
|
| 920 |
+
self.bn1 = torch.nn.GroupNorm(32, in_channels)
|
| 921 |
+
self.bn2 = torch.nn.GroupNorm(32, out_channels // 2)
|
| 922 |
+
self.bn3 = torch.nn.GroupNorm(32, out_channels // 4)
|
| 923 |
+
self.bn4 = torch.nn.GroupNorm(32, in_channels)
|
| 924 |
+
|
| 925 |
+
if in_channels != out_channels:
|
| 926 |
+
self.downsample = torch.nn.Sequential(
|
| 927 |
+
self.bn4,
|
| 928 |
+
torch.nn.ReLU(True),
|
| 929 |
+
torch.nn.Conv2d(
|
| 930 |
+
in_channels, out_channels, kernel_size=1, stride=1, bias=False
|
| 931 |
+
),
|
| 932 |
+
)
|
| 933 |
+
else:
|
| 934 |
+
self.downsample = None
|
| 935 |
+
|
| 936 |
+
def forward(self, x):
|
| 937 |
+
residual = x
|
| 938 |
+
# print(residual.size()) # torch.Size([N, 64, H, W])
|
| 939 |
+
out1 = self.bn1(x)
|
| 940 |
+
out1 = F.relu(out1, True)
|
| 941 |
+
out1 = self.conv1(out1)
|
| 942 |
+
# print(out1.size()) # torch.Size([N, 64, H, W])
|
| 943 |
+
out2 = self.bn2(out1)
|
| 944 |
+
out2 = F.relu(out2, True)
|
| 945 |
+
out2 = self.conv2(out2)
|
| 946 |
+
# print(out2.size()) # torch.Size([N, 32, H, W])
|
| 947 |
+
out3 = self.bn3(out2)
|
| 948 |
+
out3 = F.relu(out3, True)
|
| 949 |
+
out3 = self.conv3(out3)
|
| 950 |
+
# print(out3.size()) # torch.Size([N, 32, H, W])
|
| 951 |
+
out3 = torch.cat((out1, out2, out3), dim=1)
|
| 952 |
+
# print(out3.size()) # torch.Size([N, 128, H, W])
|
| 953 |
+
if self.downsample is not None:
|
| 954 |
+
residual = self.downsample(residual)
|
| 955 |
+
# print(residual.size()) # torch.Size([N, 128, H, W])
|
| 956 |
+
out3 += residual
|
| 957 |
+
return out3
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
class ModLinear(torch.nn.Module):
|
| 961 |
+
r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
|
| 962 |
+
Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
|
| 963 |
+
multiple inputs.
|
| 964 |
+
Args:
|
| 965 |
+
in_features (int): Number of input features.
|
| 966 |
+
out_features (int): Number of output features.
|
| 967 |
+
style_features (int): Number of style features.
|
| 968 |
+
bias (bool): Apply additive bias before the activation function?
|
| 969 |
+
mod_bias (bool): Whether to modulate bias.
|
| 970 |
+
output_mode (bool): If True, modulate output instead of input.
|
| 971 |
+
weight_gain (float): Initialization gain
|
| 972 |
+
"""
|
| 973 |
+
|
| 974 |
+
def __init__(
|
| 975 |
+
self,
|
| 976 |
+
in_features,
|
| 977 |
+
out_features,
|
| 978 |
+
style_features,
|
| 979 |
+
bias=True,
|
| 980 |
+
mod_bias=True,
|
| 981 |
+
output_mode=False,
|
| 982 |
+
weight_gain=1,
|
| 983 |
+
bias_init=0,
|
| 984 |
+
):
|
| 985 |
+
super(ModLinear, self).__init__()
|
| 986 |
+
weight_gain = weight_gain / np.sqrt(in_features)
|
| 987 |
+
self.weight = torch.nn.Parameter(
|
| 988 |
+
torch.randn([out_features, in_features]) * weight_gain
|
| 989 |
+
)
|
| 990 |
+
self.bias = (
|
| 991 |
+
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
|
| 992 |
+
if bias
|
| 993 |
+
else None
|
| 994 |
+
)
|
| 995 |
+
self.weight_alpha = torch.nn.Parameter(
|
| 996 |
+
torch.randn([in_features, style_features]) / np.sqrt(style_features)
|
| 997 |
+
)
|
| 998 |
+
self.bias_alpha = torch.nn.Parameter(
|
| 999 |
+
torch.full([in_features], 1, dtype=torch.float)
|
| 1000 |
+
) # init to 1
|
| 1001 |
+
self.weight_beta = None
|
| 1002 |
+
self.bias_beta = None
|
| 1003 |
+
self.mod_bias = mod_bias
|
| 1004 |
+
self.output_mode = output_mode
|
| 1005 |
+
if mod_bias:
|
| 1006 |
+
if output_mode:
|
| 1007 |
+
mod_bias_dims = out_features
|
| 1008 |
+
else:
|
| 1009 |
+
mod_bias_dims = in_features
|
| 1010 |
+
self.weight_beta = torch.nn.Parameter(
|
| 1011 |
+
torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features)
|
| 1012 |
+
)
|
| 1013 |
+
self.bias_beta = torch.nn.Parameter(
|
| 1014 |
+
torch.full([mod_bias_dims], 0, dtype=torch.float)
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
@staticmethod
|
| 1018 |
+
def _linear_f(x, w, b):
|
| 1019 |
+
w = w.to(x.dtype)
|
| 1020 |
+
x_shape = x.shape
|
| 1021 |
+
x = x.reshape(-1, x_shape[-1])
|
| 1022 |
+
if b is not None:
|
| 1023 |
+
b = b.to(x.dtype)
|
| 1024 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
| 1025 |
+
else:
|
| 1026 |
+
x = x.matmul(w.t())
|
| 1027 |
+
x = x.reshape(*x_shape[:-1], -1)
|
| 1028 |
+
return x
|
| 1029 |
+
|
| 1030 |
+
# x: B, ... , Cin
|
| 1031 |
+
# z: B, 1, 1, , Cz
|
| 1032 |
+
def forward(self, x, z):
|
| 1033 |
+
x_shape = x.shape
|
| 1034 |
+
z_shape = z.shape
|
| 1035 |
+
x = x.reshape(x_shape[0], -1, x_shape[-1])
|
| 1036 |
+
z = z.reshape(z_shape[0], 1, z_shape[-1])
|
| 1037 |
+
|
| 1038 |
+
alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I]
|
| 1039 |
+
w = self.weight.to(x.dtype) # [O I]
|
| 1040 |
+
w = w.unsqueeze(0) * alpha # [1 O I] * [B 1 I] = [B O I]
|
| 1041 |
+
|
| 1042 |
+
if self.mod_bias:
|
| 1043 |
+
beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I]
|
| 1044 |
+
if not self.output_mode:
|
| 1045 |
+
x = x + beta
|
| 1046 |
+
|
| 1047 |
+
b = self.bias
|
| 1048 |
+
if b is not None:
|
| 1049 |
+
b = b.to(x.dtype)[None, None, :]
|
| 1050 |
+
if self.mod_bias and self.output_mode:
|
| 1051 |
+
if b is None:
|
| 1052 |
+
b = beta
|
| 1053 |
+
else:
|
| 1054 |
+
b = b + beta
|
| 1055 |
+
|
| 1056 |
+
# [B ? I] @ [B I O] = [B ? O]
|
| 1057 |
+
if b is not None:
|
| 1058 |
+
x = torch.baddbmm(b, x, w.transpose(1, 2))
|
| 1059 |
+
else:
|
| 1060 |
+
x = x.bmm(w.transpose(1, 2))
|
| 1061 |
+
x = x.reshape(*x_shape[:-1], x.shape[-1])
|
| 1062 |
+
return x
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
class GanCraftDiscriminator(torch.nn.Module):
|
| 1066 |
+
def __init__(self, cfg):
|
| 1067 |
+
super(GanCraftDiscriminator, self).__init__()
|
| 1068 |
+
# bottom-up pathway
|
| 1069 |
+
# down_conv2d_block = Conv2dBlock, stride=2, kernel=3, padding=1, weight_norm=spectral
|
| 1070 |
+
# self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3
|
| 1071 |
+
self.enc1 = torch.nn.Sequential(
|
| 1072 |
+
torch.nn.utils.spectral_norm(
|
| 1073 |
+
torch.nn.Conv2d(
|
| 1074 |
+
3, # RGB
|
| 1075 |
+
cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1076 |
+
stride=2,
|
| 1077 |
+
kernel_size=3,
|
| 1078 |
+
padding=1,
|
| 1079 |
+
bias=True,
|
| 1080 |
+
)
|
| 1081 |
+
),
|
| 1082 |
+
torch.nn.LeakyReLU(0.2),
|
| 1083 |
+
)
|
| 1084 |
+
# self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7
|
| 1085 |
+
self.enc2 = torch.nn.Sequential(
|
| 1086 |
+
torch.nn.utils.spectral_norm(
|
| 1087 |
+
torch.nn.Conv2d(
|
| 1088 |
+
1 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1089 |
+
2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1090 |
+
stride=2,
|
| 1091 |
+
kernel_size=3,
|
| 1092 |
+
padding=1,
|
| 1093 |
+
bias=True,
|
| 1094 |
+
)
|
| 1095 |
+
),
|
| 1096 |
+
torch.nn.LeakyReLU(0.2),
|
| 1097 |
+
)
|
| 1098 |
+
# self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15
|
| 1099 |
+
self.enc3 = torch.nn.Sequential(
|
| 1100 |
+
torch.nn.utils.spectral_norm(
|
| 1101 |
+
torch.nn.Conv2d(
|
| 1102 |
+
2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1103 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1104 |
+
stride=2,
|
| 1105 |
+
kernel_size=3,
|
| 1106 |
+
padding=1,
|
| 1107 |
+
bias=True,
|
| 1108 |
+
)
|
| 1109 |
+
),
|
| 1110 |
+
torch.nn.LeakyReLU(0.2),
|
| 1111 |
+
)
|
| 1112 |
+
# self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31
|
| 1113 |
+
self.enc4 = torch.nn.Sequential(
|
| 1114 |
+
torch.nn.utils.spectral_norm(
|
| 1115 |
+
torch.nn.Conv2d(
|
| 1116 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1117 |
+
8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1118 |
+
stride=2,
|
| 1119 |
+
kernel_size=3,
|
| 1120 |
+
padding=1,
|
| 1121 |
+
bias=True,
|
| 1122 |
+
)
|
| 1123 |
+
),
|
| 1124 |
+
torch.nn.LeakyReLU(0.2),
|
| 1125 |
+
)
|
| 1126 |
+
# self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63
|
| 1127 |
+
self.enc5 = torch.nn.Sequential(
|
| 1128 |
+
torch.nn.utils.spectral_norm(
|
| 1129 |
+
torch.nn.Conv2d(
|
| 1130 |
+
8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1131 |
+
8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1132 |
+
stride=2,
|
| 1133 |
+
kernel_size=3,
|
| 1134 |
+
padding=1,
|
| 1135 |
+
bias=True,
|
| 1136 |
+
)
|
| 1137 |
+
),
|
| 1138 |
+
torch.nn.LeakyReLU(0.2),
|
| 1139 |
+
)
|
| 1140 |
+
# top-down pathway
|
| 1141 |
+
# latent_conv2d_block = Conv2dBlock, stride=1, kernel=1, weight_norm=spectral
|
| 1142 |
+
# self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
|
| 1143 |
+
self.lat2 = torch.nn.Sequential(
|
| 1144 |
+
torch.nn.utils.spectral_norm(
|
| 1145 |
+
torch.nn.Conv2d(
|
| 1146 |
+
2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1147 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1148 |
+
stride=1,
|
| 1149 |
+
kernel_size=1,
|
| 1150 |
+
bias=True,
|
| 1151 |
+
)
|
| 1152 |
+
),
|
| 1153 |
+
torch.nn.LeakyReLU(0.2),
|
| 1154 |
+
)
|
| 1155 |
+
# self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
|
| 1156 |
+
self.lat3 = torch.nn.Sequential(
|
| 1157 |
+
torch.nn.utils.spectral_norm(
|
| 1158 |
+
torch.nn.Conv2d(
|
| 1159 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1160 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1161 |
+
stride=1,
|
| 1162 |
+
kernel_size=1,
|
| 1163 |
+
bias=True,
|
| 1164 |
+
)
|
| 1165 |
+
),
|
| 1166 |
+
torch.nn.LeakyReLU(0.2),
|
| 1167 |
+
)
|
| 1168 |
+
# self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
|
| 1169 |
+
self.lat4 = torch.nn.Sequential(
|
| 1170 |
+
torch.nn.utils.spectral_norm(
|
| 1171 |
+
torch.nn.Conv2d(
|
| 1172 |
+
8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1173 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1174 |
+
stride=1,
|
| 1175 |
+
kernel_size=1,
|
| 1176 |
+
bias=True,
|
| 1177 |
+
)
|
| 1178 |
+
),
|
| 1179 |
+
torch.nn.LeakyReLU(0.2),
|
| 1180 |
+
)
|
| 1181 |
+
# self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
|
| 1182 |
+
self.lat5 = torch.nn.Sequential(
|
| 1183 |
+
torch.nn.utils.spectral_norm(
|
| 1184 |
+
torch.nn.Conv2d(
|
| 1185 |
+
8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1186 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1187 |
+
stride=1,
|
| 1188 |
+
kernel_size=1,
|
| 1189 |
+
bias=True,
|
| 1190 |
+
)
|
| 1191 |
+
),
|
| 1192 |
+
torch.nn.LeakyReLU(0.2),
|
| 1193 |
+
)
|
| 1194 |
+
# upsampling
|
| 1195 |
+
self.upsample2x = torch.nn.Upsample(
|
| 1196 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
| 1197 |
+
)
|
| 1198 |
+
# final layers
|
| 1199 |
+
# stride1_conv2d_block = Conv2dBlock, stride=1, kernel=3, padding=1, weight_norm=spectral
|
| 1200 |
+
# self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
|
| 1201 |
+
self.final2 = torch.nn.Sequential(
|
| 1202 |
+
torch.nn.utils.spectral_norm(
|
| 1203 |
+
torch.nn.Conv2d(
|
| 1204 |
+
4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1205 |
+
2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1206 |
+
stride=1,
|
| 1207 |
+
kernel_size=3,
|
| 1208 |
+
padding=1,
|
| 1209 |
+
bias=True,
|
| 1210 |
+
)
|
| 1211 |
+
),
|
| 1212 |
+
torch.nn.LeakyReLU(0.2),
|
| 1213 |
+
)
|
| 1214 |
+
# self.output = Conv2dBlock(num_filters * 2, num_labels + 1, kernel_size=1)
|
| 1215 |
+
self.output = torch.nn.Sequential(
|
| 1216 |
+
torch.nn.Conv2d(
|
| 1217 |
+
2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
|
| 1218 |
+
cfg.NETWORK.GANCRAFT.N_CLASSES + 1,
|
| 1219 |
+
stride=1,
|
| 1220 |
+
kernel_size=1,
|
| 1221 |
+
bias=True,
|
| 1222 |
+
),
|
| 1223 |
+
torch.nn.LeakyReLU(0.2),
|
| 1224 |
+
)
|
| 1225 |
+
self.interpolator = self._smooth_interp
|
| 1226 |
+
|
| 1227 |
+
@staticmethod
|
| 1228 |
+
def _smooth_interp(x, size):
|
| 1229 |
+
r"""Smooth interpolation of segmentation maps.
|
| 1230 |
+
|
| 1231 |
+
Args:
|
| 1232 |
+
x (4D tensor): Segmentation maps.
|
| 1233 |
+
size(2D list): Target size (H, W).
|
| 1234 |
+
"""
|
| 1235 |
+
x = F.interpolate(x, size=size, mode="area")
|
| 1236 |
+
onehot_idx = torch.argmax(x, dim=-3, keepdims=True)
|
| 1237 |
+
x.fill_(0.0)
|
| 1238 |
+
x.scatter_(1, onehot_idx, 1.0)
|
| 1239 |
+
return x
|
| 1240 |
+
|
| 1241 |
+
def _single_forward(self, images, seg_maps):
|
| 1242 |
+
# bottom-up pathway
|
| 1243 |
+
feat11 = self.enc1(images)
|
| 1244 |
+
feat12 = self.enc2(feat11)
|
| 1245 |
+
feat13 = self.enc3(feat12)
|
| 1246 |
+
feat14 = self.enc4(feat13)
|
| 1247 |
+
feat15 = self.enc5(feat14)
|
| 1248 |
+
# top-down pathway and lateral connections
|
| 1249 |
+
feat25 = self.lat5(feat15)
|
| 1250 |
+
feat24 = self.upsample2x(feat25) + self.lat4(feat14)
|
| 1251 |
+
feat23 = self.upsample2x(feat24) + self.lat3(feat13)
|
| 1252 |
+
feat22 = self.upsample2x(feat23) + self.lat2(feat12)
|
| 1253 |
+
# final prediction layers
|
| 1254 |
+
feat32 = self.final2(feat22)
|
| 1255 |
+
|
| 1256 |
+
label_map = self.interpolator(seg_maps, size=feat32.size()[2:])
|
| 1257 |
+
pred = self.output(feat32) # N, num_labels + 1, H//4, W//4
|
| 1258 |
+
return {"pred": pred, "label": label_map}
|
| 1259 |
+
|
| 1260 |
+
def forward(self, images, seg_maps, masks):
|
| 1261 |
+
# print(seg_maps.size()) # torch.Size([1, 7, H, W])
|
| 1262 |
+
# print(masks.size()) # torch.Size([1, 1, H, W])
|
| 1263 |
+
seg_maps = seg_maps * masks
|
| 1264 |
+
return self._single_forward(images * masks, seg_maps)
|
requirements.txt
CHANGED
|
@@ -2,6 +2,9 @@
|
|
| 2 |
torch==1.12.0
|
| 3 |
torchvision
|
| 4 |
|
|
|
|
|
|
|
| 5 |
numpy
|
| 6 |
opencv-python
|
| 7 |
-
|
|
|
|
|
|
| 2 |
torch==1.12.0
|
| 3 |
torchvision
|
| 4 |
|
| 5 |
+
easydict
|
| 6 |
+
gradio
|
| 7 |
numpy
|
| 8 |
opencv-python
|
| 9 |
+
pillow
|
| 10 |
+
|