|
|
from aura_sr import AuraSR |
|
|
import gradio as gr |
|
|
import spaces |
|
|
|
|
|
|
|
|
class ZeroGPUAuraSR(AuraSR): |
|
|
@classmethod |
|
|
def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True): |
|
|
import json |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
if Path(model_id).is_file(): |
|
|
local_file = Path(model_id) |
|
|
if local_file.suffix == '.safetensors': |
|
|
use_safetensors = True |
|
|
elif local_file.suffix == '.ckpt': |
|
|
use_safetensors = False |
|
|
else: |
|
|
raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") |
|
|
|
|
|
|
|
|
config_path = local_file.with_name('config.json') |
|
|
if not config_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"Config file not found: {config_path}. " |
|
|
f"When loading from a local file, ensure that 'config.json' " |
|
|
f"is present in the same directory as '{local_file.name}'. " |
|
|
f"If you're trying to load a model from Hugging Face, " |
|
|
f"please provide the model ID instead of a file path." |
|
|
) |
|
|
|
|
|
config = json.loads(config_path.read_text()) |
|
|
hf_model_path = local_file.parent |
|
|
else: |
|
|
hf_model_path = Path(snapshot_download(model_id)) |
|
|
config = json.loads((hf_model_path / "config.json").read_text()) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
if use_safetensors: |
|
|
try: |
|
|
from safetensors.torch import load_file |
|
|
checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"The safetensors library is not installed. " |
|
|
"Please install it with `pip install safetensors` " |
|
|
"or use `use_safetensors=False` to load the model with PyTorch." |
|
|
) |
|
|
else: |
|
|
checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) |
|
|
|
|
|
model.upsampler.load_state_dict(checkpoint, strict=True) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
aura_sr = ZeroGPUAuraSR.from_pretrained("fal/AuraSR-v2") |
|
|
aura_sr_v1 = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR") |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def predict(img, model_selection): |
|
|
return {'v1': aura_sr_v1, 'v2': aura_sr}.get(model_selection).upscale_4x(img) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
predict, |
|
|
inputs=[gr.Image(), gr.Dropdown(value='v2', choices=['v1', 'v2'])], |
|
|
outputs=gr.Image() |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |