tusharmagar's picture
Update app.py
101618d verified
# app.py
import gradio as gr
import torch
from diffusers import FluxPipeline
import os
# --- 1. Model Configuration ---
# Automatically detect if a GPU is available, otherwise use CPU
# The app will only work well on a GPU, but this prevents it from crashing on startup.
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model_id = "black-forest-labs/FLUX.1-Krea-dev"
lora_model_id = "tusharmagar/flux1-krea-dev-lora-solarpunk"
# --- 2. Load the Pipeline ---
# Load the base FLUX pipeline with the recommended bfloat16 data type
# We add the 'token' argument to authenticate and download the gated model.
pipe = FluxPipeline.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16,
token=os.environ.get("HF_TOKEN")
)
# Move the pipeline to the selected device (GPU)
pipe.to(device)
# Load and fuse your LoRA weights into the pipeline.
pipe.load_lora_weights(lora_model_id)
pipe.fuse_lora()
# --- 3. Define the Inference Function ---
def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
"""
Generates an image based on the provided prompt.
"""
print(f"Generating image for prompt: {prompt}")
# Ensure the generator is on the correct device
generator = torch.Generator(device=device).manual_seed(42)
image = pipe(
prompt=prompt,
num_inference_steps=8,
guidance_scale=0.0,
generator=generator
).images[0]
return image
# --- 4. Create the Gradio Interface ---
css = """
#col-container {
margin: 0 auto;
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
# FLUX.1 Solarpunk LoRA β˜€οΈ
An interactive demo for the [flux1-krea-dev-lora-solarpunk](https://huggingface.co/tusharmagar/flux1-krea-dev-lora-solarpunk) model.
This LoRA excels at creating dreamy, solarpunk imaginations of real-world cities.
**Don't forget to add the trigger word `[SLRPNK]` to your prompt!**
"""
)
with gr.Row():
prompt_input = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt...",
container=False,
scale=5,
)
generate_button = gr.Button("Generate", scale=1, variant="primary")
image_output = gr.Image(label="Result", show_label=False)
gr.Examples(
examples=[
"Solarpunk London with hexagonal solar panels, white architecture, keeping Big Ben unchanged, with a double decker bus on the road [SLRPNK]",
"Aerial view of Solarpunk San Francisco with futuristic townhouses, solar sails, keeping the Golden Gate Bridge unchanged, with a cable car [SLRPNK]",
"Solarpunk Masai Mara tribe with solar panel dome greenhouses, white mud houses, giraffes and elephants [SLRPNK]",
"Solarpunk Rio de Janeiro with tropical solar sails shaped like leaves lining the beaches, keeping Christ the Redeemer unchanged, and futuristic white towers [SLRPNK]",
],
inputs=[prompt_input],
outputs=[image_output],
fn=generate_image,
)
generate_button.click(
fn=generate_image,
inputs=[prompt_input],
outputs=[image_output],
)
prompt_input.submit(
fn=generate_image,
inputs=[prompt_input],
outputs=[image_output],
)
# --- 5. Launch the App ---
if __name__ == "__main__":
demo.launch()