crumb commited on
Commit
be78a97
·
1 Parent(s): 9c69f58

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "silu",
3
+ "architectures": [
4
+ "ShrinkModelForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_shrink.ShrinkConfig",
8
+ "AutoModelForCausalLM": "modeling_shrink.ShrinkModelForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "combined_qkv": true,
12
+ "eos_token_id": 2,
13
+ "hidden_size": 1024,
14
+ "hidden_size_0": 8192,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 4096,
17
+ "layer_dropout_prob": 0.0,
18
+ "layer_norm_epsilon": 1e-06,
19
+ "lm_head_bias": false,
20
+ "max_position_embeddings": 2048,
21
+ "model_type": "shrink",
22
+ "num_attention_heads": 16,
23
+ "num_hidden_layers": 14,
24
+ "projection_bias": true,
25
+ "qk_hidden_size": null,
26
+ "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.36.2",
28
+ "use_bias": false,
29
+ "use_cache": true,
30
+ "vocab_size": 32000
31
+ }
configuration_shrink.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Any, List, Mapping, Optional
3
+
4
+ from transformers import PreTrainedTokenizer, TensorType, is_torch_available
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class ShrinkConfig(PretrainedConfig):
12
+ model_type = "shrink"
13
+ keys_to_ignore_at_inference = ["past_key_values"]
14
+ attribute_map = {
15
+ "hidden_size": "hidden_size",
16
+ "max_position_embeddings": "max_position_embeddings",
17
+ "num_attention_heads": "num_attention_heads",
18
+ "num_hidden_layers": "num_hidden_layers",
19
+ }
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size=32000,
24
+ max_position_embeddings=2048,
25
+ hidden_size_0=8192,
26
+ hidden_size=768,
27
+ qk_hidden_size=None, # in case you want to use cross-attention
28
+ num_hidden_layers=10,
29
+ num_attention_heads=12,
30
+ intermediate_size=None,
31
+ activation_function="silu",
32
+ layer_norm_epsilon=1e-6,
33
+ initializer_range=0.02,
34
+ scale_attn_weights=True,
35
+ use_cache=True,
36
+ bos_token_id=1,
37
+ eos_token_id=2,
38
+ combined_qkv=True,
39
+ use_bias=False,
40
+ projection_bias=True,
41
+ lm_head_bias=False,
42
+ **kwargs,
43
+ ):
44
+ self.qk_hidden_size = qk_hidden_size
45
+ self.lm_head_bias = lm_head_bias
46
+ self.projection_bias = projection_bias
47
+ self.use_bias = use_bias
48
+ self.hidden_size_0 = hidden_size_0
49
+ self.combined_qkv = combined_qkv
50
+ self.vocab_size = vocab_size
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.hidden_size = hidden_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_attention_heads = num_attention_heads
55
+ self.intermediate_size = (
56
+ intermediate_size if intermediate_size is not None else hidden_size * 4
57
+ )
58
+ self.activation_function = activation_function
59
+ self.layer_norm_epsilon = layer_norm_epsilon
60
+ self.initializer_range = initializer_range
61
+
62
+ self.use_cache = use_cache
63
+
64
+ self.bos_token_id = bos_token_id
65
+ self.eos_token_id = eos_token_id
66
+
67
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.36.2"
6
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bebeb0c195aa68d37997cb841071b7c75a2e67a4fa0d755263d366c6a601f34d
3
+ size 1027759938
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c75679b86904686700edbe306217176fa25214dacc0a6a4413a10d751df32b0
3
+ size 524288128
model.safetensors.index.json ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 1552033794
4
+ },
5
+ "weight_map": {
6
+ "lm_head.0.bias": "model-00001-of-00002.safetensors",
7
+ "lm_head.0.weight": "model-00001-of-00002.safetensors",
8
+ "lm_head.1.weight": "model-00002-of-00002.safetensors",
9
+ "transformer.h.0.attn.out.weight": "model-00001-of-00002.safetensors",
10
+ "transformer.h.0.attn.qkv.weight": "model-00001-of-00002.safetensors",
11
+ "transformer.h.0.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
12
+ "transformer.h.0.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
13
+ "transformer.h.0.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
14
+ "transformer.h.0.ln1.bias": "model-00001-of-00002.safetensors",
15
+ "transformer.h.0.ln1.weight": "model-00001-of-00002.safetensors",
16
+ "transformer.h.0.ln2.bias": "model-00001-of-00002.safetensors",
17
+ "transformer.h.0.ln2.weight": "model-00001-of-00002.safetensors",
18
+ "transformer.h.1.attn.out.weight": "model-00001-of-00002.safetensors",
19
+ "transformer.h.1.attn.qkv.weight": "model-00001-of-00002.safetensors",
20
+ "transformer.h.1.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
21
+ "transformer.h.1.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
22
+ "transformer.h.1.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
23
+ "transformer.h.1.ln1.bias": "model-00001-of-00002.safetensors",
24
+ "transformer.h.1.ln1.weight": "model-00001-of-00002.safetensors",
25
+ "transformer.h.1.ln2.bias": "model-00001-of-00002.safetensors",
26
+ "transformer.h.1.ln2.weight": "model-00001-of-00002.safetensors",
27
+ "transformer.h.10.attn.out.weight": "model-00001-of-00002.safetensors",
28
+ "transformer.h.10.attn.qkv.weight": "model-00001-of-00002.safetensors",
29
+ "transformer.h.10.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
30
+ "transformer.h.10.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
31
+ "transformer.h.10.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
32
+ "transformer.h.10.ln1.bias": "model-00001-of-00002.safetensors",
33
+ "transformer.h.10.ln1.weight": "model-00001-of-00002.safetensors",
34
+ "transformer.h.10.ln2.bias": "model-00001-of-00002.safetensors",
35
+ "transformer.h.10.ln2.weight": "model-00001-of-00002.safetensors",
36
+ "transformer.h.11.attn.out.weight": "model-00001-of-00002.safetensors",
37
+ "transformer.h.11.attn.qkv.weight": "model-00001-of-00002.safetensors",
38
+ "transformer.h.11.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
39
+ "transformer.h.11.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
40
+ "transformer.h.11.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
41
+ "transformer.h.11.ln1.bias": "model-00001-of-00002.safetensors",
42
+ "transformer.h.11.ln1.weight": "model-00001-of-00002.safetensors",
43
+ "transformer.h.11.ln2.bias": "model-00001-of-00002.safetensors",
44
+ "transformer.h.11.ln2.weight": "model-00001-of-00002.safetensors",
45
+ "transformer.h.12.attn.out.weight": "model-00001-of-00002.safetensors",
46
+ "transformer.h.12.attn.qkv.weight": "model-00001-of-00002.safetensors",
47
+ "transformer.h.12.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
48
+ "transformer.h.12.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
49
+ "transformer.h.12.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
50
+ "transformer.h.12.ln1.bias": "model-00001-of-00002.safetensors",
51
+ "transformer.h.12.ln1.weight": "model-00001-of-00002.safetensors",
52
+ "transformer.h.12.ln2.bias": "model-00001-of-00002.safetensors",
53
+ "transformer.h.12.ln2.weight": "model-00001-of-00002.safetensors",
54
+ "transformer.h.13.attn.out.weight": "model-00001-of-00002.safetensors",
55
+ "transformer.h.13.attn.qkv.weight": "model-00001-of-00002.safetensors",
56
+ "transformer.h.13.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
57
+ "transformer.h.13.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
58
+ "transformer.h.13.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
59
+ "transformer.h.13.ln1.bias": "model-00001-of-00002.safetensors",
60
+ "transformer.h.13.ln1.weight": "model-00001-of-00002.safetensors",
61
+ "transformer.h.13.ln2.bias": "model-00001-of-00002.safetensors",
62
+ "transformer.h.13.ln2.weight": "model-00001-of-00002.safetensors",
63
+ "transformer.h.2.attn.out.weight": "model-00001-of-00002.safetensors",
64
+ "transformer.h.2.attn.qkv.weight": "model-00001-of-00002.safetensors",
65
+ "transformer.h.2.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
66
+ "transformer.h.2.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
67
+ "transformer.h.2.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
68
+ "transformer.h.2.ln1.bias": "model-00001-of-00002.safetensors",
69
+ "transformer.h.2.ln1.weight": "model-00001-of-00002.safetensors",
70
+ "transformer.h.2.ln2.bias": "model-00001-of-00002.safetensors",
71
+ "transformer.h.2.ln2.weight": "model-00001-of-00002.safetensors",
72
+ "transformer.h.3.attn.out.weight": "model-00001-of-00002.safetensors",
73
+ "transformer.h.3.attn.qkv.weight": "model-00001-of-00002.safetensors",
74
+ "transformer.h.3.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
75
+ "transformer.h.3.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
76
+ "transformer.h.3.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
77
+ "transformer.h.3.ln1.bias": "model-00001-of-00002.safetensors",
78
+ "transformer.h.3.ln1.weight": "model-00001-of-00002.safetensors",
79
+ "transformer.h.3.ln2.bias": "model-00001-of-00002.safetensors",
80
+ "transformer.h.3.ln2.weight": "model-00001-of-00002.safetensors",
81
+ "transformer.h.4.attn.out.weight": "model-00001-of-00002.safetensors",
82
+ "transformer.h.4.attn.qkv.weight": "model-00001-of-00002.safetensors",
83
+ "transformer.h.4.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
84
+ "transformer.h.4.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
85
+ "transformer.h.4.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
86
+ "transformer.h.4.ln1.bias": "model-00001-of-00002.safetensors",
87
+ "transformer.h.4.ln1.weight": "model-00001-of-00002.safetensors",
88
+ "transformer.h.4.ln2.bias": "model-00001-of-00002.safetensors",
89
+ "transformer.h.4.ln2.weight": "model-00001-of-00002.safetensors",
90
+ "transformer.h.5.attn.out.weight": "model-00001-of-00002.safetensors",
91
+ "transformer.h.5.attn.qkv.weight": "model-00001-of-00002.safetensors",
92
+ "transformer.h.5.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
93
+ "transformer.h.5.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
94
+ "transformer.h.5.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
95
+ "transformer.h.5.ln1.bias": "model-00001-of-00002.safetensors",
96
+ "transformer.h.5.ln1.weight": "model-00001-of-00002.safetensors",
97
+ "transformer.h.5.ln2.bias": "model-00001-of-00002.safetensors",
98
+ "transformer.h.5.ln2.weight": "model-00001-of-00002.safetensors",
99
+ "transformer.h.6.attn.out.weight": "model-00001-of-00002.safetensors",
100
+ "transformer.h.6.attn.qkv.weight": "model-00001-of-00002.safetensors",
101
+ "transformer.h.6.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
102
+ "transformer.h.6.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
103
+ "transformer.h.6.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
104
+ "transformer.h.6.ln1.bias": "model-00001-of-00002.safetensors",
105
+ "transformer.h.6.ln1.weight": "model-00001-of-00002.safetensors",
106
+ "transformer.h.6.ln2.bias": "model-00001-of-00002.safetensors",
107
+ "transformer.h.6.ln2.weight": "model-00001-of-00002.safetensors",
108
+ "transformer.h.7.attn.out.weight": "model-00001-of-00002.safetensors",
109
+ "transformer.h.7.attn.qkv.weight": "model-00001-of-00002.safetensors",
110
+ "transformer.h.7.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
111
+ "transformer.h.7.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
112
+ "transformer.h.7.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
113
+ "transformer.h.7.ln1.bias": "model-00001-of-00002.safetensors",
114
+ "transformer.h.7.ln1.weight": "model-00001-of-00002.safetensors",
115
+ "transformer.h.7.ln2.bias": "model-00001-of-00002.safetensors",
116
+ "transformer.h.7.ln2.weight": "model-00001-of-00002.safetensors",
117
+ "transformer.h.8.attn.out.weight": "model-00001-of-00002.safetensors",
118
+ "transformer.h.8.attn.qkv.weight": "model-00001-of-00002.safetensors",
119
+ "transformer.h.8.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
120
+ "transformer.h.8.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
121
+ "transformer.h.8.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
122
+ "transformer.h.8.ln1.bias": "model-00001-of-00002.safetensors",
123
+ "transformer.h.8.ln1.weight": "model-00001-of-00002.safetensors",
124
+ "transformer.h.8.ln2.bias": "model-00001-of-00002.safetensors",
125
+ "transformer.h.8.ln2.weight": "model-00001-of-00002.safetensors",
126
+ "transformer.h.9.attn.out.weight": "model-00001-of-00002.safetensors",
127
+ "transformer.h.9.attn.qkv.weight": "model-00001-of-00002.safetensors",
128
+ "transformer.h.9.ffn.down_proj.weight": "model-00001-of-00002.safetensors",
129
+ "transformer.h.9.ffn.gate_proj.weight": "model-00001-of-00002.safetensors",
130
+ "transformer.h.9.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
131
+ "transformer.h.9.ln1.bias": "model-00001-of-00002.safetensors",
132
+ "transformer.h.9.ln1.weight": "model-00001-of-00002.safetensors",
133
+ "transformer.h.9.ln2.bias": "model-00001-of-00002.safetensors",
134
+ "transformer.h.9.ln2.weight": "model-00001-of-00002.safetensors",
135
+ "transformer.ln_f.bias": "model-00001-of-00002.safetensors",
136
+ "transformer.ln_f.weight": "model-00001-of-00002.safetensors",
137
+ "transformer.wln.bias": "model-00001-of-00002.safetensors",
138
+ "transformer.wln.weight": "model-00001-of-00002.safetensors",
139
+ "transformer.wpe.scale_factor": "model-00001-of-00002.safetensors",
140
+ "transformer.wte.0.weight": "model-00001-of-00002.safetensors",
141
+ "transformer.wte.1.bias": "model-00001-of-00002.safetensors",
142
+ "transformer.wte.1.weight": "model-00001-of-00002.safetensors"
143
+ }
144
+ }
modeling_shrink.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import repeat
13
+ from torch import nn
14
+ from torch.cuda.amp import autocast
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPastAndCrossAttentions,
19
+ CausalLMOutputWithCrossAttentions, QuestionAnsweringModelOutput,
20
+ SequenceClassifierOutputWithPast, TokenClassifierOutput)
21
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
22
+ from transformers.utils import (ModelOutput, logging)
23
+ from transformers.utils.model_parallel_utils import (assert_device_map,
24
+ get_device_map)
25
+
26
+ from .configuration_shrink import ShrinkConfig
27
+
28
+
29
+ class SinusoidalPositional(torch.nn.Module):
30
+ def __init__(self, embedding_dim, max_seq_length=5000):
31
+ super().__init__()
32
+ pe = torch.zeros(max_seq_length, embedding_dim)
33
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
34
+ div_term = torch.exp(
35
+ torch.arange(0, embedding_dim, 2).float()
36
+ * (-math.log(10000.0) / embedding_dim)
37
+ )
38
+ pe[:, 0::2] = torch.sin(position * div_term)
39
+ pe[:, 1::2] = torch.cos(position * div_term)
40
+ pe = pe.unsqueeze(0)
41
+ self.register_buffer("pe", pe, persistent=False)
42
+
43
+ def forward(self, input_ids):
44
+ return self.pe[:, : input_ids.shape[1], :]
45
+
46
+
47
+ class ScaledSinusoidal(SinusoidalPositional):
48
+ def __init__(self, embedding_dim, max_seq_length):
49
+ super().__init__(embedding_dim, max_seq_length)
50
+ self.scale_factor = torch.nn.Parameter(
51
+ torch.tensor([1.0 / embedding_dim**0.5])
52
+ )
53
+
54
+ def forward(self, input_ids):
55
+ return self.scale_factor * self.pe[:, : input_ids.shape[1], :]
56
+
57
+
58
+ class ShrinkAttention(nn.Module):
59
+ def __init__(self, config):
60
+ super().__init__()
61
+ self.config = config
62
+ self.head_dim = config.hidden_size // config.num_attention_heads
63
+ assert (
64
+ self.head_dim * config.num_attention_heads == config.hidden_size
65
+ ), "d_model must be divisible by n_head"
66
+ self.use_bias = config.use_bias
67
+
68
+ if not config.combined_qkv or config.qk_hidden_size is not None:
69
+ self.query = nn.Linear(
70
+ config.hidden_size, config.hidden_size, bias=self.use_bias
71
+ )
72
+ self.key = nn.Linear(
73
+ config.hidden_size
74
+ if not config.qk_hidden_size
75
+ else config.qk_hidden_size,
76
+ config.hidden_size,
77
+ bias=self.use_bias,
78
+ )
79
+ self.value = nn.Linear(
80
+ config.hidden_size
81
+ if not config.qk_hidden_size
82
+ else config.qk_hidden_size,
83
+ config.hidden_size,
84
+ bias=self.use_bias,
85
+ )
86
+ else:
87
+ self.qkv = nn.Linear(
88
+ config.hidden_size, config.hidden_size * 3, bias=self.use_bias
89
+ )
90
+ self.out = nn.Linear(config.hidden_size, config.hidden_size, bias=self.use_bias)
91
+
92
+ def forward(self, x0, x1=None, causal=False, mask=None):
93
+ batch_size = x0.size(0)
94
+
95
+ def split_heads(x):
96
+ return x.view(
97
+ batch_size, -1, self.config.num_attention_heads, self.head_dim
98
+ ).transpose(1, 2)
99
+
100
+ if not self.config.combined_qkv:
101
+ q = split_heads(self.query(x0))
102
+ k = split_heads(self.key(x1) if x1 is not None else self.key(x0))
103
+ v = split_heads(self.value(x1 if x1 is not None else x0))
104
+ else:
105
+ q, k, v = self.qkv(x0).chunk(3,-1)
106
+ q = split_heads(q)
107
+ k = split_heads(k)
108
+ v = split_heads(v)
109
+
110
+ attn_output = F.scaled_dot_product_attention(
111
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal
112
+ )
113
+ attn_output = (
114
+ attn_output.transpose(1, 2)
115
+ .contiguous()
116
+ .view(batch_size, -1, self.config.hidden_size)
117
+ )
118
+ return self.out(attn_output)
119
+
120
+
121
+ class ShrinkGLU(nn.Module):
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ self.config = config
125
+ self.gate_proj = nn.Linear(
126
+ config.hidden_size, config.intermediate_size, bias=False
127
+ )
128
+ self.up_proj = nn.Linear(
129
+ config.hidden_size, config.intermediate_size, bias=False
130
+ )
131
+ self.down_proj = nn.Linear(
132
+ config.intermediate_size, config.hidden_size, bias=False
133
+ )
134
+ self.act_fn = ACT2FN[config.activation_function]
135
+
136
+ def forward(self, x):
137
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
138
+
139
+
140
+ class ShrinkBlock(nn.Module):
141
+ def __init__(self, config):
142
+ super().__init__()
143
+ self.config = config
144
+ self.attn = ShrinkAttention(config)
145
+ self.ffn = ShrinkGLU(config)
146
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
147
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
148
+
149
+ def forward(self, x, mask=None):
150
+ x = x + self.attn(self.ln1(x), causal=True, mask=mask)
151
+ x = x + self.ffn(self.ln2(x))
152
+ return x
153
+
154
+
155
+ class ShrinkPreTrainedModel(PreTrainedModel):
156
+ config_class = ShrinkConfig
157
+ base_model_prefix = "transformer"
158
+ is_parallelizable = False
159
+ supports_gradient_checkpointing = True
160
+ _no_split_modules = ["ShrinkBlock"]
161
+ _skip_keys_device_placement = "past_key_values"
162
+
163
+ def __init__(self, *inputs, **kwargs):
164
+ super().__init__(*inputs, **kwargs)
165
+
166
+ def _init_weights(self, module):
167
+ """Initialize the weights."""
168
+ if isinstance(module, (nn.Linear)):
169
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
170
+ if module.bias is not None:
171
+ module.bias.data.zero_()
172
+ elif isinstance(module, nn.Embedding):
173
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
174
+ if module.padding_idx is not None:
175
+ module.weight.data[module.padding_idx].zero_()
176
+ elif isinstance(module, nn.LayerNorm):
177
+ module.bias.data.zero_()
178
+ module.weight.data.fill_(1.0)
179
+
180
+ def _set_gradient_checkpointing(self, module, value=False):
181
+ if isinstance(module, ShrinkModel):
182
+ module.gradient_checkpointing = value
183
+
184
+
185
+ class ShrinkModel(ShrinkPreTrainedModel):
186
+ def __init__(self, config):
187
+ super().__init__(config)
188
+
189
+ self.wte = nn.Sequential(
190
+ nn.Embedding(config.vocab_size, config.hidden_size_0),
191
+ nn.Linear(config.hidden_size_0, config.hidden_size),
192
+ )
193
+ self.wpe = ScaledSinusoidal(config.hidden_size, config.max_position_embeddings)
194
+ self.wln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
195
+ self.h = nn.ModuleList(
196
+ [ShrinkBlock(config) for i in range(config.num_hidden_layers)]
197
+ )
198
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
199
+ self.model_parallel = False
200
+ self.device_map = None
201
+ self.gradient_checkpointing = False
202
+ self.post_init()
203
+
204
+ def get_input_embeddings(self):
205
+ return self.wte[0]
206
+
207
+ def set_input_embeddings(self, new_embeddings):
208
+ self.wte[0] = new_embeddings
209
+
210
+ def forward(
211
+ self,
212
+ input_ids: Optional[torch.LongTensor] = None,
213
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
214
+ attention_mask: Optional[torch.FloatTensor] = None,
215
+ token_type_ids: Optional[torch.LongTensor] = None,
216
+ position_ids: Optional[torch.LongTensor] = None,
217
+ head_mask: Optional[torch.FloatTensor] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ encoder_hidden_states: Optional[torch.Tensor] = None,
220
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
221
+ use_cache: Optional[bool] = None,
222
+ output_attentions: Optional[bool] = None,
223
+ output_hidden_states: Optional[bool] = None,
224
+ return_dict: Optional[bool] = None,
225
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
226
+ # soooo not all of the params are able to be used, since I just copied this framework from modeling_gpt2
227
+ output_attentions = (
228
+ output_attentions
229
+ if output_attentions is not None
230
+ else self.config.output_attentions
231
+ )
232
+ output_hidden_states = (
233
+ output_hidden_states
234
+ if output_hidden_states is not None
235
+ else self.config.output_hidden_states
236
+ )
237
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
238
+ return_dict = (
239
+ return_dict if return_dict is not None else self.config.use_return_dict
240
+ )
241
+ if input_ids is not None and inputs_embeds is not None:
242
+ raise ValueError(
243
+ "You cannot specify both input_ids and inputs_embeds at the same time"
244
+ )
245
+ elif input_ids is not None:
246
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
247
+ input_shape = input_ids.size()
248
+ input_ids = input_ids.view(-1, input_shape[-1])
249
+ batch_size = input_ids.shape[0]
250
+ elif inputs_embeds is not None:
251
+ input_shape = inputs_embeds.size()[:-1]
252
+ batch_size = inputs_embeds.shape[0]
253
+ else:
254
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
255
+
256
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
257
+
258
+ if token_type_ids is not None:
259
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
260
+ if position_ids is not None:
261
+ position_ids = position_ids.view(-1, input_shape[-1])
262
+
263
+ if past_key_values is None:
264
+ past_length = 0
265
+ past_key_values = tuple([None] * len(self.h))
266
+ else:
267
+ past_length = past_key_values[0][0].size(-2)
268
+ if position_ids is None:
269
+ position_ids = torch.arange(
270
+ past_length,
271
+ input_shape[-1] + past_length,
272
+ dtype=torch.long,
273
+ device=device,
274
+ )
275
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
276
+
277
+ if attention_mask is not None:
278
+ if batch_size <= 0:
279
+ raise ValueError("batch_size has to be defined and > 0")
280
+ attention_mask = attention_mask.view(batch_size, -1)
281
+ attention_mask = attention_mask[:, None, None, :]
282
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
283
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
284
+
285
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
286
+ (
287
+ encoder_batch_size,
288
+ encoder_sequence_length,
289
+ _,
290
+ ) = encoder_hidden_states.size()
291
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
292
+ if encoder_attention_mask is None:
293
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
294
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
295
+ else:
296
+ encoder_attention_mask = None
297
+
298
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
299
+
300
+ if inputs_embeds is None:
301
+ inputs_embeds = self.wte(input_ids)
302
+ position_embeds = self.wpe(input_ids)
303
+ hidden_states = inputs_embeds + position_embeds
304
+ hidden_states = self.wln(hidden_states)
305
+
306
+ if token_type_ids is not None:
307
+ token_type_embeds = self.wte(token_type_ids)
308
+ hidden_states = hidden_states + token_type_embeds
309
+
310
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
311
+
312
+ if self.gradient_checkpointing and self.training:
313
+ if use_cache:
314
+ logger.warning_once(
315
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
316
+ )
317
+ use_cache = False
318
+
319
+ presents = () if use_cache else None
320
+ all_self_attentions = () if output_attentions else None
321
+ all_cross_attentions = (
322
+ () if output_attentions and self.config.add_cross_attention else None
323
+ )
324
+ all_hidden_states = () if output_hidden_states else None
325
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
326
+ if random.uniform(0, 1) > self.config.layer_dropout_prob:
327
+ if self.model_parallel:
328
+ torch.cuda.set_device(hidden_states.device)
329
+ if layer_past is not None:
330
+ layer_past = tuple(
331
+ past_state.to(hidden_states.device)
332
+ for past_state in layer_past
333
+ )
334
+ if attention_mask is not None:
335
+ attention_mask = attention_mask.to(hidden_states.device)
336
+ if isinstance(head_mask, torch.Tensor):
337
+ head_mask = head_mask.to(hidden_states.device)
338
+ if output_hidden_states:
339
+ all_hidden_states = all_hidden_states + (hidden_states,)
340
+ outputs = block(hidden_states, mask=attention_mask)
341
+ outputs = (outputs,)
342
+ hidden_states = outputs[0]
343
+
344
+ hidden_states = self.ln_f(hidden_states)
345
+ hidden_states = hidden_states.view(output_shape)
346
+ if output_hidden_states:
347
+ all_hidden_states = all_hidden_states + (hidden_states,)
348
+
349
+ if not return_dict:
350
+ return tuple(
351
+ v
352
+ for v in [hidden_states, None, all_hidden_states, None, None]
353
+ if v is not None
354
+ )
355
+
356
+ return BaseModelOutputWithPastAndCrossAttentions(
357
+ last_hidden_state=hidden_states,
358
+ past_key_values=None,
359
+ hidden_states=all_hidden_states,
360
+ attentions=None,
361
+ cross_attentions=None,
362
+ )
363
+
364
+
365
+ class ShrinkModelForCausalLM(ShrinkPreTrainedModel):
366
+ _tied_weights_keys = ["lm_head.weight"]
367
+ def __init__(self, config):
368
+ super().__init__(config)
369
+ self.transformer = ShrinkModel(config)
370
+ self.lm_head = nn.Sequential(
371
+ nn.Linear(
372
+ config.hidden_size, config.hidden_size_0, bias=config.projection_bias
373
+ ),
374
+ nn.Linear(
375
+ config.hidden_size_0, config.vocab_size, bias=config.lm_head_bias
376
+ ),
377
+ )
378
+ self.model_parallel = False
379
+ self.device_map = None
380
+ self.post_init()
381
+
382
+ def get_output_embeddings(self):
383
+ return self.lm_head
384
+
385
+ def set_output_embeddings(self, new_embeddings):
386
+ self.lm_head = new_embeddings
387
+
388
+ def prepare_inputs_for_generation(
389
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
390
+ ):
391
+ token_type_ids = kwargs.get("token_type_ids", None)
392
+ # only last token for inputs_ids if past is defined in kwargs
393
+ if past_key_values:
394
+ input_ids = input_ids[:, -1].unsqueeze(-1)
395
+ if token_type_ids is not None:
396
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
397
+
398
+ attention_mask = kwargs.get("attention_mask", None)
399
+ position_ids = kwargs.get("position_ids", None)
400
+
401
+ if attention_mask is not None and position_ids is None:
402
+ # create position_ids on the fly for batch generation
403
+ position_ids = attention_mask.long().cumsum(-1) - 1
404
+ position_ids.masked_fill_(attention_mask == 0, 1)
405
+ if past_key_values:
406
+ position_ids = position_ids[:, -1].unsqueeze(-1)
407
+ else:
408
+ position_ids = None
409
+
410
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
411
+ if inputs_embeds is not None and past_key_values is None:
412
+ model_inputs = {"inputs_embeds": inputs_embeds}
413
+ else:
414
+ model_inputs = {"input_ids": input_ids}
415
+
416
+ model_inputs.update(
417
+ {
418
+ "past_key_values": past_key_values,
419
+ "use_cache": kwargs.get("use_cache"),
420
+ "position_ids": position_ids,
421
+ "attention_mask": attention_mask,
422
+ "token_type_ids": token_type_ids,
423
+ }
424
+ )
425
+ return model_inputs
426
+
427
+ def forward(
428
+ self,
429
+ input_ids: Optional[torch.LongTensor] = None,
430
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
431
+ attention_mask: Optional[torch.FloatTensor] = None,
432
+ token_type_ids: Optional[torch.LongTensor] = None,
433
+ position_ids: Optional[torch.LongTensor] = None,
434
+ head_mask: Optional[torch.FloatTensor] = None,
435
+ inputs_embeds: Optional[torch.FloatTensor] = None,
436
+ encoder_hidden_states: Optional[torch.Tensor] = None,
437
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
438
+ labels: Optional[torch.LongTensor] = None,
439
+ use_cache: Optional[bool] = None,
440
+ output_attentions: Optional[bool] = None,
441
+ output_hidden_states: Optional[bool] = None,
442
+ return_dict: Optional[bool] = None,
443
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
444
+ r"""
445
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
446
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
447
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
448
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
449
+ """
450
+ return_dict = (
451
+ return_dict if return_dict is not None else self.config.use_return_dict
452
+ )
453
+
454
+ transformer_outputs = self.transformer(
455
+ input_ids,
456
+ past_key_values=past_key_values,
457
+ attention_mask=attention_mask,
458
+ token_type_ids=token_type_ids,
459
+ position_ids=position_ids,
460
+ head_mask=head_mask,
461
+ inputs_embeds=inputs_embeds,
462
+ encoder_hidden_states=encoder_hidden_states,
463
+ encoder_attention_mask=encoder_attention_mask,
464
+ use_cache=use_cache,
465
+ output_attentions=output_attentions,
466
+ output_hidden_states=output_hidden_states,
467
+ return_dict=return_dict,
468
+ )
469
+ hidden_states = transformer_outputs[0]
470
+
471
+ # Set device for model parallelism
472
+ if self.model_parallel:
473
+ torch.cuda.set_device(self.transformer.first_device)
474
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
475
+
476
+ lm_logits = self.lm_head(hidden_states)
477
+
478
+ loss = None
479
+ if labels is not None:
480
+ # move labels to correct device to enable model parallelism
481
+ labels = labels.to(lm_logits.device)
482
+ # Shift so that tokens < n predict n
483
+ shift_logits = lm_logits[..., :-1, :].contiguous()
484
+ shift_labels = labels[..., 1:].contiguous()
485
+ # Flatten the tokens
486
+ loss_fct = CrossEntropyLoss()
487
+ loss = loss_fct(
488
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
489
+ )
490
+
491
+ if not return_dict:
492
+ output = (lm_logits,) + transformer_outputs[1:]
493
+ return ((loss,) + output) if loss is not None else output
494
+
495
+ return CausalLMOutputWithCrossAttentions(
496
+ loss=loss,
497
+ logits=lm_logits,
498
+ past_key_values=transformer_outputs.past_key_values,
499
+ hidden_states=transformer_outputs.hidden_states,
500
+ attentions=transformer_outputs.attentions,
501
+ cross_attentions=transformer_outputs.cross_attentions,
502
+ )
503
+
504
+ @staticmethod
505
+ def _reorder_cache(
506
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
507
+ ) -> Tuple[Tuple[torch.Tensor]]:
508
+ """
509
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
510
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
511
+ beam_idx at every generation step.
512
+ """
513
+ return tuple(
514
+ tuple(
515
+ past_state.index_select(0, beam_idx.to(past_state.device))
516
+ for past_state in layer_past
517
+ )
518
+ for layer_past in past_key_values
519
+ )