Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import torch | |
| import wget | |
| import yolov7 | |
| from mobile_sam import SamPredictor, sam_model_registry | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from transformers import YolosForObjectDetection, YolosImageProcessor | |
| from images_to_video import VideoCreator | |
| from video_to_images import ImageCreator | |
| def download_mobile_sam_weight(path): | |
| if not os.path.exists(path): | |
| sam_weights = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt" | |
| for i in range(2, len(path.split("/"))): | |
| temp = path.split("/")[:i] | |
| cur_path = "/".join(temp) | |
| if not os.path.isdir(cur_path): | |
| os.mkdir(cur_path) | |
| model_name = path.split("/")[-1] | |
| if model_name in sam_weights: | |
| wget.download(sam_weights, path) | |
| else: | |
| raise NameError( | |
| "There is no pretrained weight to download for %s, you need to provide a path to segformer weights." | |
| % model_name | |
| ) | |
| def get_closest_bbox(bbox_list, bbox_target): | |
| """ | |
| Given a list of bounding boxes, find the one that is closest to the target bounding box. | |
| Args: | |
| bbox_list: list of bounding boxes | |
| bbox_target: target bounding box | |
| Returns: | |
| closest bounding box | |
| """ | |
| min_dist = 100000000 | |
| min_idx = 0 | |
| for idx, bbox in enumerate(bbox_list): | |
| dist = np.linalg.norm(bbox - bbox_target) | |
| if dist < min_dist: | |
| min_dist = dist | |
| min_idx = idx | |
| return bbox_list[min_idx] | |
| def get_bboxes(image_file, image, model, image_processor, threshold=0.9): | |
| if image_processor is None: | |
| results = model(image_file) | |
| predictions = results.pred[0] | |
| boxes = predictions[:, :4].detach().numpy() | |
| return boxes | |
| else: | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = image_processor.post_process_object_detection( | |
| outputs, threshold=threshold, target_sizes=target_sizes | |
| )[0] | |
| return results["boxes"].detach().numpy() | |
| def segment_video( | |
| video_filename, | |
| dir_frames, | |
| image_start, | |
| image_end, | |
| bbox_file, | |
| skip_vid2im, | |
| mobile_sam_weights, | |
| auto_detect=False, | |
| tracker_name="yolov7", | |
| background_color="#009000", | |
| output_dir="output_frames", | |
| output_video="output.mp4", | |
| pbar=False, | |
| reverse_mask=False, | |
| ): | |
| if not skip_vid2im: | |
| vid_to_im = ImageCreator( | |
| video_filename, | |
| dir_frames, | |
| image_start=image_start, | |
| image_end=image_end, | |
| pbar=pbar, | |
| ) | |
| vid_to_im.get_images() | |
| # Get fps of video | |
| vid = cv2.VideoCapture(video_filename) | |
| fps = vid.get(cv2.CAP_PROP_FPS) | |
| vid.release() | |
| background_color = background_color.lstrip("#") | |
| background_color = ( | |
| np.array([int(background_color[i : i + 2], 16) for i in (0, 2, 4)]) / 255.0 | |
| ) | |
| with open(bbox_file, "r") as f: | |
| bbox_orig = [int(coord) for coord in f.read().split(" ")] | |
| download_mobile_sam_weight(mobile_sam_weights) | |
| if image_end == 0: | |
| frames = sorted(os.listdir(dir_frames))[image_start:] | |
| else: | |
| frames = sorted(os.listdir(dir_frames))[image_start:image_end] | |
| model_type = "vit_t" | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| sam = sam_model_registry[model_type](checkpoint=mobile_sam_weights) | |
| sam.to(device=device) | |
| sam.eval() | |
| predictor = SamPredictor(sam) | |
| if not auto_detect: | |
| if tracker_name == "yolov7": | |
| model = yolov7.load("kadirnar/yolov7-tiny-v0.1", hf_model=True) | |
| model.conf = 0.25 # NMS confidence threshold | |
| model.iou = 0.45 # NMS IoU threshold | |
| model.classes = None | |
| image_processor = None | |
| else: | |
| model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny") | |
| image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny") | |
| output_frames = [] | |
| if pbar: | |
| pb = tqdm(frames) | |
| else: | |
| pb = frames | |
| processed_frames = 0 | |
| init_time = time.time() | |
| for frame in pb: | |
| processed_frames += 1 | |
| image_file = dir_frames + "/" + frame | |
| image_pil = Image.open(image_file) | |
| image_np = np.array(image_pil) | |
| if not auto_detect: | |
| bboxes = get_bboxes(image_file, image_pil, model, image_processor) | |
| closest_bbox = get_closest_bbox(bboxes, bbox_orig) | |
| input_box = np.array(closest_bbox) | |
| else: | |
| input_box = np.array([0, 0, image_np.shape[1], image_np.shape[0]]) | |
| predictor.set_image(image_np) | |
| masks, _, _ = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_box[None, :], | |
| multimask_output=True, | |
| ) | |
| if reverse_mask: | |
| mask = masks[0] | |
| h, w = mask.shape[-2:] | |
| mask_image = ( | |
| (mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1) | |
| ) * 255 | |
| masked_image = image_np * (1 - mask).reshape(h, w, 1) | |
| masked_image = masked_image + mask_image | |
| output_frames.append(masked_image) | |
| else: | |
| mask = masks[0] | |
| h, w = mask.shape[-2:] | |
| mask_image = ( | |
| (1 - mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1) | |
| ) * 255 | |
| masked_image = image_np * mask.reshape(h, w, 1) | |
| masked_image = masked_image + mask_image | |
| output_frames.append(masked_image) | |
| if not pbar and processed_frames % 10 == 0: | |
| remaining_time = ( | |
| (time.time() - init_time) | |
| / processed_frames | |
| * (len(frames) - processed_frames) | |
| ) | |
| remaining_time = int(remaining_time) | |
| remaining_time_str = f"{remaining_time//60}m {remaining_time%60}s" | |
| print( | |
| f"Processed frame {processed_frames}/{len(frames)} - Remaining time: {remaining_time_str}" | |
| ) | |
| if not os.path.exists(output_dir): | |
| os.mkdir(output_dir) | |
| zfill_max = len(str(len(output_frames))) | |
| for idx, frame in enumerate(output_frames): | |
| cv2.imwrite( | |
| f"{output_dir}/frame_{str(idx).zfill(zfill_max)}.png", | |
| frame, | |
| ) | |
| vid_creator = VideoCreator(output_dir, output_video, pbar=pbar) | |
| vid_creator.create_video(fps=int(fps)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--video_filename", | |
| default="assets/example.mp4", | |
| type=str, | |
| help="path to the video", | |
| ) | |
| parser.add_argument( | |
| "--dir_frames", | |
| type=str, | |
| default="frames", | |
| help="path to the directory in which all input frames will be stored", | |
| ) | |
| parser.add_argument( | |
| "--image_start", type=int, default=0, help="first image to be stored" | |
| ) | |
| parser.add_argument( | |
| "--image_end", | |
| type=int, | |
| default=0, | |
| help="last image to be stored, last one if 0", | |
| ) | |
| parser.add_argument( | |
| "--bbox_file", | |
| type=str, | |
| default="bbox.txt", | |
| help="path to the bounding box text file", | |
| ) | |
| parser.add_argument( | |
| "--skip_vid2im", | |
| action="store_true", | |
| help="whether to write the video frames as images", | |
| ) | |
| parser.add_argument( | |
| "--mobile_sam_weights", | |
| type=str, | |
| default="./models/mobile_sam.pt", | |
| help="path to MobileSAM weights", | |
| ) | |
| parser.add_argument( | |
| "--tracker_name", | |
| type=str, | |
| default="yolov7", | |
| help="tracker name", | |
| choices=["yolov7", "yoloS"], | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="output_frames", | |
| help="directory to store the output frames", | |
| ) | |
| parser.add_argument( | |
| "--output_video", | |
| type=str, | |
| default="output.mp4", | |
| help="path to store the output video", | |
| ) | |
| parser.add_argument( | |
| "--auto_detect", | |
| action="store_true", | |
| help="whether to use a bounding box to force the model to segment the object", | |
| ) | |
| parser.add_argument( | |
| "--background_color", | |
| type=str, | |
| default="#009000", | |
| help="background color for the output (hex)", | |
| ) | |
| args = parser.parse_args() | |
| segment_video( | |
| args.video_filename, | |
| args.dir_frames, | |
| args.image_start, | |
| args.image_end, | |
| args.bbox_file, | |
| args.skip_vid2im, | |
| args.mobile_sam_weights, | |
| args.auto_detect, | |
| args.output_dir, | |
| args.output_video, | |
| args.tracker_name, | |
| args.background_color, | |
| ) | |