hiitsmeme
small change
1710048
import torch
import csv
import subprocess
import os
from src.preprocess import create_clean_smiles
from src.commands import generate_features, predict_from_csv
def predict(smiles_list):
"""
Predict toxicity targets for a list of SMILES strings.
Args:
smiles_list (list[str]): SMILES strings
Returns:
dict: {smiles: {target_name: prediction_prob}}
"""
data_path = "tox21/predict_smiles.csv"
features_path = data_path.replace(".csv", ".npz")
checkpoint_dir = "checkpoints"
output_path = "predictions/smiles_predictions.csv"
# clean smiles
clean_smiles, valid_mask = create_clean_smiles(smiles_list)
# Mapping from cleaned to original for valid ones
originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
# sanity check (optional but nice to have)
if len(originals_valid) != len(clean_smiles):
raise ValueError(
f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
)
# map cleaned → original
cleaned_to_original = dict(zip(clean_smiles, originals_valid))
# tox21 targets
TARGET_NAMES = [
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Received {len(smiles_list)} SMILES strings")
# put smiles into csv
with open(data_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["smiles"] + TARGET_NAMES) # header
for smi in clean_smiles:
writer.writerow([smi] + [""] * len(TARGET_NAMES))
# generate features
generate_features(data_path, features_path)
# predict
predict_from_csv(data_path, features_path, checkpoint_dir, output_path)
# create results dictionary from predictions
predictions = {}
with open(output_path, "r", newline="") as f:
reader = csv.DictReader(f)
rows = list(reader)
# Identify the SMILES column even if it is unnamed
fieldnames = reader.fieldnames
smiles_col = fieldnames[0] # first column, even if empty string
target_names = fieldnames[1:] # all columns except first
for row in rows:
clean_smi = row[smiles_col]
original_smi = cleaned_to_original.get(clean_smi, clean_smi)
pred_dict = {t: float(row[t]) for t in target_names}
predictions[original_smi] = pred_dict
# Add placeholder predictions for invalid SMILES
for smi, is_valid in zip(smiles_list, valid_mask):
if not is_valid:
predictions[smi] = {t: 0.5 for t in TARGET_NAMES}
return predictions