--- license: apache-2.0 tags: - sam2 - segment-anything - onnx - webgpu - computer-vision - image-segmentation library_name: onnxruntime --- # SAM2-HIERA-SMALL - ONNX Format for WebGPU **Powered by [Segment Anything 2 (SAM2)](https://github.com/facebookresearch/segment-anything-2) from Meta Research** This repository contains ONNX-converted models from [facebook/sam2-hiera-small](https://huggingface.co/facebook/sam2-hiera-small), optimized for WebGPU deployment in browsers. ## Model Information - **Original Model**: [facebook/sam2-hiera-small](https://huggingface.co/facebook/sam2-hiera-small) - **Version**: SAM 2.0 - **Size**: 46M parameters - **Description**: Small variant - balanced speed and quality - **Format**: ONNX (encoder + decoder) - **Optimization**: Encoder optimized to .ort format for WebGPU ## Files - `encoder.onnx` - Image encoder (ONNX format) - `encoder.with_runtime_opt.ort` - Image encoder (optimized for WebGPU) - `decoder.onnx` - Mask decoder (ONNX format) - `config.json` - Model configuration ## Usage ### In Browser with ONNX Runtime Web ```javascript import * as ort from 'onnxruntime-web/webgpu'; // Load encoder (use optimized .ort version for WebGPU) const encoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-small-onnx/resolve/main/encoder.with_runtime_opt.ort'; const encoderSession = await ort.InferenceSession.create(encoderURL, { executionProviders: ['webgpu'], graphOptimizationLevel: 'disabled' }); // Load decoder const decoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-small-onnx/resolve/main/decoder.onnx'; const decoderSession = await ort.InferenceSession.create(decoderURL, { executionProviders: ['webgpu'] }); // Run encoder const imageData = preprocessImage(image); // Your preprocessing const encoderOutputs = await encoderSession.run({ image: imageData }); // Run decoder with point const point_coords = new ort.Tensor('float32', [x, y, 0, 0], [1, 2, 2]); const point_labels = new ort.Tensor('float32', [1, -1], [1, 2]); const mask_input = new ort.Tensor('float32', new Float32Array(256 * 256).fill(0), [1, 1, 256, 256]); const has_mask_input = new ort.Tensor('float32', [0], [1]); const decoderOutputs = await decoderSession.run({ image_embed: encoderOutputs.image_embed, high_res_feats_0: encoderOutputs.high_res_feats_0, high_res_feats_1: encoderOutputs.high_res_feats_1, point_coords: point_coords, point_labels: point_labels, mask_input: mask_input, has_mask_input: has_mask_input }); // Get masks const masks = decoderOutputs.masks; // Shape: [1, num_masks, 256, 256] ``` ### In Python with ONNX Runtime ```python import onnxruntime as ort import numpy as np # Load models encoder_session = ort.InferenceSession("encoder.onnx") decoder_session = ort.InferenceSession("decoder.onnx") # Run encoder encoder_outputs = encoder_session.run(None, {"image": image_tensor}) # Run decoder decoder_outputs = decoder_session.run(None, { "image_embed": encoder_outputs[0], "high_res_feats_0": encoder_outputs[1], "high_res_feats_1": encoder_outputs[2], "point_coords": point_coords, "point_labels": point_labels, "mask_input": mask_input, "has_mask_input": has_mask_input }) masks = decoder_outputs[0] ``` ## Input/Output Specifications ### Encoder **Input:** - `image`: Float32[1, 3, 1024, 1024] - Normalized RGB image **Outputs:** - `image_embed`: Float32[1, 256, 64, 64] - Image embeddings - `high_res_feats_0`: Float32[1, 32, 256, 256] - High-res features (level 0) - `high_res_feats_1`: Float32[1, 64, 128, 128] - High-res features (level 1) ### Decoder **Inputs:** - `image_embed`: Float32[1, 256, 64, 64] - From encoder - `high_res_feats_0`: Float32[1, 32, 256, 256] - From encoder - `high_res_feats_1`: Float32[1, 64, 128, 128] - From encoder - `point_coords`: Float32[1, 2, 2] - Point coordinates [[x, y], [0, 0]] - `point_labels`: Float32[1, 2] - Point labels [1, -1] (1=foreground, -1=padding) - `mask_input`: Float32[1, 1, 256, 256] - Previous mask (zeros if none) - `has_mask_input`: Float32[1] - Flag [0] or [1] **Outputs:** - `masks`: Float32[1, 3, 256, 256] - Generated masks (3 candidates) - `iou_predictions`: Float32[1, 3] - IoU scores for each mask - `low_res_masks`: Float32[1, 3, 256, 256] - Low-resolution masks ## Browser Requirements - Chrome 113+ with WebGPU enabled (`chrome://flags/#enable-unsafe-webgpu`) - Firefox Nightly with WebGPU enabled - Safari Technology Preview with WebGPU enabled ## Performance Typical inference times on Chrome with WebGPU: - **Encoder**: {'2-3s' if 'tiny' in model_name else '3-5s' if 'small' in model_name else '4-6s' if 'base' in model_name else '8-10s'} - **Decoder**: 0.1-0.5s per point ## License This model is released under the Apache 2.0 license, following the original SAM2 model. ## Citation ```bibtex @article{ravi2024sam2, title={SAM 2: Segment Anything in Images and Videos}, author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, journal={arXiv preprint arXiv:2408.00714}, year={2024} } ``` ## Related Resources - **Original SAM2**: [facebookresearch/segment-anything-2](https://github.com/facebookresearch/segment-anything-2) - **WebGPU Demo**: [Aegis AI SAM2 WebGPU Demo](https://github.com/yourusername/Aegis-AI/tree/main/tools/sam2-webgpu) - **Conversion Tool**: [SAM2 ONNX Converter](https://github.com/yourusername/Aegis-AI/tree/main/tools/sam2-converter) ## Acknowledgments - **Meta Research** for the original SAM2 model - **Microsoft** for ONNX Runtime - **SamExporter** for conversion tools --- *Converted and optimized by [Aegis AI](https://github.com/yourusername/Aegis-AI)*