| | import os |
| | import sys |
| | import json |
| |
|
| | |
| | import jax |
| | import numpy as np |
| | import tensorflow as tf |
| | from scipy.special import expit as sigmoid |
| |
|
| | import skimage |
| | from skimage import io as skimage_io |
| | from skimage import transform as skimage_transform |
| | import matplotlib as mpl |
| | from matplotlib import pyplot as plt |
| |
|
| | sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') |
| | tf.config.experimental.set_visible_devices([], 'GPU') |
| |
|
| | from scenic.projects.owl_vit import configs |
| | from scenic.projects.owl_vit import models |
| |
|
| | |
| | from owlv2_helper_functions import read_images, preprocess_images |
| | from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image, plot_boxes_on_image |
| | from owlv2_helper_functions import top_object_index |
| | from owlv2_helper_functions import rescale_detection_box |
| |
|
| |
|
| |
|
| |
|
| | """ |
| | Prepare OWLv2 pretrained model |
| | """ |
| | config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') |
| | module = models.TextZeroShotDetectionModule( |
| | body_configs=config.model.body, |
| | objectness_head_configs=config.model.objectness_head, |
| | normalize=config.model.normalize, |
| | box_bias=config.model.box_bias) |
| | variables = module.load_variables(config.init_from.checkpoint_path) |
| |
|
| |
|
| |
|
| |
|
| | """ |
| | Wrapped model components |
| | """ |
| | import functools |
| |
|
| | image_embedder = jax.jit( |
| | functools.partial(module.apply, variables, train=False, method=module.image_embedder)) |
| | objectness_predictor = jax.jit( |
| | functools.partial(module.apply, variables, method=module.objectness_predictor)) |
| | box_predictor = jax.jit( |
| | functools.partial(module.apply, variables, method=module.box_predictor)) |
| | class_predictor = jax.jit( |
| | functools.partial(module.apply, variables, method=module.class_predictor)) |
| |
|
| |
|
| |
|
| |
|
| | """ |
| | Detect the main object on instances' images |
| | """ |
| | INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_0' |
| | INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections_0' |
| |
|
| | model_input_size = config.dataset_configs.input_size |
| | images, source_images_names = read_images(INSTANCE_DIR) |
| | source_images = preprocess_images(images, model_input_size) |
| |
|
| | feature_map = image_embedder(source_images) |
| | b, h, w, d = feature_map.shape |
| | image_features = feature_map.reshape(b, h * w, d) |
| |
|
| | objectnesses = objectness_predictor(image_features)['objectness_logits'] |
| | bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] |
| | source_class_embeddings = class_predictor(image_features=image_features)['class_embeddings'] |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | objectnesses = sigmoid(objectnesses) |
| | top_objectnesses = np.max(objectnesses, axis=1) |
| |
|
| | instances, query_embeddings, indexes = [], [], [] |
| | for i in range(len(source_images_names)): |
| | index = top_object_index(objectnesses[i], top_objectnesses[i]) |
| | query_embedding = source_class_embeddings[index] |
| |
|
| | indexes.append(index) |
| | instances.append(source_images_names[i].split('_')[0]) |
| | query_embeddings.append(query_embedding) |
| |
|
| | output_file = os.path.join(INSTANCE_DETECTION, source_images_names[i]) |
| | plot_bbox_on_image(source_images[i], bboxes[i], objectnesses[i], top_objectnesses[i], output_file) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample' |
| | OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results' |
| |
|
| | images, target_images_names = read_images(IMAGE_DIR) |
| | target_images = preprocess_images(images, model_input_size) |
| |
|
| | for target_image, target_image_name, image in zip(target_images, target_images_names, images): |
| |
|
| | feature_map = image_embedder(target_image[None, ...]) |
| | b, h, w, d = feature_map.shape |
| | target_boxes = box_predictor(image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map)['pred_boxes'] |
| |
|
| | target_class_predictions = class_predictor( |
| | image_features=feature_map.reshape(b, h * w, d), |
| | query_embeddings=query_embedding[None, ...], |
| | ) |
| |
|
| | logits = np.array(target_class_predictions['pred_logits'][0]) |
| | raw_boxes = np.array(target_boxes[0]) |
| |
|
| | top_ind = np.argmax(logits[:, 0], axis=0) |
| | score = sigmoid(logits[top_ind, 0]) |
| |
|
| | |
| | |
| |
|
| | boxes = rescale_detection_box(raw_boxes, image) |
| | boxes = boxes[top_ind] |
| |
|
| | score = np.array([score]) |
| | boxes = np.array([boxes]) |
| |
|
| | image_based_plot_boxes_on_image(image, instances, score, boxes, target_image_name, OUTPUT_DIR) |
| | |
| | print(f"Debug: traget instance detection") |
| | |
| | print(f" target_logits: {logits.shape}") |
| | print(logits) |
| | |
| | |
| | |
| |
|
| | |
| |
|