Enhanced ECG-Mamba2: Bidirectional State Space Model for ECG Classification
This repository contains an enhanced version of the ECG-Mamba model for 12-lead ECG arrhythmia classification. The Enhanced ECG-Mamba2 builds upon the original models_mamba_ecg.py implementation with significant architectural improvements.
Overview
Enhanced ECG-Mamba2 is a deep learning model for ECG classification that combines:
- Original CNN feature extraction from
models_mamba_ecg.py - Mamba-2 (State Space Duality) - 2-8x faster than the original Mamba/VisionMamba
- Bidirectional scanning for better temporal context
- Multi-branch architecture for lead-specific processing
- Transformer attention for capturing short-term anomalies
Key Improvements over Original Implementation
| Feature | Original (models_mamba_ecg.py) |
Enhanced (Enhanced_ECG_Mamba_Test.ipynb) |
|---|---|---|
| State Space Model | VisionMamba (Mamba-1 based) | Mamba-2 (State Space Duality) |
| Scanning Direction | Unidirectional | Bidirectional (Forward + Backward) |
| Lead Processing | Single pathway | Multi-branch (4 lead groups) |
| Attention | None | Transformer attention layer |
| Training | Standard | Adversarial + Frequency Masking |
| Explainability | None | MambaLRP |
Architecture
1. CNN Feature Extraction (Original)
The CNN layers from the original implementation are preserved:
Input: (batch, 12, 8192) -> Conv1d layers -> Output: (batch, 729, 384)
2. Multi-Branch Lead Encoder (New)
Four specialized branches process different ECG lead groups:
- Limb leads (I, II, III): Standard cardiac views
- Augmented leads (aVR, aVL, aVF): Enhanced limb perspectives
- Precordial anterior (V1-V3): Septal/anterior views
- Precordial lateral (V4-V6): Lateral views
3. Bidirectional Mamba-2 (New)
- Forward Mamba-2 processes the sequence left-to-right
- Backward Mamba-2 processes the sequence right-to-left
- Outputs are fused for comprehensive temporal understanding
4. Transformer Attention (New)
Multi-head self-attention layer captures short-term dependencies that complement Mamba-2's long-range modeling.
5. Classification Head
Global average pooling followed by a linear classifier.
Training Features
Adversarial Training
FGSM-style perturbations are applied during training to improve model robustness.
Frequency Masking Augmentation
Random frequency bands are masked in the FFT domain to make the model robust to noise and artifacts.
Explainability: MambaLRP
MambaLRP (Layer-wise Relevance Propagation) provides interpretability by highlighting which parts of the ECG signal contribute most to the model's predictions.
Model Parameters
- Total Parameters: ~29.3M
- Embedding Dimension: 384
- Number of Mamba-2 Layers: 4
- Number of Attention Heads: 4
Dataset
The model is designed for the PhysioNet Challenge 2021 dataset with 5 arrhythmia classes:
- SR (Sinus Rhythm)
- AF (Atrial Fibrillation)
- AFL (Atrial Flutter)
- PAC (Premature Atrial Contraction)
- PVC (Premature Ventricular Contraction)
Requirements
torch>=2.0
mamba-ssm>=2.0
causal-conv1d
einops
wfdb
numpy
scikit-learn
matplotlib
Files
models_mamba_ecg.py- Original VisionMamba implementation for ECGEnhanced_ECG_Mamba_Test.ipynb- Enhanced ECG-Mamba2 implementation with all improvementsREADME.md- This fileLICENSE- MIT License
Usage
from enhanced_ecg_mamba2 import EnhancedECGMamba2
# Create model
model = EnhancedECGMamba2(
n_classes=5,
embed_dim=384,
n_layers=4,
use_multi_branch=True,
use_attention=True
)
# Forward pass
# Input: (batch, seq_len=8192, channels=12)
output = model(x)
Citation
If you use this code, please cite:
@software{enhanced_ecg_mamba2,
title={Enhanced ECG-Mamba2: Bidirectional State Space Model for ECG Classification},
year={2024},
note={Improvements over models_mamba_ecg.py with Mamba-2, bidirectional scanning, multi-branch architecture, and attention}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Acknowledgments
- Original VisionMamba implementation
- Mamba and Mamba-2 from state-spaces/mamba
- PhysioNet Challenge 2021 for the ECG dataset