File size: 558 Bytes
b25d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np

from src.eval import compute_roc_auc_from_csv
from src.commands import predict_from_csv

data_path = "./tox21/tox21_test_clean.csv"
features_path = data_path.replace(".csv", ".npz")
checkpoint_dir = "checkpoints"
output_path = "predictions/test_set_preds_best.csv"

predict_from_csv(data_path, features_path, checkpoint_dir, output_path)

valid_mask = np.load("./tox21/valid_mask_test.npy")
auc_array, mean_auc = compute_roc_auc_from_csv(output_path, "./tox21/tox21_test.csv", valid_mask)

print(auc_array)
print(mean_auc)