File size: 73,802 Bytes
3eac0cc bdca3d4 3eac0cc bdca3d4 3eac0cc bdca3d4 3eac0cc bdca3d4 3eac0cc bdca3d4 3eac0cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 |
"""Modeling file for HF compatibility and zero-shot experiments."""
import torch
import math
from torch import Tensor
from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
from torch.nn.attention import bias as attn_bias
from dataclasses import dataclass
from typing import Union, Optional, Any
from .raven_config_minimal import RavenConfig
from transformers.cache_utils import Cache, DynamicCache, StaticCache
###################### Huggingface Glue code I ##################################################################
from transformers import PreTrainedModel, GenerationMixin
from transformers.utils import ModelOutput
from transformers.generation.utils import GenerateDecoderOnlyOutput
import torch.nn.functional as F
from transformers import GenerationConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
# torch.backends.cuda.enable_math_sdp(False)
class RavenPreTrainedModel(PreTrainedModel):
config_class = RavenConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SandwichBlock"]
_skip_keys_device_placement = ["past_key_values"]
_tied_weights_keys = ["lm_head.weight"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = True
_tp_plan = {}
def _init_weights(self, module):
if not torch.rand((1,)).is_meta:
print("Random Initialization not implemented.")
@dataclass
class CausalLMOutputRecurrentLatents(ModelOutput):
loss: Optional[torch.Tensor] = None
log_ppl: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
past_key_values: Optional[Cache] = None
latent_states: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None
attention_maps: Optional[dict[int, torch.Tensor]] = None
stats: Optional[dict] = None
###################### Minimal implementation from here ############################################################
class RMSNorm(torch.nn.Module):
"""Saner dtype handling and slightly better for fusion"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
return self._norm(x.float()).type_as(x) * self.weight
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
class HuginnDynamicCache(DynamicCache):
def __init__(self, lookup_strategy: str = "full") -> None:
super().__init__()
self._seen_tokens = 0
self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
# structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
# the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
# per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
# Also, It is critical that the head indices do not overlap with the recurrent iteration indices
self.lookup_strategy = lookup_strategy
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
step_idx_tensor: torch.Tensor,
lookup_strategy: Optional[str] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
if "compress-s" in self.lookup_strategy:
compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
new_step_idx = (step_idx - 2) % compression_stage + 2
elif "compress-anchor" in self.lookup_strategy:
if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
new_step_idx = step_idx
else: # then re-use the next 4 KV states = one recurrence for all future recurrence
new_step_idx = 34 + (step_idx - 34) % 4
# print(step_idx, new_step_idx)
else: # compress-r
compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
new_step_idx = (step_idx - 2) // compression_stage + 2
step_idx = new_step_idx
# Init
if step_idx not in self.key_cache:
self.key_cache[step_idx] = {}
self.value_cache[step_idx] = {}
# Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
if step_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Add entries to cache
for idx, entry in enumerate(key_states.unbind(dim=-2)):
if "compress-" not in self.lookup_strategy:
assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
for idx, entry in enumerate(value_states.unbind(dim=-2)):
self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
# Materialize past state based on lookup strategy:
if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
# All entries are present, materialize cache as normal
return (
torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
)
else: # some entries were not previously computed
if lookup_strategy.startswith("latest-m4"):
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
# For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
if step_idx >= 2:
# Find valid steps for this token position
valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
else:
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
latest_keys.append(self.key_cache[max_step][token_pos])
latest_values.append(self.value_cache[max_step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy.startswith("available-m4"):
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
if token_pos in self.key_cache[step_idx]:
step = step_idx
else:
# Find valid steps for this token position
valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
latest_keys.append(self.key_cache[step][token_pos])
latest_values.append(self.value_cache[step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy.startswith("always-last-m4"):
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
# For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
if step_idx >= 2:
# Find valid steps for this token position
valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
else:
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
latest_keys.append(self.key_cache[max_step][token_pos])
latest_values.append(self.value_cache[max_step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy.startswith("skip"):
existing_keys = []
existing_values = []
for token_pos in range(self._seen_tokens):
if token_pos in self.key_cache[step_idx]:
existing_keys.append(self.key_cache[step_idx][token_pos])
existing_values.append(self.value_cache[step_idx][token_pos])
return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
elif lookup_strategy.startswith("randomized"): # sanity check
rand_keys = []
rand_values = []
for token_pos in range(self._seen_tokens):
if step_idx < 2: # For prelude steps
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
else: # Get all steps from same block position
curr_modulo = (step_idx - 2) % 4 + 2
valid_steps = [
s
for s in range(2, step_idx + 1)
if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
]
max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
rand_keys.append(self.key_cache[max_step][token_pos])
rand_values.append(self.value_cache[max_step][token_pos])
return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
else:
raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
def reset(self) -> None:
"""Reset the cache state."""
self._seen_tokens = 0
self.key_cache.clear()
self.value_cache.clear()
def clear_last_k_entries(self, k: int = 0):
"""Partially clear cache."""
assert self._seen_tokens >= k
self._seen_tokens = self._seen_tokens - k
# self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
self.key_cache = {
step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
for step, cache in self.key_cache.items()
}
self.value_cache = {
step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
for step, cache in self.value_cache.items()
}
def get_seq_length(self, step_idx: int = 0) -> int:
return self._seen_tokens
def get_memory_usage(self) -> float:
total_bytes = 0
# For each recurrent step/layer index
for step_idx in self.key_cache:
# Get the sequence cache for this step
key_seq_cache = self.key_cache[step_idx]
for seq_idx in key_seq_cache:
key_tensor = key_seq_cache[seq_idx]
# Add memory for of key tensors, assuming value is the same
total_bytes += key_tensor.nelement() * key_tensor.element_size()
return total_bytes * 2 / (1024 * 1024)
class HuginnStaticCache(Cache):
"""Static Cache for the recurrent model"""
is_compileable = False # this is todo
def __init__(
self,
max_length: int,
max_num_steps: int,
num_heads: int,
hidden_dim: int,
batch_size: int = 1,
lookup_strategy: str = "full",
device: Optional[Union[torch.device, str]] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self._seen_tokens = 0
self.max_length = max_length
self.lookup_strategy = lookup_strategy
# Adjust max_num_steps based on compression strategy
if "compress-" in lookup_strategy:
compression_stage = int(lookup_strategy.split("compress-")[1][1:])
if "compress-s" in lookup_strategy:
# For modulo compression (s), we need steps for 0,1 + compressed steps
self.max_num_steps = 4 + compression_stage
else:
# For relative compression, we need steps for 0,1 + compressed steps
self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
else:
self.max_num_steps = max_num_steps
# Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
device = torch.device(device) if device is not None else None
cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
# Mark tensors as static for compile
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
torch._dynamo.mark_static_address(self.valid_mask)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
step_idx: torch.Tensor,
lookup_strategy: Optional[str] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if step_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Adjust step_idx for compression
lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
if "compress-" in lookup_strategy and step_idx > 1:
compression_stage = int(lookup_strategy.split("compress-")[1][1:])
if "compress-s" in lookup_strategy:
step_idx = (step_idx - 2) % compression_stage + 2
else:
step_idx = (step_idx - 2) // compression_stage + 2
start_idx = self._seen_tokens - key_states.shape[-2]
indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
self.key_cache[step_idx].index_copy_(2, indices, key_states)
self.value_cache[step_idx].index_copy_(2, indices, value_states)
self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
# Return based on lookup strategy
if lookup_strategy == "full":
return (
self.key_cache[step_idx, :, :, : self._seen_tokens],
self.value_cache[step_idx, :, :, : self._seen_tokens],
)
elif lookup_strategy.startswith("latest-m4"):
if step_idx >= 2:
pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
pattern_valid = self.valid_mask[pattern_steps]
max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
return (
self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
)
return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
step_idx, :, :, : self._seen_tokens
]
elif lookup_strategy == "skip":
valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
return (
self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
)
elif lookup_strategy.startswith("randomized"):
if step_idx < 2:
max_step = step_idx
else:
curr_modulo = (step_idx - 2) % 4 + 2
valid_steps = (
torch.where(
(torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
)[0]
+ 2
)
rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
max_step = valid_steps[rand_idx]
return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
else:
raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
def reset(self) -> None:
self._seen_tokens = 0
self.key_cache.zero_()
self.value_cache.zero_()
self.valid_mask.zero_()
def get_seq_length(self, step_idx: int = 0) -> int:
return self._seen_tokens
def get_memory_usage(self) -> float:
return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
ValidCache = HuginnDynamicCache | HuginnStaticCache
class CausalSelfAttention(torch.nn.Module):
def __init__(self, config: RavenConfig) -> None:
super().__init__()
self.config = config
self.n_head = config.num_attention_heads
self.n_kv_heads = config.num_key_value_heads
self.head_dim = config.n_embd // self.n_head
shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
if config.qk_bias:
self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
block_idx: torch.Tensor,
mask: Optional[BlockMask] = None,
past_key_values: Optional[ValidCache] = None,
) -> Tensor:
B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
q = q.view(B, S, self.n_head, self.head_dim)
k = k.view(B, S, self.n_kv_heads, self.head_dim)
v = v.view(B, S, self.n_kv_heads, self.head_dim)
# bias?
if self.config.qk_bias:
q_bias, k_bias = self.qk_bias.split(1, dim=0)
q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
q = q.transpose(1, 2) # (B, nh, S, hs)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# apply rotary
cos, sin = freqs_cis
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if past_key_values is not None:
k, v = past_key_values.update(k, v, block_idx)
if mask is not None:
y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
else:
if q.shape[2] < k.shape[2]:
if q.shape[2] > 1:
bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0, enable_gqa=True)
else:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, enable_gqa=True)
else:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True, enable_gqa=True)
y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
return self.proj(y)
class GatedMLP(torch.nn.Module):
def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
super().__init__()
in_features = config.n_embd if in_features == 0 else in_features
self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
self.nonlin = torch.nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
# modified to single FC layer to improve parallelism
x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
x = self.nonlin(x_fc_1) * x_fc_2
return self.proj(x)
class SandwichBlock(torch.nn.Module):
expanded = False
def __init__(self, config: RavenConfig, layer_id: int) -> None:
super().__init__()
self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.mlp = GatedMLP(config)
self.layer_id = layer_id
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
step_idx: int,
mask: Optional[BlockMask] = None,
past_key_values: Optional[ValidCache] = None,
) -> Tensor:
attn_out = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values)
x = attn_out + x
x = self.mlp(self.norm_2(x)) + x
return x
class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
def __init__(
self,
config: RavenConfig,
) -> None:
super().__init__(config)
self.config = config
# Transformer layers
prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
core_block = torch.nn.ModuleList(
SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
for i in range(config.n_layers_in_recurrent_block)
)
o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
self.transformer = torch.nn.ModuleDict(
dict(
wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
prelude=prelude,
adapter=adapter,
core_block=core_block,
coda=coda,
ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
)
)
self.emb_scale = config.init_values["embed_scale"]
# Head
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
if self.config.tie_embeddings:
self.tie_weights()
# rope
self.rotary_emb = LlamaRotaryEmbedding(config=config)
def get_input_embeddings(self):
return self.transformer.wte
def get_output_embeddings(self):
return self.lm_head
def compile_mask(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
pad_token_id=65509,
) -> Optional[BlockMask]:
batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
# If no padding and no attention mask, no need for a mask
if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
return None
if past_key_values is not None and seq_len == 1:
return None
# Get total sequence length including cache
cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
kv_length = cache_len + seq_len
if attention_mask is None:
def mask_mod(b, h, q_idx, kv_idx):
return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
else:
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
if kv_length == 0:
kv_length = seq_len # prefill
block_mask = create_block_mask(
mask_mod,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=kv_length,
device=input_ids.device,
)
# # Define mask_mod function
# def mask_mod(b, h, q_idx, kv_idx):
# # Always apply causal constraint
# is_causal = q_idx >= kv_idx
# # Handle cache vs current tokens
# is_cache = kv_idx < cache_len
# current_idx = kv_idx - cache_len
# # For cache: always valid; For current: check padding
# not_pad = input_ids[b, current_idx] != pad_token_id
# valid = is_cache | not_pad
# # Apply attention mask if provided
# if attention_mask is not None:
# q_idx_curr = q_idx - cache_len
# attn_valid = attention_mask[b, q_idx_curr, current_idx]
# valid = valid & (is_cache | attn_valid)
# return is_causal & valid
# def mask_mod(b, h, q_idx, kv_idx):
# is_causal = q_idx >= kv_idx
# is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
# current_idx = kv_idx - cache_len
# is_valid = (~is_current) | (
# (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
# )
# return is_causal & is_valid
# # Define mask_mod function
# def mask_mod(b, h, q_idx, kv_idx):
# # Always apply causal constraint
# is_causal = q_idx >= kv_idx
# # Handle cache vs current tokens
# is_cache = kv_idx < cache_len
# current_idx = kv_idx - cache_len
# in_bounds = (current_idx >= 0) & (current_idx < seq_len)
# # For cache: always valid; For current: check padding
# not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
# valid = is_cache | (not_pad & in_bounds)
# # Apply attention mask if provided
# if attention_mask is not None:
# q_idx_curr = q_idx - cache_len
# q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
# attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
# valid = valid & (is_cache | attn_valid)
# return is_causal & valid
# Create block mask
block_mask = create_block_mask(
mask_mod,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=kv_length,
device=input_ids.device,
)
return block_mask
def forward(
self,
input_ids: torch.Tensor,
input_embeds: Optional[torch.Tensor] = None,
input_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
num_steps: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
output_details: dict = {
"return_logits": True,
"return_latents": True,
"return_head": False,
"return_stats": False,
},
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
init_scale: float = 1.0,
**kwargs,
) -> CausalLMOutputRecurrentLatents:
# Support multiple position formats:
if position_ids is None and cache_position is None:
position_ids = torch.arange(input_ids.shape[1], device=self.device).unsqueeze(0)
elif cache_position is not None:
position_ids = cache_position.unsqueeze(0)
if input_embeds is None:
input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
if self.emb_scale != 1:
input_embeds = input_embeds * self.emb_scale # type: ignore
if use_cache and past_key_values is None:
past_key_values = HuginnDynamicCache()
prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
freqs_cis = self.rotary_emb(input_embeds, position_ids)
# Non-recurrent prelude
for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
block_idx += 1
input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
# Main recurrence
x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
input_embeds, # type: ignore # mystery typing error
input_states,
freqs_cis,
block_idx,
prepared_attn_mask,
past_key_values,
num_steps,
init_scale,
)
latent_states = x.clone().detach()
# Coda layers
block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
for block in self.transformer.coda: # type: ignore # types broken in 2.6+
block_idx -= 1
x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
# Prediction head, assuming labels really are labels and not equal to input_ids
if labels is not None:
logits = self.lm_head(x).float()
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
)
log_ppl = loss.clone().detach().exp()
else:
logits = self.lm_head(x)#.float()
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
return CausalLMOutputRecurrentLatents(
loss=loss,
log_ppl=log_ppl,
logits=logits if output_details["return_logits"] else None,
past_key_values=past_key_values,
hidden_states=x if output_details["return_head"] else None,
latent_states=latent_states if output_details["return_latents"] else None,
stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
if output_details["return_stats"]
else None,
)
@torch._dynamo.disable(recursive=False) # type: ignore
def iterate_forward(
self,
input_embeds: torch.Tensor,
input_states: torch.Tensor,
freqs_cis,
block_idx: torch.Tensor,
mask: Optional[BlockMask],
past_key_values: Optional[ValidCache] = None,
num_steps: Optional[torch.Tensor] = None,
init_scale: float = 1.0,
):
x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
if num_steps is None:
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
num_steps_no_grad, num_steps_with_grad = num_steps
else:
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
with torch.no_grad():
# ultra annoying in ddp due to
# https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
# for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
# and all parameters are always used
for no_grad_step in range(num_steps_no_grad):
xk = x
x, block_idx = self.core_block_forward(
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
)
for grad_step in range(num_steps_with_grad):
xk = x
x, block_idx = self.core_block_forward(
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
)
return x, num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
def core_block_forward(
self,
x,
input_embeds,
freqs_cis,
mask: Optional[BlockMask],
past_key_values,
block_idx: torch.Tensor,
current_step: int | Tensor,
):
x = self._maybe_inject_noise(x, current_step)
x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
block_idx += 1
x = block(x, freqs_cis, block_idx, mask, past_key_values)
return x, block_idx
@torch.no_grad()
def iterate_one_step(
self,
input_embeds,
input_states,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
attention_mask: Optional[BlockMask] = None,
past_key_values: Optional[ValidCache] = None,
current_step: int = 0,
):
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
x, block_idx = self.core_block_forward(
input_states,
input_embeds,
freqs_cis,
attention_mask,
past_key_values,
block_idx,
current_step=current_step,
)
return x, block_idx, current_step + 1
def predict_from_latents(
self,
latents,
attention_mask: Optional[BlockMask] = None,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
):
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : latents.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
# Coda layers
block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
for block in self.transformer.coda: # type: ignore # types broken in 2.6+
block_idx -= 1
x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
logits = self.lm_head(x).float()
return CausalLMOutputRecurrentLatents(
loss=torch.as_tensor(0.0),
log_ppl=torch.as_tensor(0.0),
logits=logits,
past_key_values=past_key_values,
latent_states=x,
)
def embed_inputs(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Support multiple position formats:
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
if self.emb_scale != 1:
input_embeds = input_embeds * self.emb_scale # type: ignore
if use_cache and past_key_values is None:
past_key_values = HuginnDynamicCache()
block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
# Non-recurrent prelude
for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
block_idx += 1
input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
return input_embeds, block_idx
@torch._dynamo.disable(recursive=False) # type: ignore
def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Outputs are long tensors so that they can be passed through compiled functions"""
t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
s = self.config.mean_backprop_depth
if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
# these values are only the mean TFLOPs of the randomized sampler
# Note that this clause also breaks the contract, and returns ints in meta tensor mode
return t, s # type: ignore
if self.training:
sigma = 0.5
mu = math.log(t + s) - (sigma**2 / 2)
rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
n = torch.clamp(p - s, min=0)
k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
else:
n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
return n.to(dtype=torch.long), k.to(dtype=torch.long)
def initialize_state(self, input_embeds, scale: float = 1.0):
x = torch.randn_like(input_embeds)
std = self.config.init_values["std"] * scale
if std > 0:
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
if self.emb_scale != 1:
x = x * self.emb_scale
else:
x.zero_()
return x
def _maybe_inject_noise(self, x, current_step, renorm=False):
if self.config.test_time_noise > 0:
n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
if self.config.test_time_noise_type == "geom":
step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
elif self.config.test_time_noise_type == "sqrt":
step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
elif self.config.test_time_noise_type == "line":
noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
x = x * (1 - noise) + torch.randn_like(x) * noise
elif self.config.test_time_noise_type == "chi":
noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
x = x * (1 - noise) + torch.randn_like(x) * noise
elif self.config.test_time_noise_type == "fixed":
x = x * (1 - n) + torch.randn_like(x) * n
else:
raise ValueError()
if renorm:
x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
return x
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.Tensor] = None,
cache_lookup_strategy: str = "full",
**kwargs,
):
model_inputs = {}
model_inputs["cache_position"] = cache_position
current_input_length = input_ids.shape[1]
if past_key_values is not None:
if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
assert past_key_values.get_seq_length() == 0 # only replace empty caches
# Need to use custom cache, detect and replace HF cache if generate injects it
if isinstance(past_key_values, StaticCache):
past_key_values = HuginnStaticCache(
max_length=getattr(self.generation_config, "max_length", self.config.block_size),
max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
num_heads=self.config.num_key_value_heads,
hidden_dim=self.config.n_embd // self.config.num_attention_heads,
dtype=torch.bfloat16,
device=input_ids.device,
lookup_strategy=cache_lookup_strategy,
)
else:
past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
input_ids = input_ids[:, cache_position] # type: ignore
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
if cache_position is None:
position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
memory_format=torch.contiguous_format
) # some form of position_ids is a critical argument for the model to correctly apply rope!
# forward all other entries
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
return model_inputs
@torch.no_grad()
def generate(self, *args, **kwargs):
"""Dispatcher - use HF generate in all normal cases."""
self.generation_config = args[1] if len(args) > 1 else self.generation_config
if any(k in kwargs for k in ("criterion", "exit_threshold")):
# print("Dispatching to custom generate_adaptive function call")
return self.generate_with_adaptive_compute(*args, **kwargs)
elif "continuous_compute" in kwargs:
# print("Dispatching to custom generate_minimal function call")
return self.generate_minimal(*args, **kwargs)
else:
return super().generate(*args, **kwargs)
@torch.no_grad()
def _prep_generate_args(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
cache_lookup_strategy: str = "full",
model_kwargs: dict = {},
):
# Setup
if generation_config is None:
generation_config: GenerationConfig = self.generation_config # type: ignore
if "max_new_tokens" in model_kwargs:
max_new_tokens = model_kwargs["max_new_tokens"]
if "max_length" in model_kwargs:
max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
else:
max_length = model_kwargs.get("max_length", generation_config.max_length)
max_new_tokens = max_length - input_ids.shape[1]
if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
else:
model_kwargs["past_key_values"] = HuginnStaticCache(
max_length=max_length,
max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
num_heads=self.config.num_key_value_heads,
hidden_dim=self.config.n_embd // self.config.num_attention_heads,
batch_size=input_ids.shape[0],
dtype=torch.bfloat16,
device=input_ids.device,
lookup_strategy=cache_lookup_strategy,
)
model_kwargs["use_cache"] = True
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
return model_kwargs, generation_config, max_new_tokens
@torch.no_grad()
def generate_minimal(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
tokenizer=None,
streamer=None,
continuous_compute=False, # warm-start state / continuous CoT
init_scale: float = 1.0,
cache_lookup_strategy: str = "full",
**model_kwargs,
) -> Union[torch.Tensor, dict[str, Any]]:
"""Minimal single-sequence generation. Template for more complicated generate tasks"""
model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
input_ids, generation_config, cache_lookup_strategy
)
stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Set up continuous compute if enabled
if continuous_compute:
embedded_inputs, _ = self.embed_inputs(input_ids)
model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
# Generate tokens
batch_size = input_ids.shape[0]
for _ in range(max_new_tokens):
# Forward pass
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs, init_scale=init_scale)
# Get next token
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
next_token = self._sample_next_token(next_token_logits, generation_config)
# Append token to sequence
input_ids = torch.cat([input_ids, next_token], dim=-1)
if streamer:
streamer.put(next_token.cpu())
# Update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
if stop_tokens is not None:
for i in range(batch_size):
if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
unfinished_sequences[i] = 0
if "stopping_criteria" in model_kwargs:
unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
if unfinished_sequences.max() == 0:
break
if streamer:
streamer.end()
if generation_config.return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids, # type: ignore
scores=None,
logits=None,
attentions=None,
hidden_states=None,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
@torch.no_grad()
def generate_with_adaptive_compute(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
tokenizer=None,
streamer=None,
continuous_compute=False, # warm-start state / continuous CoT
criterion="none", # off by default, turn on by choosing an exit criterion
exit_threshold: Union[str, float, int] = "auto",
init_scale: float = 1.0,
cache_lookup_strategy: str = "full",
**model_kwargs,
) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
"""
Generate tokens with adaptive compute. This is NOT the most efficient implementation.
For batches, on each token, we iterate until the entire batch finishes.
"""
model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
input_ids, generation_config, cache_lookup_strategy, model_kwargs
)
max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
batch_size = input_ids.shape[0]
compute_steps = []
# Set up continuous compute if enabled
if continuous_compute:
embedded_inputs, _ = self.embed_inputs(input_ids)
model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
# Track which sequences have finished (using unfinished_sequences to match generate_minimal)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
# Generate tokens
for _ in range(max_new_tokens):
# Adaptive compute forward
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
aux_inputs = {
k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
}
embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
current_latents = (
self.initialize_state(embedded_inputs, scale=init_scale)
if not continuous_compute
else model_kwargs["input_states"]
)
# Initialize criterion tracking for each sequence in batch
exit_values_per_seq = [[] for _ in range(batch_size)]
compute_steps_per_seq = [0] * batch_size
exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
# Set up criterions based on selected strategy
if criterion == "entropy-diff":
entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
elif criterion == "latent-diff":
exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
elif "kl" in criterion:
V = self.config.padded_vocab_size
log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
if criterion == "minp-kl":
exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
else:
exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
elif criterion == "argmax-stability":
stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
elif criterion == "none":
exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
else:
raise ValueError("Invalid adaptive compute strategy.")
next_token_logits = None
# Iterate through compute steps
for compute_step in range(max_steps):
prev_latents = current_latents.clone()
current_latents, block_idx, _ = self.iterate_one_step(
embedded_inputs,
current_latents,
block_idx=block_idx,
**aux_inputs,
current_step=compute_step,
)
if _ > 0: # do not exit in prefill
# Check exit condition for each sequence in batch
if criterion == "entropy-diff":
prev_entropy = entropy
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
probs = F.softmax(logits[:, -1, :], dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
exit_values = (entropy - prev_entropy).abs()
elif criterion == "latent-diff":
norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
exit_values = norm_diff.mean(dim=-1)
elif "kl" in criterion:
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
prev_log_probs = log_probs
if criterion == "minp-kl":
probs = F.softmax(logits[:, -1, :].float(), dim=-1)
max_probs = probs.max(dim=-1, keepdim=True)[0]
probs_mask = probs < (0.1 * max_probs)
masked_probs = probs.clone()
masked_probs[probs_mask] = 1 / V
probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
log_probs = probs.log()
else:
log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
elif criterion == "argmax-stability":
prev_argmax = current_argmax
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
current_argmax = logits[:, -1, :].argmax(dim=-1)
stable_for_n_steps = torch.where(
current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
)
exit_values = stable_for_n_steps
elif criterion == "none":
exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
# Record values and check exits for each sequence
for i in range(batch_size):
if not exit_reached[i] and unfinished_sequences[i].bool():
exit_values_per_seq[i].append(exit_values[i].item())
# Check for new exits, respecting unfinished_sequences
new_exits = (
exit_values < exit_threshold
if criterion != "argmax-stability"
else exit_values >= exit_threshold
)
new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
if new_exits.any():
exit_reached = exit_reached | new_exits
if criterion == "latent-diff":
# Normally we don't compute the output for latent-diff, but when there is an exit,
# we need to compute and save the output
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
if next_token_logits is None:
next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
else:
for i in range(batch_size):
if new_exits[i]:
next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
for i in range(batch_size):
if new_exits[i]:
compute_steps_per_seq[i] = compute_step + 1
# If all sequences have exited or finished, break early
if (exit_reached | ~unfinished_sequences.bool()).all():
break
# This else is if the for loop finished without breaking
else:
outputs = self.predict_from_latents(current_latents, **aux_inputs)
# For sequences that didn't exit early, use the final logits
if next_token_logits is None:
next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
else:
for i in range(batch_size):
if not exit_reached[i] and unfinished_sequences[i].bool():
next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
compute_steps_per_seq[i] = max_steps
# Save latent states for continuous compute if enabled
if continuous_compute:
model_kwargs["input_states"] = current_latents[:, -1:, :]
# Record compute steps for this token generation
compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
# Sample or select next token based on generation config
next_token = self._sample_next_token(next_token_logits, generation_config)
# Append token to sequence
input_ids = torch.cat([input_ids, next_token], dim=-1)
if streamer:
streamer.put(next_token.cpu())
# Update model kwargs for next iteration
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
# Check for stop tokens and update unfinished sequences
for i in range(batch_size):
if (
unfinished_sequences[i].bool()
and stop_tokens is not None
and next_token[i, 0].item() in stop_tokens
):
unfinished_sequences[i] = 0
# Apply any custom stopping criteria
if "stopping_criteria" in model_kwargs:
unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
# Break if all sequences are finished
if unfinished_sequences.max() == 0:
break
if streamer:
streamer.end()
if generation_config.return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids, # type: ignore
scores=compute_steps, # type: ignore
logits=None,
attentions=None,
hidden_states=None,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
def _get_stops(self, generation_config, tokenizer, model_kwargs):
stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
if generation_config.eos_token_id is not None:
stop_tokens.add(generation_config.eos_token_id)
if "stopping_criteria" in model_kwargs and tokenizer is None:
tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
for s in generation_config.stop_strings:
token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
stop_tokens.add(token_id)
return torch.tensor(list(stop_tokens))
def _sample_next_token(self, next_token_logits, generation_config):
"""Helper function to sample the next token."""
if generation_config.do_sample:
if generation_config.temperature:
next_token_logits = next_token_logits.float() / generation_config.temperature
probs = F.softmax(next_token_logits, dim=-1)
# Apply top_k
if generation_config.top_k:
top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
# Apply top_p (nucleus sampling)
if generation_config.top_p:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create mask for probs to keep
remove_indices = cumulative_probs > generation_config.top_p
remove_indices[:, 0] = False # Keep at least the top probability
# Convert sorted indices mask back to original indices mask
mask = torch.zeros_like(probs, dtype=torch.bool)
for i in range(probs.shape[0]):
mask[i, sorted_indices[i, remove_indices[i]]] = True
probs = torch.where(mask, torch.zeros_like(probs), probs)
# Apply min_p
if generation_config.min_p:
max_probs = probs.max(dim=-1, keepdim=True)[0]
min_p_threshold = generation_config.min_p * max_probs
probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
# Renormalize probabilities
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
# Sample from the distribution
return torch.multinomial(probs, num_samples=1)
else:
return torch.argmax(next_token_logits, dim=-1, keepdim=True)
@torch.no_grad()
def generate_speculative(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
tokenizer=None,
streamer=None,
continuous_compute=False, # warm-start state / continuous CoT
init_scale: float = 1.0,
cache_lookup_strategy: str = "full",
draft_steps=32,
lookahead_for_draft=8,
verification_threshold=1,
num_steps: int = 32, # intercept deliberately
**model_kwargs,
) -> Union[torch.Tensor, dict[str, Any]]:
"""Batched speculative decoding with per-sequence acceptance."""
assert lookahead_for_draft > 0
pad_id = 65509
model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
input_ids, generation_config, cache_lookup_strategy, model_kwargs
)
stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Set up continuous compute if enabled
if continuous_compute:
embedded_inputs, _ = self.embed_inputs(input_ids)
model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
tokens_generated = 0
# Prefill cache with full num_steps
if model_kwargs["past_key_values"].get_seq_length() == 0:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
next_token = self._sample_next_token(
outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
)
input_ids = torch.cat([input_ids, next_token], dim=-1)
tokens_generated += 1
if streamer:
streamer.put(next_token.cpu())
model_kwargs["cache_position"] = torch.as_tensor(
[model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
)
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
# Generate tokens
batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
accepted_tokens = []
while tokens_generated < max_new_tokens:
### Run the next draft ####
drafted_inputs = input_ids.clone()
current_len = input_ids.shape[1]
for _ in range(lookahead_for_draft):
model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
next_token = self._sample_next_token(next_token_logits, generation_config)
drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
## Verify drafted tokens ###
model_kwargs["cache_position"] = torch.arange(
current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
)
model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
verified_next_token_preds = outputs.logits.argmax(dim=-1)
if verification_threshold >= 1:
mismatched_tokens = (
verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
)
not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
else:
verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
verified_probs = F.softmax(verified_logits, dim=-1)
drafted_token_probs = torch.gather(
verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
).squeeze(-1)
max_probs = verified_probs.max(dim=-1)[0]
verification_passed = drafted_token_probs >= verification_threshold * max_probs
not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
# Per-sequence acceptance handling
acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
# Build next_tokens for each sequence
next_tokens_batch = []
for i in range(batch_size):
seq_acceptance = acceptance_lengths[i].item()
if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
# Accept up to mismatch + sample final token
accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
final_token = self._sample_next_token(final_token_logits, generation_config)
seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
else:
# Accept all drafted tokens
seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
next_tokens_batch.append(seq_tokens)
# Clean up KV cache - only if any sequence had mismatches
if not_all_matched.any():
min_first_mismatch = first_mismatch.min().item()
model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
# Concatenate accepted tokens to input_ids
batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
max_len = max(batch_accepted_counts)
padded_tokens = [
torch.cat(
[
tokens,
pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
],
dim=-1,
)
if tokens.shape[1] < max_len
else tokens
for tokens in next_tokens_batch
]
next_tokens = torch.cat(padded_tokens, dim=0)
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
accepted_tokens.append(batch_accepted_counts)
tokens_generated += max(batch_accepted_counts)
if streamer:
streamer.put(next_tokens_batch[0].cpu())
model_kwargs["cache_position"] = torch.as_tensor(
[model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
)
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
# Check stopping conditions
if stop_tokens is not None:
for i in range(batch_size):
if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
unfinished_sequences[i] = 0
if "stopping_criteria" in model_kwargs:
unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
if unfinished_sequences.max() == 0:
break
if streamer:
streamer.end()
# Cut off extraneous parts of the sequence per batch element
if stop_tokens is not None:
for i in range(batch_size):
stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
if len(stop_positions) > 0:
input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
# Trim tensor to remove columns that are pad_id across all sequences
non_pad_mask = input_ids != pad_id
last_real_token = non_pad_mask.any(dim=0).nonzero()
if len(last_real_token) > 0:
input_ids = input_ids[:, : last_real_token[-1].item() + 1]
if generation_config.return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids, # type: ignore
scores=accepted_tokens, # type: ignore
logits=None,
attentions=None,
hidden_states=None,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
probs = torch.softmax(logits.float(), dim=-1)
prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
residual_diff = (x - latent_states).norm(dim=-1)
rel_residual = residual_diff / latent_states.norm(dim=-1)
stats = {
"entropy": prob_entropy,
"residual_diff": residual_diff,
"rel_residual": rel_residual,
"num_steps_no_grad": num_steps_no_grad,
"num_steps_with_grad": num_steps_with_grad,
}
return stats
#################################### HF registration ############################################################
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
# New
RavenConfig.register_for_auto_class()
RavenForCausalLM.register_for_auto_class("AutoModel")
RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
# Old?
AutoConfig.register("huginn_raven", RavenConfig)
AutoModel.register(RavenConfig, RavenForCausalLM)
AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM) |