File size: 3,217 Bytes
b25d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
import numpy as np
import pandas as pd
from datasets import load_dataset
from typing import List, Optional

TOX21_TASKS = [
    "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]

def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
    """

    Clean and canonicalize SMILES strings while staying in SMILES space.

    Returns (list of cleaned SMILES, mask of valid SMILES).

    """
    clean_smis = []
    valid_mask = []

    cleaner = rdMolStandardize.CleanupParameters()
    tautomer_enumerator = rdMolStandardize.TautomerEnumerator()

    for smi in smiles_list:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                valid_mask.append(False)
                continue

            # Cleanup and tautomer canonicalization
            mol = rdMolStandardize.Cleanup(mol, cleaner)
            mol = tautomer_enumerator.Canonicalize(mol)

            # -------- Charge filtering (prevents GROVER crash) --------
            allowed_charges = {-1, 0, 1}
            bad_charge = False
            for atom in mol.GetAtoms():
                if atom.GetFormalCharge() not in allowed_charges:
                    bad_charge = True
                    break

            if bad_charge:
                valid_mask.append(False)
                continue
            # ----------------------------------------------------------

            # Canonical SMILES output
            clean_smi = Chem.MolToSmiles(mol, canonical=True)
            clean_smis.append(clean_smi)
            valid_mask.append(True)

        except Exception as e:
            print(f"Failed to clean {smi}: {e}")
            valid_mask.append(False)

    return clean_smis, np.array(valid_mask, dtype=bool)


def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: Optional[List[str]] = None):
    """

    Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.

    """
    # Load dataset
    df = pd.read_csv(input_csv)
    if smiles_col not in df.columns:
        raise ValueError(f"'{smiles_col}' column not found in CSV.")

    # Infer target columns if not specified
    if target_cols is None:
        target_cols = [c for c in df.columns if c != smiles_col]
    keep_cols = target_cols
    # Validate target columns
    missing_targets = [c for c in target_cols if c not in df.columns]
    if missing_targets:
        raise ValueError(f"Missing target columns in CSV: {missing_targets}")

    # Clean SMILES
    clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())

    # Keep only valid rows
    df_clean = df.loc[valid_mask, keep_cols].copy()
    df_clean.insert(0, smiles_col, clean_smis)  # smiles first column

    # Save cleaned dataset
    df_clean.to_csv(output_csv, index=False)
    print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
    return valid_mask