YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

🎨 PMA-VAE: Parallel Mobile Artistic Variational Autoencoder

A novel attention-free architecture for image generation, super-resolution, artifact removal, and artistic style transfer.

Open In Colab

πŸ—οΈ Architecture

Image β†’ PixelUnshuffle stem β†’ MobileConv stages β†’ Parallel 2D Mamba blocks
  β†’ Multi-scale latent (z_base H/16, z_detail H/8, z_style global)
  β†’ Light parallel decoder with FiLM style modulation β†’ Reconstructed image

Key Design Principles

Component Choice Why
Backbone MobileConv + Parallel 2D Mamba Fast, efficient, attention-free
Downsampling PixelUnshuffle β†’ stride-2 conv Lossless initial features
Upsampling PixelShuffle (sub-pixel) Mobile-friendly, no checkerboard
Latent Multi-scale (base/detail/style) Controllable, prevents collapse
Style control FiLM conditioning Lightweight, multiplicative
Global context 4-dir cross-scan SSM O(n) complexity, no attention
Local context Depthwise separable conv + SE Standard mobile building block
Training Progressive resolution + KL warmup Stable convergence
Loss L1 + VGG + PatchGAN + edge + KL Comprehensive quality

✨ Features

  • Attention-free: Uses parallel 2D Mamba/SSM blocks instead of self-attention
  • Mobile-deployable: Lightweight decoder (~4-8M params) using depthwise separable convolutions
  • Multi-scale latent space: z_base (structure), z_detail (texture), z_style (global style)
  • No sequential pixel loops: Fully parallel training AND inference via Blelloch parallel scan
  • Anti-collapse: KL warmup + free bits + progressive resolution training
  • FiLM conditioning: Style modulation throughout the decoder for artist style transfer
  • Pure PyTorch: No custom CUDA kernels needed β€” works on Colab free tier T4

πŸ“Š Model Variants

Config Encoder Decoder Total Target
pmavae_tiny 0.56M 1.91M 2.47M Testing
pmavae_small 2.00M 4.27M 6.27M Free Colab T4
pmavae_base 5.18M 9.83M 15.01M Colab Pro / better GPU

πŸ”§ Multi-Scale Latent Space

z_base   : H/16 Γ— W/16 Γ— 24-32  β†’ Structure, composition, objects
z_detail : H/8  Γ— W/8  Γ— 6-8    β†’ Texture, brush strokes, edges
z_style  : 1 Γ— 1 Γ— 96-128       β†’ Global style vector

This separation enables:

  • Style transfer: Swap z_style between images
  • Super-resolution: Enhance z_detail while keeping z_base
  • Artifact removal: Clean up z_detail while preserving structure
  • Image generation: Sample from learned distributions

πŸ‹οΈ Training

Loss Function

Loss = L1 + 0.5 Γ— VGG_perceptual + 0.1 Γ— edge_sobel + Ξ² Γ— KL_free_bits + Ξ» Γ— PatchGAN
  • KL warmup: Ξ² linearly increases from 0 β†’ 1e-6 over 5000 steps
  • Discriminator cold start: PatchGAN activates after 10000 steps
  • Adaptive disc weight: Gradient magnitude balancing (taming-transformers trick)
  • Free bits: 0.25 nats per latent dimension prevents posterior collapse

Progressive Resolution

Phase 1: 256Γ—256 β†’ Learn structure
Phase 2: 384Γ—384 β†’ Refine texture
Phase 3: 512Γ—512 β†’ Full detail
Phase 4: FHD tiled β†’ High-resolution fine-tuning

πŸ“± Mobile Deployment

The decoder is designed for mobile:

  • Depthwise separable convolutions (MobileNet-style)
  • PixelShuffle upsampling (no transpose conv artifacts)
  • Squeeze-Excitation for channel attention
  • FiLM style modulation (single MLP)
  • Exportable to ONNX β†’ Core ML / TFLite / ONNX Runtime Mobile

πŸ“ Files

  • model.py β€” Full PMA-VAE architecture (encoder, decoder, SSM blocks)
  • losses.py β€” Loss functions (VGG perceptual, PatchGAN, KL free bits, edge loss)
  • train.py β€” Training script with progressive resolution, checkpoint management
  • PMA_VAE_Colab_Training.ipynb β€” Complete Colab notebook

πŸ“š References

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for krystv/PMA-VAE