""" The cross validation function for finetuning. This implementation is adapted from https://github.com/chemprop/chemprop/blob/master/chemprop/train/cross_validate.py """ import os import time from argparse import Namespace from logging import Logger from typing import Tuple import numpy as np from grover.util.utils import get_task_names from grover.util.utils import makedirs from task.run_evaluation import run_evaluation from task.train import run_training def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]: """ k-fold cross validation. :return: A tuple of mean_score and std_score. """ info = logger.info if logger is not None else print # Initialize relevant variables init_seed = args.seed save_dir = args.save_dir task_names = get_task_names(args.data_path) # Run training with different random seeds for each fold all_scores = [] time_start = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) for fold_num in range(args.num_folds): info(f'Fold {fold_num}') args.seed = init_seed + fold_num args.save_dir = os.path.join(save_dir, f'fold_{fold_num}') makedirs(args.save_dir) if args.parser_name == "finetune": model_scores = run_training(args, time_start, logger) else: model_scores = run_evaluation(args, logger) all_scores.append(model_scores) all_scores = np.array(all_scores) # Report scores for each fold info(f'{args.num_folds}-fold cross validation') for fold_num, scores in enumerate(all_scores): info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}') if args.show_individual_scores: for task_name, score in zip(task_names, scores): info(f'Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}') # Report scores across models avg_scores = np.nanmean(all_scores, axis=1) # average score for each model across tasks mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores) info(f'overall_{args.split_type}_test_{args.metric}={mean_score:.6f}') info(f'std={std_score:.6f}') if args.show_individual_scores: for task_num, task_name in enumerate(task_names): info(f'Overall test {task_name} {args.metric} = ' f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}') return mean_score, std_score