HeavensHackDev commited on
Commit
6d5a01a
·
verified ·
1 Parent(s): 8992d53

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +21 -0
train.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def train_model(model, data, epochs=5):
2
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
3
+ criterion = nn.CrossEntropyLoss()
4
+ model.train()
5
+ for epoch in range(epochs):
6
+ total_loss = 0
7
+ for text in data:
8
+ tokens = tokenizer(text)
9
+ indices = [vocab[token] for token in tokens][:50] # Ограничение длины
10
+ if len(indices) < 2:
11
+ continue
12
+ src = torch.tensor(indices[:-1], dtype=torch.long).unsqueeze(0)
13
+ tgt = torch.tensor(indices[1:], dtype=torch.long).unsqueeze(0)
14
+ optimizer.zero_grad()
15
+ output = model(src)
16
+ loss = criterion(output.view(-1, VOCAB_SIZE), tgt.view(-1))
17
+ loss.backward()
18
+ optimizer.step()
19
+ total_loss += loss.item()
20
+ print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}")
21
+ torch.save(model.state_dict(), "model.pt")