Spaces:
Runtime error
Runtime error
- rembg/__init__.py +1 -6
- rembg/_version.py +2 -2
- rembg/bg.py +32 -1
- rembg/cli.py +43 -3
- rembg/session_base.py +1 -1
- rembg/session_factory.py +32 -24
- rembg/session_simple.py +1 -1
rembg/__init__.py
CHANGED
|
@@ -1,11 +1,6 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import warnings
|
| 3 |
-
|
| 4 |
-
if not (sys.version_info.major == 3 and sys.version_info.minor == 9):
|
| 5 |
-
warnings.warn("This library is only for Python 3.9", RuntimeWarning)
|
| 6 |
-
|
| 7 |
from . import _version
|
| 8 |
|
| 9 |
__version__ = _version.get_versions()["version"]
|
| 10 |
|
| 11 |
from .bg import remove
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from . import _version
|
| 2 |
|
| 3 |
__version__ = _version.get_versions()["version"]
|
| 4 |
|
| 5 |
from .bg import remove
|
| 6 |
+
from .session_factory import new_session
|
rembg/_version.py
CHANGED
|
@@ -24,8 +24,8 @@ def get_keywords():
|
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
git_refnames = " (HEAD -> main)"
|
| 27 |
-
git_full = "
|
| 28 |
-
git_date = "2022-
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
|
|
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
git_refnames = " (HEAD -> main)"
|
| 27 |
+
git_full = "edc9fe27dff030cf6c2f29ef9a66c32d6e3f4658"
|
| 28 |
+
git_date = "2022-11-28 08:14:19 -0300"
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
rembg/bg.py
CHANGED
|
@@ -3,16 +3,26 @@ from enum import Enum
|
|
| 3 |
from typing import List, Optional, Union
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
from PIL.Image import Image as PILImage
|
| 8 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
| 9 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
| 10 |
from pymatting.util.util import stack_images
|
| 11 |
-
from scipy.ndimage
|
| 12 |
|
| 13 |
from .session_base import BaseSession
|
| 14 |
from .session_factory import new_session
|
| 15 |
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class ReturnType(Enum):
|
| 18 |
BYTES = 0
|
|
@@ -27,6 +37,10 @@ def alpha_matting_cutout(
|
|
| 27 |
background_threshold: int,
|
| 28 |
erode_structure_size: int,
|
| 29 |
) -> PILImage:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
img = np.asarray(img)
|
| 31 |
mask = np.asarray(mask)
|
| 32 |
|
|
@@ -79,6 +93,19 @@ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
|
|
| 79 |
return dst
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def remove(
|
| 83 |
data: Union[bytes, PILImage, np.ndarray],
|
| 84 |
alpha_matting: bool = False,
|
|
@@ -87,6 +114,7 @@ def remove(
|
|
| 87 |
alpha_matting_erode_size: int = 10,
|
| 88 |
session: Optional[BaseSession] = None,
|
| 89 |
only_mask: bool = False,
|
|
|
|
| 90 |
) -> Union[bytes, PILImage, np.ndarray]:
|
| 91 |
|
| 92 |
if isinstance(data, PILImage):
|
|
@@ -108,6 +136,9 @@ def remove(
|
|
| 108 |
cutouts = []
|
| 109 |
|
| 110 |
for mask in masks:
|
|
|
|
|
|
|
|
|
|
| 111 |
if only_mask:
|
| 112 |
cutout = mask
|
| 113 |
|
|
|
|
| 3 |
from typing import List, Optional, Union
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
+
from cv2 import (
|
| 7 |
+
BORDER_DEFAULT,
|
| 8 |
+
MORPH_ELLIPSE,
|
| 9 |
+
MORPH_OPEN,
|
| 10 |
+
GaussianBlur,
|
| 11 |
+
getStructuringElement,
|
| 12 |
+
morphologyEx,
|
| 13 |
+
)
|
| 14 |
from PIL import Image
|
| 15 |
from PIL.Image import Image as PILImage
|
| 16 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
| 17 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
| 18 |
from pymatting.util.util import stack_images
|
| 19 |
+
from scipy.ndimage import binary_erosion
|
| 20 |
|
| 21 |
from .session_base import BaseSession
|
| 22 |
from .session_factory import new_session
|
| 23 |
|
| 24 |
+
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
| 25 |
+
|
| 26 |
|
| 27 |
class ReturnType(Enum):
|
| 28 |
BYTES = 0
|
|
|
|
| 37 |
background_threshold: int,
|
| 38 |
erode_structure_size: int,
|
| 39 |
) -> PILImage:
|
| 40 |
+
|
| 41 |
+
if img.mode == "RGBA" or img.mode == "CMYK":
|
| 42 |
+
img = img.convert("RGB")
|
| 43 |
+
|
| 44 |
img = np.asarray(img)
|
| 45 |
mask = np.asarray(mask)
|
| 46 |
|
|
|
|
| 93 |
return dst
|
| 94 |
|
| 95 |
|
| 96 |
+
def post_process(mask: np.ndarray) -> np.ndarray:
|
| 97 |
+
"""
|
| 98 |
+
Post Process the mask for a smooth boundary by applying Morphological Operations
|
| 99 |
+
Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
|
| 100 |
+
args:
|
| 101 |
+
mask: Binary Numpy Mask
|
| 102 |
+
"""
|
| 103 |
+
mask = morphologyEx(mask, MORPH_OPEN, kernel)
|
| 104 |
+
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
|
| 105 |
+
mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
|
| 106 |
+
return mask
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def remove(
|
| 110 |
data: Union[bytes, PILImage, np.ndarray],
|
| 111 |
alpha_matting: bool = False,
|
|
|
|
| 114 |
alpha_matting_erode_size: int = 10,
|
| 115 |
session: Optional[BaseSession] = None,
|
| 116 |
only_mask: bool = False,
|
| 117 |
+
post_process_mask: bool = False,
|
| 118 |
) -> Union[bytes, PILImage, np.ndarray]:
|
| 119 |
|
| 120 |
if isinstance(data, PILImage):
|
|
|
|
| 136 |
cutouts = []
|
| 137 |
|
| 138 |
for mask in masks:
|
| 139 |
+
if post_process_mask:
|
| 140 |
+
mask = Image.fromarray(post_process(np.array(mask)))
|
| 141 |
+
|
| 142 |
if only_mask:
|
| 143 |
cutout = mask
|
| 144 |
|
rembg/cli.py
CHANGED
|
@@ -33,7 +33,9 @@ def main() -> None:
|
|
| 33 |
"-m",
|
| 34 |
"--model",
|
| 35 |
default="u2net",
|
| 36 |
-
type=click.Choice(
|
|
|
|
|
|
|
| 37 |
show_default=True,
|
| 38 |
show_choices=True,
|
| 39 |
help="model name",
|
|
@@ -76,6 +78,13 @@ def main() -> None:
|
|
| 76 |
show_default=True,
|
| 77 |
help="output only the mask",
|
| 78 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
@click.argument(
|
| 80 |
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
| 81 |
)
|
|
@@ -93,7 +102,9 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
|
| 93 |
"-m",
|
| 94 |
"--model",
|
| 95 |
default="u2net",
|
| 96 |
-
type=click.Choice(
|
|
|
|
|
|
|
| 97 |
show_default=True,
|
| 98 |
show_choices=True,
|
| 99 |
help="model name",
|
|
@@ -136,6 +147,13 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
|
| 136 |
show_default=True,
|
| 137 |
help="output only the mask",
|
| 138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
@click.option(
|
| 140 |
"-w",
|
| 141 |
"--watch",
|
|
@@ -243,7 +261,15 @@ def p(
|
|
| 243 |
show_default=True,
|
| 244 |
help="log level",
|
| 245 |
)
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
sessions: dict[str, BaseSession] = {}
|
| 248 |
tags_metadata = [
|
| 249 |
{
|
|
@@ -284,6 +310,7 @@ def s(port: int, log_level: str) -> None:
|
|
| 284 |
u2netp = "u2netp"
|
| 285 |
u2net_human_seg = "u2net_human_seg"
|
| 286 |
u2net_cloth_seg = "u2net_cloth_seg"
|
|
|
|
| 287 |
|
| 288 |
class CommonQueryParams:
|
| 289 |
def __init__(
|
|
@@ -309,6 +336,7 @@ def s(port: int, log_level: str) -> None:
|
|
| 309 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 310 |
),
|
| 311 |
om: bool = Query(default=False, description="Only Mask"),
|
|
|
|
| 312 |
):
|
| 313 |
self.model = model
|
| 314 |
self.a = a
|
|
@@ -316,6 +344,7 @@ def s(port: int, log_level: str) -> None:
|
|
| 316 |
self.ab = ab
|
| 317 |
self.ae = ae
|
| 318 |
self.om = om
|
|
|
|
| 319 |
|
| 320 |
class CommonQueryPostParams:
|
| 321 |
def __init__(
|
|
@@ -341,6 +370,7 @@ def s(port: int, log_level: str) -> None:
|
|
| 341 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 342 |
),
|
| 343 |
om: bool = Form(default=False, description="Only Mask"),
|
|
|
|
| 344 |
):
|
| 345 |
self.model = model
|
| 346 |
self.a = a
|
|
@@ -348,6 +378,7 @@ def s(port: int, log_level: str) -> None:
|
|
| 348 |
self.ab = ab
|
| 349 |
self.ae = ae
|
| 350 |
self.om = om
|
|
|
|
| 351 |
|
| 352 |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
| 353 |
return Response(
|
|
@@ -361,10 +392,19 @@ def s(port: int, log_level: str) -> None:
|
|
| 361 |
alpha_matting_background_threshold=commons.ab,
|
| 362 |
alpha_matting_erode_size=commons.ae,
|
| 363 |
only_mask=commons.om,
|
|
|
|
| 364 |
),
|
| 365 |
media_type="image/png",
|
| 366 |
)
|
| 367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
@app.get(
|
| 369 |
path="/",
|
| 370 |
tags=["Background Removal"],
|
|
|
|
| 33 |
"-m",
|
| 34 |
"--model",
|
| 35 |
default="u2net",
|
| 36 |
+
type=click.Choice(
|
| 37 |
+
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
| 38 |
+
),
|
| 39 |
show_default=True,
|
| 40 |
show_choices=True,
|
| 41 |
help="model name",
|
|
|
|
| 78 |
show_default=True,
|
| 79 |
help="output only the mask",
|
| 80 |
)
|
| 81 |
+
@click.option(
|
| 82 |
+
"-ppm",
|
| 83 |
+
"--post-process-mask",
|
| 84 |
+
is_flag=True,
|
| 85 |
+
show_default=True,
|
| 86 |
+
help="post process the mask",
|
| 87 |
+
)
|
| 88 |
@click.argument(
|
| 89 |
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
| 90 |
)
|
|
|
|
| 102 |
"-m",
|
| 103 |
"--model",
|
| 104 |
default="u2net",
|
| 105 |
+
type=click.Choice(
|
| 106 |
+
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
| 107 |
+
),
|
| 108 |
show_default=True,
|
| 109 |
show_choices=True,
|
| 110 |
help="model name",
|
|
|
|
| 147 |
show_default=True,
|
| 148 |
help="output only the mask",
|
| 149 |
)
|
| 150 |
+
@click.option(
|
| 151 |
+
"-ppm",
|
| 152 |
+
"--post-process-mask",
|
| 153 |
+
is_flag=True,
|
| 154 |
+
show_default=True,
|
| 155 |
+
help="post process the mask",
|
| 156 |
+
)
|
| 157 |
@click.option(
|
| 158 |
"-w",
|
| 159 |
"--watch",
|
|
|
|
| 261 |
show_default=True,
|
| 262 |
help="log level",
|
| 263 |
)
|
| 264 |
+
@click.option(
|
| 265 |
+
"-t",
|
| 266 |
+
"--threads",
|
| 267 |
+
default=None,
|
| 268 |
+
type=int,
|
| 269 |
+
show_default=True,
|
| 270 |
+
help="number of worker threads",
|
| 271 |
+
)
|
| 272 |
+
def s(port: int, log_level: str, threads: int) -> None:
|
| 273 |
sessions: dict[str, BaseSession] = {}
|
| 274 |
tags_metadata = [
|
| 275 |
{
|
|
|
|
| 310 |
u2netp = "u2netp"
|
| 311 |
u2net_human_seg = "u2net_human_seg"
|
| 312 |
u2net_cloth_seg = "u2net_cloth_seg"
|
| 313 |
+
silueta = "silueta"
|
| 314 |
|
| 315 |
class CommonQueryParams:
|
| 316 |
def __init__(
|
|
|
|
| 336 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 337 |
),
|
| 338 |
om: bool = Query(default=False, description="Only Mask"),
|
| 339 |
+
ppm: bool = Query(default=False, description="Post Process Mask"),
|
| 340 |
):
|
| 341 |
self.model = model
|
| 342 |
self.a = a
|
|
|
|
| 344 |
self.ab = ab
|
| 345 |
self.ae = ae
|
| 346 |
self.om = om
|
| 347 |
+
self.ppm = ppm
|
| 348 |
|
| 349 |
class CommonQueryPostParams:
|
| 350 |
def __init__(
|
|
|
|
| 370 |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 371 |
),
|
| 372 |
om: bool = Form(default=False, description="Only Mask"),
|
| 373 |
+
ppm: bool = Form(default=False, description="Post Process Mask"),
|
| 374 |
):
|
| 375 |
self.model = model
|
| 376 |
self.a = a
|
|
|
|
| 378 |
self.ab = ab
|
| 379 |
self.ae = ae
|
| 380 |
self.om = om
|
| 381 |
+
self.ppm = ppm
|
| 382 |
|
| 383 |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
| 384 |
return Response(
|
|
|
|
| 392 |
alpha_matting_background_threshold=commons.ab,
|
| 393 |
alpha_matting_erode_size=commons.ae,
|
| 394 |
only_mask=commons.om,
|
| 395 |
+
post_process_mask=commons.ppm,
|
| 396 |
),
|
| 397 |
media_type="image/png",
|
| 398 |
)
|
| 399 |
|
| 400 |
+
@app.on_event("startup")
|
| 401 |
+
def startup():
|
| 402 |
+
if threads is not None:
|
| 403 |
+
from anyio import CapacityLimiter
|
| 404 |
+
from anyio.lowlevel import RunVar
|
| 405 |
+
|
| 406 |
+
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
| 407 |
+
|
| 408 |
@app.get(
|
| 409 |
path="/",
|
| 410 |
tags=["Background Removal"],
|
rembg/session_base.py
CHANGED
|
@@ -18,7 +18,7 @@ class BaseSession:
|
|
| 18 |
std: Tuple[float, float, float],
|
| 19 |
size: Tuple[int, int],
|
| 20 |
) -> Dict[str, np.ndarray]:
|
| 21 |
-
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
| 22 |
|
| 23 |
im_ary = np.array(im)
|
| 24 |
im_ary = im_ary / np.max(im_ary)
|
|
|
|
| 18 |
std: Tuple[float, float, float],
|
| 19 |
size: Tuple[int, int],
|
| 20 |
) -> Dict[str, np.ndarray]:
|
| 21 |
+
im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
|
| 22 |
|
| 23 |
im_ary = np.array(im)
|
| 24 |
im_ary = im_ary / np.max(im_ary)
|
rembg/session_factory.py
CHANGED
|
@@ -5,50 +5,56 @@ from contextlib import redirect_stdout
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Type
|
| 7 |
|
| 8 |
-
import gdown
|
| 9 |
import onnxruntime as ort
|
|
|
|
| 10 |
|
| 11 |
from .session_base import BaseSession
|
| 12 |
from .session_cloth import ClothSession
|
| 13 |
from .session_simple import SimpleSession
|
| 14 |
|
| 15 |
|
| 16 |
-
def new_session(model_name: str) -> BaseSession:
|
| 17 |
session_class: Type[BaseSession]
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
if model_name == "u2netp":
|
| 20 |
md5 = "8e83ca70e441ab06c318d82300c84806"
|
| 21 |
-
url =
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
md5 = "60024c5c889badc19c04ad937298a77b"
|
| 25 |
-
url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
|
| 26 |
session_class = SimpleSession
|
| 27 |
elif model_name == "u2net_human_seg":
|
| 28 |
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
| 29 |
-
url = "https://
|
| 30 |
session_class = SimpleSession
|
| 31 |
elif model_name == "u2net_cloth_seg":
|
| 32 |
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
| 33 |
-
url = "https://
|
| 34 |
session_class = ClothSession
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
| 38 |
)
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
gdown.download(url, str(path), use_cookies=False)
|
| 52 |
|
| 53 |
sess_opts = ort.SessionOptions()
|
| 54 |
|
|
@@ -58,6 +64,8 @@ def new_session(model_name: str) -> BaseSession:
|
|
| 58 |
return session_class(
|
| 59 |
model_name,
|
| 60 |
ort.InferenceSession(
|
| 61 |
-
str(
|
|
|
|
|
|
|
| 62 |
),
|
| 63 |
)
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Type
|
| 7 |
|
|
|
|
| 8 |
import onnxruntime as ort
|
| 9 |
+
import pooch
|
| 10 |
|
| 11 |
from .session_base import BaseSession
|
| 12 |
from .session_cloth import ClothSession
|
| 13 |
from .session_simple import SimpleSession
|
| 14 |
|
| 15 |
|
| 16 |
+
def new_session(model_name: str = "u2net") -> BaseSession:
|
| 17 |
session_class: Type[BaseSession]
|
| 18 |
+
md5 = "60024c5c889badc19c04ad937298a77b"
|
| 19 |
+
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
| 20 |
+
session_class = SimpleSession
|
| 21 |
|
| 22 |
if model_name == "u2netp":
|
| 23 |
md5 = "8e83ca70e441ab06c318d82300c84806"
|
| 24 |
+
url = (
|
| 25 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
|
| 26 |
+
)
|
|
|
|
|
|
|
| 27 |
session_class = SimpleSession
|
| 28 |
elif model_name == "u2net_human_seg":
|
| 29 |
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
| 30 |
+
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
|
| 31 |
session_class = SimpleSession
|
| 32 |
elif model_name == "u2net_cloth_seg":
|
| 33 |
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
| 34 |
+
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
|
| 35 |
session_class = ClothSession
|
| 36 |
+
elif model_name == "silueta":
|
| 37 |
+
md5 = "55e59e0d8062d2f5d013f4725ee84782"
|
| 38 |
+
url = (
|
| 39 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
| 40 |
)
|
| 41 |
+
session_class = SimpleSession
|
| 42 |
|
| 43 |
+
u2net_home = os.getenv(
|
| 44 |
+
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
fname = f"{model_name}.onnx"
|
| 48 |
+
path = Path(u2net_home).expanduser()
|
| 49 |
+
full_path = Path(u2net_home).expanduser() / fname
|
| 50 |
|
| 51 |
+
pooch.retrieve(
|
| 52 |
+
url,
|
| 53 |
+
f"md5:{md5}",
|
| 54 |
+
fname=fname,
|
| 55 |
+
path=Path(u2net_home).expanduser(),
|
| 56 |
+
progressbar=True,
|
| 57 |
+
)
|
|
|
|
| 58 |
|
| 59 |
sess_opts = ort.SessionOptions()
|
| 60 |
|
|
|
|
| 64 |
return session_class(
|
| 65 |
model_name,
|
| 66 |
ort.InferenceSession(
|
| 67 |
+
str(full_path),
|
| 68 |
+
providers=ort.get_available_providers(),
|
| 69 |
+
sess_options=sess_opts,
|
| 70 |
),
|
| 71 |
)
|
rembg/session_simple.py
CHANGED
|
@@ -25,6 +25,6 @@ class SimpleSession(BaseSession):
|
|
| 25 |
pred = np.squeeze(pred)
|
| 26 |
|
| 27 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 28 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
| 29 |
|
| 30 |
return [mask]
|
|
|
|
| 25 |
pred = np.squeeze(pred)
|
| 26 |
|
| 27 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 28 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
| 29 |
|
| 30 |
return [mask]
|