tox21_grover_classifier / src /preprocess.py
hiitsmeme
initial commit
b25d2b6
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
import numpy as np
import pandas as pd
from datasets import load_dataset
from typing import List, Optional
TOX21_TASKS = [
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]
def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
"""
Clean and canonicalize SMILES strings while staying in SMILES space.
Returns (list of cleaned SMILES, mask of valid SMILES).
"""
clean_smis = []
valid_mask = []
cleaner = rdMolStandardize.CleanupParameters()
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
for smi in smiles_list:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
valid_mask.append(False)
continue
# Cleanup and tautomer canonicalization
mol = rdMolStandardize.Cleanup(mol, cleaner)
mol = tautomer_enumerator.Canonicalize(mol)
# -------- Charge filtering (prevents GROVER crash) --------
allowed_charges = {-1, 0, 1}
bad_charge = False
for atom in mol.GetAtoms():
if atom.GetFormalCharge() not in allowed_charges:
bad_charge = True
break
if bad_charge:
valid_mask.append(False)
continue
# ----------------------------------------------------------
# Canonical SMILES output
clean_smi = Chem.MolToSmiles(mol, canonical=True)
clean_smis.append(clean_smi)
valid_mask.append(True)
except Exception as e:
print(f"Failed to clean {smi}: {e}")
valid_mask.append(False)
return clean_smis, np.array(valid_mask, dtype=bool)
def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: Optional[List[str]] = None):
"""
Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.
"""
# Load dataset
df = pd.read_csv(input_csv)
if smiles_col not in df.columns:
raise ValueError(f"'{smiles_col}' column not found in CSV.")
# Infer target columns if not specified
if target_cols is None:
target_cols = [c for c in df.columns if c != smiles_col]
keep_cols = target_cols
# Validate target columns
missing_targets = [c for c in target_cols if c not in df.columns]
if missing_targets:
raise ValueError(f"Missing target columns in CSV: {missing_targets}")
# Clean SMILES
clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())
# Keep only valid rows
df_clean = df.loc[valid_mask, keep_cols].copy()
df_clean.insert(0, smiles_col, clean_smis) # smiles first column
# Save cleaned dataset
df_clean.to_csv(output_csv, index=False)
print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
return valid_mask