rahul7star commited on
Commit
e7274da
·
verified ·
1 Parent(s): 7b09632

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +55 -222
app_quant_latent.py CHANGED
@@ -579,6 +579,8 @@ def upload_latents_to_hf(latent_dict, filename="latents.pt"):
579
  os.remove(local_path)
580
  raise e
581
 
 
 
582
  @spaces.GPU
583
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
584
  LOGS = []
@@ -589,49 +591,52 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
589
  latent_gallery = []
590
  final_gallery = []
591
 
592
- # --- Generate latent previews in a loop ---
 
 
593
  try:
594
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
595
 
596
- # Convert latents to float32 if necessary
597
- if latents.dtype != torch.float32:
598
- latents = latents.float()
599
-
600
- # Loop for multiple previews before final image
601
- num_previews = min(10, steps) # show ~10 previews
602
- preview_steps = torch.linspace(0, 1, num_previews)
603
-
604
- for i, alpha in enumerate(preview_steps):
605
  try:
 
606
  with torch.no_grad():
607
- # Simple noise interpolation for preview (simulate denoising progress)
608
- preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
609
- # Decode to PIL
610
- latent_img_tensor = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
611
  latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
612
  latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
613
- latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype('uint8'))
614
- except Exception as e:
615
- LOGS.append(f"⚠️ Latent preview decode failed: {e}")
616
  latent_img = placeholder
 
617
 
618
  latent_gallery.append(latent_img)
619
- yield None, latent_gallery, LOGS # update Gradio with intermediate preview
620
 
621
- # Save final latents to HF
622
- latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
 
 
 
623
  try:
624
- hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
625
- LOGS.append(f"🔹 Latents uploaded: {hf_url}")
 
 
 
 
 
626
  except Exception as e:
627
- LOGS.append(f"⚠️ Failed to upload latents: {e}")
628
 
629
  except Exception as e:
630
  LOGS.append(f"⚠️ Latent generation failed: {e}")
631
  latent_gallery.append(placeholder)
632
  yield None, latent_gallery, LOGS
633
 
634
- # --- Final image: untouched standard pipeline ---
635
  try:
