rahul7star commited on
Commit
a6261ec
Β·
verified Β·
1 Parent(s): 0506317

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +39 -35
app_quant_latent.py CHANGED
@@ -560,66 +560,70 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
560
  device = "cuda"
561
  generator = torch.Generator(device).manual_seed(int(seed))
562
 
 
 
563
  latent_gallery = []
564
- final_image = None
565
 
566
  try:
567
- # Try advanced latent mode
568
  try:
569
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
570
 
571
  for i, t in enumerate(pipe.scheduler.timesteps):
 
572
  with torch.no_grad():
573
- noise_pred = pipe.unet(
574
- latents,
575
- t,
576
- encoder_hidden_states=pipe.get_text_embeddings(prompt)
577
- )["sample"]
578
-
579
  latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]
580
 
581
- # Convert to preview
582
  try:
583
- latent_img = latent_to_image(latents)
584
  except Exception:
585
- latent_img = Image.new("RGB", (width, height), "white")
586
 
587
  latent_gallery.append(latent_img)
588
 
589
- # ---- LIVE UPDATE ----
590
- yield latent_gallery, None, "\n".join(LOGS)
591
 
592
- # Final decode
593
- final = pipe.decode_latents(latents)[0]
594
- final_image = final
595
  LOGS.append("βœ… Advanced latent pipeline succeeded.")
596
-
597
- # ---- FINAL OUTPUT ----
598
- yield latent_gallery, final_image, "\n".join(LOGS)
599
 
600
  except Exception as e:
601
  LOGS.append(f"⚠️ Advanced latent mode failed: {e}")
602
  LOGS.append("πŸ” Switching to standard pipeline...")
603
 
604
- output = pipe(
605
- prompt=prompt,
606
- height=height,
607
- width=width,
608
- num_inference_steps=steps,
609
- guidance_scale=guidance_scale,
610
- generator=generator,
611
- )
612
-
613
- final_image = output.images[0]
614
- latent_gallery.append(final_image)
615
- LOGS.append("βœ… Standard pipeline succeeded.")
616
-
617
- yield latent_gallery, final_image, "\n".join(LOGS)
 
 
 
 
 
 
 
618
 
619
  except Exception as e:
620
  LOGS.append(f"❌ Total failure: {e}")
621
- placeholder = Image.new("RGB", (width, height), "white")
622
- yield [placeholder], placeholder, "\n".join(LOGS)
 
623
 
624
 
625
 
 
560
  device = "cuda"
561
  generator = torch.Generator(device).manual_seed(int(seed))
562
 
563
+ # placeholders
564
+ placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
565
  latent_gallery = []
566
+ final_gallery = []
567
 
568
  try:
569
+ # --- Try advanced latent mode ---
570
  try:
571
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
572
 
573
  for i, t in enumerate(pipe.scheduler.timesteps):
574
+ # Step-wise denoising
575
  with torch.no_grad():
576
+ noise_pred = pipe.unet(latents, t, encoder_hidden_states=pipe.get_text_embeddings(prompt))["sample"]
 
 
 
 
 
577
  latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]
578
 
579
+ # Convert latent to preview image
580
  try:
581
+ latent_img = latent_to_image(latents) # returns single PIL image
582
  except Exception:
583
+ latent_img = placeholder
584
 
585
  latent_gallery.append(latent_img)
586
 
587
+ # Yield intermediate update: final gallery empty for now
588
+ yield None, latent_gallery, final_gallery, LOGS
589
 
590
+ # decode final image after all timesteps
591
+ final_img = pipe.decode_latents(latents)[0]
592
+ final_gallery.append(final_img)
593
  LOGS.append("βœ… Advanced latent pipeline succeeded.")
594
+ yield final_img, latent_gallery, final_gallery, LOGS
 
 
595
 
596
  except Exception as e:
597
  LOGS.append(f"⚠️ Advanced latent mode failed: {e}")
598
  LOGS.append("πŸ” Switching to standard pipeline...")
599
 
600
+ # Standard pipeline fallback
601
+ try:
602
+ output = pipe(
603
+ prompt=prompt,
604
+ height=height,
605
+ width=width,
606
+ num_inference_steps=steps,
607
+ guidance_scale=guidance_scale,
608
+ generator=generator,
609
+ )
610
+ final_img = output.images[0]
611
+ final_gallery.append(final_img)
612
+ latent_gallery.append(final_img) # optionally show in latent gallery as last step
613
+ LOGS.append("βœ… Standard pipeline succeeded.")
614
+ yield final_img, latent_gallery, final_gallery, LOGS
615
+
616
+ except Exception as e2:
617
+ LOGS.append(f"❌ Standard pipeline failed: {e2}")
618
+ final_gallery.append(placeholder)
619
+ latent_gallery.append(placeholder)
620
+ yield placeholder, latent_gallery, final_gallery, LOGS
621
 
622
  except Exception as e:
623
  LOGS.append(f"❌ Total failure: {e}")
624
+ final_gallery.append(placeholder)
625
+ latent_gallery.append(placeholder)
626
+ yield placeholder, latent_gallery, final_gallery, LOGS
627
 
628
 
629