OmniSVG-3B / app.py
OmniSVG's picture
Update app.py
43d5733 verified
raw
history blame
57.5 kB
import gradio as gr
import torch
import os
from PIL import Image
import cairosvg
import io
import tempfile
import argparse
import gc
import yaml
import glob
import numpy as np
import time
import threading
import copy
import spaces
from huggingface_hub import hf_hub_download, snapshot_download
from decoder import SketchDecoder
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from tokenizer import SVGTokenizer
# ============================================================
# Configuration Loading with Variant Support
# ============================================================
def load_config(config_path: str, variant: str = None) -> dict:
"""
Load config file and merge variant-specific settings.
Args:
config_path: Path to the config.yaml file
variant: Model variant ("8B" or "4B"). If None, uses default_variant from config.
Returns:
Merged configuration dictionary
"""
with open(config_path, 'r') as f:
raw_config = yaml.safe_load(f)
# Determine which variant to use
if variant is None:
variant = raw_config.get('default_variant', '8B')
# Check if variant exists
variants = raw_config.get('variants', {})
if variant not in variants:
available = list(variants.keys())
raise ValueError(f"Unknown model variant '{variant}'. Available variants: {available}")
# Start with a copy of raw config (excluding 'variants' key)
merged_config = {k: v for k, v in raw_config.items() if k != 'variants'}
# Merge variant-specific settings
variant_config = variants[variant]
for key, value in variant_config.items():
if isinstance(value, dict) and key in merged_config and isinstance(merged_config[key], dict):
# Deep merge for nested dicts
merged_config[key] = {**merged_config.get(key, {}), **value}
else:
merged_config[key] = value
# Store the active variant name
merged_config['active_variant'] = variant
return merged_config
def write_variant_config(config: dict, output_path: str):
"""
Write a variant-specific config file for SVGTokenizer.
Args:
config: Merged configuration dictionary
output_path: Path to write the temporary config file
"""
# Create a config without the 'variants' and 'active_variant' keys
clean_config = {k: v for k, v in config.items()
if k not in ['variants', 'active_variant', 'default_variant']}
with open(output_path, 'w') as f:
yaml.safe_dump(clean_config, f, default_flow_style=False)
# ============================================================
# Global Variables
# ============================================================
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# Global Models
tokenizer = None
processor = None
sketch_decoder = None
svg_tokenizer = None
# Global Config (will be set after loading)
config = None
MODEL_VARIANT = None
# Thread lock for model inference
generation_lock = threading.Lock()
# Constants (will be set from config)
SYSTEM_PROMPT = """You are an expert SVG code generator.
Generate precise, valid SVG path commands that accurately represent the described scene or object.
Focus on capturing key shapes, spatial relationships, and visual composition."""
SUPPORTED_FORMATS = ['.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif']
def init_config_constants(cfg: dict):
"""Initialize global constants from config."""
global TARGET_IMAGE_SIZE, RENDER_SIZE, BACKGROUND_THRESHOLD
global EMPTY_THRESHOLD_ILLUSTRATION, EMPTY_THRESHOLD_ICON
global EDGE_SAMPLE_RATIO, COLOR_SIMILARITY_THRESHOLD, MIN_EDGE_SAMPLES
global BLACK_COLOR_TOKEN, BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID, MAX_LENGTH
global DEFAULT_QWEN_MODEL, DEFAULT_OMNISVG_MODEL
global TASK_CONFIGS, DEFAULT_NUM_CANDIDATES, MAX_NUM_CANDIDATES, EXTRA_CANDIDATES_BUFFER
global MIN_SVG_LENGTH
# Image processing settings
image_config = cfg.get('image', {})
TARGET_IMAGE_SIZE = image_config.get('target_size', 448)
RENDER_SIZE = image_config.get('render_size', 512)
BACKGROUND_THRESHOLD = image_config.get('background_threshold', 240)
EMPTY_THRESHOLD_ILLUSTRATION = image_config.get('empty_threshold_illustration', 250)
EMPTY_THRESHOLD_ICON = image_config.get('empty_threshold_icon', 252)
EDGE_SAMPLE_RATIO = image_config.get('edge_sample_ratio', 0.1)
COLOR_SIMILARITY_THRESHOLD = image_config.get('color_similarity_threshold', 30)
MIN_EDGE_SAMPLES = image_config.get('min_edge_samples', 10)
# Color settings
colors_config = cfg.get('colors', {})
BLACK_COLOR_TOKEN = colors_config.get('black_color_token',
colors_config.get('color_token_start', 40010) + 2)
# Model settings
model_config = cfg.get('model', {})
BOS_TOKEN_ID = model_config.get('bos_token_id', 196998)
EOS_TOKEN_ID = model_config.get('eos_token_id', 196999)
PAD_TOKEN_ID = model_config.get('pad_token_id', 151643)
MAX_LENGTH = model_config.get('max_length', 1536)
# HuggingFace model IDs
hf_config = cfg.get('huggingface', {})
DEFAULT_QWEN_MODEL = hf_config.get('qwen_model', "Qwen/Qwen2.5-VL-7B-Instruct")
DEFAULT_OMNISVG_MODEL = hf_config.get('omnisvg_model', "OmniSVG/OmniSVG1.1_8B")
# Task configurations
task_config = cfg.get('task_configs', {})
TASK_CONFIGS = {
"text-to-svg-icon": task_config.get('text_to_svg_icon', {
"default_temperature": 0.5,
"default_top_p": 0.88,
"default_top_k": 50,
"default_repetition_penalty": 1.05,
}),
"text-to-svg-illustration": task_config.get('text_to_svg_illustration', {
"default_temperature": 0.6,
"default_top_p": 0.90,
"default_top_k": 60,
"default_repetition_penalty": 1.03,
}),
"image-to-svg": task_config.get('image_to_svg', {
"default_temperature": 0.3,
"default_top_p": 0.90,
"default_top_k": 50,
"default_repetition_penalty": 1.05,
})
}
# Generation parameters
gen_config = cfg.get('generation', {})
DEFAULT_NUM_CANDIDATES = gen_config.get('default_num_candidates', 4)
MAX_NUM_CANDIDATES = gen_config.get('max_num_candidates', 8)
EXTRA_CANDIDATES_BUFFER = gen_config.get('extra_candidates_buffer', 4)
# Validation settings
validation_config = cfg.get('validation', {})
MIN_SVG_LENGTH = validation_config.get('min_svg_length', 20)
# Custom CSS (same as before)
CUSTOM_CSS = """
/* Main container centering */
.gradio-container {
max-width: 1400px !important;
margin: 0 auto !important;
padding: 20px !important;
}
/* Header styling */
.header-container {
text-align: center;
margin-bottom: 20px;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 16px;
color: white;
}
.header-container h1 {
margin: 0;
font-size: 2.5em;
font-weight: 700;
}
.header-container p {
margin: 10px 0 0 0;
opacity: 0.9;
font-size: 1.1em;
}
/* Model badge styling */
.model-badge {
display: inline-block;
background: rgba(255,255,255,0.2);
padding: 4px 12px;
border-radius: 20px;
font-size: 0.85em;
margin-top: 8px;
}
/* Tips section */
.tips-box {
background: #f8f9fa;
border-radius: 12px;
padding: 20px;
margin-bottom: 20px;
border: 1px solid #e0e0e0;
}
.tips-box h3 {
margin-top: 0;
color: #333;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.tip-category {
background: white;
border-radius: 8px;
padding: 15px;
margin: 10px 0;
border-left: 4px solid #667eea;
}
.tip-category h4 {
margin: 0 0 10px 0;
color: #667eea;
}
.tip-category code {
background: #f0f0f0;
padding: 2px 6px;
border-radius: 4px;
font-size: 0.9em;
}
.example-prompt {
background: #e8f4fd;
padding: 10px;
border-radius: 6px;
margin: 8px 0;
font-style: italic;
font-size: 0.95em;
color: #333;
}
.red-tip {
color: #dc3545;
font-weight: 600;
}
.red-box {
background: #fff5f5;
border: 1px solid #ffcccc;
border-left: 4px solid #dc3545;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.red-box strong {
color: #dc3545;
}
.orange-box {
background: #fff8e6;
border: 1px solid #ffc107;
border-left: 4px solid #ff9800;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.orange-box strong {
color: #ff9800;
}
.green-box {
background: #e8f5e9;
border: 1px solid #81c784;
border-left: 4px solid #4caf50;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
}
.green-box strong {
color: #4caf50;
}
/* Tab styling */
.tabs {
border-radius: 12px !important;
overflow: hidden;
}
.tabitem {
padding: 20px !important;
}
/* Button styling */
.primary-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
font-weight: 600 !important;
padding: 12px 24px !important;
font-size: 1.1em !important;
}
.primary-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
/* Settings group */
.settings-group {
background: #f8f9fa;
border-radius: 10px;
padding: 15px;
margin: 10px 0;
}
.advanced-settings {
background: #f0f4f8;
border-radius: 8px;
padding: 12px;
margin-top: 10px;
}
/* Code output */
.code-output textarea {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
font-size: 12px !important;
background: #1e1e1e !important;
color: #d4d4d4 !important;
border-radius: 8px !important;
}
/* Input image area */
.input-image {
border: 2px dashed #ccc;
border-radius: 12px;
transition: border-color 0.3s;
}
.input-image:hover {
border-color: #667eea;
}
/* Footer */
.footer {
text-align: center;
padding: 20px;
color: #666;
font-size: 0.9em;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.gradio-container {
padding: 10px !important;
}
.header-container h1 {
font-size: 1.8em;
}
}
"""
# Enhanced Tips HTML (same as before - abbreviated for brevity)
TIPS_HTML = """
<div class="tips-box">
<h3>Prompting Guide & Best Practices</h3>
<!-- Critical Red Tips Section -->
<div class="red-box">
<strong>CRITICAL: Tips That WILL Improve Your Results</strong>
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
<li style="color: #dc3545; font-weight: 600;">
<strong>Generate 4-8 candidates and pick the best one!</strong> Results vary significantly between generations - this is NORMAL!
</li>
<li style="color: #dc3545; font-weight: 600;">
<strong>Use GEOMETRIC descriptions:</strong> "triangular roof", "circular head", "rectangular body", "curved tail"
</li>
<li style="color: #dc3545; font-weight: 600;">
<strong>ALWAYS specify colors for EACH element:</strong> "black outline", "red roof", "blue shirt", "green grass"
</li>
<li style="color: #dc3545; font-weight: 600;">
<strong>Describe position & orientation:</strong> "centrally positioned", "pointing upward", "facing right", "at the bottom"
</li>
<li style="color: #dc3545; font-weight: 600;">
<strong>Keep it SIMPLE:</strong> Avoid complex sentences. Use short, clear phrases connected by commas.
</li>
</ul>
</div>
<!-- Parameter Tuning Tips -->
<div class="orange-box">
<strong>Parameter Tuning Guide</strong>
<table style="width: 100%; margin-top: 10px; border-collapse: collapse;">
<tr style="background: rgba(255,255,255,0.5);">
<th style="padding: 8px; text-align: left; border-bottom: 1px solid #ddd;">Scenario</th>
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Temperature</th>
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Top-P</th>
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Top-K</th>
<th style="padding: 8px; text-align: center; border-bottom: 1px solid #ddd;">Rep. Penalty</th>
</tr>
<tr>
<td style="padding: 8px;">Simple icons/shapes</td>
<td style="padding: 8px; text-align: center;">0.3 - 0.5</td>
<td style="padding: 8px; text-align: center;">0.85 - 0.90</td>
<td style="padding: 8px; text-align: center;">40 - 50</td>
<td style="padding: 8px; text-align: center;">1.05</td>
</tr>
<tr style="background: rgba(255,255,255,0.3);">
<td style="padding: 8px;">Characters/Avatars</td>
<td style="padding: 8px; text-align: center;">0.5 - 0.7</td>
<td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
<td style="padding: 8px; text-align: center;">50 - 70</td>
<td style="padding: 8px; text-align: center;">1.02 - 1.05</td>
</tr>
<tr>
<td style="padding: 8px;">Landscapes/Scenes</td>
<td style="padding: 8px; text-align: center;">0.5 - 0.7</td>
<td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
<td style="padding: 8px; text-align: center;">50 - 60</td>
<td style="padding: 8px; text-align: center;">1.03</td>
</tr>
<tr style="background: rgba(255,255,255,0.3);">
<td style="padding: 8px;">Image-to-SVG</td>
<td style="padding: 8px; text-align: center;">0.2 - 0.4</td>
<td style="padding: 8px; text-align: center;">0.88 - 0.92</td>
<td style="padding: 8px; text-align: center;">40 - 50</td>
<td style="padding: 8px; text-align: center;">1.05</td>
</tr>
</table>
<p style="margin: 10px 0 0 0; font-size: 0.9em; color: #856404;">
Tip: If results are too chaotic, lower temperature. If too simple/empty, raise it slightly.
</p>
</div>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 15px; margin-top: 15px;">
<div class="tip-category">
<h4>Icons & Simple Shapes</h4>
<p>Use clear geometric descriptions with explicit colors.</p>
<div class="example-prompt">
"A black triangle pointing downward, centrally positioned."
</div>
<div class="example-prompt">
"A red heart shape with smooth curved edges, centered."
</div>
<p><strong>Keywords:</strong> <code>triangle</code> <code>circle</code> <code>arrow</code> <code>heart</code> <code>star</code> <code>centered</code></p>
</div>
<div class="tip-category">
<h4>Characters & People</h4>
<p>Break down into simple geometric parts. Describe each body part with shape + color.</p>
<div class="example-prompt">
"A simple person: round beige head, rectangular blue shirt body, two dark gray rectangular legs. Standing pose, arms at sides, flat colors."
</div>
<p class="red-tip">Keep poses SIMPLE: standing, sitting, waving. Avoid complex actions!</p>
</div>
<div class="tip-category">
<h4>Landscapes & Scenes</h4>
<p>Layer elements from background to foreground. Specify color for EACH layer.</p>
<div class="example-prompt">
"Layered landscape: light blue sky at top, gray triangular mountains in middle, dark green triangular pine trees at bottom. Flat colors, simple shapes."
</div>
<p class="red-tip">Use geometric shapes for nature: triangular trees, wavy water, semicircle sun!</p>
</div>
<div class="tip-category">
<h4>Animals</h4>
<p>Describe as geometric shapes: oval body, round head, triangular ears, curved tail.</p>
<div class="example-prompt">
"Cute cat: orange round head with two triangular ears, oval orange body, curved tail. Simple cartoon style with black outlines, sitting pose."
</div>
</div>
</div>
<!-- Quick Troubleshooting -->
<div class="green-box" style="margin-top: 15px;">
<strong>Quick Troubleshooting</strong>
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
<li><strong>Messy/chaotic?</strong> Lower temperature to 0.3-0.4, simplify description, reduce top_k</li>
<li><strong>Too simple/empty?</strong> Raise temperature to 0.5-0.6, add more shape details</li>
<li><strong>Wrong colors?</strong> Explicitly name EVERY color: "red roof", "blue shirt", "black outline"</li>
<li><strong>Missing elements?</strong> Add position words: "at top", "in center", "at bottom left"</li>
<li><strong>Repetitive patterns?</strong> Increase repetition_penalty to 1.08-1.15</li>
<li><strong>Inconsistent?</strong> <span class="red-tip">Generate MORE candidates (6-8) and pick the best!</span></li>
</ul>
</div>
</div>
"""
# Image-to-SVG specific tips
IMAGE_TIPS_HTML = """
<div class="red-box">
<strong>Image-to-SVG Tips</strong>
<ul style="margin: 8px 0 0 0; padding-left: 20px;">
<li><strong>Best input: Simple images with clean background</strong></li>
<li><strong>PNG with transparency (RGBA) works best!</strong> We auto-convert to white background.</li>
<li><strong>For complex backgrounds:</strong> Enable "Replace Background" option below.</li>
<li><strong>Lower temperature (0.2-0.4)</strong> for more accurate reproduction.</li>
<li style="color: #dc3545; font-weight: 600;"><strong>Generate 4-8 candidates!</strong> Pick the one that best matches your input.</li>
</ul>
</div>
"""
def parse_args():
parser = argparse.ArgumentParser(description='SVG Generator Service')
parser.add_argument('--listen', type=str, default='0.0.0.0')
parser.add_argument('--port', type=int, default=7860)
parser.add_argument('--share', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--model_size', type=str, default=None, choices=['8B', '4B'],
help='Model size variant to use (8B or 4B). Overrides config default.')
parser.add_argument('--config', type=str, default='./config.yaml',
help='Path to config file (default: ./config.yaml)')
parser.add_argument('--weight_path', type=str, default=None,
help='HuggingFace repo ID or local path for OmniSVG weights (overrides config)')
parser.add_argument('--model_path', type=str, default=None,
help='HuggingFace repo ID or local path for Qwen model (overrides config)')
return parser.parse_args()
def download_model_weights(repo_id: str, filename: str = "pytorch_model.bin") -> str:
"""
Download model weights from Hugging Face Hub.
"""
print(f"Downloading {filename} from {repo_id}...")
try:
local_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
resume_download=True,
)
print(f"Successfully downloaded to: {local_path}")
return local_path
except Exception as e:
print(f"Error downloading from {repo_id}: {e}")
raise
def is_local_path(path: str) -> bool:
"""Check if a path is a local filesystem path or a HuggingFace repo ID."""
if os.path.exists(path):
return True
if path.startswith('/') or path.startswith('./') or path.startswith('../'):
return True
if os.path.sep in path and os.path.exists(os.path.dirname(path)):
return True
if len(path) > 1 and path[1] == ':':
return True
return False
def load_models(weight_path: str, model_path: str, variant_config_path: str):
"""
Load all models with support for both local paths and HuggingFace Hub.
Args:
weight_path: Local path or HuggingFace repo ID for OmniSVG weights
model_path: Local path or HuggingFace repo ID for Qwen model
variant_config_path: Path to the variant-specific config file for SVGTokenizer
"""
global tokenizer, processor, sketch_decoder, svg_tokenizer
print(f"Loading Qwen model from: {model_path}")
print(f"Loading OmniSVG weights from: {weight_path}")
print(f"Using precision: {DTYPE}")
# Load Qwen tokenizer and processor
print("\n[1/3] Loading tokenizer and processor...")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
padding_side="left",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
model_path,
padding_side="left",
trust_remote_code=True
)
processor.tokenizer.padding_side = "left"
print("Tokenizer and processor loaded successfully!")
# Initialize sketch decoder
print("\n[2/3] Initializing SketchDecoder...")
sketch_decoder = SketchDecoder(
pix_len=MAX_LENGTH,
text_len=config.get('text', {}).get('max_length', 200),
model_path=model_path,
torch_dtype=DTYPE
)
# Load OmniSVG weights
print("\n[3/3] Loading OmniSVG weights...")
if is_local_path(weight_path):
bin_path = os.path.join(weight_path, "pytorch_model.bin")
if not os.path.exists(bin_path):
if os.path.exists(weight_path) and weight_path.endswith('.bin'):
bin_path = weight_path
else:
raise FileNotFoundError(
f"Could not find pytorch_model.bin at {weight_path}. "
f"Please provide a valid local path or HuggingFace repo ID."
)
print(f"Loading weights from local path: {bin_path}")
else:
print(f"Downloading weights from HuggingFace: {weight_path}")
bin_path = download_model_weights(weight_path, "pytorch_model.bin")
state_dict = torch.load(bin_path, map_location='cpu')
sketch_decoder.load_state_dict(state_dict)
print("OmniSVG weights loaded successfully!")
sketch_decoder = sketch_decoder.to(device).eval()
# Initialize SVG tokenizer with variant-specific config
svg_tokenizer = SVGTokenizer(variant_config_path)
print("\n" + "="*60)
print("All models loaded successfully!")
print("="*60 + "\n")
def detect_text_subtype(text_prompt):
"""Auto-detect text prompt subtype"""
text_lower = text_prompt.lower()
icon_keywords = ['icon', 'logo', 'symbol', 'badge', 'button', 'emoji', 'glyph', 'simple',
'arrow', 'triangle', 'circle', 'square', 'heart', 'star', 'checkmark']
if any(kw in text_lower for kw in icon_keywords):
return "icon"
illustration_keywords = [
'illustration', 'scene', 'person', 'people', 'character', 'man', 'woman', 'boy', 'girl',
'avatar', 'portrait', 'face', 'head', 'body',
'cat', 'dog', 'bird', 'animal', 'pet', 'fox', 'rabbit',
'sitting', 'standing', 'walking', 'running', 'sleeping', 'holding', 'playing',
'house', 'building', 'tree', 'garden', 'landscape', 'mountain', 'forest', 'city',
'ocean', 'beach', 'sunset', 'sunrise', 'sky'
]
match_count = sum(1 for kw in illustration_keywords if kw in text_lower)
if match_count >= 1 or len(text_prompt) > 50:
return "illustration"
return "icon"
def detect_and_replace_background(image, threshold=None, edge_sample_ratio=None):
"""Detect if image has non-white background and optionally replace it."""
if threshold is None:
threshold = BACKGROUND_THRESHOLD
if edge_sample_ratio is None:
edge_sample_ratio = EDGE_SAMPLE_RATIO
img_array = np.array(image)
if image.mode == 'RGBA':
bg = Image.new('RGBA', image.size, (255, 255, 255, 255))
composite = Image.alpha_composite(bg, image)
return composite.convert('RGB'), True
h, w = img_array.shape[:2]
edge_pixels = []
sample_count = max(MIN_EDGE_SAMPLES, int(min(h, w) * edge_sample_ratio))
for i in range(0, w, max(1, w // sample_count)):
edge_pixels.append(img_array[0, i])
edge_pixels.append(img_array[h-1, i])
for i in range(0, h, max(1, h // sample_count)):
edge_pixels.append(img_array[i, 0])
edge_pixels.append(img_array[i, w-1])
edge_pixels = np.array(edge_pixels)
if len(edge_pixels) > 0:
mean_edge = edge_pixels.mean(axis=0)
if np.all(mean_edge > threshold):
return image, False
if len(img_array.shape) == 3 and img_array.shape[2] >= 3:
edge_colors = []
for i in range(w):
edge_colors.append(tuple(img_array[0, i, :3]))
edge_colors.append(tuple(img_array[h-1, i, :3]))
for i in range(h):
edge_colors.append(tuple(img_array[i, 0, :3]))
edge_colors.append(tuple(img_array[i, w-1, :3]))
from collections import Counter
color_counts = Counter(edge_colors)
bg_color = color_counts.most_common(1)[0][0]
color_diff = np.sqrt(np.sum((img_array[:, :, :3].astype(float) - np.array(bg_color)) ** 2, axis=2))
bg_mask = color_diff < COLOR_SIMILARITY_THRESHOLD
result = img_array.copy()
if result.shape[2] == 4:
result[bg_mask] = [255, 255, 255, 255]
else:
result[bg_mask] = [255, 255, 255]
return Image.fromarray(result).convert('RGB'), True
return image, False
def preprocess_image_for_svg(image, replace_background=True, target_size=None):
"""Preprocess image for SVG generation."""
if target_size is None:
target_size = TARGET_IMAGE_SIZE
if isinstance(image, str):
raw_img = Image.open(image)
else:
raw_img = image
was_modified = False
if raw_img.mode == 'RGBA':
bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
was_modified = True
elif raw_img.mode == 'LA' or raw_img.mode == 'PA':
raw_img = raw_img.convert('RGBA')
bg = Image.new('RGBA', raw_img.size, (255, 255, 255, 255))
img_with_bg = Image.alpha_composite(bg, raw_img).convert('RGB')
was_modified = True
elif raw_img.mode != 'RGB':
img_with_bg = raw_img.convert('RGB')
else:
img_with_bg = raw_img
if replace_background:
img_with_bg, bg_replaced = detect_and_replace_background(img_with_bg)
was_modified = was_modified or bg_replaced
img_resized = img_with_bg.resize((target_size, target_size), Image.Resampling.LANCZOS)
return img_resized, was_modified
def prepare_inputs(task_type, content):
"""Prepare model inputs"""
if task_type == "text-to-svg":
prompt_text = str(content).strip()
instruction = f"""Generate an SVG illustration for: {prompt_text}
Requirements:
- Create complete SVG path commands
- Include proper coordinates and colors
- Maintain visual clarity and composition"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": [{"type": "text", "text": instruction}]}
]
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text_input], padding=True, truncation=True, return_tensors="pt")
else: # image-to-svg
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": [
{"type": "text", "text": "Generate SVG code that accurately represents this image:"},
{"type": "image", "image": content},
]}
]
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(text=[text_input], images=image_inputs, padding=True, truncation=True, return_tensors="pt")
return inputs
def render_svg_to_image(svg_str, size=None):
"""Render SVG to high-quality PIL Image"""
if size is None:
size = RENDER_SIZE
try:
png_data = cairosvg.svg2png(
bytestring=svg_str.encode('utf-8'),
output_width=size,
output_height=size
)
image_rgba = Image.open(io.BytesIO(png_data)).convert("RGBA")
bg = Image.new("RGB", image_rgba.size, (255, 255, 255))
bg.paste(image_rgba, mask=image_rgba.split()[3])
return bg
except Exception as e:
print(f"Render error: {e}")
return None
def create_gallery_html(candidates, cols=4):
"""Create HTML gallery for multiple SVG candidates"""
if not candidates:
return '<div style="text-align:center;color:#999;padding:50px;">No candidates generated</div>'
items_html = []
for i, cand in enumerate(candidates):
svg_str = cand['svg']
if 'viewBox' not in svg_str:
svg_str = svg_str.replace('<svg', f'<svg viewBox="0 0 {TARGET_IMAGE_SIZE} {TARGET_IMAGE_SIZE}"', 1)
item_html = f'''
<div style="
background: white;
border: 1px solid #ddd;
border-radius: 8px;
padding: 10px;
text-align: center;
transition: transform 0.2s, box-shadow 0.2s;
cursor: pointer;
" onmouseover="this.style.transform='scale(1.02)';this.style.boxShadow='0 4px 12px rgba(0,0,0,0.15)';"
onmouseout="this.style.transform='scale(1)';this.style.boxShadow='none';">
<div style="width: 180px; height: 180px; margin: 0 auto; display: flex; justify-content: center; align-items: center; overflow: hidden;">
{svg_str}
</div>
<div style="margin-top: 8px; font-size: 12px; color: #666;">
#{i+1} | {cand['path_count']} paths
</div>
</div>
'''
items_html.append(item_html)
grid_html = f'''
<div style="
display: grid;
grid-template-columns: repeat({cols}, 1fr);
gap: 15px;
padding: 15px;
background: #fafafa;
border-radius: 12px;
">
{''.join(items_html)}
</div>
'''
return grid_html
def is_valid_candidate(svg_str, img, subtype="illustration"):
"""Check candidate validity"""
if not svg_str or len(svg_str) < MIN_SVG_LENGTH:
return False, "too_short"
if '<svg' not in svg_str:
return False, "no_svg_tag"
if img is None:
return False, "render_failed"
img_array = np.array(img)
mean_val = img_array.mean()
threshold = EMPTY_THRESHOLD_ILLUSTRATION if subtype == "illustration" else EMPTY_THRESHOLD_ICON
if mean_val > threshold:
return False, "empty_image"
return True, "ok"
def generate_candidates(inputs, task_type, subtype, temperature, top_p, top_k, repetition_penalty,
max_length, num_samples, progress_callback=None):
"""Generate candidate SVGs with full parameter control"""
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
if 'pixel_values' in inputs:
model_inputs["pixel_values"] = inputs['pixel_values'].to(device, dtype=DTYPE)
if 'image_grid_thw' in inputs:
model_inputs["image_grid_thw"] = inputs['image_grid_thw'].to(device)
all_candidates = []
gen_config = {
'do_sample': True,
'temperature': temperature,
'top_p': top_p,
'top_k': int(top_k),
'repetition_penalty': repetition_penalty,
'early_stopping': True,
'no_repeat_ngram_size': 0,
'eos_token_id': EOS_TOKEN_ID,
'pad_token_id': PAD_TOKEN_ID,
'bos_token_id': BOS_TOKEN_ID,
}
actual_samples = num_samples + EXTRA_CANDIDATES_BUFFER
try:
if progress_callback:
progress_callback(0.1, "Waiting for model access...")
with generation_lock:
if progress_callback:
progress_callback(0.15, "Generating SVG tokens...")
with torch.no_grad():
results = sketch_decoder.transformer.generate(
**model_inputs,
max_new_tokens=max_length,
num_return_sequences=actual_samples,
use_cache=True,
**gen_config
)
input_len = input_ids.shape[1]
generated_ids_batch = results[:, input_len:]
if progress_callback:
progress_callback(0.5, "Processing generated tokens...")
for i in range(min(actual_samples, generated_ids_batch.shape[0])):
try:
current_ids = generated_ids_batch[i:i+1]
fake_wrapper = torch.cat([
torch.full((1, 1), BOS_TOKEN_ID, device=device),
current_ids,
torch.full((1, 1), EOS_TOKEN_ID, device=device)
], dim=1)
generated_xy = svg_tokenizer.process_generated_tokens(fake_wrapper)
if len(generated_xy) == 0:
continue
svg_tensors, color_tensors = svg_tokenizer.raster_svg(generated_xy)
if not svg_tensors or not svg_tensors[0]:
continue
num_paths = len(svg_tensors[0])
while len(color_tensors) < num_paths:
color_tensors.append(BLACK_COLOR_TOKEN)
svg = svg_tokenizer.apply_colors_to_svg(svg_tensors[0], color_tensors)
svg_str = svg.to_str()
if 'width=' not in svg_str:
svg_str = svg_str.replace('<svg', f'<svg width="{TARGET_IMAGE_SIZE}" height="{TARGET_IMAGE_SIZE}"', 1)
png_image = render_svg_to_image(svg_str, size=RENDER_SIZE)
is_valid, reason = is_valid_candidate(svg_str, png_image, subtype)
if is_valid:
all_candidates.append({
'svg': svg_str,
'img': png_image,
'path_count': num_paths,
'index': len(all_candidates) + 1
})
if progress_callback:
progress_callback(0.5 + 0.4 * (i / actual_samples),
f"Found {len(all_candidates)} valid candidates...")
if len(all_candidates) >= num_samples:
break
except Exception as e:
print(f" Candidate {i} error: {e}")
continue
except Exception as e:
print(f"Generation Error: {e}")
import traceback
traceback.print_exc()
if progress_callback:
progress_callback(0.95, f"Generated {len(all_candidates)} valid candidates")
return all_candidates
@spaces.GPU
def gradio_text_to_svg(text_description, num_candidates, temperature, top_p, top_k, repetition_penalty,
progress=gr.Progress()):
"""Gradio interface - text-to-svg with advanced parameters"""
if not text_description or text_description.strip() == "":
return '<div style="text-align:center;color:#999;padding:50px;">Please enter a description</div>', ""
print("\n" + "="*60)
print(f"[TASK] text-to-svg ({MODEL_VARIANT})")
print(f"[INPUT] {text_description[:100]}{'...' if len(text_description) > 100 else ''}")
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}")
print("="*60)
progress(0, "Starting generation...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
start_time = time.time()
subtype = detect_text_subtype(text_description)
print(f"[SUBTYPE] Detected: {subtype}")
progress(0.05, f"Detected: {subtype}")
inputs = prepare_inputs("text-to-svg", text_description.strip())
def update_progress(val, msg):
progress(val, msg)
all_candidates = generate_candidates(
inputs, "text-to-svg", subtype,
temperature, top_p, int(top_k), repetition_penalty,
MAX_LENGTH, int(num_candidates),
progress_callback=update_progress
)
elapsed = time.time() - start_time
print(f"[RESULT] Generated {len(all_candidates)} valid candidates in {elapsed:.2f}s")
if not all_candidates:
print("[WARNING] No valid SVG generated")
return (
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try different parameters or rephrase your prompt.</div>',
f"<!-- No valid SVG (took {elapsed:.1f}s) -->"
)
svg_codes = []
for i, cand in enumerate(all_candidates):
svg_codes.append(f"<!-- ====== Candidate {i+1} | Paths: {cand['path_count']} ====== -->\n{cand['svg']}")
combined_svg = "\n\n".join(svg_codes)
gallery_html = create_gallery_html(all_candidates)
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
print(f"[COMPLETE] text-to-svg finished\n")
return gallery_html, combined_svg
@spaces.GPU
def gradio_image_to_svg(image, num_candidates, temperature, top_p, top_k, repetition_penalty,
replace_background, progress=gr.Progress()):
"""Gradio interface - image-to-svg with background handling"""
if image is None:
return (
'<div style="text-align:center;color:#999;padding:50px;">Please upload an image</div>',
"",
None
)
print("\n" + "="*60)
print(f"[TASK] image-to-svg ({MODEL_VARIANT})")
print(f"[INPUT] Image size: {image.size if hasattr(image, 'size') else 'unknown'}, mode: {image.mode if hasattr(image, 'mode') else 'unknown'}")
print(f"[PARAMS] candidates={num_candidates}, temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}, replace_bg={replace_background}")
print("="*60)
progress(0, "Processing input image...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
start_time = time.time()
img_processed, was_modified = preprocess_image_for_svg(
image,
replace_background=replace_background,
target_size=TARGET_IMAGE_SIZE
)
if was_modified:
print("[PREPROCESS] Background processed/replaced")
progress(0.05, "Background processed")
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
img_processed.save(tmp_file.name, format='PNG', quality=100)
tmp_path = tmp_file.name
try:
progress(0.1, "Preparing model inputs...")
inputs = prepare_inputs("image-to-svg", tmp_path)
def update_progress(val, msg):
progress(val, msg)
all_candidates = generate_candidates(
inputs, "image-to-svg", "image",
temperature, top_p, int(top_k), repetition_penalty,
MAX_LENGTH, int(num_candidates),
progress_callback=update_progress
)
elapsed = time.time() - start_time
print(f"[RESULT] Generated {len(all_candidates)} valid candidates in {elapsed:.2f}s")
if not all_candidates:
print("[WARNING] No valid SVG generated")
return (
'<div style="text-align:center;color:#999;padding:50px;">No valid SVG generated. Try adjusting parameters.</div>',
f"<!-- No valid SVG (took {elapsed:.1f}s) -->",
img_processed
)
svg_codes = []
for i, cand in enumerate(all_candidates):
svg_codes.append(f"<!-- ====== Candidate {i+1} | Paths: {cand['path_count']} ====== -->\n{cand['svg']}")
combined_svg = "\n\n".join(svg_codes)
gallery_html = create_gallery_html(all_candidates)
progress(1.0, f"Done! {len(all_candidates)} candidates in {elapsed:.1f}s")
print(f"[COMPLETE] image-to-svg finished\n")
return gallery_html, combined_svg, img_processed
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
def get_example_images():
"""Get example images from the examples directory"""
example_dir = "./examples"
example_images = []
if os.path.exists(example_dir):
for ext in SUPPORTED_FORMATS:
pattern = os.path.join(example_dir, f"*{ext}")
example_images.extend(glob.glob(pattern))
example_images.sort()
return example_images
def create_interface():
"""Create Gradio interface"""
# Example prompts
example_texts = [
"A black triangle pointing downward, centrally positioned.",
"A red heart shape with smooth curved edges, centered.",
"A yellow star with five sharp points, simple geometric design, flat color.",
"A blue arrow pointing to the right, thick solid shape, centered.",
"A green circle with a white checkmark inside, centered.",
"A black plus sign with equal length arms, thick lines, centered.",
"A simple person standing: round beige head, rectangular blue shirt body, two dark gray rectangular legs, arms at sides. Flat colors.",
"A girl with long black hair, wearing pink dress with triangular skirt, small circular face with dot eyes and curved smile. Simple cartoon style.",
"Circular avatar: person with short black hair, round face with two dot eyes and small curved smile, wearing blue collar shirt. Minimal style, centered in circle.",
"Layered mountain landscape: light blue sky at top, gray triangular snow-capped mountains in middle, dark green triangular pine trees at bottom. Flat colors.",
"Sunset beach scene: orange gradient sky at top, yellow semicircle sun on horizon, dark blue wavy ocean, tan beach strip at bottom. Simple shapes.",
"Cute orange cat sitting: round head with two triangular ears, oval body, curved tail. Black outline cartoon style, facing forward.",
"Simple house icon: red triangular roof, beige rectangular walls, brown door in center, two blue square windows, green ground at bottom.",
"Coffee mug: brown cylindrical cup with curved handle on right, three wavy steam lines rising from top. Flat style.",
"Red fox logo: triangular orange face with pointed ears, white chest marking, bushy tail. Minimalist style, facing right, centered.",
]
example_images = get_example_images()
# Dynamic header with model info
header_html = f"""
<div class="header-container">
<h1>OmniSVG Generator</h1>
<p>Transform images and text descriptions into scalable vector graphics</p>
<div class="model-badge">Model: OmniSVG {MODEL_VARIANT} | Qwen: {DEFAULT_QWEN_MODEL.split('/')[-1]}</div>
</div>
"""
with gr.Blocks(title=f"OmniSVG Generator ({MODEL_VARIANT})") as demo:
# Header
gr.HTML(header_html)
# Queue status
gr.HTML("""
<div style="background: #e7f3ff; border: 1px solid #b3d7ff; border-radius: 8px; padding: 12px 15px; margin-bottom: 15px;">
<span style="font-size: 1.5em;">ℹ️</span>
<strong>Queue System Active</strong> - Requests processed one at a time. Please wait patiently if busy.
</div>
""")
# Tips section
gr.HTML(TIPS_HTML)
with gr.Tabs():
# ==================== Image-to-SVG Tab ====================
with gr.TabItem("Image-to-SVG", id="image-tab"):
gr.HTML(IMAGE_TIPS_HTML)
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=300):
gr.Markdown("### Upload Image")
image_input = gr.Image(
label="Drag, upload, or Ctrl+V to paste",
type="pil",
image_mode="RGBA",
height=250,
sources=["upload", "clipboard"],
elem_classes=["input-image"]
)
with gr.Group(elem_classes=["settings-group"]):
gr.Markdown("### Settings")
img_num_candidates = gr.Slider(
minimum=1, maximum=MAX_NUM_CANDIDATES, value=DEFAULT_NUM_CANDIDATES, step=1,
label="Number of Candidates"
)
img_replace_bg = gr.Checkbox(
label="Replace non-white background",
value=True,
info="Enable for images with colored backgrounds"
)
with gr.Accordion("Advanced Parameters", open=False):
img_temperature = gr.Slider(
minimum=0.1, maximum=1.0,
value=TASK_CONFIGS["image-to-svg"].get("default_temperature", 0.3),
step=0.05,
label="Temperature (Lower=accurate)",
info="0.2-0.4 recommended"
)
img_top_p = gr.Slider(
minimum=0.5, maximum=1.0,
value=TASK_CONFIGS["image-to-svg"].get("default_top_p", 0.90),
step=0.02,
label="Top-P"
)
img_top_k = gr.Slider(
minimum=10, maximum=100,
value=TASK_CONFIGS["image-to-svg"].get("default_top_k", 50),
step=5,
label="Top-K"
)
img_rep_penalty = gr.Slider(
minimum=1.0, maximum=1.3,
value=TASK_CONFIGS["image-to-svg"].get("default_repetition_penalty", 1.05),
step=0.01,
label="Repetition Penalty"
)
image_generate_btn = gr.Button(
"Generate SVG",
variant="primary",
size="lg",
elem_classes=["primary-btn"]
)
if example_images:
gr.Markdown("### Examples")
gr.Examples(examples=example_images, inputs=[image_input], label="")
with gr.Column(scale=2, min_width=500):
gr.Markdown("### Processed Input")
image_processed = gr.Image(label="", type="pil", height=120)
gr.Markdown("### Generated SVG Candidates")
image_gallery = gr.HTML(
value='<div style="text-align:center;color:#999;padding:50px;background:#fafafa;border-radius:12px;">Generated SVGs will appear here</div>'
)
gr.Markdown("### SVG Code")
image_svg_output = gr.Code(label="", language="html", lines=10, elem_classes=["code-output"])
image_generate_btn.click(
fn=gradio_image_to_svg,
inputs=[image_input, img_num_candidates, img_temperature, img_top_p,
img_top_k, img_rep_penalty, img_replace_bg],
outputs=[image_gallery, image_svg_output, image_processed],
queue=True
)
# ==================== Text-to-SVG Tab ====================
with gr.TabItem("Text-to-SVG", id="text-tab"):
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=300):
gr.Markdown("### Description")
gr.HTML("""
<div style="background: #fff5f5; padding: 10px; border-radius: 8px; border-left: 4px solid #dc3545; margin-bottom: 10px;">
<strong style="color: #dc3545;">Generate 4-8 candidates and pick the best!</strong>
</div>
""")
text_input = gr.Textbox(
label="",
placeholder="Describe your SVG with geometric shapes and colors...\n\nExample: A black triangle pointing downward, centrally positioned.",
lines=5
)
with gr.Group(elem_classes=["settings-group"]):
gr.Markdown("### Settings")
text_num_candidates = gr.Slider(
minimum=1, maximum=MAX_NUM_CANDIDATES, value=6, step=1,
label="Number of Candidates",
info="More = better chances!"
)
with gr.Accordion("Advanced Parameters", open=False):
text_temperature = gr.Slider(
minimum=0.1, maximum=1.0,
value=TASK_CONFIGS["text-to-svg-icon"].get("default_temperature", 0.5),
step=0.05,
label="Temperature",
info="Icons: 0.3-0.5 | Complex: 0.5-0.7"
)
text_top_p = gr.Slider(
minimum=0.5, maximum=1.0,
value=TASK_CONFIGS["text-to-svg-icon"].get("default_top_p", 0.90),
step=0.02,
label="Top-P"
)
text_top_k = gr.Slider(
minimum=10, maximum=100,
value=TASK_CONFIGS["text-to-svg-icon"].get("default_top_k", 60),
step=5,
label="Top-K"
)
text_rep_penalty = gr.Slider(
minimum=1.0, maximum=1.3,
value=TASK_CONFIGS["text-to-svg-icon"].get("default_repetition_penalty", 1.03),
step=0.01,
label="Repetition Penalty",
info="Increase if you see repetitive patterns"
)
text_generate_btn = gr.Button(
"Generate SVG",
variant="primary",
size="lg",
elem_classes=["primary-btn"]
)
gr.Markdown("### Example Prompts")
gr.Examples(
examples=[[text] for text in example_texts],
inputs=[text_input],
label=""
)
with gr.Column(scale=2, min_width=500):
gr.Markdown("### Generated SVG Candidates")
gr.HTML("""
<div style="background: #d4edda; padding: 10px; border-radius: 8px; margin-bottom: 10px;">
<strong>Pick the best from multiple candidates!</strong>
</div>
""")
text_gallery = gr.HTML(
value='<div style="text-align:center;color:#999;padding:50px;background:#fafafa;border-radius:12px;">Generated SVGs will appear here</div>'
)
gr.Markdown("### SVG Code")
text_svg_output = gr.Code(label="", language="html", lines=12, elem_classes=["code-output"])
text_generate_btn.click(
fn=gradio_text_to_svg,
inputs=[text_input, text_num_candidates, text_temperature, text_top_p,
text_top_k, text_rep_penalty],
outputs=[text_gallery, text_svg_output],
queue=True
)
# Footer
gr.HTML(f"""
<div class="footer">
<p>Built with OmniSVG {MODEL_VARIANT}</p>
<p style="color: #dc3545; font-weight: 600;">Remember: Generate 4-8 candidates and pick the best!</p>
</div>
""")
return demo
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_args()
print("="*60)
print("OmniSVG Demo Page - Gradio App")
print("="*60)
# Load config with variant support
print(f"\nLoading config from: {args.config}")
config = load_config(args.config, variant=args.model_size)
MODEL_VARIANT = config['active_variant']
# Initialize constants from config
init_config_constants(config)
# Override model paths if provided via command line
weight_path = args.weight_path if args.weight_path else DEFAULT_OMNISVG_MODEL
model_path = args.model_path if args.model_path else DEFAULT_QWEN_MODEL
print(f"\n[CONFIG] Active variant: {MODEL_VARIANT}")
print(f"[CONFIG] Qwen model: {model_path}")
print(f"[CONFIG] OmniSVG weights: {weight_path}")
print(f"[CONFIG] Device: {device}")
print(f"[CONFIG] Precision: {DTYPE}")
print("="*60)
# Print loaded config values
print("\n[CONFIG] Loaded settings:")
print(f" - TARGET_IMAGE_SIZE: {TARGET_IMAGE_SIZE}")
print(f" - RENDER_SIZE: {RENDER_SIZE}")
print(f" - BLACK_COLOR_TOKEN: {BLACK_COLOR_TOKEN}")
print(f" - MAX_LENGTH: {MAX_LENGTH}")
print(f" - BOS_TOKEN_ID: {BOS_TOKEN_ID}")
print(f" - EOS_TOKEN_ID: {EOS_TOKEN_ID}")
print(f" - PAD_TOKEN_ID: {PAD_TOKEN_ID}")
# Print variant-specific token offsets
print(f"\n[CONFIG] Variant-specific ({MODEL_VARIANT}):")
print(f" - base_offset: {config.get('tokens', {}).get('base_offset', 'N/A')}")
print(f" - color_start_offset: {config.get('colors', {}).get('color_start_offset', 'N/A')}")
print(f" - color_end_offset: {config.get('colors', {}).get('color_end_offset', 'N/A')}")
print("="*60)
# Write variant-specific config for SVGTokenizer
variant_config_path = f'./config_{MODEL_VARIANT.lower()}_runtime.yaml'
write_variant_config(config, variant_config_path)
print(f"\n[CONFIG] Written variant config to: {variant_config_path}")
print("\nLoading models (may download from HuggingFace Hub if needed)...")
load_models(weight_path, model_path, variant_config_path)
print("Models loaded successfully!\n")
demo = create_interface()
demo.queue(default_concurrency_limit=1, max_size=20)
demo.launch(
server_name=args.listen,
server_port=args.port,
share=args.share,
debug=args.debug,
)