Insta360-Research commited on
Commit
f5f5ec2
·
verified ·
1 Parent(s): b355cef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -34
app.py CHANGED
@@ -8,21 +8,19 @@ import numpy as np
8
  import gradio as gr
9
  from huggingface_hub import hf_hub_download
10
 
11
- # ================== 必须最早 import spaces ==================
12
  try:
13
- import spaces # type: ignore
14
  gpu_decorator = spaces.GPU
15
  except Exception:
16
  gpu_decorator = lambda f: f
17
 
18
- # ================== 工程路径:确保能 import networks ==================
19
- # 适配:无论你从哪里启动 python app.py,都能找到项目根目录
20
  PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
21
  sys.path.append(PROJECT_ROOT)
22
 
23
  from networks.models import make # noqa: E402
24
 
25
- # ================== HF 模型配置 ==================
26
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
27
  WEIGHTS_FILE = "model.pth"
28
  CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml")
@@ -30,7 +28,7 @@ CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml")
30
  model = None
31
  device = "cpu"
32
 
33
- # ================== 固定颜色映射(颜色一致) ==================
34
  import matplotlib
35
 
36
  def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
@@ -43,7 +41,7 @@ def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.nda
43
  colored = (colored * 255).astype(np.uint8)
44
  return np.ascontiguousarray(colored)
45
 
46
- # ================== 模型加载 ==================
47
  def load_model(config_path: str):
48
  import torch
49
  import torch.nn as nn
@@ -77,17 +75,17 @@ def load_model(config_path: str):
77
  print("✅ Model loaded.")
78
  return m
79
 
80
- # ================== 启动时加载一次模型 ==================
81
  model = load_model(CONFIG_PATH)
82
 
83
- # ================== 加载标度尺图片 ==================
84
  COLORBAR_DIR = os.path.join(PROJECT_ROOT, "colorbars")
85
  colorbar_100m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_color.png"))
86
  colorbar_100m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_gray.png"))
87
  colorbar_10m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_color.png"))
88
  colorbar_10m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_gray.png"))
89
 
90
- # 转换为RGB(Gradio需要RGB格式)
91
  if colorbar_100m_color is not None:
92
  colorbar_100m_color = cv2.cvtColor(colorbar_100m_color, cv2.COLOR_BGR2RGB)
93
  if colorbar_100m_gray is not None:
@@ -97,7 +95,7 @@ if colorbar_10m_color is not None:
97
  if colorbar_10m_gray is not None:
98
  colorbar_10m_gray = cv2.cvtColor(colorbar_10m_gray, cv2.COLOR_BGR2RGB)
99
 
100
- # ================== 推理函数 ==================
101
  @gpu_decorator
102
  def infer_raw(img_rgb: np.ndarray):
103
  if img_rgb is None:
@@ -105,7 +103,6 @@ def infer_raw(img_rgb: np.ndarray):
105
 
106
  import torch
107
 
108
- # 保持你原逻辑:不 resize,直接喂入
109
  img = img_rgb.astype(np.float32) / 255.0
110
  tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
111
 
@@ -119,7 +116,6 @@ def infer_raw(img_rgb: np.ndarray):
119
  outputs["pred_depth"][~mask] = 1
120
  pred = outputs["pred_depth"][0].cpu().squeeze().numpy()
121
  else:
122
- # 保持你原逻辑的 fallback
123
  pred = outputs[0].cpu().squeeze().numpy()
124
 
125
  return pred.astype(np.float32)
@@ -152,11 +148,10 @@ def visualize_10m(pred: np.ndarray):
152
 
153
  @gpu_decorator
154
  def infer_and_vis_100m(img_rgb: np.ndarray):
155
- pred = infer_raw(img_rgb) # 跑模型一次(GPU)
156
- color, gray, npy, cbar_color, cbar_gray = visualize_100m(pred) # 默认100m显示(CPU)
157
  return pred, color, gray, npy, cbar_color, cbar_gray
158
 
