""" The predict function using the finetuned model to make the prediction. . """ from argparse import Namespace from typing import List import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader from grover.data import MolCollator from grover.data import MoleculeDataset from grover.data import StandardScaler from grover.util.utils import get_data, get_data_from_smiles, create_logger, load_args, get_task_names, tqdm, \ load_checkpoint, load_scalars def predict(model: nn.Module, data: MoleculeDataset, args: Namespace, batch_size: int, loss_func, logger, shared_dict, scaler: StandardScaler = None ) -> List[List[float]]: """ Makes predictions on a dataset using an ensemble of models. :param model: A model. :param data: A MoleculeDataset. :param batch_size: Batch size. :param scaler: A StandardScaler object fit on the training targets. :return: A list of lists of predictions. The outer list is examples while the inner list is tasks. """ # debug = logger.debug if logger is not None else print model.eval() args.bond_drop_rate = 0 preds = [] # num_iters, iter_step = len(data), batch_size loss_sum, iter_count = 0, 0 mol_collator = MolCollator(args=args, shared_dict=shared_dict) # mol_dataset = MoleculeDataset(data) num_workers = 4 mol_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=mol_collator) for _, item in enumerate(mol_loader): _, batch, features_batch, mask, targets = item class_weights = torch.ones(targets.shape) if next(model.parameters()).is_cuda: targets = targets.cuda() mask = mask.cuda() class_weights = class_weights.cuda() with torch.no_grad(): batch_preds = model(batch, features_batch) iter_count += 1 if args.fingerprint: preds.extend(batch_preds.data.cpu().numpy()) continue if loss_func is not None: loss = loss_func(batch_preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() # Collect vectors batch_preds = batch_preds.data.cpu().numpy().tolist() if scaler is not None: batch_preds = scaler.inverse_transform(batch_preds) preds.extend(batch_preds) loss_avg = loss_sum / iter_count return preds, loss_avg def make_predictions(args: Namespace, newest_train_args=None, smiles: List[str] = None): """ Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. :param args: Arguments. :param smiles: Smiles to make predictions on. :return: A list of lists of target predictions. """ if args.gpu is not None: torch.cuda.set_device(args.gpu) print('Loading training args') path = args.checkpoint_paths[0] scaler, features_scaler = load_scalars(path) train_args = load_args(path) # Update args with training arguments saved in checkpoint for key, value in vars(train_args).items(): if not hasattr(args, key): setattr(args, key, value) # update args with newest training args if newest_train_args is not None: for key, value in vars(newest_train_args).items(): if not hasattr(args, key): setattr(args, key, value) # deal with multiprocess problem args.debug = True logger = create_logger('predict', quiet=False) print('Loading data') args.task_names = get_task_names(args.data_path) if smiles is not None: test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) else: test_data = get_data(path=args.data_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) args.num_tasks = test_data.num_tasks() args.features_size = test_data.features_size() print('Validating SMILES') valid_indices = [i for i in range(len(test_data))] full_data = test_data # test_data = MoleculeDataset([test_data[i] for i in valid_indices]) test_data_list = [] for i in valid_indices: test_data_list.append(test_data[i]) test_data = MoleculeDataset(test_data_list) # Edge case if empty list of smiles is provided if len(test_data) == 0: return [None] * len(full_data) print(f'Test size = {len(test_data):,}') # Normalize features if hasattr(train_args, 'features_scaling'): if train_args.features_scaling: test_data.normalize_features(features_scaler) # Predict with each model individually and sum predictions if hasattr(args, 'num_tasks'): sum_preds = np.zeros((len(test_data), args.num_tasks)) print(f'Predicting...') shared_dict = {} # loss_func = torch.nn.BCEWithLogitsLoss() count = 0 for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)): # Load model model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger) model_preds, _ = predict( model=model, data=test_data, batch_size=args.batch_size, scaler=scaler, shared_dict=shared_dict, args=args, logger=logger, loss_func=None ) if args.fingerprint: return model_preds sum_preds += np.array(model_preds, dtype=float) count += 1 # Ensemble predictions avg_preds = sum_preds / len(args.checkpoint_paths) # Save predictions assert len(test_data) == len(avg_preds) # Put Nones for invalid smiles args.valid_indices = valid_indices avg_preds = np.array(avg_preds) test_smiles = full_data.smiles() return avg_preds, test_smiles def write_prediction(avg_preds, test_smiles, args): """ write prediction to disk :param avg_preds: prediction value :param test_smiles: input smiles :param args: Arguments """ if args.dataset_type == 'multiclass': avg_preds = np.argmax(avg_preds, -1) full_preds = [[None]] * len(test_smiles) for i, si in enumerate(args.valid_indices): full_preds[si] = avg_preds[i] result = pd.DataFrame(data=full_preds, index=test_smiles, columns=args.task_names) result.to_csv(args.output_path) print(f'Saving predictions to {args.output_path}') def evaluate_predictions(preds: List[List[float]], targets: List[List[float]], num_tasks: int, metric_func, dataset_type: str, logger = None) -> List[float]: """ Evaluates predictions using a metric function and filtering out invalid targets. :param preds: A list of lists of shape (data_size, num_tasks) with model predictions. :param targets: A list of lists of shape (data_size, num_tasks) with targets. :param num_tasks: Number of tasks. :param metric_func: Metric function which takes in a list of targets and a list of predictions. :param dataset_type: Dataset type. :param logger: Logger. :return: A list with the score for each task based on `metric_func`. """ if dataset_type == 'multiclass': results = metric_func(np.argmax(preds, -1), [i[0] for i in targets]) return [results] # info = logger.info if logger is not None else print if len(preds) == 0: return [float('nan')] * num_tasks # Filter out empty targets # valid_preds and valid_targets have shape (num_tasks, data_size) valid_preds = [[] for _ in range(num_tasks)] valid_targets = [[] for _ in range(num_tasks)] for i in range(num_tasks): for j in range(len(preds)): if targets[j][i] is not None: # Skip those without targets valid_preds[i].append(preds[j][i]) valid_targets[i].append(targets[j][i]) # Compute metric results = [] for i in range(num_tasks): # # Skip if all targets or preds are identical, otherwise we'll crash during classification if dataset_type == 'classification': nan = False if all(target == 0 for target in valid_targets[i]) or all(target == 1 for target in valid_targets[i]): nan = True # info('Warning: Found a task with targets all 0s or all 1s') if all(pred == 0 for pred in valid_preds[i]) or all(pred == 1 for pred in valid_preds[i]): nan = True # info('Warning: Found a task with predictions all 0s or all 1s') if nan: results.append(float('nan')) continue if len(valid_targets[i]) == 0: continue results.append(metric_func(valid_targets[i], valid_preds[i])) return results def evaluate(model: nn.Module, data: MoleculeDataset, num_tasks: int, metric_func, loss_func, batch_size: int, dataset_type: str, args: Namespace, shared_dict, scaler: StandardScaler = None, logger = None) -> List[float]: """ Evaluates an ensemble of models on a dataset. :param model: A model. :param data: A MoleculeDataset. :param num_tasks: Number of tasks. :param metric_func: Metric function which takes in a list of targets and a list of predictions. :param batch_size: Batch size. :param dataset_type: Dataset type. :param scaler: A StandardScaler object fit on the training targets. :param logger: Logger. :return: A list with the score for each task based on `metric_func`. """ preds, loss_avg = predict( model=model, data=data, loss_func=loss_func, batch_size=batch_size, scaler=scaler, shared_dict=shared_dict, logger=logger, args=args ) targets = data.targets() if scaler is not None: targets = scaler.inverse_transform(targets) results = evaluate_predictions( preds=preds, targets=targets, num_tasks=num_tasks, metric_func=metric_func, dataset_type=dataset_type, logger=logger ) return results, loss_avg