oweller2
commited on
Commit
·
f66abc1
1
Parent(s):
6d1817e
udpdate
Browse files- attention.py +1 -1
- modeling_flexbert.py +19 -11
attention.py
CHANGED
|
@@ -863,7 +863,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
| 863 |
qkv = self.Wqkv(hidden_states)
|
| 864 |
|
| 865 |
# only needed for inference when we have KV cache
|
| 866 |
-
seqlen_offset = 0
|
| 867 |
|
| 868 |
# (total_seqlen, 3, nheads, headdim)
|
| 869 |
qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
|
|
|
|
| 863 |
qkv = self.Wqkv(hidden_states)
|
| 864 |
|
| 865 |
# only needed for inference when we have KV cache
|
| 866 |
+
seqlen_offset = max_seqlen * (len(cu_seqlens) - 2) if len(cu_seqlens) > 1 else 0
|
| 867 |
|
| 868 |
# (total_seqlen, 3, nheads, headdim)
|
| 869 |
qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
|
modeling_flexbert.py
CHANGED
|
@@ -1715,20 +1715,28 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1715 |
def prepare_inputs_for_generation(
|
| 1716 |
self,
|
| 1717 |
input_ids: torch.Tensor,
|
| 1718 |
-
past_key_values: Optional[torch.FloatTensor] = None,
|
| 1719 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 1720 |
**kwargs
|
| 1721 |
) -> dict:
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
|
| 1730 |
-
|
| 1731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1732 |
|
| 1733 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
| 1734 |
"""Returns the number of parameters in the model.
|
|
|
|
| 1715 |
def prepare_inputs_for_generation(
|
| 1716 |
self,
|
| 1717 |
input_ids: torch.Tensor,
|
|
|
|
| 1718 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1719 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1720 |
**kwargs
|
| 1721 |
) -> dict:
|
| 1722 |
+
if attention_mask is None:
|
| 1723 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1724 |
+
|
| 1725 |
+
batch_size, seq_len = input_ids.shape[:2]
|
| 1726 |
+
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
| 1727 |
+
input_ids, attention_mask, position_ids, None
|
| 1728 |
+
)
|
| 1729 |
+
breakpoint()
|
| 1730 |
+
return {
|
| 1731 |
+
"input_ids": input_ids,
|
| 1732 |
+
"attention_mask": attention_mask,
|
| 1733 |
+
"position_ids": position_ids,
|
| 1734 |
+
"indices": indices,
|
| 1735 |
+
"cu_seqlens": cu_seqlens,
|
| 1736 |
+
"max_seqlen": max_seqlen,
|
| 1737 |
+
"batch_size": batch_size,
|
| 1738 |
+
"seq_len": seq_len
|
| 1739 |
+
}
|
| 1740 |
|
| 1741 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
| 1742 |
"""Returns the number of parameters in the model.
|