159
- # ================== Gradio UI ==================
160
  example_paths = [
161
  "hfdemo/01.jpg",
162
  "hfdemo/02.jpg",
@@ -196,12 +191,10 @@ with gr.Blocks() as demo:
196
  )
197
  gr.Markdown("# Official Depth Prediction demo for **[DAP](https://insta360-research-team.github.io/DAP_website/)**")
198
 
199
- raw_depth = gr.State() # 🔑 保存模型输出
200
 
201
  with gr.Row():
202
 
203
- # ========== Left ==========
204
- # 左侧列(Input Image)
205
  with gr.Column(scale=10):
206
  inp = gr.Image(
207
  type="numpy",
@@ -231,29 +224,23 @@ with gr.Blocks() as demo:
231
  elem_id="vis_hint",
232
  )
233
 
234
- # ========== Right ==========
235
- # 右侧整体(包含 中间列 + colorbar 列)
236
  with gr.Column(scale=11):
237
 
238
  # -------- Row 1: Color Depth --------
239
  with gr.Row():
240
- # 中间列(必须和左侧等宽)
241
  with gr.Column(scale=10):
242
  out_color = gr.Image(
243
  label="Depth (Color)",
244
  height=260
245
  )
246
 
247
- # colorbar 列(很窄)
248
  with gr.Column(scale=1, min_width=80):
249
  colorbar_color = gr.Image(
250
  label="Scale",
251
  height=260,
252
- show_label=False,
253
- show_download_button=False
254
  )
255
 
256
- # -------- Row 2: Gray Depth --------
257
  with gr.Row():
258
  with gr.Column(scale=10):
