owlv2 / owlv2_img_embeding_2.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
import os
import sys
import json
# pip install ott-jax==0.2.0
import jax
import numpy as np
import tensorflow as tf
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import matplotlib as mpl
from matplotlib import pyplot as plt
sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision')
tf.config.experimental.set_visible_devices([], 'GPU')
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
# from owlv2_helper_functions import prepare_images
from owlv2_helper_functions import read_images, preprocess_images
from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image, plot_boxes_on_image
from owlv2_helper_functions import top_object_index
from owlv2_helper_functions import rescale_detection_box
"""
Prepare OWLv2 pretrained model
"""
config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint')
module = models.TextZeroShotDetectionModule(
body_configs=config.model.body,
objectness_head_configs=config.model.objectness_head,
normalize=config.model.normalize,
box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
"""
Wrapped model components
"""
import functools
image_embedder = jax.jit(
functools.partial(module.apply, variables, train=False, method=module.image_embedder))
objectness_predictor = jax.jit(
functools.partial(module.apply, variables, method=module.objectness_predictor))
box_predictor = jax.jit(
functools.partial(module.apply, variables, method=module.box_predictor))
class_predictor = jax.jit(
functools.partial(module.apply, variables, method=module.class_predictor))
"""
Detect the main object on instances' images
"""
INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_0'
INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections_0'
model_input_size = config.dataset_configs.input_size
images, source_images_names = read_images(INSTANCE_DIR)
source_images = preprocess_images(images, model_input_size)
feature_map = image_embedder(source_images)
b, h, w, d = feature_map.shape
image_features = feature_map.reshape(b, h * w, d)
objectnesses = objectness_predictor(image_features)['objectness_logits']
bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes']
source_class_embeddings = class_predictor(image_features=image_features)['class_embeddings']
# print(f"Debug: source instance detection")
# print(f" Source images features shape: {image_features.shape}")
# print(f" objectnesses shape: {objectnesses.shape}")
# print(f" bboxes shape: {bboxes.shape}")
# print(f" source_class_embeddings shape: {source_class_embeddings.shape}")
objectnesses = sigmoid(objectnesses)
top_objectnesses = np.max(objectnesses, axis=1)
instances, query_embeddings, indexes = [], [], []
for i in range(len(source_images_names)):
index = top_object_index(objectnesses[i], top_objectnesses[i])
query_embedding = source_class_embeddings[index]
indexes.append(index)
instances.append(source_images_names[i].split('_')[0])
query_embeddings.append(query_embedding)
output_file = os.path.join(INSTANCE_DETECTION, source_images_names[i])
plot_bbox_on_image(source_images[i], bboxes[i], objectnesses[i], top_objectnesses[i], output_file)
IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample'
OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results'
images, target_images_names = read_images(IMAGE_DIR)
target_images = preprocess_images(images, model_input_size)
for target_image, target_image_name, image in zip(target_images, target_images_names, images):
feature_map = image_embedder(target_image[None, ...])
b, h, w, d = feature_map.shape
target_boxes = box_predictor(image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map)['pred_boxes']
target_class_predictions = class_predictor(
image_features=feature_map.reshape(b, h * w, d),
query_embeddings=query_embedding[None, ...], # [batch, queries, d]
)
logits = np.array(target_class_predictions['pred_logits'][0])
raw_boxes = np.array(target_boxes[0])
top_ind = np.argmax(logits[:, 0], axis=0)
score = sigmoid(logits[top_ind, 0])
# labels = np.argmax(target_class_predictions['pred_logits'][0], axis=-1)
# scores = sigmoid(np.max(logits, axis=-1))
boxes = rescale_detection_box(raw_boxes, image)
boxes = boxes[top_ind]
score = np.array([score])
boxes = np.array([boxes])
image_based_plot_boxes_on_image(image, instances, score, boxes, target_image_name, OUTPUT_DIR)
print(f"Debug: traget instance detection")
# print(f" target_class_predictions' keys: {target_class_predictions.keys()}")
print(f" target_logits: {logits.shape}")
print(logits)
# print(f" target_scores: {scores.shape}")
# print(f" target_labels: {labels.shape}")
# print(f" target_boxes shape: {raw_boxes.shape}")
# plot_boxes_on_image(image, instances, scores, boxes, labels, target_image_name, 0.5, OUTPUT_DIR)