import torch import csv import subprocess import os from src.preprocess import create_clean_smiles from src.commands import generate_features, predict_from_csv def predict(smiles_list): """ Predict toxicity targets for a list of SMILES strings. Args: smiles_list (list[str]): SMILES strings Returns: dict: {smiles: {target_name: prediction_prob}} """ data_path = "tox21/predict_smiles.csv" features_path = data_path.replace(".csv", ".npz") checkpoint_dir = "checkpoints" output_path = "predictions/smiles_predictions.csv" # clean smiles clean_smiles, valid_mask = create_clean_smiles(smiles_list) # Mapping from cleaned to original for valid ones originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok] # sanity check (optional but nice to have) if len(originals_valid) != len(clean_smiles): raise ValueError( f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES" ) # map cleaned → original cleaned_to_original = dict(zip(clean_smiles, originals_valid)) # tox21 targets TARGET_NAMES = [ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53" ] DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Received {len(smiles_list)} SMILES strings") # put smiles into csv with open(data_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["smiles"] + TARGET_NAMES) # header for smi in clean_smiles: writer.writerow([smi] + [""] * len(TARGET_NAMES)) # generate features generate_features(data_path, features_path) # predict predict_from_csv(data_path, features_path, checkpoint_dir, output_path) # create results dictionary from predictions predictions = {} with open(output_path, "r", newline="") as f: reader = csv.DictReader(f) rows = list(reader) # Identify the SMILES column even if it is unnamed fieldnames = reader.fieldnames smiles_col = fieldnames[0] # first column, even if empty string target_names = fieldnames[1:] # all columns except first for row in rows: clean_smi = row[smiles_col] original_smi = cleaned_to_original.get(clean_smi, clean_smi) pred_dict = {t: float(row[t]) for t in target_names} predictions[original_smi] = pred_dict # Add placeholder predictions for invalid SMILES for smi, is_valid in zip(smiles_list, valid_mask): if not is_valid: predictions[smi] = {t: 0.5 for t in TARGET_NAMES} return predictions