Update README.md
Browse files
README.md
CHANGED
|
@@ -10,68 +10,86 @@ tags:
|
|
| 10 |
- Sam-2
|
| 11 |
- text-generation
|
| 12 |
---
|
| 13 |
-
|
| 14 |
# 🧠 Model Card: Sam‑2.0
|
| 15 |
|
| 16 |
## 📌 Model Overview
|
| 17 |
-
**Sam‑2.0** is a modular,
|
|
|
|
| 18 |
|
| 19 |
-
- **Architecture**: Transformer
|
| 20 |
-
- **Training Objective**: Causal language modeling (CLM)
|
| 21 |
- **Checkpoint**: `sam2-epoch35.safetensors`
|
| 22 |
- **Final Train Loss**: 1.04
|
| 23 |
- **Validation Loss**: Not tracked in this run
|
| 24 |
-
- **Training Duration**: ~
|
| 25 |
-
- **Framework**: PyTorch + Hugging Face Transformers (custom
|
| 26 |
|
| 27 |
## 🧱 Model Architecture
|
| 28 |
| Component | Description |
|
| 29 |
-
|
| 30 |
-
| Backbone
|
| 31 |
-
|
|
| 32 |
-
|
|
| 33 |
-
|
|
| 34 |
-
|
|
| 35 |
-
|
|
|
|
|
| 36 |
|
| 37 |
## 🧪 Training Details
|
| 38 |
-
- **Dataset**:
|
| 39 |
-
- **Batch Size**:
|
| 40 |
- **Optimizer**: AdamW
|
| 41 |
-
- **Learning Rate
|
| 42 |
-
- **Loss Function**: Cross
|
| 43 |
-
- **Hardware**: Kaggle
|
| 44 |
-
- **Logging**: Step
|
| 45 |
|
| 46 |
## 📊 Evaluation
|
| 47 |
-
| Metric
|
| 48 |
-
|
| 49 |
-
| Final Train Loss | 1.04
|
| 50 |
-
| Validation Loss | —
|
| 51 |
-
| Inference Speed | Fast
|
| 52 |
-
| Generalisation | TBD
|
| 53 |
|
| 54 |
## 🔧 Intended Use
|
| 55 |
- **Research**: Benchmarking modular architectures and ablation studies
|
| 56 |
- **Education**: Reasoning scaffolds and logic quizzes
|
| 57 |
-
- **Deployment**: Lightweight agents for chat and
|
| 58 |
|
| 59 |
## 🚫 Limitations
|
| 60 |
- No validation tracking — generalisation must be inferred via external harnesses
|
| 61 |
-
- Trained on
|
| 62 |
-
-
|
| 63 |
|
| 64 |
## 📁 Files
|
| 65 |
- `sam2-epoch35.safetensors` — final checkpoint
|
| 66 |
-
- `config.
|
| 67 |
-
- `tokenizer.json` —
|
| 68 |
- `README.md` — training logs and setup instructions
|
| 69 |
|
| 70 |
## 🧩 How to Load
|
| 71 |
```python
|
| 72 |
-
from
|
| 73 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
model.load_state_dict(torch.load("sam2-epoch35.safetensors"))
|
| 77 |
-
model.eval()
|
|
|
|
| 10 |
- Sam-2
|
| 11 |
- text-generation
|
| 12 |
---
|
| 13 |
+
|
| 14 |
# 🧠 Model Card: Sam‑2.0
|
| 15 |
|
| 16 |
## 📌 Model Overview
|
| 17 |
+
**Sam‑2.0** is a minimal, modular, decoder‑only Transformer architecture designed for chat‑style reasoning tasks.
|
| 18 |
+
It emphasizes reproducibility, ablation‑friendly design, and clean benchmarking across input modalities.
|
| 19 |
|
| 20 |
+
- **Architecture**: Decoder‑only Transformer with RMSNorm, SwiGLU feed‑forward, and causal masking
|
| 21 |
+
- **Training Objective**: Causal language modeling (CLM) with role‑based label masking
|
| 22 |
- **Checkpoint**: `sam2-epoch35.safetensors`
|
| 23 |
- **Final Train Loss**: 1.04
|
| 24 |
- **Validation Loss**: Not tracked in this run
|
| 25 |
+
- **Training Duration**: ~6272 s over 35 epochs
|
| 26 |
+
- **Framework**: PyTorch + Hugging Face Transformers (custom model class)
|
| 27 |
|
| 28 |
## 🧱 Model Architecture
|
| 29 |
| Component | Description |
|
| 30 |
+
|-------------------|-----------------------------------------------------------------------------|
|
| 31 |
+
| Backbone | Decoder‑only Transformer stack |
|
| 32 |
+
| Normalization | RMSNorm |
|
| 33 |
+
| Attention | Multi‑head self‑attention (causal) |
|
| 34 |
+
| Feed‑Forward | SwiGLU activation with dropout |
|
| 35 |
+
| Positional Bias | Learned absolute positions (no RoPE in this minimal variant) |
|
| 36 |
+
| Head | Tied‑embedding LM head |
|
| 37 |
+
| Checkpoint Format | `safetensors` with metadata for reproducibility |
|
| 38 |
|
| 39 |
## 🧪 Training Details
|
| 40 |
+
- **Dataset**: [pfb30/multi_woz_v22](https://huggingface.co/datasets/pfb30/multi_woz_v22)
|
| 41 |
+
- **Batch Size**: 8
|
| 42 |
- **Optimizer**: AdamW
|
| 43 |
+
- **Learning Rate**: 2 × 10⁻⁴ (constant in this run)
|
| 44 |
+
- **Loss Function**: Cross‑entropy over assistant tokens only
|
| 45 |
+
- **Hardware**: Kaggle GPU runtime
|
| 46 |
+
- **Logging**: Step‑wise loss tracking, no validation during training
|
| 47 |
|
| 48 |
## 📊 Evaluation
|
| 49 |
+
| Metric | Value | Notes |
|
| 50 |
+
|------------------|-------------|---------------------------------------|
|
| 51 |
+
| Final Train Loss | 1.04 | Achieved at Epoch 35/35 |
|
| 52 |
+
| Validation Loss | — | Not tracked in this run |
|
| 53 |
+
| Inference Speed | Fast | Lightweight architecture |
|
| 54 |
+
| Generalisation | TBD | To be compared against Sam‑2.5 |
|
| 55 |
|
| 56 |
## 🔧 Intended Use
|
| 57 |
- **Research**: Benchmarking modular architectures and ablation studies
|
| 58 |
- **Education**: Reasoning scaffolds and logic quizzes
|
| 59 |
+
- **Deployment**: Lightweight agents for chat and dialogue modeling
|
| 60 |
|
| 61 |
## 🚫 Limitations
|
| 62 |
- No validation tracking — generalisation must be inferred via external harnesses
|
| 63 |
+
- Trained on MultiWOZ v2.2 only — may not generalize to other domains without fine‑tuning
|
| 64 |
+
- Minimal architecture — no RoPE/MQA in this variant
|
| 65 |
|
| 66 |
## 📁 Files
|
| 67 |
- `sam2-epoch35.safetensors` — final checkpoint
|
| 68 |
+
- `config.json` — architecture and training config
|
| 69 |
+
- `tokenizer.json` — tokenizer with special tokens
|
| 70 |
- `README.md` — training logs and setup instructions
|
| 71 |
|
| 72 |
## 🧩 How to Load
|
| 73 |
```python
|
| 74 |
+
from transformers import AutoTokenizer
|
| 75 |
import torch
|
| 76 |
+
from sam2 import Sam2, Sam2Config # your custom model class
|
| 77 |
+
|
| 78 |
+
tok = AutoTokenizer.from_pretrained("Smilyai-labs/Sam-2.0")
|
| 79 |
+
cfg = Sam2Config(**json.load(open("config.json")))
|
| 80 |
+
model = Sam2(cfg)
|
| 81 |
+
state = torch.load("sam2-epoch35.safetensors", map_location="cpu")
|
| 82 |
+
model.load_state_dict(state)
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
prompt = "<|user|> Hello! <|eot|>\n<|assistant|>"
|
| 86 |
+
ids = tok.encode(prompt, return_tensors="pt")
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for _ in range(50):
|
| 89 |
+
logits = model(ids)
|
| 90 |
+
next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
| 91 |
+
ids = torch.cat([ids, next_id], dim=1)
|
| 92 |
+
if next_id.item() == tok.eos_token_id:
|
| 93 |
+
break
|
| 94 |
|
| 95 |
+
print(tok.decode(ids[0]))
|
|
|
|
|
|