Spaces:
Sleeping
Sleeping
| """ | |
| The GROVER pretrain function. | |
| """ | |
| import os | |
| import time | |
| from argparse import Namespace | |
| from logging import Logger | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from grover.data.dist_sampler import DistributedSampler | |
| from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset | |
| from grover.data.torchvocab import MolVocab | |
| from grover.model.models import GROVEREmbedding | |
| from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw | |
| from grover.util.nn_utils import param_count | |
| from grover.util.utils import build_optimizer, build_lr_scheduler | |
| from task.grovertrainer import GROVERTrainer | |
| def pretrain_model(args: Namespace, logger: Logger = None): | |
| """ | |
| The entrey of pretrain. | |
| :param args: the argument. | |
| :param logger: the logger. | |
| :return: | |
| """ | |
| # avoid auto optimized import by pycharm. | |
| a = MolVocab | |
| s_time = time.time() | |
| run_training(args=args, logger=logger) | |
| e_time = time.time() | |
| print("Total Time: %.3f" % (e_time - s_time)) | |
| def pre_load_data(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0): | |
| """ | |
| Pre-load data at the beginning of each epoch. | |
| :param dataset: the training dataset. | |
| :param rank: the rank of the current worker. | |
| :param num_replicas: the replicas. | |
| :param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be | |
| loaded. It implies the testing phase. (TODO: bad design here.) | |
| :param epoch: the epoch number. | |
| :return: | |
| """ | |
| mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False, | |
| sample_per_file=sample_per_file) | |
| mock_sampler.set_epoch(epoch) | |
| pre_indices = mock_sampler.get_indices() | |
| for i in pre_indices: | |
| dataset.load_data(i) | |
| def run_training(args, logger): | |
| """ | |
| Run the pretrain task. | |
| :param args: | |
| :param logger: | |
| :return: | |
| """ | |
| # initalize the logger. | |
| if logger is not None: | |
| debug, _ = logger.debug, logger.info | |
| else: | |
| debug = print | |
| # initialize the horovod library | |
| if args.enable_multi_gpu: | |
| mgw.init() | |
| # binding training to GPUs. | |
| master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True | |
| # pin GPU to local rank. By default, we use gpu:0 for training. | |
| local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0 | |
| with_cuda = args.cuda | |
| if with_cuda: | |
| torch.cuda.set_device(local_gpu_idx) | |
| # get rank an number of workers | |
| rank = mgw.rank() if args.enable_multi_gpu else 0 | |
| num_replicas = mgw.size() if args.enable_multi_gpu else 1 | |
| # print("Rank: %d Rep: %d" % (rank, num_replicas)) | |
| # load file paths of the data. | |
| if master_worker: | |
| print(args) | |
| if args.enable_multi_gpu: | |
| debug("Total workers: %d" % (mgw.size())) | |
| debug('Loading data') | |
| data, sample_per_file = get_data(data_path=args.data_path) | |
| # data splitting | |
| if master_worker: | |
| debug(f'Splitting data with seed 0.') | |
| train_data, test_data, _ = split_data(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger) | |
| # Here the true train data size is the train_data divided by #GPUs | |
| if args.enable_multi_gpu: | |
| args.train_data_size = len(train_data) // mgw.size() | |
| else: | |
| args.train_data_size = len(train_data) | |
| if master_worker: | |
| debug(f'Total size = {len(data):,} | ' | |
| f'train size = {len(train_data):,} | val size = {len(test_data):,}') | |
| # load atom and bond vocabulary and the semantic motif labels. | |
| atom_vocab = MolVocab.load_vocab(args.atom_vocab_path) | |
| bond_vocab = MolVocab.load_vocab(args.bond_vocab_path) | |
| atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab) | |
| # Hard coding here, since we haven't load any data yet! | |
| fg_size = 85 | |
| shared_dict = {} | |
| mol_collator = GroverCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args) | |
| if master_worker: | |
| debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size, | |
| bond_vocab_size, fg_size)) | |
| # Define the distributed sampler. If using the single card, the sampler will be None. | |
| train_sampler = None | |
| test_sampler = None | |
| shuffle = True | |
| if args.enable_multi_gpu: | |
| # If not shuffle, the performance may decayed. | |
| train_sampler = DistributedSampler( | |
| train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file) | |
| # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by | |
| # rank. (TODO: bad design here.) | |
| test_sampler = DistributedSampler( | |
| test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False) | |
| train_sampler.set_epoch(args.epochs) | |
| test_sampler.set_epoch(1) | |
| # if we enables multi_gpu training. shuffle should be disabled. | |
| shuffle = False | |
| # Pre load data. (Maybe unnecessary. ) | |
| pre_load_data(train_data, rank, num_replicas, sample_per_file) | |
| pre_load_data(test_data, rank, num_replicas) | |
| if master_worker: | |
| # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints()) | |
| print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints()) | |
| # Build dataloader | |
| train_data_dl = DataLoader(train_data, | |
| batch_size=args.batch_size, | |
| shuffle=shuffle, | |
| num_workers=12, | |
| sampler=train_sampler, | |
| collate_fn=mol_collator) | |
| test_data_dl = DataLoader(test_data, | |
| batch_size=args.batch_size, | |
| shuffle=shuffle, | |
| num_workers=10, | |
| sampler=test_sampler, | |
| collate_fn=mol_collator) | |
| # Build the embedding model. | |
| grover_model = GROVEREmbedding(args) | |
| # Build the trainer. | |
| trainer = GROVERTrainer(args=args, | |
| embedding_model=grover_model, | |
| atom_vocab_size=atom_vocab_size, | |
| bond_vocab_size=bond_vocab_size, | |
| fg_szie=fg_size, | |
| train_dataloader=train_data_dl, | |
| test_dataloader=test_data_dl, | |
| optimizer_builder=build_optimizer, | |
| scheduler_builder=build_lr_scheduler, | |
| logger=logger, | |
| with_cuda=with_cuda, | |
| enable_multi_gpu=args.enable_multi_gpu) | |
| # Restore the interrupted training. | |
| model_dir = os.path.join(args.save_dir, "model") | |
| resume_from_epoch = 0 | |
| resume_scheduler_step = 0 | |
| if master_worker: | |
| resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir) | |
| if args.enable_multi_gpu: | |
| resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item() | |
| resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step), | |
| root_rank=0, name="resume_scheduler_step").item() | |
| trainer.scheduler.current_step = resume_scheduler_step | |
| print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step)) | |
| trainer.broadcast_parameters() | |
| # Print model details. | |
| if master_worker: | |
| # Change order here. | |
| print(grover_model) | |
| print("Total parameters: %d" % param_count(trainer.grover)) | |
| # Perform training. | |
| for epoch in range(resume_from_epoch + 1, args.epochs): | |
| s_time = time.time() | |
| # Data pre-loading. | |
| if args.enable_multi_gpu: | |
| train_sampler.set_epoch(epoch) | |
| train_data.clean_cache() | |
| idxs = train_sampler.get_indices() | |
| for local_gpu_idx in idxs: | |
| train_data.load_data(local_gpu_idx) | |
| d_time = time.time() - s_time | |
| # perform training and validation. | |
| s_time = time.time() | |
| _, train_loss, _ = trainer.train(epoch) | |
| t_time = time.time() - s_time | |
| s_time = time.time() | |
| _, val_loss, detailed_loss_val = trainer.test(epoch) | |
| val_av_loss, val_bv_loss, val_fg_loss, _, _, _ = detailed_loss_val | |
| v_time = time.time() - s_time | |
| # print information. | |
| if master_worker: | |
| print('Epoch: {:04d}'.format(epoch), | |
| 'loss_train: {:.6f}'.format(train_loss), | |
| 'loss_val: {:.6f}'.format(val_loss), | |
| 'loss_val_av: {:.6f}'.format(val_av_loss), | |
| 'loss_val_bv: {:.6f}'.format(val_bv_loss), | |
| 'loss_val_fg: {:.6f}'.format(val_fg_loss), | |
| 'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]), | |
| 't_time: {:.4f}s'.format(t_time), | |
| 'v_time: {:.4f}s'.format(v_time), | |
| 'd_time: {:.4f}s'.format(d_time), flush=True) | |
| if epoch % args.save_interval == 0: | |
| trainer.save(epoch, model_dir) | |
| trainer.save_tmp(epoch, model_dir, rank) | |
| # Only save final version. | |
| if master_worker: | |
| trainer.save(args.epochs, model_dir, "") | |