Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,10 +33,11 @@ def load_openshape(name, to_cpu=False):
|
|
| 33 |
|
| 34 |
|
| 35 |
def load_tripletmix(name, to_cpu=False):
|
| 36 |
-
pce = openshape.load_pc_encoder_mix(name)
|
| 37 |
if to_cpu:
|
| 38 |
pce = pce.cpu()
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
|
|
@@ -81,9 +82,10 @@ def classification_lvis(load_data):
|
|
| 81 |
col2 = utils.render_pc(pc)
|
| 82 |
prog.progress(0.5, "Running Classification")
|
| 83 |
ref_dev = next(model_g14.parameters()).device
|
| 84 |
-
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
argsort = torch.argsort(sim, descending=True)
|
| 88 |
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
| 89 |
with col2:
|
|
@@ -103,8 +105,10 @@ def classification_custom(load_data, cats):
|
|
| 103 |
|
| 104 |
prog.progress(0.5, "Running Classification")
|
| 105 |
ref_dev = next(model_g14.parameters()).device
|
| 106 |
-
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
argsort = torch.argsort(sim, descending=True)
|
| 109 |
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
| 110 |
with col2:
|
|
@@ -197,11 +201,19 @@ try:
|
|
| 197 |
f32 = numpy.float32
|
| 198 |
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
| 199 |
clip_model, clip_prep = load_openclip()
|
| 200 |
-
model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
|
| 201 |
#model_g14 = load_tripletmix('tripletmix-spconv-all')
|
| 202 |
|
| 203 |
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
| 204 |
st.sidebar.title("TripletMix Demo Configuration Panel")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
task = st.sidebar.selectbox(
|
| 206 |
'Task Selection',
|
| 207 |
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def load_tripletmix(name, to_cpu=False):
|
| 36 |
+
pce, pca = openshape.load_pc_encoder_mix(name)
|
| 37 |
if to_cpu:
|
| 38 |
pce = pce.cpu()
|
| 39 |
+
pca = pca.cpu()
|
| 40 |
+
return pce, pca
|
| 41 |
|
| 42 |
|
| 43 |
|
|
|
|
| 82 |
col2 = utils.render_pc(pc)
|
| 83 |
prog.progress(0.5, "Running Classification")
|
| 84 |
ref_dev = next(model_g14.parameters()).device
|
| 85 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
| 86 |
+
if model_name == "pb-sn-M":
|
| 87 |
+
enc = pc_adapter(enc)
|
| 88 |
+
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
| 89 |
argsort = torch.argsort(sim, descending=True)
|
| 90 |
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
| 91 |
with col2:
|
|
|
|
| 105 |
|
| 106 |
prog.progress(0.5, "Running Classification")
|
| 107 |
ref_dev = next(model_g14.parameters()).device
|
| 108 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
| 109 |
+
if model_name == "pb-sn-M":
|
| 110 |
+
enc = pc_adapter(enc)
|
| 111 |
+
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
| 112 |
argsort = torch.argsort(sim, descending=True)
|
| 113 |
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
| 114 |
with col2:
|
|
|
|
| 201 |
f32 = numpy.float32
|
| 202 |
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
| 203 |
clip_model, clip_prep = load_openclip()
|
| 204 |
+
#model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
|
| 205 |
#model_g14 = load_tripletmix('tripletmix-spconv-all')
|
| 206 |
|
| 207 |
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
| 208 |
st.sidebar.title("TripletMix Demo Configuration Panel")
|
| 209 |
+
model_name = st.sidebar.selectbox(
|
| 210 |
+
'Model Selection',
|
| 211 |
+
("pb-sn-M", "pb-sn")
|
| 212 |
+
)
|
| 213 |
+
if model_name == "pb-sn-M":
|
| 214 |
+
model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
|
| 215 |
+
elif model_name == "pb-sn":
|
| 216 |
+
model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
|
| 217 |
task = st.sidebar.selectbox(
|
| 218 |
'Task Selection',
|
| 219 |
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|