hiitsmeme commited on
Commit
62a699b
·
1 Parent(s): 1710048
Files changed (2) hide show
  1. main.py +4 -4
  2. src/commands.py +1 -0
main.py CHANGED
@@ -6,10 +6,10 @@ from rdkit import RDLogger
6
 
7
  from grover.util.parsing import parse_args, get_newest_train_args
8
  from grover.util.utils import create_logger
9
- from grover.task.cross_validate import cross_validate
10
- from grover.task.fingerprint import generate_fingerprints
11
- from grover.task.predict import make_predictions, write_prediction
12
- from grover.task.pretrain import pretrain_model
13
  from grover.data.torchvocab import MolVocab
14
 
15
 
 
6
 
7
  from grover.util.parsing import parse_args, get_newest_train_args
8
  from grover.util.utils import create_logger
9
+ from task.cross_validate import cross_validate
10
+ from task.fingerprint import generate_fingerprints
11
+ from task.predict import make_predictions, write_prediction
12
+ from task.pretrain import pretrain_model
13
  from grover.data.torchvocab import MolVocab
14
 
15
 
src/commands.py CHANGED
@@ -25,6 +25,7 @@ def predict_from_csv(data_path, features_path, checkpoint_dir, output_path):
25
  f"--no_features_scaling "
26
  f"--output {output_path}"
27
  )
 
28
 
29
 
30
  def finetune(train_path, val_path, train_features_path, val_features_path,
 
25
  f"--no_features_scaling "
26
  f"--output {output_path}"
27
  )
28
+ run(predict_cmd)
29
 
30
 
31
  def finetune(train_path, val_path, train_features_path, val_features_path,