qingshan777 commited on
Commit
2ee4fc8
·
verified ·
1 Parent(s): 176a32d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import io
5
+ from PIL import Image
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ )
11
+ import numpy as np
12
+
13
+ model_root = "qihoo360/fg-clip2-base"
14
+ model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True)
15
+
16
+ device = model.device
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(model_root)
19
+ image_processor = AutoImageProcessor.from_pretrained(model_root)
20
+
21
+ import math
22
+ import matplotlib
23
+ matplotlib.use('Agg')
24
+ import matplotlib.pyplot as plt
25
+
26
+ import ast
27
+
28
+
29
+ def resize_short_edge(image, target_size=2048):
30
+
31
+ if isinstance(image, str):
32
+ image = Image.open(image)
33
+
34
+ width, height = image.size
35
+ short_edge = min(width, height)
36
+
37
+ if short_edge >= target_size:
38
+ return image
39
+
40
+ scale = target_size / short_edge
41
+ new_width = int(width * scale)
42
+ new_height = int(height * scale)
43
+
44
+ resized_image = image.resize((new_width, new_height))
45
+
46
+ return resized_image
47
+
48
+ def Get_Densefeature(image, candidate_labels):
49
+ """
50
+ Takes an image and a comma-separated string of candidate labels,
51
+ and returns the classification scores.
52
+ """
53
+ candidate_labels = ast.literal_eval(candidate_labels)
54
+ assert len(candidate_labels) != 0
55
+ print(candidate_labels)
56
+
57
+ image = image.convert("RGB")
58
+
59
+ image = resize_short_edge(image,target_size=2048)
60
+
61
+ image_input = image_processor(images=image, max_num_patches=16384, return_tensors="pt").to(device)
62
+ # captions = ["电脑","黑猫","窗户","window","white cat","book"]
63
+ captions = candidate_labels
64
+
65
+ with torch.no_grad():
66
+ dense_image_feature = model.get_image_dense_feature(**image_input)
67
+
68
+ spatial_values = image_input["spatial_shapes"][0]
69
+ real_h = spatial_values[0].item()
70
+ real_w = spatial_values[1].item()
71
+ real_pixel_tokens_num = real_w*real_h
72
+ dense_image_feature = dense_image_feature[0][:real_pixel_tokens_num]
73
+
74
+
75
+ captions = [caption.lower() for caption in captions]
76
+ caption_input = tokenizer(captions, padding="max_length", max_length=64, truncation=True, return_tensors="pt").to(device)
77
+
78
+ text_feature = model.get_text_features(**caption_input, walk_type="box")
79
+ text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
80
+ dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True)
81
+
82
+ similarity = dense_image_feature @ text_feature.T
83
+ similarity = similarity.cpu()
84
+
85
+ num_classes = len(captions)
86
+ cols = 3
87
+ rows = (num_classes + cols - 1) // cols
88
+
89
+
90
+ aspect_ratio = real_w / real_h
91
+
92
+ fig_width_inch = 3 * cols
93
+ fig_height_inch = fig_width_inch / aspect_ratio * rows / cols
94
+
95
+ fig, axes = plt.subplots(rows, cols, figsize=(fig_width_inch, fig_height_inch))
96
+ fig.subplots_adjust(wspace=0.01, hspace=0.01)
97
+
98
+ if num_classes == 1:
99
+ axes = [axes]
100
+ else:
101
+ axes = axes.flatten()
102
+
103
+ for cls_index in range(num_classes):
104
+ similarity_map = similarity[:, cls_index].cpu().numpy()
105
+ show_image = similarity_map.reshape((real_h, real_w))
106
+
107
+ ax = axes[cls_index]
108
+ ax.imshow(show_image, cmap='viridis', aspect='equal') # 保持原始比例
109
+ ax.set_xticks([])
110
+ ax.set_yticks([])
111
+ ax.axis('off')
112
+
113
+
114
+ for idx in range(num_classes, len(axes)):
115
+ axes[idx].axis('off')
116
+
117
+ buf = io.BytesIO()
118
+ plt.savefig(buf, format='png')
119
+ buf.seek(0)
120
+ plt.close(fig)
121
+
122
+ pil_img = Image.open(buf)
123
+
124
+ # buf.close()
125
+ return pil_img
126
+
127
+
128
+
129
+
130
+
131
+ with gr.Blocks() as demo:
132
+ gr.Markdown("# FG-CLIP 2 Densefeature")
133
+ gr.Markdown(
134
+
135
+ "This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for Densefeature show on CPU :"
136
+ )
137
+
138
+ with gr.Row():
139
+ with gr.Column():
140
+ image_input = gr.Image(type="pil")
141
+ text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']")
142
+ dfs_button = gr.Button("Run Densefeature", visible=True)
143
+ with gr.Column():
144
+ dfs_output = gr.Image(label="Similarity Visualization", type="pil")
145
+
146
+ examples = [
147
+ ["./cat_dfclor.jpg", str(["电脑","黑猫","窗户","window","white cat","book"])],
148
+ ]
149
+ gr.Examples(
150
+ examples=examples,
151
+ inputs=[image_input, text_input],
152
+
153
+ )
154
+ dfs_button.click(fn=Get_Densefeature, inputs=[image_input, text_input], outputs=dfs_output)
155
+
156
+ demo.launch()
157
+
158
+ # demo.launch(server_name="0.0.0.0", server_port=7862, share=True)