Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,25 +30,22 @@ def load_openshape(name, to_cpu=False):
|
|
| 30 |
pce = pce.cpu()
|
| 31 |
return pce
|
| 32 |
|
| 33 |
-
def retrieval_filter_expand(
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
anim_n
|
| 45 |
-
face_n
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
and (tag_n or tag in x['tags'])
|
| 50 |
-
)
|
| 51 |
-
return sim_th, filter_fn
|
| 52 |
|
| 53 |
def retrieval_results(results):
|
| 54 |
st.caption("Click the link to view the 3D shape")
|
|
@@ -148,32 +145,125 @@ def demo_retrieval():
|
|
| 148 |
|
| 149 |
prog.progress(1.0, "Idle")
|
| 150 |
|
| 151 |
-
st.title("TripletMix Demo")
|
| 152 |
-
st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
|
| 153 |
-
prog = st.progress(0.0, "Idle")
|
| 154 |
-
tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([
|
| 155 |
-
"Classification",
|
| 156 |
-
"Retrieval w/ 3D",
|
| 157 |
-
"Retrieval w/ Image",
|
| 158 |
-
"Retrieval w/ Text",
|
| 159 |
-
"Image Generation",
|
| 160 |
-
"Captioning",
|
| 161 |
-
])
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
|
| 169 |
try:
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
except Exception:
|
| 178 |
import traceback
|
| 179 |
st.error(traceback.format_exc().replace("\n", " \n"))
|
|
|
|
| 30 |
pce = pce.cpu()
|
| 31 |
return pce
|
| 32 |
|
| 33 |
+
def retrieval_filter_expand():
|
| 34 |
+
sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth')
|
| 35 |
+
tag = ""
|
| 36 |
+
face_min = 0
|
| 37 |
+
face_max = 34985808
|
| 38 |
+
anim_min = 0
|
| 39 |
+
anim_max = 563
|
| 40 |
+
tag_n = not bool(tag.strip())
|
| 41 |
+
anim_n = not (anim_min > 0 or anim_max < 563)
|
| 42 |
+
face_n = not (face_min > 0 or face_max < 34985808)
|
| 43 |
+
filter_fn = lambda x: (
|
| 44 |
+
(anim_n or anim_min <= x['anims'] <= anim_max)
|
| 45 |
+
and (face_n or face_min <= x['faces'] <= face_max)
|
| 46 |
+
and (tag_n or tag in x['tags'])
|
| 47 |
+
)
|
| 48 |
+
return sim_th, filter_fn
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def retrieval_results(results):
|
| 51 |
st.caption("Click the link to view the 3D shape")
|
|
|
|
| 145 |
|
| 146 |
prog.progress(1.0, "Idle")
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
def retrieval_pc(load_data, k, sim_th, filter_fn):
|
| 150 |
+
pc = load_data(prog)
|
| 151 |
+
prog.progress(0.49, "Computing Embeddings")
|
| 152 |
+
col2 = utils.render_pc(pc)
|
| 153 |
+
ref_dev = next(model_g14.parameters()).device
|
| 154 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
|
| 155 |
+
|
| 156 |
+
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
|
| 157 |
+
argsort = torch.argsort(sim, descending=True)
|
| 158 |
+
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
| 159 |
+
with col2:
|
| 160 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
| 161 |
+
st.text(cat)
|
| 162 |
+
st.caption("Similarity %.4f" % sim)
|
| 163 |
+
|
| 164 |
+
prog.progress(0.7, "Running Retrieval")
|
| 165 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
| 166 |
+
|
| 167 |
+
prog.progress(1.0, "Idle")
|
| 168 |
+
|
| 169 |
+
def retrieval_img(pic, k, sim_th, filter_fn):
|
| 170 |
+
img = Image.open(pic)
|
| 171 |
+
prog.progress(0.49, "Computing Embeddings")
|
| 172 |
+
st.image(img)
|
| 173 |
+
device = clip_model.device
|
| 174 |
+
tn = clip_prep(images=[img], return_tensors="pt").to(device)
|
| 175 |
+
enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
|
| 176 |
+
|
| 177 |
+
prog.progress(0.7, "Running Retrieval")
|
| 178 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
| 179 |
+
|
| 180 |
+
prog.progress(1.0, "Idle")
|
| 181 |
+
|
| 182 |
+
def retrieval_text(text, k, sim_th, filter_fn):
|
| 183 |
+
prog.progress(0.49, "Computing Embeddings")
|
| 184 |
+
device = clip_model.device
|
| 185 |
+
tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
|
| 186 |
+
enc = clip_model.get_text_features(**tn).float().cpu()
|
| 187 |
|
| 188 |
+
prog.progress(0.7, "Running Retrieval")
|
| 189 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
| 190 |
+
|
| 191 |
+
prog.progress(1.0, "Idle")
|
| 192 |
|
| 193 |
try:
|
| 194 |
+
f32 = numpy.float32
|
| 195 |
+
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
| 196 |
+
clip_model, clip_prep = load_openclip()
|
| 197 |
+
model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
|
| 198 |
+
|
| 199 |
+
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
| 200 |
+
st.sidebar.title("TripletMix Demo Configuration Panel")
|
| 201 |
+
task = st.sidebar.selectbox(
|
| 202 |
+
'Task Selection',
|
| 203 |
+
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if task == "3D Classification":
|
| 207 |
+
cls_mode = st.sidebar.selectbox(
|
| 208 |
+
'Choose the source of categories',
|
| 209 |
+
("LVIS Categories", "Custom Categories")
|
| 210 |
+
)
|
| 211 |
+
pc = st.sidebar.text_input("Input pc", key='rtextinput')
|
| 212 |
+
if cls_mode == "LVIS Categories":
|
| 213 |
+
if st.sidebar.button("submit"):
|
| 214 |
+
st.title("Classification with LVIS Categories")
|
| 215 |
+
prog = st.progress(0.0, "Idle")
|
| 216 |
+
|
| 217 |
+
elif cls_mode == "Custom Categories":
|
| 218 |
+
cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
|
| 219 |
+
cats = [a.strip() for a in cats.split(',')]
|
| 220 |
+
if len(cats) > 64:
|
| 221 |
+
st.error('Maximum 64 custom categories supported in the demo')
|
| 222 |
+
if st.sidebar.button("submit"):
|
| 223 |
+
st.title("Classification with Custom Categories")
|
| 224 |
+
prog = st.progress(0.0, "Idle")
|
| 225 |
+
|
| 226 |
+
elif task == "Cross-modal retrieval":
|
| 227 |
+
input_mode = st.sidebar.selectbox(
|
| 228 |
+
'Choose an input modality',
|
| 229 |
+
("Point Cloud", "Image", "Text")
|
| 230 |
+
)
|
| 231 |
+
k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
|
| 232 |
+
sim_th, filter_fn = retrieval_filter_expand()
|
| 233 |
+
if input_mode == "Point Cloud":
|
| 234 |
+
load_data = utils.input_3d_shape('rpcinput')
|
| 235 |
+
if st.sidebar.button("submit"):
|
| 236 |
+
st.title("Retrieval with Point Cloud")
|
| 237 |
+
prog = st.progress(0.0, "Idle")
|
| 238 |
+
retrieval_pc(load_data, k, sim_th, filter_fn)
|
| 239 |
+
elif input_mode == "Image":
|
| 240 |
+
pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
|
| 241 |
+
if st.sidebar.button("submit"):
|
| 242 |
+
st.title("Retrieval with Image")
|
| 243 |
+
prog = st.progress(0.0, "Idle")
|
| 244 |
+
retrieval_img(pic, k, sim_th, filter_fn)
|
| 245 |
+
elif input_mode == "Text":
|
| 246 |
+
text = st.sidebar.text_input("Input Text", key='rtextinput')
|
| 247 |
+
if st.sidebar.button("submit"):
|
| 248 |
+
st.title("Retrieval with Text")
|
| 249 |
+
prog = st.progress(0.0, "Idle")
|
| 250 |
+
retrieval_text(text, k, sim_th, filter_fn)
|
| 251 |
+
elif task == "Cross-modal generation":
|
| 252 |
+
generation_mode = st.sidebar.selectbox(
|
| 253 |
+
'Choose the mode of generation',
|
| 254 |
+
("PointCloud-to-Image", "PointCloud-to-Text")
|
| 255 |
+
)
|
| 256 |
+
pc = st.sidebar.text_input("Input pc", key='rtextinput')
|
| 257 |
+
if generation_mode == "PointCloud-to-Image":
|
| 258 |
+
if st.sidebar.button("submit"):
|
| 259 |
+
st.title("Image Generation")
|
| 260 |
+
prog = st.progress(0.0, "Idle")
|
| 261 |
+
|
| 262 |
+
elif generation_mode == "PointCloud-to-Text":
|
| 263 |
+
if st.sidebar.button("submit"):
|
| 264 |
+
st.title("Text Generation")
|
| 265 |
+
prog = st.progress(0.0, "Idle")
|
| 266 |
+
|
| 267 |
except Exception:
|
| 268 |
import traceback
|
| 269 |
st.error(traceback.format_exc().replace("\n", " \n"))
|