return past hidden states when `output_hidden_states` provided
#59
by
noahtren
- opened
- modeling_phi.py +2 -1
modeling_phi.py
CHANGED
|
@@ -947,6 +947,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
| 947 |
input_ids: torch.LongTensor,
|
| 948 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 949 |
attention_mask: Optional[torch.BoolTensor] = None,
|
|
|
|
| 950 |
labels: Optional[torch.LongTensor] = None,
|
| 951 |
**kwargs,
|
| 952 |
) -> CausalLMOutputWithPast:
|
|
@@ -957,4 +958,4 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
| 957 |
if labels is not None:
|
| 958 |
loss = self.loss(lm_logits, labels)
|
| 959 |
|
| 960 |
-
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
|
|
|
| 947 |
input_ids: torch.LongTensor,
|
| 948 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 949 |
attention_mask: Optional[torch.BoolTensor] = None,
|
| 950 |
+
output_hidden_states: Optional[bool] = None,
|
| 951 |
labels: Optional[torch.LongTensor] = None,
|
| 952 |
**kwargs,
|
| 953 |
) -> CausalLMOutputWithPast:
|
|
|
|
| 958 |
if labels is not None:
|
| 959 |
loss = self.loss(lm_logits, labels)
|
| 960 |
|
| 961 |
+
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values, hidden_states=hidden_states if output_hidden_states else None)
|