259
  out_gray = gr.Image(
@@ -265,28 +252,24 @@ with gr.Blocks() as demo:
265
  colorbar_gray = gr.Image(
266
  label="Scale",
267
  height=260,
268
- show_label=False,
269
- show_download_button=False
270
  )
271
 
272
  out_npy = gr.File(label="Depth (.npy)")
273
 
274
 
275
- # 1️⃣ 跑模型
276
  btn_infer.click(
277
  fn=infer_and_vis_100m,
278
  inputs=inp,
279
  outputs=[raw_depth, out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
280
  )
281
 
282
- # 2️⃣ 100m
283
  btn_100m.click(
284
  fn=visualize_100m,
285
  inputs=raw_depth,
286
  outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
287
  )
288
 
289
- # 3️⃣ 10m
290
  btn_10m.click(
291
  fn=visualize_10m,
292
  inputs=raw_depth,
@@ -298,9 +281,9 @@ if __name__ == "__main__":
298
  host = os.environ.get("HOST", "0.0.0.0")
299
  port = int(os.environ.get("PORT", "7860"))
300
 
301
- demo.queue( # ✅ 开启排队
302
- max_size=32, # 队列最大长度(可选)
303
- default_concurrency_limit=1, # 每个函数默认并发(可选,GPU任务通常设 1)
304
  ).launch(
305
  server_name=host,
306
  server_port=port,
 
8
  import gradio as gr
9
  from huggingface_hub import hf_hub_download
10
 
 
11
  try:
12
+ import spaces
13
  gpu_decorator = spaces.GPU
14
  except Exception:
15
  gpu_decorator = lambda f: f
16
 
17
+
 
18
  PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
19
  sys.path.append(PROJECT_ROOT)
20
 
21
  from networks.models import make # noqa: E402
22
 
23
+
24
  WEIGHTS_REPO = "Insta360-Research/DAP-weights"
25
  WEIGHTS_FILE = "model.pth"
26
  CONFIG_PATH = os.path.join(PROJECT_ROOT, "config", "infer.yaml")
 
28
  model = None
29
  device = "cpu"
30
 
31
+
32
  import matplotlib
33
 
34
  def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
 
41
  colored = (colored * 255).astype(np.uint8)
42
  return np.ascontiguousarray(colored)
43
 
44
+
45
  def load_model(config_path: str):
46
  import torch
47
  import torch.nn as nn
 
75
  print("✅ Model loaded.")
76
  return m
77
 
78
+
79
  model = load_model(CONFIG_PATH)
80
 
81
+
82
  COLORBAR_DIR = os.path.join(PROJECT_ROOT, "colorbars")
83
  colorbar_100m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_color.png"))
84
  colorbar_100m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_100m_gray.png"))
85
  colorbar_10m_color = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_color.png"))
86
  colorbar_10m_gray = cv2.imread(os.path.join(COLORBAR_DIR, "colorbar_10m_gray.png"))
87
 
88
+
89
  if colorbar_100m_color is not None:
90
  colorbar_100m_color = cv2.cvtColor(colorbar_100m_color, cv2.COLOR_BGR2RGB)
91
  if colorbar_100m_gray is not None:
 
95
  if colorbar_10m_gray is not None:
96
  colorbar_10m_gray = cv2.cvtColor(colorbar_10m_gray, cv2.COLOR_BGR2RGB)
97
 
98
+
99
  @gpu_decorator
100
  def infer_raw(img_rgb: np.ndarray):
101
  if img_rgb is None:
 
103
 
104
  import torch
105
 
 
106
  img = img_rgb.astype(np.float32) / 255.0
107
  tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
108
 
 
116
  outputs["pred_depth"][~mask] = 1
117
  pred = outputs["pred_depth"][0].cpu().squeeze().numpy()
118
  else:
 
119
  pred = outputs[0].cpu().squeeze().numpy()
120
 
121
  return pred.astype(np.float32)
 
148
 
149
  @gpu_decorator
150
  def infer_and_vis_100m(img_rgb: np.ndarray):
151
+ pred = infer_raw(img_rgb)
152
+ color, gray, npy, cbar_color, cbar_gray = visualize_100m(pred)
153
  return pred, color, gray, npy, cbar_color, cbar_gray
154
 
 
155
  example_paths = [
156
  "hfdemo/01.jpg",
157
  "hfdemo/02.jpg",
 
191
  )
192
  gr.Markdown("# Official Depth Prediction demo for **[DAP](https://insta360-research-team.github.io/DAP_website/)**")
193
 
194
+ raw_depth = gr.State()
195
 
196
  with gr.Row():
197
 
 
 
198
  with gr.Column(scale=10):
199
  inp = gr.Image(
200
  type="numpy",
 
224
  elem_id="vis_hint",
225
  )
226
 
 
 
227
  with gr.Column(scale=11):
228
 
229
  # -------- Row 1: Color Depth --------
230
  with gr.Row():
 
231
  with gr.Column(scale=10):
232
  out_color = gr.Image(
233
  label="Depth (Color)",
234
  height=260
235
  )
236
 
 
237
  with gr.Column(scale=1, min_width=80):
238
  colorbar_color = gr.Image(
239
  label="Scale",
240
  height=260,
241
+ show_label=False
 
242
  )
243
 
 
244
  with gr.Row():
245
  with gr.Column(scale=10):
246
  out_gray = gr.Image(
 
252
  colorbar_gray = gr.Image(
253
  label="Scale",
254
  height=260,
255
+ show_label=False
 
256
  )
257
 
258
  out_npy = gr.File(label="Depth (.npy)")
259
 
260
 
 
261
  btn_infer.click(
262
  fn=infer_and_vis_100m,
263
  inputs=inp,
264
  outputs=[raw_depth, out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
265
  )
266
 
 
267
  btn_100m.click(
268
  fn=visualize_100m,
269
  inputs=raw_depth,
270
  outputs=[out_color, out_gray, out_npy, colorbar_color, colorbar_gray],
271
  )
272
 
 
273
  btn_10m.click(
274
  fn=visualize_10m,
275
  inputs=raw_depth,
 
281
  host = os.environ.get("HOST", "0.0.0.0")
282
  port = int(os.environ.get("PORT", "7860"))
283
 
284
+ demo.queue(
285
+ max_size=32,
286
+ default_concurrency_limit=1,
287
  ).launch(
288
  server_name=host,
289
  server_port=port,