Commit
·
455129a
1
Parent(s):
76e8ee6
Edited comments
Browse files- attention.py +2 -10
- phi2_model.py +3 -3
attention.py
CHANGED
|
@@ -19,9 +19,7 @@ except ImportError:
|
|
| 19 |
|
| 20 |
|
| 21 |
class RotaryEmbedding(nn.Module):
|
| 22 |
-
"""Rotary positional embedding (RoPE)
|
| 23 |
-
See https://www.youtube.com/watch?v=C6rV8BsrrCc
|
| 24 |
-
"""
|
| 25 |
|
| 26 |
def __init__(
|
| 27 |
self,
|
|
@@ -129,8 +127,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 129 |
|
| 130 |
|
| 131 |
class SelfAttention(nn.Module):
|
| 132 |
-
"""Self-attention layer, taken from Phi2 model."""
|
| 133 |
-
|
| 134 |
def __init__(
|
| 135 |
self,
|
| 136 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
|
@@ -174,8 +170,6 @@ class SelfAttention(nn.Module):
|
|
| 174 |
|
| 175 |
|
| 176 |
class CrossAttention(nn.Module):
|
| 177 |
-
"""Cross-attention layer, taken from Phi2 model."""
|
| 178 |
-
|
| 179 |
def __init__(
|
| 180 |
self,
|
| 181 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
|
@@ -225,8 +219,6 @@ class CrossAttention(nn.Module):
|
|
| 225 |
|
| 226 |
|
| 227 |
class MLP(nn.Module):
|
| 228 |
-
"""Taken from Phi2 as well."""
|
| 229 |
-
|
| 230 |
def __init__(
|
| 231 |
self,
|
| 232 |
d_embedding: int,
|
|
@@ -489,7 +481,7 @@ class MHA(nn.Module):
|
|
| 489 |
|
| 490 |
|
| 491 |
class ParallelAttentionBlock(nn.Module):
|
| 492 |
-
"""
|
| 493 |
|
| 494 |
def __init__(
|
| 495 |
self,
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class RotaryEmbedding(nn.Module):
|
| 22 |
+
"""Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc"""
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def __init__(
|
| 25 |
self,
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
class SelfAttention(nn.Module):
|
|
|
|
|
|
|
| 130 |
def __init__(
|
| 131 |
self,
|
| 132 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
class CrossAttention(nn.Module):
|
|
|
|
|
|
|
| 173 |
def __init__(
|
| 174 |
self,
|
| 175 |
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
class MLP(nn.Module):
|
|
|
|
|
|
|
| 222 |
def __init__(
|
| 223 |
self,
|
| 224 |
d_embedding: int,
|
|
|
|
| 481 |
|
| 482 |
|
| 483 |
class ParallelAttentionBlock(nn.Module):
|
| 484 |
+
"""Calculates attention and MLP in parallel."""
|
| 485 |
|
| 486 |
def __init__(
|
| 487 |
self,
|
phi2_model.py
CHANGED
|
@@ -37,7 +37,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
|
|
| 37 |
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
|
| 38 |
kv_cache: KVCache | None = None,
|
| 39 |
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
|
| 40 |
-
**kwargs,
|
| 41 |
) -> dict[str, Any]:
|
| 42 |
if not kv_cache:
|
| 43 |
kv_cache = KVCache(
|
|
@@ -61,7 +61,7 @@ class Phi2PreTrainedModel(PreTrainedModel):
|
|
| 61 |
|
| 62 |
|
| 63 |
class Embedding(nn.Module):
|
| 64 |
-
"""Token embedding with dropout
|
| 65 |
|
| 66 |
def __init__(
|
| 67 |
self,
|
|
@@ -150,7 +150,7 @@ class Phi2ModelForCausalLM(Phi2PreTrainedModel):
|
|
| 150 |
kv_cache: KVCache | None = None,
|
| 151 |
key_padding_mask: torch.BoolTensor | None = None,
|
| 152 |
labels: torch.LongTensor | None = None,
|
| 153 |
-
**kwargs,
|
| 154 |
) -> CausalLMOutputWithPast:
|
| 155 |
x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
| 156 |
x = self.lm_head_layer_norm(x)
|
|
|
|
| 37 |
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
|
| 38 |
kv_cache: KVCache | None = None,
|
| 39 |
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
|
| 40 |
+
**kwargs, # has to be here
|
| 41 |
) -> dict[str, Any]:
|
| 42 |
if not kv_cache:
|
| 43 |
kv_cache = KVCache(
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
class Embedding(nn.Module):
|
| 64 |
+
"""Token embedding with dropout."""
|
| 65 |
|
| 66 |
def __init__(
|
| 67 |
self,
|
|
|
|
| 150 |
kv_cache: KVCache | None = None,
|
| 151 |
key_padding_mask: torch.BoolTensor | None = None,
|
| 152 |
labels: torch.LongTensor | None = None,
|
| 153 |
+
**kwargs, # has to be here
|
| 154 |
) -> CausalLMOutputWithPast:
|
| 155 |
x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
|
| 156 |
x = self.lm_head_layer_norm(x)
|