qingshan777's picture
Update app.py
7072ff7 verified
import gradio as gr
import torch
import io
from PIL import Image
from transformers import (
AutoImageProcessor,
AutoTokenizer,
AutoModelForCausalLM,
)
import numpy as np
model_root = "qihoo360/fg-clip2-base"
model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True)
device = model.device
tokenizer = AutoTokenizer.from_pretrained(model_root)
image_processor = AutoImageProcessor.from_pretrained(model_root)
import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import ast
def resize_short_edge(image, target_size=2048):
if isinstance(image, str):
image = Image.open(image)
width, height = image.size
short_edge = min(width, height)
if short_edge >= target_size:
return image
scale = target_size / short_edge
new_width = int(width * scale)
new_height = int(height * scale)
resized_image = image.resize((new_width, new_height))
return resized_image
def Get_Densefeature(image, candidate_labels):
"""
Takes an image and a comma-separated string of candidate labels,
and returns the classification scores.
"""
candidate_labels = ast.literal_eval(candidate_labels)
assert len(candidate_labels) != 0
print(candidate_labels)
image = image.convert("RGB")
image = resize_short_edge(image,target_size=1024)
image_input = image_processor(images=image, max_num_patches=4096, return_tensors="pt").to(device)
# captions = ["电脑","黑猫","窗户","window","white cat","book"]
captions = candidate_labels
with torch.no_grad():
dense_image_feature = model.get_image_dense_feature(**image_input)
spatial_values = image_input["spatial_shapes"][0]
real_h = spatial_values[0].item()
real_w = spatial_values[1].item()
real_pixel_tokens_num = real_w*real_h
dense_image_feature = dense_image_feature[0][:real_pixel_tokens_num]
captions = [caption.lower() for caption in captions]
caption_input = tokenizer(captions, padding="max_length", max_length=64, truncation=True, return_tensors="pt").to(device)
text_feature = model.get_text_features(**caption_input, walk_type="box")
text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
dense_image_feature = dense_image_feature / dense_image_feature.norm(p=2, dim=-1, keepdim=True)
similarity = dense_image_feature @ text_feature.T
similarity = similarity.cpu()
num_classes = len(captions)
cols = 3
rows = (num_classes + cols - 1) // cols
aspect_ratio = real_w / real_h
fig_width_inch = 3 * cols
fig_height_inch = fig_width_inch / aspect_ratio * rows / cols
fig, axes = plt.subplots(rows, cols, figsize=(fig_width_inch, fig_height_inch))
fig.subplots_adjust(wspace=0.01, hspace=0.01)
if num_classes == 1:
axes = [axes]
else:
axes = axes.flatten()
for cls_index in range(num_classes):
similarity_map = similarity[:, cls_index].cpu().numpy()
show_image = similarity_map.reshape((real_h, real_w))
ax = axes[cls_index]
ax.imshow(show_image, cmap='viridis', aspect='equal') # 保持原始比例
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
for idx in range(num_classes, len(axes)):
axes[idx].axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close(fig)
pil_img = Image.open(buf)
# buf.close()
return pil_img
with gr.Blocks() as demo:
gr.Markdown("# FG-CLIP 2 Densefeature")
gr.Markdown(
"This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for Densefeature show on CPU :"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']")
dfs_button = gr.Button("Run Densefeature", visible=True)
with gr.Column():
dfs_output = gr.Image(label="Similarity Visualization", type="pil")
examples = [
["./cat_dfclor.jpg", str(["电脑","黑猫","窗户","window","white cat","book"])],
]
gr.Examples(
examples=examples,
inputs=[image_input, text_input],
)
dfs_button.click(fn=Get_Densefeature, inputs=[image_input, text_input], outputs=dfs_output)
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=7862, share=True)