636
  output = pipe(
637
  prompt=prompt,
@@ -652,6 +657,7 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
652
  final_gallery.append(placeholder)
653
  latent_gallery.append(placeholder)
654
  yield placeholder, latent_gallery, LOGS
 
655
  # this is astable vesopn tha can gen final and a noise to latent
656
  @spaces.GPU
657
  def generate_image0(prompt, height, width, steps, seed, guidance_scale=0.0):
@@ -663,46 +669,36 @@ def generate_image0(prompt, height, width, steps, seed, guidance_scale=0.0):
663
  latent_gallery = []
664
  final_gallery = []
665
 
666
- # --- Try generating latent previews ---
667
  try:
668
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
669
 
670
- # Decode latent tensor to PIL for preview with robust fallbacks
671
- latent_img = placeholder
672
- try:
673
- with torch.no_grad():
674
- # 1️⃣ Try normal VAE decode if available
675
- if hasattr(pipe, "vae") and hasattr(pipe.vae, "decode"):
676
- try:
677
- latent_img_tensor = pipe.vae.decode(latents).sample # [1,3,H,W]
678
- latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
679
- latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
680
- latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype('uint8'))
681
- except Exception as e1:
682
- LOGS.append(f"⚠️ VAE decode failed: {e1}")
683
-
684
- # 2️⃣ Collapse first 3 channels if decode failed
685
- if latent_img is placeholder and latents.shape[1] >= 3:
686
- ch = latents[0, :3, :, :]
687
- ch = (ch - ch.min()) / (ch.max() - ch.min() + 1e-8)
688
- latent_img = Image.fromarray((ch.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8'))
689
-
690
- # 3️⃣ Collapse all channels to mean -> replicate to RGB
691
- if latent_img is placeholder:
692
- mean_ch = latents[0].mean(dim=0, keepdim=True) # [1,H,W]
693
- mean_ch = (mean_ch - mean_ch.min()) / (mean_ch.max() - mean_ch.min() + 1e-8)
694
- latent_img = Image.fromarray(
695
- torch.cat([mean_ch]*3, dim=0).permute(1,2,0).cpu().numpy().astype('uint8')
696
- )
697
 
698
- except Exception as e:
699
- LOGS.append(f"⚠️ Latent to image conversion failed: {e}")
700
- latent_img = placeholder
701
 
702
- latent_gallery.append(latent_img)
703
- yield None, latent_gallery, LOGS # show preview immediately
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
 
705
- # Save latents to HF for later testing
706
  latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
707
  try:
708
  hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
@@ -715,7 +711,7 @@ def generate_image0(prompt, height, width, steps, seed, guidance_scale=0.0):
715
  latent_gallery.append(placeholder)
716
  yield None, latent_gallery, LOGS
717
 
718
- # --- Final image: completely untouched, uses standard pipeline ---
719
  try:
720
  output = pipe(
721
  prompt=prompt,
@@ -736,7 +732,6 @@ def generate_image0(prompt, height, width, steps, seed, guidance_scale=0.0):
736
  final_gallery.append(placeholder)
737
  latent_gallery.append(placeholder)
738
  yield placeholder, latent_gallery, LOGS
739
- # this version generate well for final and gives a tensor back for latent
740
 
741
 
742
 
@@ -744,169 +739,7 @@ def generate_image0(prompt, height, width, steps, seed, guidance_scale=0.0):
744
 
745
 
746
 
747
- @spaces.GPU
748
- def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
749
- """
750
- Robust dual pipeline:
751
- - Advanced latent generation first
752
- - Fallback to standard pipeline if latent fails
753
- - Always returns final image
754
- - Returns gallery (latents or final image) and logs
755
- """
756
 
757
- LOGS = []
758
- image = None
759
- latents = None
760
- gallery = []
761
-
762
- # Keep a placeholder original image (white) in case everything fails
763
- original_image = Image.new("RGB", (width, height), color=(255, 255, 255))
764
-
765
- try:
766
- generator = torch.Generator(device).manual_seed(int(seed))
767
-
768
- # -------------------------------
769
- # Try advanced latent generation
770
- # -------------------------------
771
- try:
772
- batch_size = 1
773
- num_channels_latents = getattr(pipe.unet, "in_channels", None)
774
- if num_channels_latents is None:
775
- raise AttributeError("pipe.unet.in_channels not found, fallback to standard pipeline")
776
-
777
- latents = pipe.prepare_latents(
778
- batch_size=batch_size,
779
- num_channels=num_channels_latents,
780
- height=height,
781
- width=width,
782
- dtype=torch.float32,
783
- device=device,
784
- generator=generator
785
- )
786
- LOGS.append(f"✅ Latents prepared: {latents.shape}")
787
-
788
- output = pipe(
789
- prompt=prompt,
790
- height=height,
791
- width=width,
792
- num_inference_steps=steps,
793
- guidance_scale=guidance_scale,
794
- generator=generator,
795
- latents=latents
796
- )
797
- image = output.images[0]
798
- gallery = [image] if image else []
799
-
800
- LOGS.append("✅ Advanced latent generation succeeded.")
801
-
802
- # -------------------------------
803
- # Fallback to standard pipeline
804
- # -------------------------------
805
- except Exception as e_latent:
806
- LOGS.append(f"⚠️ Advanced latent generation failed: {e_latent}")
807
- LOGS.append("🔁 Falling back to standard pipeline...")
808
-
809
- try:
810
- output = pipe(
811
- prompt=prompt,
812
- height=height,
813
- width=width,
814
- num_inference_steps=steps,
815
- guidance_scale=guidance_scale,
816
- generator=generator
817
- )
818
- image = output.images[0]
819
- gallery = [image] if image else []
820
- LOGS.append("✅ Standard pipeline generation succeeded.")
821
- except Exception as e_standard:
822
- LOGS.append(f"❌ Standard pipeline generation failed: {e_standard}")
823
- image = original_image # Always return some image
824
- gallery = [image]
825
-
826
- # -------------------------------
827
- # Return all 3 outputs
828
- # -------------------------------
829
- return image, gallery, LOGS
830
-
831
- except Exception as e:
832
- LOGS.append(f"❌ Inference failed entirely: {e}")
833
- return original_image, [original_image], LOGS
834
-
835
- # ============================================================
836
- # UI
837
- # ============================================================
838
-
839
- # Utility: scan local HF cache for safetensors in a repo folder name
840
- def list_loras_from_repo(repo_id):
841
- """
842
- Attempts to find safetensors inside HF cache directory for repo_id.
843
- This only scans local cache; it does NOT download anything.
844
-
845
- Returns:
846
- A list of strings suitable for showing in the dropdown. Prefer returning
847
- paths relative to the repo root (e.g. "NSFW/doggystyle_pov.safetensors") so that
848
- pipe.load_lora_weights(repo_id, weight_name=that_path) works for nested files.
849
- If a relative path can't be determined, returns absolute cached file paths.
850
- """
851
- if not repo_id:
852
- return []
853
-
854
- safe_list = []
855
-
856
- # Candidate cache roots
857
- hf_cache = os.path.expanduser("~/.cache/huggingface/hub")
858
- alt_cache = "/home/user/.cache/huggingface/hub"
859
- candidates = [hf_cache, alt_cache]
860
-
861
- # Normalize repo variants to search for in path
862
- owner_repo = repo_id.replace("/", "_")
863
- owner_repo_dash = repo_id.replace("/", "-")
864
- owner_repo_double = repo_id.replace("/", "--")
865
-
866
- # Walk caches and collect safetensors
867
- for root_cache in candidates:
868
- if not os.path.exists(root_cache):
869
- continue
870
- for dirpath, dirnames, filenames in os.walk(root_cache):
871
- for f in filenames:
872
- if not f.endswith(".safetensors"):
873
- continue
874
- full_path = os.path.join(dirpath, f)
875
-
876
- # try to find a repo-root-like substring in dirpath
877
- chosen_base = None
878
- for pattern in (owner_repo_double, owner_repo_dash, owner_repo):
879
- idx = dirpath.find(pattern)
880
- if idx != -1:
881
- chosen_base = dirpath[: idx + len(pattern)]
882
- break
883
-
884
- # fallback: look for the repo folder name (last component) e.g., "ZImageLora"
885
- if chosen_base is None:
886
- repo_tail = repo_id.split("/")[-1]
887
- idx2 = dirpath.find(repo_tail)
888
- if idx2 != -1:
889
- chosen_base = dirpath[: idx2 + len(repo_tail)]
890
-
891
- # If we found a base that looks like the cached repo root, compute relative path
892
- if chosen_base:
893
- try:
894
- rel = os.path.relpath(full_path, chosen_base)
895
- # If relpath goes up (starts with ..) then prefer full_path
896
- if rel and not rel.startswith(".."):
897
- # Normalize to forward slashes for HF repo weight_name usage
898
- rel_normalized = rel.replace(os.sep, "/")
899
- safe_list.append(rel_normalized)
900
- continue
901
- except Exception:
902
- pass
903
-
904
- # Otherwise append absolute path (last resort)
905
- safe_list.append(full_path)
906
-
907
- # remove duplicates and sort
908
- safe_list = sorted(list(dict.fromkeys(safe_list)))
909
- return safe_list
910
 
911
 
912
  with gr.Blocks(title="Z-Image-Turbo") as demo:
 
579
  os.remove(local_path)
580
  raise e
581
 
582
+
583
+
584
  @spaces.GPU
585
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
586
  LOGS = []
 
591
  latent_gallery = []
592
  final_gallery = []
593
 
594
+ all_latents = [] # store all preview latents
595
+
596
+ # --- Try generating latent previews ---
597
  try:
598
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
599
 
600
+ # Loop through timesteps for preview generation
601
+ for i, t in enumerate(pipe.scheduler.timesteps):
 
 
 
 
 
 
 
602
  try:
603
+ # Convert latent tensor to PIL for preview
604
  with torch.no_grad():
605
+ # Some pipelines may require same dtype as bias
606
+ latent_to_decode = latents.to(pipe.vae.dtype)
607
+ latent_img_tensor = pipe.vae.decode(latent_to_decode).sample # [1,3,H,W]
 
608
  latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
609
  latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
610
+ latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype("uint8"))
611
+ except Exception:
 
612
  latent_img = placeholder
613
+ LOGS.append("⚠️ Latent preview decode failed.")
614
 
615
  latent_gallery.append(latent_img)
616
+ all_latents.append(latents.cpu().clone()) # save current latent
617
 
618
+ # Yield intermediate preview every few steps
619
+ if i % max(1, len(pipe.scheduler.timesteps) // 10) == 0:
620
+ yield None, latent_gallery, LOGS
621
+
622
+ # Upload full series of latents
623
  try:
624
+ latent_dict = {
625
+ "latents_series": all_latents,
626
+ "prompt": prompt,
627
+ "seed": seed
628
+ }
629
+ hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_series_{seed}.pt")
630
+ LOGS.append(f"🔹 All preview latents uploaded: {hf_url}")
631
  except Exception as e:
632
+ LOGS.append(f"⚠️ Failed to upload all preview latents: {e}")
633
 
634
  except Exception as e:
635
  LOGS.append(f"⚠️ Latent generation failed: {e}")
636
  latent_gallery.append(placeholder)
637
  yield None, latent_gallery, LOGS
638
 
639
+ # --- Final image: completely untouched, uses standard pipeline ---
640
  try:
641
  output = pipe(
642
  prompt=prompt,
 
657
  final_gallery.append(placeholder)
658
  latent_gallery.append(placeholder)
659
  yield placeholder, latent_gallery, LOGS
660
+
661
  # this is astable vesopn tha can gen final and a noise to latent
662
  @spaces.GPU
663
  def generate_image0(prompt, height, width, steps, seed, guidance_scale=0.0):
 
669
  latent_gallery = []
670
  final_gallery = []
671
 
672
+ # --- Generate latent previews in a loop ---
673
  try:
674
  latents = safe_get_latents(pipe, height, width, generator, device, LOGS)
675
 
676
+ # Convert latents to float32 if necessary
677
+ if latents.dtype != torch.float32:
678
+ latents = latents.float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
+ # Loop for multiple previews before final image
681
+ num_previews = min(10, steps) # show ~10 previews
682
+ preview_steps = torch.linspace(0, 1, num_previews)
683
 
684
+ for i, alpha in enumerate(preview_steps):
685
+ try:
686
+ with torch.no_grad():
687
+ # Simple noise interpolation for preview (simulate denoising progress)
688
+ preview_latent = latents * alpha + torch.randn_like(latents) * (1 - alpha)
689
+ # Decode to PIL
690
+ latent_img_tensor = pipe.vae.decode(preview_latent).sample # [1,3,H,W]
691
+ latent_img_tensor = (latent_img_tensor / 2 + 0.5).clamp(0, 1)
692
+ latent_img_tensor = latent_img_tensor.cpu().permute(0, 2, 3, 1)[0]
693
+ latent_img = Image.fromarray((latent_img_tensor.numpy() * 255).astype('uint8'))
694
+ except Exception as e:
695
+ LOGS.append(f"⚠️ Latent preview decode failed: {e}")
696
+ latent_img = placeholder
697
+
698
+ latent_gallery.append(latent_img)
699
+ yield None, latent_gallery, LOGS # update Gradio with intermediate preview
700
 
701
+ # Save final latents to HF
702
  latent_dict = {"latents": latents.cpu(), "prompt": prompt, "seed": seed}
703
  try:
704
  hf_url = upload_latents_to_hf(latent_dict, filename=f"latents_{seed}.pt")
 
711
  latent_gallery.append(placeholder)
712
  yield None, latent_gallery, LOGS
713
 
714
+ # --- Final image: untouched standard pipeline ---
715
  try:
716
  output = pipe(
717
  prompt=prompt,
 
732
  final_gallery.append(placeholder)
733
  latent_gallery.append(placeholder)
734
  yield placeholder, latent_gallery, LOGS
 
735
 
736
 
737
 
 
739
 
740
 
741
 
 
 
 
 
 
 
 
 
 
742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
 
744
 
745
  with gr.Blocks(title="Z-Image-Turbo") as demo: