Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,8 +9,8 @@ import torchvision.transforms.functional as F
|
|
| 9 |
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
-
device = "
|
| 13 |
-
weight_type = torch.
|
| 14 |
|
| 15 |
controlnet = ControlNetModel.from_pretrained(
|
| 16 |
"IDKiro/sdxs-512-dreamshaper-sketch", torch_dtype=weight_type
|
|
@@ -88,7 +88,7 @@ def run(
|
|
| 88 |
prompt_template,
|
| 89 |
style_name,
|
| 90 |
controlnet_conditioning_scale,
|
| 91 |
-
device_type="
|
| 92 |
param_dtype='torch.float16',
|
| 93 |
):
|
| 94 |
if device_type == "CPU":
|
|
|
|
| 9 |
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
+
device = "cuda"
|
| 13 |
+
weight_type = torch.float16
|
| 14 |
|
| 15 |
controlnet = ControlNetModel.from_pretrained(
|
| 16 |
"IDKiro/sdxs-512-dreamshaper-sketch", torch_dtype=weight_type
|
|
|
|
| 88 |
prompt_template,
|
| 89 |
style_name,
|
| 90 |
controlnet_conditioning_scale,
|
| 91 |
+
device_type="GPU",
|
| 92 |
param_dtype='torch.float16',
|
| 93 |
):
|
| 94 |
if device_type == "CPU":
|