oweller2
commited on
Commit
·
e1a243a
1
Parent(s):
d831694
same as training code
Browse files- modeling_flexbert.py +16 -38
- padding.py +1 -1
modeling_flexbert.py
CHANGED
|
@@ -1529,16 +1529,13 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1529 |
self.unpad_embeddings = config.unpad_embeddings
|
| 1530 |
self.pad_logits = config.pad_logits
|
| 1531 |
self.compile_model = config.compile_model
|
| 1532 |
-
self.vocab_size = config.vocab_size
|
| 1533 |
# self.masked_prediction = config.masked_prediction
|
| 1534 |
|
| 1535 |
# Initialize weights and apply final processing
|
| 1536 |
self._init_weights(reset_params=False)
|
| 1537 |
|
| 1538 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
| 1539 |
-
# Handle the XOR condition
|
| 1540 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
| 1541 |
-
|
| 1542 |
if module is not None:
|
| 1543 |
# Add basic initialization for common module types
|
| 1544 |
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
@@ -1552,7 +1549,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1552 |
assert isinstance(reset_params, bool)
|
| 1553 |
self.bert._init_weights(reset_params=reset_params)
|
| 1554 |
self.lm_head._init_weights(reset_params=reset_params)
|
| 1555 |
-
|
| 1556 |
if not self.config.tie_word_embeddings:
|
| 1557 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
| 1558 |
|
|
@@ -1640,27 +1637,22 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1640 |
#
|
| 1641 |
# Prediction scores are only computed for masked tokens and the (bs,
|
| 1642 |
# seqlen) dimensions are flattened
|
|
|
|
| 1643 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1644 |
-
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
| 1645 |
-
batch_size, seq_len = input_ids.shape[:2]
|
| 1646 |
-
if attention_mask is None:
|
| 1647 |
-
# unpad expects a encoder-like mask where all non-padding are ones
|
| 1648 |
-
attention_mask = torch.ones_like(input_ids)
|
| 1649 |
-
attention_mask[input_ids == 50283] = 0 # zero out pad tokens
|
| 1650 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
| 1651 |
input_ids, attention_mask, position_ids, labels
|
| 1652 |
)
|
| 1653 |
|
| 1654 |
-
|
| 1655 |
hidden_states = self.bert(
|
| 1656 |
input_ids,
|
| 1657 |
-
attention_mask=None, # let FA
|
| 1658 |
position_ids=position_ids,
|
| 1659 |
indices=indices,
|
| 1660 |
cu_seqlens=cu_seqlens,
|
| 1661 |
max_seqlen=max_seqlen,
|
| 1662 |
)
|
| 1663 |
-
# print(hidden_states.shape)
|
| 1664 |
|
| 1665 |
if self.compile_model:
|
| 1666 |
logits = self.compiled_lm_head(hidden_states)
|
|
@@ -1673,26 +1665,24 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1673 |
shift_labels = torch.full_like(input_ids, -100)
|
| 1674 |
shift_labels[:-1] = input_ids[1:]
|
| 1675 |
|
| 1676 |
-
# Mask boundaries
|
| 1677 |
for i in range(len(cu_seqlens) - 1):
|
| 1678 |
boundary_pos = cu_seqlens[i+1] - 1
|
| 1679 |
shift_labels[boundary_pos] = -100
|
| 1680 |
-
|
| 1681 |
-
# Mask out PAD tokens
|
| 1682 |
-
mask = (shift_labels == 50283)
|
| 1683 |
-
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
| 1684 |
-
|
| 1685 |
|
| 1686 |
-
|
| 1687 |
-
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
# breakpoint() # pkill -u oweller2 -f wandb
|
| 1691 |
|
| 1692 |
else:
|
| 1693 |
# Padded case: simple shift
|
| 1694 |
shift_labels = input_ids[..., 1:].contiguous()
|
| 1695 |
logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1696 |
|
| 1697 |
# For both cases, we'll use the shifted input_ids as our labels
|
| 1698 |
labels = shift_labels
|
|
@@ -1703,26 +1693,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1703 |
shift_labels.view(-1)
|
| 1704 |
)
|
| 1705 |
|
| 1706 |
-
if self.unpad_embeddings: # revert back to normal logits
|
| 1707 |
-
logits = logits.view(batch_size, -1, self.vocab_size)
|
| 1708 |
-
|
| 1709 |
if self.pad_logits:
|
| 1710 |
-
# print(f"Padding logits: {logits.shape}")
|
| 1711 |
-
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len-1)[0]
|
| 1712 |
-
# print(f"New logits: {new_logits.shape}")
|
| 1713 |
-
# print(new_logits.shape)
|
| 1714 |
-
# if new_logits.dim() == 2:
|
| 1715 |
-
# new_logits = new_logits.unsqueeze(0)
|
| 1716 |
return CausalLMOutput(
|
| 1717 |
loss=loss,
|
| 1718 |
-
logits=
|
| 1719 |
hidden_states=None,
|
| 1720 |
attentions=None,
|
| 1721 |
)
|
| 1722 |
else:
|
| 1723 |
-
# print(f"Non-padding logits: {logits.shape}")
|
| 1724 |
-
# if logits.dim() == 2:
|
| 1725 |
-
# logits = logits.unsqueeze(0)
|
| 1726 |
return CausalLMOutput(
|
| 1727 |
loss=loss,
|
| 1728 |
logits=logits,
|
|
@@ -1947,4 +1925,4 @@ def init_mlm_model_from_pretrained(
|
|
| 1947 |
pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
|
| 1948 |
)
|
| 1949 |
else:
|
| 1950 |
-
tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
|
|
|
|
| 1529 |
self.unpad_embeddings = config.unpad_embeddings
|
| 1530 |
self.pad_logits = config.pad_logits
|
| 1531 |
self.compile_model = config.compile_model
|
|
|
|
| 1532 |
# self.masked_prediction = config.masked_prediction
|
| 1533 |
|
| 1534 |
# Initialize weights and apply final processing
|
| 1535 |
self._init_weights(reset_params=False)
|
| 1536 |
|
| 1537 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
|
|
|
| 1538 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
|
|
|
| 1539 |
if module is not None:
|
| 1540 |
# Add basic initialization for common module types
|
| 1541 |
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
|
| 1549 |
assert isinstance(reset_params, bool)
|
| 1550 |
self.bert._init_weights(reset_params=reset_params)
|
| 1551 |
self.lm_head._init_weights(reset_params=reset_params)
|
| 1552 |
+
|
| 1553 |
if not self.config.tie_word_embeddings:
|
| 1554 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
| 1555 |
|
|
|
|
| 1637 |
#
|
| 1638 |
# Prediction scores are only computed for masked tokens and the (bs,
|
| 1639 |
# seqlen) dimensions are flattened
|
| 1640 |
+
|
| 1641 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1642 |
+
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
| 1643 |
+
batch_size, seq_len = input_ids.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1644 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
| 1645 |
input_ids, attention_mask, position_ids, labels
|
| 1646 |
)
|
| 1647 |
|
|
|
|
| 1648 |
hidden_states = self.bert(
|
| 1649 |
input_ids,
|
| 1650 |
+
attention_mask=None, # let FA do this
|
| 1651 |
position_ids=position_ids,
|
| 1652 |
indices=indices,
|
| 1653 |
cu_seqlens=cu_seqlens,
|
| 1654 |
max_seqlen=max_seqlen,
|
| 1655 |
)
|
|
|
|
| 1656 |
|
| 1657 |
if self.compile_model:
|
| 1658 |
logits = self.compiled_lm_head(hidden_states)
|
|
|
|
| 1665 |
shift_labels = torch.full_like(input_ids, -100)
|
| 1666 |
shift_labels[:-1] = input_ids[1:]
|
| 1667 |
|
| 1668 |
+
# Mask boundaries, so eos doesn't predict bos
|
| 1669 |
for i in range(len(cu_seqlens) - 1):
|
| 1670 |
boundary_pos = cu_seqlens[i+1] - 1
|
| 1671 |
shift_labels[boundary_pos] = -100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1672 |
|
| 1673 |
+
# NOTE: no padding or mask in there for now
|
| 1674 |
+
assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
|
| 1675 |
+
assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
|
| 1676 |
+
assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
|
|
|
|
| 1677 |
|
| 1678 |
else:
|
| 1679 |
# Padded case: simple shift
|
| 1680 |
shift_labels = input_ids[..., 1:].contiguous()
|
| 1681 |
logits = logits[..., :-1, :].contiguous()
|
| 1682 |
+
# mask out PAD tokens in the shift_labels
|
| 1683 |
+
mask = (shift_labels == 50283)
|
| 1684 |
+
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
| 1685 |
+
assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
|
| 1686 |
|
| 1687 |
# For both cases, we'll use the shifted input_ids as our labels
|
| 1688 |
labels = shift_labels
|
|
|
|
| 1693 |
shift_labels.view(-1)
|
| 1694 |
)
|
| 1695 |
|
|
|
|
|
|
|
|
|
|
| 1696 |
if self.pad_logits:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1697 |
return CausalLMOutput(
|
| 1698 |
loss=loss,
|
| 1699 |
+
logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
|
| 1700 |
hidden_states=None,
|
| 1701 |
attentions=None,
|
| 1702 |
)
|
| 1703 |
else:
|
|
|
|
|
|
|
|
|
|
| 1704 |
return CausalLMOutput(
|
| 1705 |
loss=loss,
|
| 1706 |
logits=logits,
|
|
|
|
| 1925 |
pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
|
| 1926 |
)
|
| 1927 |
else:
|
| 1928 |
+
tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
|
padding.py
CHANGED
|
@@ -84,4 +84,4 @@ def pad_input(
|
|
| 84 |
padded_labels[indices] = labels
|
| 85 |
padded_labels = padded_labels.view(batch, seqlen)
|
| 86 |
|
| 87 |
-
return padded_inputs, padded_labels
|
|
|
|
| 84 |
padded_labels[indices] = labels
|
| 85 |
padded_labels = padded_labels.view(batch, seqlen)
|
| 86 |
|
| 87 |
+
return padded_inputs, padded_labels
|