Spaces:
Sleeping
Sleeping
| import torch | |
| import csv | |
| import subprocess | |
| import os | |
| from src.preprocess import create_clean_smiles | |
| from src.commands import generate_features, predict_from_csv | |
| def predict(smiles_list): | |
| """ | |
| Predict toxicity targets for a list of SMILES strings. | |
| Args: | |
| smiles_list (list[str]): SMILES strings | |
| Returns: | |
| dict: {smiles: {target_name: prediction_prob}} | |
| """ | |
| data_path = "tox21/predict_smiles.csv" | |
| features_path = data_path.replace(".csv", ".npz") | |
| checkpoint_dir = "checkpoints" | |
| output_path = "predictions/smiles_predictions.csv" | |
| # clean smiles | |
| clean_smiles, valid_mask = create_clean_smiles(smiles_list) | |
| # Mapping from cleaned to original for valid ones | |
| originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok] | |
| # sanity check (optional but nice to have) | |
| if len(originals_valid) != len(clean_smiles): | |
| raise ValueError( | |
| f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES" | |
| ) | |
| # map cleaned → original | |
| cleaned_to_original = dict(zip(clean_smiles, originals_valid)) | |
| # tox21 targets | |
| TARGET_NAMES = [ | |
| "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53" | |
| ] | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # put smiles into csv | |
| with open(data_path, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["smiles"] + TARGET_NAMES) # header | |
| for smi in clean_smiles: | |
| writer.writerow([smi] + [""] * len(TARGET_NAMES)) | |
| # generate features | |
| generate_features(data_path, features_path) | |
| # predict | |
| predict_from_csv(data_path, features_path, checkpoint_dir, output_path) | |
| # create results dictionary from predictions | |
| predictions = {} | |
| with open(output_path, "r", newline="") as f: | |
| reader = csv.DictReader(f) | |
| rows = list(reader) | |
| # Identify the SMILES column even if it is unnamed | |
| fieldnames = reader.fieldnames | |
| smiles_col = fieldnames[0] # first column, even if empty string | |
| target_names = fieldnames[1:] # all columns except first | |
| for row in rows: | |
| clean_smi = row[smiles_col] | |
| original_smi = cleaned_to_original.get(clean_smi, clean_smi) | |
| pred_dict = {t: float(row[t]) for t in target_names} | |
| predictions[original_smi] = pred_dict | |
| # Add placeholder predictions for invalid SMILES | |
| for smi, is_valid in zip(smiles_list, valid_mask): | |
| if not is_valid: | |
| predictions[smi] = {t: 0.5 for t in TARGET_NAMES} | |
| return predictions |