hiitsmeme
initial commit
b25d2b6
import numpy as np
from src.eval import compute_roc_auc_from_csv
from src.commands import predict_from_csv
data_path = "./tox21/tox21_test_clean.csv"
features_path = data_path.replace(".csv", ".npz")
checkpoint_dir = "checkpoints"
output_path = "predictions/test_set_preds_best.csv"
predict_from_csv(data_path, features_path, checkpoint_dir, output_path)
valid_mask = np.load("./tox21/valid_mask_test.npy")
auc_array, mean_auc = compute_roc_auc_from_csv(output_path, "./tox21/tox21_test.csv", valid_mask)
print(auc_array)
print(mean_auc)