🧠 Standard Transformer - Fully Sharded Edition

Happy New Year 2026! πŸŽ‰πŸŽŠ

A high-performance GPT-style transformer implementation with full model and activation sharding for TPU/GPU clusters. Built with JAX and Flax for maximum efficiency on distributed hardware.


✨ Features

Core Architecture

  • Standard Transformer Decoder - GPT-style autoregressive language model
  • 70 Layers - Deep architecture with 1024 hidden dimensions
  • Multi-Head Attention - 16 attention heads with RoPE positional embeddings
  • SwiGLU FFN - Modern feed-forward networks with gated activations
  • RMSNorm - Efficient layer normalization for stability

Distributed Training

  • Full Model Sharding - Tensor parallelism across devices
  • Activation Sharding - Memory-efficient forward/backward passes
  • Mesh Parallelism - Flexible data + model parallel configurations
  • Per-Device RAM Monitoring - Real-time memory tracking for each chip

Production Ready

  • Safetensors Format - Fast and safe model serialization
  • Memory Profiling - Detailed RAM usage analytics and history
  • ETA Tracking - Accurate training time estimates
  • Test Generation - Visual verification during training

πŸš€ Quick Start

Requirements

pip install jax flax optax safetensors tokenizers pandas numpy

Training

python transformer_sharded.py

The script automatically:

  • Detects available TPU/GPU devices
  • Creates optimal sharding mesh
  • Loads tokenized data from cache
  • Trains with memory monitoring
  • Saves checkpoints + final model

πŸ“Š Model Specifications

Parameter Value
Vocabulary Size 50,262 tokens (GPT-2)
Hidden Dimension 1,024
Attention Heads 16 (64 dim each)
Layers 70
FFN Dimension 4,096 (4x multiplier)
Max Sequence Length 1,024 tokens
Total Parameters ~700M+
Precision BFloat16

πŸ”€ Sharding Strategy

Model Parallelism

  • Embeddings: Sharded across model dimension
  • Attention QKV: Column-parallel projections
  • Attention Output: Row-parallel projection
  • FFN Gate/Up: Column-parallel
  • FFN Down: Row-parallel
  • LM Head: Row-parallel (tied weights optional)

Activation Sharding

Embeddings:   (B, L, D) β†’ P('data', None, 'model')
Attention:    (B, H, L, D_head) β†’ P('data', 'model', None, None)
FFN Hidden:   (B, L, 4*D) β†’ P('data', None, 'model')

Memory Efficiency

  • Activation checkpointing via sharding constraints
  • Per-layer memory profiling
  • Automatic batch sharding across data parallel dimension

πŸ’Ύ Memory Monitoring

Real-time tracking includes:

  • Total RAM: Across all devices
  • Per-Chip Usage: Individual device breakdowns
  • Peak Memory: Maximum allocation reached
  • Utilization %: Current memory pressure
  • History Export: CSV logs for analysis

Example output:

πŸ’Ύ RAM: 45.23/96.00GB (47.1%) | Peak: 52.34GB
Per-chip: [5.65GB, 5.67GB, 5.64GB, 5.66GB, ...]

🎯 Training Features

Optimizer

  • AdamW with weight decay (0.1)
  • Cosine LR Schedule with warmup (200 steps)
  • Gradient Clipping (max norm: 1.0)
  • Label Smoothing (optional)

Loss Functions

  • Cross-Entropy with integer labels
  • Z-Loss for training stability (1e-4 weight)
  • Accuracy Tracking per batch
  • Perplexity Metrics

Data Pipeline

  • Auto-splits train/val (90/10)
  • Batch creation with shuffling
  • Automatic padding/truncation
  • Efficient numpy arrays

πŸ“ Output Structure

/kaggle/working/transformer-sharded/
β”œβ”€β”€ checkpoints/
β”‚   β”œβ”€β”€ model_ep1.safetensors
β”‚   β”œβ”€β”€ model_ep2.safetensors
β”‚   └── ...
β”œβ”€β”€ model_final.safetensors
β”œβ”€β”€ config.json
β”œβ”€β”€ history.csv
└── memory_history.csv

Files

  • model_final.safetensors: Final trained weights
  • config.json: Model architecture + training stats
  • history.csv: Loss, accuracy, memory per epoch
  • memory_history.csv: Detailed per-step RAM usage

πŸ§ͺ Test Generation

During training, the model generates text from prompts:

> Hello, how are you today? I'm doing great, thanks for asking...
> The meaning of life is to find purpose and happiness in...
> Once upon a time there was a brave knight who...

Visual verification ensures the model is learning properly!


βš™οΈ Configuration

Key hyperparameters in TransformerConfig:

d_model: int = 1024          # Hidden size
n_heads: int = 16            # Attention heads
n_layers: int = 70           # Transformer blocks
max_len: int = 1024          # Context window
lr: float = 3e-4             # Learning rate
global_batch: int = 16       # Total batch size
shard_activations: bool = True  # Enable activation sharding

Mesh configuration auto-adjusts to available devices!


πŸ”§ Hardware Requirements

Recommended

  • TPU v3-8 or higher (8+ cores)
  • GPU: 4x A100 40GB or similar
  • RAM: Depends on model size and sharding

Scaling

  • Adjust MESH_SHAPE for your hardware
  • Works on single GPU (limited capacity)
  • Optimal on 8+ device clusters

πŸ“ˆ Performance Tips

  1. Increase batch size with more data parallelism
  2. Tune FFN multiplier (2x/4x/8x) for compute/memory tradeoff
  3. Enable activation sharding for large models
  4. Monitor per-device memory to balance load
  5. Use BF16 for speed (FP32 for debugging only)

πŸ› Troubleshooting

OOM Errors: Reduce batch size or increase model parallelism

Slow Training: Check device utilization and batch sharding

NaN Loss: Lower learning rate or enable gradient clipping

Uneven Memory: Adjust sharding strategy for your model size


πŸ“š Technical Details

RoPE Implementation

Rotary Position Embeddings applied per attention head:

freqs = 1.0 / (theta ** (arange(0, head_dim, 2) / head_dim))
cos, sin = precompute_rope_freqs(head_dim, max_len)
q, k = apply_rotary_emb(q, k, cos, sin)

Sharding Rules

Regex-based parameter sharding:

'embed_tokens/embedding': P(None, 'model')
'(q_proj|k_proj|v_proj)/kernel': P(None, 'model')
'(o_proj|down_proj)/kernel': P('model', None)

πŸŽ“ Educational Value

Perfect for learning:

  • Distributed training with JAX
  • Model parallelism strategies
  • Activation sharding techniques
  • Memory profiling on TPUs/GPUs
  • Modern transformer architectures

Great project for students interested in AI and aviation-scale compute! ✈️


πŸ™ Acknowledgments

Built with:

  • JAX - High-performance numerical computing
  • Flax - Neural network library
  • Optax - Gradient processing
  • Safetensors - Secure model serialization

πŸ“ License

Open source - feel free to modify and experiment for personal use!

Happy training and Happy New Year 2026! πŸŽ†πŸ€–


Downloads last month
4
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ 1 Ask for provider support

Space using Smilyai-labs/Nova-1-large 1