File size: 2,790 Bytes
b25d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1710048
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import csv
import subprocess
import os

from src.preprocess import create_clean_smiles
from src.commands import generate_features, predict_from_csv



def predict(smiles_list):
    """
    Predict toxicity targets for a list of SMILES strings.

    Args:
        smiles_list (list[str]): SMILES strings

    Returns:
        dict: {smiles: {target_name: prediction_prob}}
    """
    data_path = "tox21/predict_smiles.csv"
    features_path = data_path.replace(".csv", ".npz")
    checkpoint_dir = "checkpoints"
    output_path = "predictions/smiles_predictions.csv"

    # clean smiles
    clean_smiles, valid_mask = create_clean_smiles(smiles_list)

    # Mapping from cleaned to original for valid ones
    originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]

    # sanity check (optional but nice to have)
    if len(originals_valid) != len(clean_smiles):
        raise ValueError(
            f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
        )

    # map cleaned → original
    cleaned_to_original = dict(zip(clean_smiles, originals_valid))

    # tox21 targets
    TARGET_NAMES = [
            "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
        ]
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Received {len(smiles_list)} SMILES strings")

    # put smiles into csv
    with open(data_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["smiles"] + TARGET_NAMES)  # header
        for smi in clean_smiles:
            writer.writerow([smi] + [""] * len(TARGET_NAMES))

    # generate features
    generate_features(data_path, features_path)

    # predict
    predict_from_csv(data_path, features_path, checkpoint_dir, output_path)

    # create results dictionary from predictions
    predictions = {}
    with open(output_path, "r", newline="") as f:
        reader = csv.DictReader(f)
        rows = list(reader)

        # Identify the SMILES column even if it is unnamed
        fieldnames = reader.fieldnames
        smiles_col = fieldnames[0]  # first column, even if empty string

        target_names = fieldnames[1:]  # all columns except first

        for row in rows:
            clean_smi = row[smiles_col]  
            original_smi = cleaned_to_original.get(clean_smi, clean_smi)

            pred_dict = {t: float(row[t]) for t in target_names}
            predictions[original_smi] = pred_dict

    # Add placeholder predictions for invalid SMILES
    for smi, is_valid in zip(smiles_list, valid_mask):
        if not is_valid:
            predictions[smi] = {t: 0.5 for t in TARGET_NAMES}
        

    return predictions