π§ Standard Transformer - Fully Sharded Edition
Happy New Year 2026! ππ
A high-performance GPT-style transformer implementation with full model and activation sharding for TPU/GPU clusters. Built with JAX and Flax for maximum efficiency on distributed hardware.
β¨ Features
Core Architecture
- Standard Transformer Decoder - GPT-style autoregressive language model
- 70 Layers - Deep architecture with 1024 hidden dimensions
- Multi-Head Attention - 16 attention heads with RoPE positional embeddings
- SwiGLU FFN - Modern feed-forward networks with gated activations
- RMSNorm - Efficient layer normalization for stability
Distributed Training
- Full Model Sharding - Tensor parallelism across devices
- Activation Sharding - Memory-efficient forward/backward passes
- Mesh Parallelism - Flexible data + model parallel configurations
- Per-Device RAM Monitoring - Real-time memory tracking for each chip
Production Ready
- Safetensors Format - Fast and safe model serialization
- Memory Profiling - Detailed RAM usage analytics and history
- ETA Tracking - Accurate training time estimates
- Test Generation - Visual verification during training
π Quick Start
Requirements
pip install jax flax optax safetensors tokenizers pandas numpy
Training
python transformer_sharded.py
The script automatically:
- Detects available TPU/GPU devices
- Creates optimal sharding mesh
- Loads tokenized data from cache
- Trains with memory monitoring
- Saves checkpoints + final model
π Model Specifications
| Parameter | Value |
|---|---|
| Vocabulary Size | 50,262 tokens (GPT-2) |
| Hidden Dimension | 1,024 |
| Attention Heads | 16 (64 dim each) |
| Layers | 70 |
| FFN Dimension | 4,096 (4x multiplier) |
| Max Sequence Length | 1,024 tokens |
| Total Parameters | ~700M+ |
| Precision | BFloat16 |
π Sharding Strategy
Model Parallelism
- Embeddings: Sharded across model dimension
- Attention QKV: Column-parallel projections
- Attention Output: Row-parallel projection
- FFN Gate/Up: Column-parallel
- FFN Down: Row-parallel
- LM Head: Row-parallel (tied weights optional)
Activation Sharding
Embeddings: (B, L, D) β P('data', None, 'model')
Attention: (B, H, L, D_head) β P('data', 'model', None, None)
FFN Hidden: (B, L, 4*D) β P('data', None, 'model')
Memory Efficiency
- Activation checkpointing via sharding constraints
- Per-layer memory profiling
- Automatic batch sharding across data parallel dimension
πΎ Memory Monitoring
Real-time tracking includes:
- Total RAM: Across all devices
- Per-Chip Usage: Individual device breakdowns
- Peak Memory: Maximum allocation reached
- Utilization %: Current memory pressure
- History Export: CSV logs for analysis
Example output:
πΎ RAM: 45.23/96.00GB (47.1%) | Peak: 52.34GB
Per-chip: [5.65GB, 5.67GB, 5.64GB, 5.66GB, ...]
π― Training Features
Optimizer
- AdamW with weight decay (0.1)
- Cosine LR Schedule with warmup (200 steps)
- Gradient Clipping (max norm: 1.0)
- Label Smoothing (optional)
Loss Functions
- Cross-Entropy with integer labels
- Z-Loss for training stability (1e-4 weight)
- Accuracy Tracking per batch
- Perplexity Metrics
Data Pipeline
- Auto-splits train/val (90/10)
- Batch creation with shuffling
- Automatic padding/truncation
- Efficient numpy arrays
π Output Structure
/kaggle/working/transformer-sharded/
βββ checkpoints/
β βββ model_ep1.safetensors
β βββ model_ep2.safetensors
β βββ ...
βββ model_final.safetensors
βββ config.json
βββ history.csv
βββ memory_history.csv
Files
- model_final.safetensors: Final trained weights
- config.json: Model architecture + training stats
- history.csv: Loss, accuracy, memory per epoch
- memory_history.csv: Detailed per-step RAM usage
π§ͺ Test Generation
During training, the model generates text from prompts:
> Hello, how are you today? I'm doing great, thanks for asking...
> The meaning of life is to find purpose and happiness in...
> Once upon a time there was a brave knight who...
Visual verification ensures the model is learning properly!
βοΈ Configuration
Key hyperparameters in TransformerConfig:
d_model: int = 1024 # Hidden size
n_heads: int = 16 # Attention heads
n_layers: int = 70 # Transformer blocks
max_len: int = 1024 # Context window
lr: float = 3e-4 # Learning rate
global_batch: int = 16 # Total batch size
shard_activations: bool = True # Enable activation sharding
Mesh configuration auto-adjusts to available devices!
π§ Hardware Requirements
Recommended
- TPU v3-8 or higher (8+ cores)
- GPU: 4x A100 40GB or similar
- RAM: Depends on model size and sharding
Scaling
- Adjust
MESH_SHAPEfor your hardware - Works on single GPU (limited capacity)
- Optimal on 8+ device clusters
π Performance Tips
- Increase batch size with more data parallelism
- Tune FFN multiplier (2x/4x/8x) for compute/memory tradeoff
- Enable activation sharding for large models
- Monitor per-device memory to balance load
- Use BF16 for speed (FP32 for debugging only)
π Troubleshooting
OOM Errors: Reduce batch size or increase model parallelism
Slow Training: Check device utilization and batch sharding
NaN Loss: Lower learning rate or enable gradient clipping
Uneven Memory: Adjust sharding strategy for your model size
π Technical Details
RoPE Implementation
Rotary Position Embeddings applied per attention head:
freqs = 1.0 / (theta ** (arange(0, head_dim, 2) / head_dim))
cos, sin = precompute_rope_freqs(head_dim, max_len)
q, k = apply_rotary_emb(q, k, cos, sin)
Sharding Rules
Regex-based parameter sharding:
'embed_tokens/embedding': P(None, 'model')
'(q_proj|k_proj|v_proj)/kernel': P(None, 'model')
'(o_proj|down_proj)/kernel': P('model', None)
π Educational Value
Perfect for learning:
- Distributed training with JAX
- Model parallelism strategies
- Activation sharding techniques
- Memory profiling on TPUs/GPUs
- Modern transformer architectures
Great project for students interested in AI and aviation-scale compute! βοΈ
π Acknowledgments
Built with:
- JAX - High-performance numerical computing
- Flax - Neural network library
- Optax - Gradient processing
- Safetensors - Secure model serialization
π License
Open source - feel free to modify and experiment for personal use!
Happy training and Happy New Year 2026! ππ€
- Downloads last month
- 4