|
|
import gradio as gr |
|
|
|
|
|
import torch |
|
|
import io |
|
|
from PIL import Image |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
) |
|
|
import numpy as np |
|
|
|
|
|
model_root = "qihoo360/fg-clip2-base" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True) |
|
|
|
|
|
device = model.device |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_root) |
|
|
image_processor = AutoImageProcessor.from_pretrained(model_root) |
|
|
|
|
|
import math |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import ast |
|
|
|
|
|
|
|
|
def resize_short_edge(image, target_size=2048): |
|
|
|
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
|
|
|
width, height = image.size |
|
|
short_edge = min(width, height) |
|
|
|
|
|
if short_edge >= target_size: |
|
|
return image |
|
|
|
|
|
scale = target_size / short_edge |
|
|
new_width = int(width * scale) |
|
|
new_height = int(height * scale) |
|
|
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
|
|
|
|
return resized_image |
|
|
|
|
|
def Get_Densefeature(image, candidate_labels): |
|
|
""" |
|
|
Takes an image and a comma-separated string of candidate labels, |
|
|
and returns the classification scores. |
|
|
""" |
|
|
candidate_labels = ast.literal_eval(candidate_labels) |
|
|
assert len(candidate_labels) != 0 |
|
|
print(candidate_labels) |
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
image = resize_short_edge(image,target_size=1024) |
|
|
|
|
|
image_input = image_processor(images=image, max_num_patches=4096, return_tensors="pt").to(device) |
|
|
|
|
|
captions = candidate_labels |
|
|
|
|
|
with torch.no_grad(): |
|
|
dense_image_feature = model.get_image_dense_feature(**image_input) |
|
|
|
|
|
spatial_values = image_input["spatial_shapes"][0] |
|
|
real_h = spatial_values[0].item() |
|
|
real_w = spatial_values[1].item() |
|
|
real_pixel_tokens_num = real_w*real_h |
|
|
dense_image_feature = dense_image_feature[0][:real_pixel_tokens_num] |
|
|
|
|
|
|
|
|
captions = [caption.lower() for caption in captions] |
|
|
caption_input = tokenizer(captions, padding="max_length", max_length=64, truncation=True, return_tensors="pt").to(device) |
|
|
|
|
|
text_feature = model.get_text_features(**caption_input, walk_type="box") |
|
|
text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True) |
|
|
dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
similarity = dense_image_feature @ text_feature.T |
|
|
similarity = similarity.cpu() |
|
|
|
|
|
num_classes = len(captions) |
|
|
cols = 3 |
|
|
rows = (num_classes + cols - 1) // cols |
|
|
|
|
|
|
|
|
aspect_ratio = real_w / real_h |
|
|
|
|
|
fig_width_inch = 3 * cols |
|
|
fig_height_inch = fig_width_inch / aspect_ratio * rows / cols |
|
|
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(fig_width_inch, fig_height_inch)) |
|
|
fig.subplots_adjust(wspace=0.01, hspace=0.01) |
|
|
|
|
|
if num_classes == 1: |
|
|
axes = [axes] |
|
|
else: |
|
|
axes = axes.flatten() |
|
|
|
|
|
for cls_index in range(num_classes): |
|
|
similarity_map = similarity[:, cls_index].cpu().numpy() |
|
|
show_image = similarity_map.reshape((real_h, real_w)) |
|
|
|
|
|
ax = axes[cls_index] |
|
|
ax.imshow(show_image, cmap='viridis', aspect='equal') |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.axis('off') |
|
|
|
|
|
|
|
|
for idx in range(num_classes, len(axes)): |
|
|
axes[idx].axis('off') |
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png') |
|
|
buf.seek(0) |
|
|
plt.close(fig) |
|
|
|
|
|
pil_img = Image.open(buf) |
|
|
|
|
|
|
|
|
return pil_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# FG-CLIP 2 Densefeature") |
|
|
gr.Markdown( |
|
|
|
|
|
"This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for Densefeature show on CPU :" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(type="pil") |
|
|
text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']") |
|
|
dfs_button = gr.Button("Run Densefeature", visible=True) |
|
|
with gr.Column(): |
|
|
dfs_output = gr.Image(label="Similarity Visualization", type="pil") |
|
|
|
|
|
examples = [ |
|
|
["./cat_dfclor.jpg", str(["电脑","黑猫","窗户","window","white cat","book"])], |
|
|
] |
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[image_input, text_input], |
|
|
|
|
|
) |
|
|
dfs_button.click(fn=Get_Densefeature, inputs=[image_input, text_input], outputs=dfs_output) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|