hiitsmeme
added grover code, hf api files
f986893
"""
The predict function using the finetuned model to make the prediction. .
"""
from argparse import Namespace
from typing import List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from grover.data import MolCollator
from grover.data import MoleculeDataset
from grover.data import StandardScaler
from grover.util.utils import get_data, get_data_from_smiles, create_logger, load_args, get_task_names, tqdm, \
load_checkpoint, load_scalars
def predict(model: nn.Module,
data: MoleculeDataset,
args: Namespace,
batch_size: int,
loss_func,
logger,
shared_dict,
scaler: StandardScaler = None
) -> List[List[float]]:
"""
Makes predictions on a dataset using an ensemble of models.
:param model: A model.
:param data: A MoleculeDataset.
:param batch_size: Batch size.
:param scaler: A StandardScaler object fit on the training targets.
:return: A list of lists of predictions. The outer list is examples
while the inner list is tasks.
"""
# debug = logger.debug if logger is not None else print
model.eval()
args.bond_drop_rate = 0
preds = []
# num_iters, iter_step = len(data), batch_size
loss_sum, iter_count = 0, 0
mol_collator = MolCollator(args=args, shared_dict=shared_dict)
# mol_dataset = MoleculeDataset(data)
num_workers = 4
mol_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers,
collate_fn=mol_collator)
for _, item in enumerate(mol_loader):
_, batch, features_batch, mask, targets = item
class_weights = torch.ones(targets.shape)
if next(model.parameters()).is_cuda:
targets = targets.cuda()
mask = mask.cuda()
class_weights = class_weights.cuda()
with torch.no_grad():
batch_preds = model(batch, features_batch)
iter_count += 1
if args.fingerprint:
preds.extend(batch_preds.data.cpu().numpy())
continue
if loss_func is not None:
loss = loss_func(batch_preds, targets) * class_weights * mask
loss = loss.sum() / mask.sum()
loss_sum += loss.item()
# Collect vectors
batch_preds = batch_preds.data.cpu().numpy().tolist()
if scaler is not None:
batch_preds = scaler.inverse_transform(batch_preds)
preds.extend(batch_preds)
loss_avg = loss_sum / iter_count
return preds, loss_avg
def make_predictions(args: Namespace, newest_train_args=None, smiles: List[str] = None):
"""
Makes predictions. If smiles is provided, makes predictions on smiles.
Otherwise makes predictions on args.test_data.
:param args: Arguments.
:param smiles: Smiles to make predictions on.
:return: A list of lists of target predictions.
"""
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
print('Loading training args')
path = args.checkpoint_paths[0]
scaler, features_scaler = load_scalars(path)
train_args = load_args(path)
# Update args with training arguments saved in checkpoint
for key, value in vars(train_args).items():
if not hasattr(args, key):
setattr(args, key, value)
# update args with newest training args
if newest_train_args is not None:
for key, value in vars(newest_train_args).items():
if not hasattr(args, key):
setattr(args, key, value)
# deal with multiprocess problem
args.debug = True
logger = create_logger('predict', quiet=False)
print('Loading data')
args.task_names = get_task_names(args.data_path)
if smiles is not None:
test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
else:
test_data = get_data(path=args.data_path, args=args,
use_compound_names=args.use_compound_names, skip_invalid_smiles=False)
args.num_tasks = test_data.num_tasks()
args.features_size = test_data.features_size()
print('Validating SMILES')
valid_indices = [i for i in range(len(test_data))]
full_data = test_data
# test_data = MoleculeDataset([test_data[i] for i in valid_indices])
test_data_list = []
for i in valid_indices:
test_data_list.append(test_data[i])
test_data = MoleculeDataset(test_data_list)
# Edge case if empty list of smiles is provided
if len(test_data) == 0:
return [None] * len(full_data)
print(f'Test size = {len(test_data):,}')
# Normalize features
if hasattr(train_args, 'features_scaling'):
if train_args.features_scaling:
test_data.normalize_features(features_scaler)
# Predict with each model individually and sum predictions
if hasattr(args, 'num_tasks'):
sum_preds = np.zeros((len(test_data), args.num_tasks))
print(f'Predicting...')
shared_dict = {}
# loss_func = torch.nn.BCEWithLogitsLoss()
count = 0
for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
# Load model
model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
model_preds, _ = predict(
model=model,
data=test_data,
batch_size=args.batch_size,
scaler=scaler,
shared_dict=shared_dict,
args=args,
logger=logger,
loss_func=None
)
if args.fingerprint:
return model_preds
sum_preds += np.array(model_preds, dtype=float)
count += 1
# Ensemble predictions
avg_preds = sum_preds / len(args.checkpoint_paths)
# Save predictions
assert len(test_data) == len(avg_preds)
# Put Nones for invalid smiles
args.valid_indices = valid_indices
avg_preds = np.array(avg_preds)
test_smiles = full_data.smiles()
return avg_preds, test_smiles
def write_prediction(avg_preds, test_smiles, args):
"""
write prediction to disk
:param avg_preds: prediction value
:param test_smiles: input smiles
:param args: Arguments
"""
if args.dataset_type == 'multiclass':
avg_preds = np.argmax(avg_preds, -1)
full_preds = [[None]] * len(test_smiles)
for i, si in enumerate(args.valid_indices):
full_preds[si] = avg_preds[i]
result = pd.DataFrame(data=full_preds, index=test_smiles, columns=args.task_names)
result.to_csv(args.output_path)
print(f'Saving predictions to {args.output_path}')
def evaluate_predictions(preds: List[List[float]],
targets: List[List[float]],
num_tasks: int,
metric_func,
dataset_type: str,
logger = None) -> List[float]:
"""
Evaluates predictions using a metric function and filtering out invalid targets.
:param preds: A list of lists of shape (data_size, num_tasks) with model predictions.
:param targets: A list of lists of shape (data_size, num_tasks) with targets.
:param num_tasks: Number of tasks.
:param metric_func: Metric function which takes in a list of targets and a list of predictions.
:param dataset_type: Dataset type.
:param logger: Logger.
:return: A list with the score for each task based on `metric_func`.
"""
if dataset_type == 'multiclass':
results = metric_func(np.argmax(preds, -1), [i[0] for i in targets])
return [results]
# info = logger.info if logger is not None else print
if len(preds) == 0:
return [float('nan')] * num_tasks
# Filter out empty targets
# valid_preds and valid_targets have shape (num_tasks, data_size)
valid_preds = [[] for _ in range(num_tasks)]
valid_targets = [[] for _ in range(num_tasks)]
for i in range(num_tasks):
for j in range(len(preds)):
if targets[j][i] is not None: # Skip those without targets
valid_preds[i].append(preds[j][i])
valid_targets[i].append(targets[j][i])
# Compute metric
results = []
for i in range(num_tasks):
# # Skip if all targets or preds are identical, otherwise we'll crash during classification
if dataset_type == 'classification':
nan = False
if all(target == 0 for target in valid_targets[i]) or all(target == 1 for target in valid_targets[i]):
nan = True
# info('Warning: Found a task with targets all 0s or all 1s')
if all(pred == 0 for pred in valid_preds[i]) or all(pred == 1 for pred in valid_preds[i]):
nan = True
# info('Warning: Found a task with predictions all 0s or all 1s')
if nan:
results.append(float('nan'))
continue
if len(valid_targets[i]) == 0:
continue
results.append(metric_func(valid_targets[i], valid_preds[i]))
return results
def evaluate(model: nn.Module,
data: MoleculeDataset,
num_tasks: int,
metric_func,
loss_func,
batch_size: int,
dataset_type: str,
args: Namespace,
shared_dict,
scaler: StandardScaler = None,
logger = None) -> List[float]:
"""
Evaluates an ensemble of models on a dataset.
:param model: A model.
:param data: A MoleculeDataset.
:param num_tasks: Number of tasks.
:param metric_func: Metric function which takes in a list of targets and a list of predictions.
:param batch_size: Batch size.
:param dataset_type: Dataset type.
:param scaler: A StandardScaler object fit on the training targets.
:param logger: Logger.
:return: A list with the score for each task based on `metric_func`.
"""
preds, loss_avg = predict(
model=model,
data=data,
loss_func=loss_func,
batch_size=batch_size,
scaler=scaler,
shared_dict=shared_dict,
logger=logger,
args=args
)
targets = data.targets()
if scaler is not None:
targets = scaler.inverse_transform(targets)
results = evaluate_predictions(
preds=preds,
targets=targets,
num_tasks=num_tasks,
metric_func=metric_func,
dataset_type=dataset_type,
logger=logger
)
return results, loss_avg