AubreeL commited on
Commit
e763fc3
·
verified ·
1 Parent(s): f58422d

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +194 -0
model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import chess
6
+
7
+ MOVES_PER_SQUARE = 73
8
+ POLICY_SIZE = 64 * MOVES_PER_SQUARE
9
+
10
+
11
+ class ResidualBlock(nn.Module):
12
+ def __init__(self, channels: int) -> None:
13
+ super().__init__()
14
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
15
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ residual = x
19
+ out = F.relu(self.conv1(x))
20
+ out = self.conv2(out)
21
+ out = out + residual
22
+ return F.relu(out)
23
+
24
+
25
+ class TinyPCN(nn.Module):
26
+ def __init__(self, board_channels: int = 18, policy_size: int = POLICY_SIZE) -> None:
27
+ """Tiny policy-value net: shared trunk plus separate heads."""
28
+ super().__init__()
29
+
30
+ self.conv1 = nn.Conv2d(board_channels, 32, kernel_size=3, padding=1)
31
+ self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
32
+ self.res_block = ResidualBlock(32)
33
+
34
+ self.policy_conv = nn.Conv2d(32, 32, kernel_size=1)
35
+ self.policy_fc = nn.Linear(32 * 8 * 8, policy_size)
36
+
37
+ self.value_conv = nn.Conv2d(32, 1, kernel_size=1)
38
+ self.value_fc1 = nn.Linear(8 * 8, 64)
39
+ self.value_fc2 = nn.Linear(64, 1)
40
+
41
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
42
+ x = F.relu(self.conv1(x))
43
+ x = F.relu(self.conv2(x))
44
+ x = self.res_block(x)
45
+
46
+ p = F.relu(self.policy_conv(x))
47
+ p = p.view(p.size(0), -1)
48
+ policy_logits = self.policy_fc(p)
49
+
50
+ v = F.relu(self.value_conv(x))
51
+ v = v.view(v.size(0), -1)
52
+ v = F.relu(self.value_fc1(v))
53
+ value = torch.tanh(self.value_fc2(v))
54
+
55
+ return policy_logits, value
56
+
57
+
58
+ def board_to_18_planes(board: chess.Board) -> torch.FloatTensor:
59
+ """Return 18 AlphaZero-style planes for the given board."""
60
+ planes = np.zeros((18, 8, 8), dtype=np.float32)
61
+
62
+ for square, piece in board.piece_map().items():
63
+ row = 7 - (square // 8)
64
+ col = square % 8
65
+ color_offset = 0 if piece.color == chess.WHITE else 6
66
+ plane_idx = (piece.piece_type - 1) + color_offset
67
+ planes[plane_idx, row, col] = 1.0
68
+
69
+ planes[12, :, :] = 1.0 if board.has_kingside_castling_rights(chess.WHITE) else 0.0
70
+ planes[13, :, :] = 1.0 if board.has_queenside_castling_rights(chess.WHITE) else 0.0
71
+ planes[14, :, :] = 1.0 if board.has_kingside_castling_rights(chess.BLACK) else 0.0
72
+ planes[15, :, :] = 1.0 if board.has_queenside_castling_rights(chess.BLACK) else 0.0
73
+
74
+ planes[16, :, :] = 1.0 if board.turn == chess.WHITE else 0.0
75
+
76
+ if board.ep_square is not None:
77
+ ep_row = 7 - (board.ep_square // 8)
78
+ ep_col = board.ep_square % 8
79
+ planes[17, ep_row, ep_col] = 1.0
80
+
81
+ return torch.from_numpy(planes)
82
+
83
+
84
+ def board_to_20_planes(board: chess.Board) -> torch.FloatTensor:
85
+ """Return 20 planes (18 standard plus repetition and move count)."""
86
+ planes18 = board_to_18_planes(board).numpy()
87
+ extra = np.zeros((2, 8, 8), dtype=np.float32)
88
+
89
+ try:
90
+ repetition = board.is_repetition()
91
+ except Exception:
92
+ repetition = False
93
+ extra[0, :, :] = 1.0 if repetition else 0.0
94
+
95
+ move_norm = min(board.fullmove_number / 100.0, 1.0)
96
+ extra[1, :, :] = float(move_norm)
97
+
98
+ planes20 = np.concatenate([planes18, extra], axis=0)
99
+ return torch.from_numpy(planes20)
100
+
101
+
102
+ def encode_board(board: chess.Board, variant: str = "18") -> torch.FloatTensor:
103
+ if variant == "18":
104
+ return board_to_18_planes(board)
105
+ if variant == "20":
106
+ return board_to_20_planes(board)
107
+ raise ValueError("variant must be '18' or '20'")
108
+
109
+
110
+ _RAY_OFFSETS = ((1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1))
111
+ _KNIGHT_OFFSETS = ((2, 1), (1, 2), (-1, 2), (-2, 1), (-2, -1), (-1, -2), (1, -2), (2, -1))
112
+ _PROMOTION_OFFSETS = ((1, 0), (1, 1), (1, -1), (2, 0))
113
+ _PROMOTION_PIECES = ("q", "r", "b", "n")
114
+
115
+ _move_to_index: dict[tuple[int, int, str | None], int] = {}
116
+ _index_to_move: dict[int, tuple[int, int, str | None]] = {}
117
+
118
+
119
+ def _init_move_tables() -> None:
120
+ idx = 0
121
+ for sq in range(64):
122
+ row0, col0 = divmod(sq, 8)
123
+
124
+ for dx, dy in _RAY_OFFSETS:
125
+ for step in range(1, 8):
126
+ row = row0 + dx * step
127
+ col = col0 + dy * step
128
+ if 0 <= row < 8 and 0 <= col < 8:
129
+ target = row * 8 + col
130
+ _move_to_index[(sq, target, None)] = idx
131
+ _index_to_move[idx] = (sq, target, None)
132
+ idx += 1
133
+
134
+ for dx, dy in _KNIGHT_OFFSETS:
135
+ row = row0 + dx
136
+ col = col0 + dy
137
+ if 0 <= row < 8 and 0 <= col < 8:
138
+ target = row * 8 + col
139
+ _move_to_index[(sq, target, None)] = idx
140
+ _index_to_move[idx] = (sq, target, None)
141
+ idx += 1
142
+
143
+ for dx, dy in _PROMOTION_OFFSETS:
144
+ row = row0 + dx
145
+ col = col0 + dy
146
+ if 0 <= row < 8 and 0 <= col < 8:
147
+ target = row * 8 + col
148
+ for promo in _PROMOTION_PIECES:
149
+ _move_to_index[(sq, target, promo)] = idx
150
+ _index_to_move[idx] = (sq, target, promo)
151
+ idx += 1
152
+ else:
153
+ idx += len(_PROMOTION_PIECES)
154
+
155
+ while idx % MOVES_PER_SQUARE != 0:
156
+ _index_to_move[idx] = None
157
+ idx += 1
158
+
159
+
160
+ _init_move_tables()
161
+
162
+
163
+ def _promotion_symbol(piece_type: int | None) -> str | None:
164
+ if piece_type is None:
165
+ return None
166
+ return chess.Piece(piece_type, chess.WHITE).symbol().lower()
167
+
168
+
169
+ def encode_move(move: chess.Move, board: chess.Board) -> int:
170
+ from_sq = move.from_square
171
+ to_sq = move.to_square
172
+ promo_symbol = _promotion_symbol(move.promotion)
173
+
174
+ if board.color_at(move.from_square) == chess.BLACK:
175
+ from_sq = chess.square_mirror(from_sq)
176
+ to_sq = chess.square_mirror(to_sq)
177
+
178
+ key = (from_sq, to_sq, promo_symbol)
179
+ return _move_to_index.get(key, -1)
180
+
181
+
182
+ def decode_move(index: int, board: chess.Board | None = None) -> chess.Move | None:
183
+ triple = _index_to_move.get(index)
184
+ if triple is None:
185
+ return None
186
+
187
+ from_sq, to_sq, promo = triple
188
+
189
+ if board is not None and board.turn == chess.BLACK:
190
+ from_sq = chess.square_mirror(from_sq)
191
+ to_sq = chess.square_mirror(to_sq)
192
+
193
+ promotion = chess.Piece.from_symbol(promo.upper()).piece_type if promo else None
194
+ return chess.Move(from_sq, to_sq, promotion=promotion)