hiitsmeme commited on
Commit
b25d2b6
·
1 Parent(s): 9da50ef

initial commit

Browse files
Files changed (20) hide show
  1. .example.env +1 -0
  2. .gitignore +8 -0
  3. LICENSE +400 -0
  4. LICENSE_GROVER +53 -0
  5. MODEL_CARD.md +29 -0
  6. app.py +77 -0
  7. config/config.json +15 -0
  8. environment.yaml +229 -0
  9. evaluate.py +17 -0
  10. generate_features.py +9 -0
  11. grover +1 -0
  12. hp_search.py +101 -0
  13. main.py +55 -0
  14. predict.py +89 -0
  15. prepare_data.py +20 -0
  16. src/commands.py +54 -0
  17. src/eval.py +47 -0
  18. src/hp_search.py +20 -0
  19. src/preprocess.py +88 -0
  20. train.py +42 -0
.example.env ADDED
@@ -0,0 +1 @@
 
 
1
+ TOKEN=example_token
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ tox21/*
2
+ predictions/*
3
+ hp_search/trials/*
4
+ hp_search/logs/*
5
+ __pycache__
6
+ .env
7
+ pretrained/*
8
+ grover_base.pt
LICENSE ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+ =======================================================================
55
+ Creative Commons Attribution-NonCommercial 4.0 International Public
56
+ License
57
+ By exercising the Licensed Rights (defined below), You accept and agree
58
+ to be bound by the terms and conditions of this Creative Commons
59
+ Attribution-NonCommercial 4.0 International Public License ("Public
60
+ License"). To the extent this Public License may be interpreted as a
61
+ contract, You are granted the Licensed Rights in consideration of Your
62
+ acceptance of these terms and conditions, and the Licensor grants You
63
+ such rights in consideration of benefits the Licensor receives from
64
+ making the Licensed Material available under these terms and
65
+ conditions.
66
+ Section 1 -- Definitions.
67
+ a. Adapted Material means material subject to Copyright and Similar
68
+ Rights that is derived from or based upon the Licensed Material
69
+ and in which the Licensed Material is translated, altered,
70
+ arranged, transformed, or otherwise modified in a manner requiring
71
+ permission under the Copyright and Similar Rights held by the
72
+ Licensor. For purposes of this Public License, where the Licensed
73
+ Material is a musical work, performance, or sound recording,
74
+ Adapted Material is always produced where the Licensed Material is
75
+ synched in timed relation with a moving image.
76
+ b. Adapter's License means the license You apply to Your Copyright
77
+ and Similar Rights in Your contributions to Adapted Material in
78
+ accordance with the terms and conditions of this Public License.
79
+
80
+ c. Copyright and Similar Rights means copyright and/or similar rights
81
+ closely related to copyright including, without limitation,
82
+ performance, broadcast, sound recording, and Sui Generis Database
83
+ Rights, without regard to how the rights are labeled or
84
+ categorized. For purposes of this Public License, the rights
85
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
86
+ Rights.
87
+ d. Effective Technological Measures means those measures that, in the
88
+ absence of proper authority, may not be circumvented under laws
89
+ fulfilling obligations under Article 11 of the WIPO Copyright
90
+ Treaty adopted on December 20, 1996, and/or similar international
91
+ agreements.
92
+
93
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
94
+ any other exception or limitation to Copyright and Similar Rights
95
+ that applies to Your use of the Licensed Material.
96
+
97
+ f. Licensed Material means the artistic or literary work, database,
98
+ or other material to which the Licensor applied this Public
99
+ License.
100
+
101
+ g. Licensed Rights means the rights granted to You subject to the
102
+ terms and conditions of this Public License, which are limited to
103
+ all Copyright and Similar Rights that apply to Your use of the
104
+ Licensed Material and that the Licensor has authority to license.
105
+
106
+ h. Licensor means the individual(s) or entity(ies) granting rights
107
+ under this Public License.
108
+
109
+ i. NonCommercial means not primarily intended for or directed towards
110
+ commercial advantage or monetary compensation. For purposes of
111
+ this Public License, the exchange of the Licensed Material for
112
+ other material subject to Copyright and Similar Rights by digital
113
+ file-sharing or similar means is NonCommercial provided there is
114
+ no payment of monetary compensation in connection with the
115
+ exchange.
116
+
117
+ j. Share means to provide material to the public by any means or
118
+ process that requires permission under the Licensed Rights, such
119
+ as reproduction, public display, public performance, distribution,
120
+ dissemination, communication, or importation, and to make material
121
+ available to the public including in ways that members of the
122
+ public may access the material from a place and at a time
123
+ individually chosen by them.
124
+
125
+ k. Sui Generis Database Rights means rights other than copyright
126
+ resulting from Directive 96/9/EC of the European Parliament and of
127
+ the Council of 11 March 1996 on the legal protection of databases,
128
+ as amended and/or succeeded, as well as other essentially
129
+ equivalent rights anywhere in the world.
130
+
131
+ l. You means the individual or entity exercising the Licensed Rights
132
+ under this Public License. Your has a corresponding meaning.
133
+
134
+
135
+ Section 2 -- Scope.
136
+
137
+ a. License grant.
138
+
139
+ 1. Subject to the terms and conditions of this Public License,
140
+ the Licensor hereby grants You a worldwide, royalty-free,
141
+ non-sublicensable, non-exclusive, irrevocable license to
142
+ exercise the Licensed Rights in the Licensed Material to:
143
+
144
+ a. reproduce and Share the Licensed Material, in whole or
145
+ in part, for NonCommercial purposes only; and
146
+
147
+ b. produce, reproduce, and Share Adapted Material for
148
+ NonCommercial purposes only.
149
+
150
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
151
+ Exceptions and Limitations apply to Your use, this Public
152
+ License does not apply, and You do not need to comply with
153
+ its terms and conditions.
154
+
155
+ 3. Term. The term of this Public License is specified in Section
156
+ 6(a).
157
+
158
+ 4. Media and formats; technical modifications allowed. The
159
+ Licensor authorizes You to exercise the Licensed Rights in
160
+ all media and formats whether now known or hereafter created,
161
+ and to make technical modifications necessary to do so. The
162
+ Licensor waives and/or agrees not to assert any right or
163
+ authority to forbid You from making technical modifications
164
+ necessary to exercise the Licensed Rights, including
165
+ technical modifications necessary to circumvent Effective
166
+ Technological Measures. For purposes of this Public License,
167
+ simply making modifications authorized by this Section 2(a)
168
+ (4) never produces Adapted Material.
169
+
170
+ 5. Downstream recipients.
171
+
172
+ a. Offer from the Licensor -- Licensed Material. Every
173
+ recipient of the Licensed Material automatically
174
+ receives an offer from the Licensor to exercise the
175
+ Licensed Rights under the terms and conditions of this
176
+ Public License.
177
+
178
+ b. No downstream restrictions. You may not offer or impose
179
+ any additional or different terms or conditions on, or
180
+ apply any Effective Technological Measures to, the
181
+ Licensed Material if doing so restricts exercise of the
182
+ Licensed Rights by any recipient of the Licensed
183
+ Material.
184
+
185
+ 6. No endorsement. Nothing in this Public License constitutes or
186
+ may be construed as permission to assert or imply that You
187
+ are, or that Your use of the Licensed Material is, connected
188
+ with, or sponsored, endorsed, or granted official status by,
189
+ the Licensor or others designated to receive attribution as
190
+ provided in Section 3(a)(1)(A)(i).
191
+
192
+ b. Other rights.
193
+
194
+ 1. Moral rights, such as the right of integrity, are not
195
+ licensed under this Public License, nor are publicity,
196
+ privacy, and/or other similar personality rights; however, to
197
+ the extent possible, the Licensor waives and/or agrees not to
198
+ assert any such rights held by the Licensor to the limited
199
+ extent necessary to allow You to exercise the Licensed
200
+ Rights, but not otherwise.
201
+
202
+ 2. Patent and trademark rights are not licensed under this
203
+ Public License.
204
+
205
+ 3. To the extent possible, the Licensor waives any right to
206
+ collect royalties from You for the exercise of the Licensed
207
+ Rights, whether directly or through a collecting society
208
+ under any voluntary or waivable statutory or compulsory
209
+ licensing scheme. In all other cases the Licensor expressly
210
+ reserves any right to collect such royalties, including when
211
+ the Licensed Material is used other than for NonCommercial
212
+ purposes.
213
+
214
+
215
+ Section 3 -- License Conditions.
216
+
217
+ Your exercise of the Licensed Rights is expressly made subject to the
218
+ following conditions.
219
+
220
+ a. Attribution.
221
+
222
+ 1. If You Share the Licensed Material (including in modified
223
+ form), You must:
224
+
225
+ a. retain the following if it is supplied by the Licensor
226
+ with the Licensed Material:
227
+
228
+ i. identification of the creator(s) of the Licensed
229
+ Material and any others designated to receive
230
+ attribution, in any reasonable manner requested by
231
+ the Licensor (including by pseudonym if
232
+ designated);
233
+
234
+ ii. a copyright notice;
235
+
236
+ iii. a notice that refers to this Public License;
237
+
238
+ iv. a notice that refers to the disclaimer of
239
+ warranties;
240
+
241
+ v. a URI or hyperlink to the Licensed Material to the
242
+ extent reasonably practicable;
243
+
244
+ b. indicate if You modified the Licensed Material and
245
+ retain an indication of any previous modifications; and
246
+
247
+ c. indicate the Licensed Material is licensed under this
248
+ Public License, and include the text of, or the URI or
249
+ hyperlink to, this Public License.
250
+
251
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
252
+ reasonable manner based on the medium, means, and context in
253
+ which You Share the Licensed Material. For example, it may be
254
+ reasonable to satisfy the conditions by providing a URI or
255
+ hyperlink to a resource that includes the required
256
+ information.
257
+
258
+ 3. If requested by the Licensor, You must remove any of the
259
+ information required by Section 3(a)(1)(A) to the extent
260
+ reasonably practicable.
261
+
262
+ 4. If You Share Adapted Material You produce, the Adapter's
263
+ License You apply must not prevent recipients of the Adapted
264
+ Material from complying with this Public License.
265
+
266
+
267
+ Section 4 -- Sui Generis Database Rights.
268
+
269
+ Where the Licensed Rights include Sui Generis Database Rights that
270
+ apply to Your use of the Licensed Material:
271
+
272
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
273
+ to extract, reuse, reproduce, and Share all or a substantial
274
+ portion of the contents of the database for NonCommercial purposes
275
+ only;
276
+
277
+ b. if You include all or a substantial portion of the database
278
+ contents in a database in which You have Sui Generis Database
279
+ Rights, then the database in which You have Sui Generis Database
280
+ Rights (but not its individual contents) is Adapted Material; and
281
+
282
+ c. You must comply with the conditions in Section 3(a) if You Share
283
+ all or a substantial portion of the contents of the database.
284
+
285
+ For the avoidance of doubt, this Section 4 supplements and does not
286
+ replace Your obligations under this Public License where the Licensed
287
+ Rights include other Copyright and Similar Rights.
288
+
289
+
290
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
291
+
292
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
293
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
294
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
295
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
296
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
297
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
298
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
299
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
300
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
301
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
302
+
303
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
304
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
305
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
306
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
307
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
308
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
309
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
310
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
311
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
312
+
313
+ c. The disclaimer of warranties and limitation of liability provided
314
+ above shall be interpreted in a manner that, to the extent
315
+ possible, most closely approximates an absolute disclaimer and
316
+ waiver of all liability.
317
+
318
+
319
+ Section 6 -- Term and Termination.
320
+
321
+ a. This Public License applies for the term of the Copyright and
322
+ Similar Rights licensed here. However, if You fail to comply with
323
+ this Public License, then Your rights under this Public License
324
+ terminate automatically.
325
+
326
+ b. Where Your right to use the Licensed Material has terminated under
327
+ Section 6(a), it reinstates:
328
+
329
+ 1. automatically as of the date the violation is cured, provided
330
+ it is cured within 30 days of Your discovery of the
331
+ violation; or
332
+
333
+ 2. upon express reinstatement by the Licensor.
334
+
335
+ For the avoidance of doubt, this Section 6(b) does not affect any
336
+ right the Licensor may have to seek remedies for Your violations
337
+ of this Public License.
338
+
339
+ c. For the avoidance of doubt, the Licensor may also offer the
340
+ Licensed Material under separate terms or conditions or stop
341
+ distributing the Licensed Material at any time; however, doing so
342
+ will not terminate this Public License.
343
+
344
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
345
+ License.
346
+
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+
358
+ Section 8 -- Interpretation.
359
+
360
+ a. For the avoidance of doubt, this Public License does not, and
361
+ shall not be interpreted to, reduce, limit, restrict, or impose
362
+ conditions on any use of the Licensed Material that could lawfully
363
+ be made without permission under this Public License.
364
+
365
+ b. To the extent possible, if any provision of this Public License is
366
+ deemed unenforceable, it shall be automatically reformed to the
367
+ minimum extent necessary to make it enforceable. If the provision
368
+ cannot be reformed, it shall be severed from this Public License
369
+ without affecting the enforceability of the remaining terms and
370
+ conditions.
371
+
372
+ c. No term or condition of this Public License will be waived and no
373
+ failure to comply consented to unless expressly agreed to by the
374
+ Licensor.
375
+
376
+ d. Nothing in this Public License constitutes or may be interpreted
377
+ as a limitation upon, or waiver of, any privileges and immunities
378
+ that apply to the Licensor or You, including from the legal
379
+ processes of any jurisdiction or authority.
380
+
381
+ =======================================================================
382
+
383
+ Creative Commons is not a party to its public
384
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
385
+ its public licenses to material it publishes and in those instances
386
+ will be considered the “Licensor.” The text of the Creative Commons
387
+ public licenses is dedicated to the public domain under the CC0 Public
388
+ Domain Dedication. Except for the limited purpose of indicating that
389
+ material is shared under a Creative Commons public license or as
390
+ otherwise permitted by the Creative Commons policies published at
391
+ creativecommons.org/policies, Creative Commons does not authorize the
392
+ use of the trademark "Creative Commons" or any other trademark or logo
393
+ of Creative Commons without its prior written consent including,
394
+ without limitation, in connection with any unauthorized modifications
395
+ to any of its public licenses or any other arrangements,
396
+ understandings, or agreements concerning use of licensed material. For
397
+ the avoidance of doubt, this paragraph does not form part of the
398
+ public licenses.
399
+
400
+ Creative Commons may be contacted at creativecommons.org.
LICENSE_GROVER ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Tencent AI Lab. All rights reserved.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ =============================================================================
24
+
25
+ Other dependencies and licenses:
26
+
27
+ -----------------------------------------------------------------------------
28
+ The MIT License (MIT)
29
+ applies to:
30
+ - chemprop (https://github.com/chemprop/chemprop)
31
+ -----------------------------------------------------------------------------
32
+
33
+ MIT License
34
+
35
+ Copyright (c) 2020 Wengong Jin, Kyle Swanson, Kevin Yang, Regina Barzilay, Tommi Jaakkola
36
+
37
+ Permission is hereby granted, free of charge, to any person obtaining a copy
38
+ of this software and associated documentation files (the "Software"), to deal
39
+ in the Software without restriction, including without limitation the rights
40
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41
+ copies of the Software, and to permit persons to whom the Software is
42
+ furnished to do so, subject to the following conditions:
43
+
44
+ The above copyright notice and this permission notice shall be included in all
45
+ copies or substantial portions of the Software.
46
+
47
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53
+ SOFTWARE.
MODEL_CARD.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model card - tox21_grover_classifier
2
+ ### Model details
3
+ - Model name: Grover Tox21 Baseline
4
+ - Developer: Tencent AI Lab (trained by JKU Linz)
5
+ - Paper URL: https://arxiv.org/pdf/2007.02835
6
+ - Model type / architecture:
7
+ - Grover implemented using the code accompanying the paper.
8
+ - The pretrained grover_base model is used as provided.
9
+ - A multitask network is finetuned for all Tox21 targets.
10
+ - Inference: Access via FastAPI endpoint. Upon a Tox21 prediction request, the model generates and returns predictions for all Tox21 targets simultaneously.
11
+ - Model version: v0
12
+ - Model date: 10.12.2025
13
+ - Reproducibility: Code for full training is available and enables refitting of TabPFN
14
+ from scratch.
15
+ - Reproducibility: Code for full training is available and enables retraining of the model from scratch.
16
+
17
+ ### Intended use
18
+ This model serves as a baseline benchmark for evaluating and comparing toxicity prediction
19
+ methods across the 12 pathway assays of the Tox21 dataset. It is not intended for clinical
20
+ decision-making without experimental validation.
21
+
22
+ ### Metric
23
+ Each Tox21 task is evaluated using the area under the receiver operating characteristic curve (AUC). Overall performance is reported as the mean AUC across all individual tasks.
24
+
25
+ ### Training data
26
+ Tox21 training and validation sets.
27
+
28
+ ### Evaluation data
29
+ Tox21 test set.
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']}",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "Tox21 GROVER Classifier",
48
+ "version": "0.1.0",
49
+ "tox_endpoints": [
50
+ "NR-AR",
51
+ "NR-AR-LBD",
52
+ "NR-AhR",
53
+ "NR-Aromatase",
54
+ "NR-ER",
55
+ "NR-ER-LBD",
56
+ "NR-PPAR-gamma",
57
+ "SR-ARE",
58
+ "SR-ATAD5",
59
+ "SR-HSE",
60
+ "SR-MMP",
61
+ "SR-p53",
62
+ ],
63
+ }
64
+
65
+
66
+ @app.get("/healthz")
67
+ def healthz():
68
+ return {"ok": True}
69
+
70
+
71
+ @app.post("/predict", response_model=Response)
72
+ def predict(request: Request):
73
+ predictions = predict_func(request.smiles)
74
+ return {
75
+ "predictions": predictions,
76
+ "model_info": {"name": "Tox21 GROVER classifier", "version": "0.1.0"},
77
+ }
config/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_size": 32,
3
+ "init_lr": 10,
4
+ "max_lr": 0.0001,
5
+ "final_lr": 9,
6
+ "dropout": 0.2,
7
+ "attn_hidden": 128,
8
+ "attn_out": 8,
9
+ "dist_coff": 0.15,
10
+ "bond_drop_rate": 0.2,
11
+ "ffn_num_layer": 2,
12
+ "ffn_hidden_size": 13,
13
+ "real_init_lr": 1e-05,
14
+ "real_final_lr": 1.1111111111111112e-05
15
+ }
environment.yaml ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: grover
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ - pytorch
6
+ - https://repo.anaconda.com/pkgs/free
7
+ - rmg
8
+ - rdkit
9
+ dependencies:
10
+ - _libgcc_mutex=0.1=conda_forge
11
+ - _openmp_mutex=4.5=2_gnu
12
+ - absl-py=2.3.1=pyhd8ed1ab_0
13
+ - aom=3.5.0=h27087fc_0
14
+ - blas=1.0=mkl
15
+ - boost=1.74.0=py39h5472131_5
16
+ - boost-cpp=1.74.0=h75c5d50_8
17
+ - brotli-python=1.1.0=py39hf88036b_3
18
+ - bzip2=1.0.8=hda65f42_8
19
+ - c-ares=1.34.5=hb9d3cd8_0
20
+ - ca-certificates=2025.11.12=hbd8a1cb_0
21
+ - cairo=1.18.4=h44eff21_0
22
+ - certifi=2025.8.3=pyhd8ed1ab_0
23
+ - cffi=1.17.1=py39h15c3d72_0
24
+ - charset-normalizer=3.4.3=pyhd8ed1ab_0
25
+ - colorama=0.4.6=pyhd8ed1ab_1
26
+ - cpython=3.9.23=py39hd8ed1ab_0
27
+ - cuda-cudart=12.4.127=he02047a_2
28
+ - cuda-cudart_linux-64=12.4.127=h85509e4_2
29
+ - cuda-cupti=12.4.127=he02047a_2
30
+ - cuda-libraries=12.4.1=ha770c72_1
31
+ - cuda-nvrtc=12.4.127=he02047a_2
32
+ - cuda-nvtx=12.4.127=he02047a_2
33
+ - cuda-opencl=12.4.127=he02047a_1
34
+ - cuda-runtime=12.4.1=ha804496_0
35
+ - cuda-version=12.4=h3060b56_3
36
+ - ffmpeg=5.1.2=gpl_h8dda1f0_106
37
+ - filelock=3.19.1=pyhd8ed1ab_0
38
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
39
+ - font-ttf-inconsolata=3.000=h77eed37_0
40
+ - font-ttf-source-code-pro=2.038=h77eed37_0
41
+ - font-ttf-ubuntu=0.83=h77eed37_3
42
+ - fontconfig=2.15.0=h7e30c49_1
43
+ - fonts-conda-ecosystem=1=0
44
+ - fonts-conda-forge=1=hc364b38_1
45
+ - freetype=2.14.1=ha770c72_0
46
+ - gmp=6.3.0=hac33072_2
47
+ - gmpy2=2.2.1=py39h7196dd7_0
48
+ - gnutls=3.7.9=hb077bed_0
49
+ - grpcio=1.74.1=py39h6482d51_0
50
+ - h2=4.2.0=pyhd8ed1ab_0
51
+ - hpack=4.1.0=pyhd8ed1ab_0
52
+ - hyperframe=6.1.0=pyhd8ed1ab_0
53
+ - icu=70.1=h27087fc_0
54
+ - idna=3.10=pyhd8ed1ab_1
55
+ - importlib-metadata=8.7.0=pyhe01879c_1
56
+ - intel-openmp=2023.0.0=h9e868ea_25371
57
+ - jinja2=3.1.6=pyhd8ed1ab_0
58
+ - joblib=1.5.1=pyhd8ed1ab_0
59
+ - lame=3.100=h166bdaf_1003
60
+ - lcms2=2.17=h717163a_0
61
+ - ld_impl_linux-64=2.45=default_hbd61a6d_104
62
+ - lerc=4.0.0=h0aef613_1
63
+ - libabseil=20250512.1=cxx17_hba17884_0
64
+ - libblas=3.9.0=1_h86c2bf4_netlib
65
+ - libcblas=3.9.0=13_h8e06fc2_netlib
66
+ - libcublas=12.4.5.8=he02047a_2
67
+ - libcufft=11.2.1.3=he02047a_2
68
+ - libcufile=1.9.1.3=he02047a_2
69
+ - libcurand=10.3.5.147=he02047a_2
70
+ - libcusolver=11.6.1.9=he02047a_2
71
+ - libcusparse=12.3.1.170=he02047a_2
72
+ - libdeflate=1.25=h17f619e_0
73
+ - libdrm=2.4.125=hb03c661_1
74
+ - libegl=1.7.0=ha4b6fd6_2
75
+ - libexpat=2.7.3=hecca717_0
76
+ - libffi=3.4.6=h2dba641_1
77
+ - libfreetype=2.14.1=ha770c72_0
78
+ - libfreetype6=2.14.1=h73754d4_0
79
+ - libgcc=15.2.0=he0feb66_13
80
+ - libgcc-ng=15.2.0=h69a702a_13
81
+ - libgfortran=15.2.0=h69a702a_13
82
+ - libgfortran-ng=15.2.0=h69a702a_13
83
+ - libgfortran5=15.2.0=h68bc16d_13
84
+ - libgl=1.7.0=ha4b6fd6_2
85
+ - libglib=2.86.0=h1fed272_0
86
+ - libglvnd=1.7.0=ha4b6fd6_2
87
+ - libglx=1.7.0=ha4b6fd6_2
88
+ - libgomp=15.2.0=he0feb66_13
89
+ - libgrpc=1.74.1=hebd82d6_0
90
+ - libhwloc=2.9.1=hd6dc26d_0
91
+ - libiconv=1.18=h3b78370_2
92
+ - libidn2=2.3.8=hfac485b_1
93
+ - libjpeg-turbo=3.1.2=hb03c661_0
94
+ - liblapack=3.9.0=13_h8876d29_netlib
95
+ - liblzma=5.8.1=hb9d3cd8_2
96
+ - liblzma-devel=5.8.1=hb9d3cd8_2
97
+ - libnpp=12.2.5.30=he02047a_2
98
+ - libnsl=2.0.1=hb9d3cd8_1
99
+ - libnvfatbin=12.4.127=he02047a_2
100
+ - libnvjitlink=12.4.127=he02047a_2
101
+ - libnvjpeg=12.3.1.117=he02047a_2
102
+ - libopus=1.5.2=hd0c01bc_0
103
+ - libpciaccess=0.18=hb9d3cd8_0
104
+ - libpng=1.6.51=h421ea60_0
105
+ - libprotobuf=6.31.1=h49aed37_2
106
+ - libre2-11=2025.11.05=h7b12aa8_0
107
+ - libsqlite=3.51.0=h0c1763c_0
108
+ - libstdcxx=15.2.0=h934c35e_13
109
+ - libstdcxx-ng=15.2.0=hdf11a46_13
110
+ - libtasn1=4.20.0=hb03c661_1
111
+ - libtiff=4.7.1=h9d88235_1
112
+ - libunistring=0.9.10=h7f98852_0
113
+ - libuuid=2.41.2=he9a06e4_0
114
+ - libva=2.22.0=h4f16b4b_2
115
+ - libvpx=1.11.0=h9c3ff4c_3
116
+ - libwebp-base=1.6.0=hd42ef1d_0
117
+ - libxcb=1.17.0=h8a09558_0
118
+ - libxcrypt=4.4.36=hd590300_1
119
+ - libxml2=2.10.3=hca2bb57_4
120
+ - libzlib=1.3.1=hb9d3cd8_2
121
+ - llvm-openmp=15.0.7=h0cdce71_0
122
+ - markdown=3.8.2=pyhd8ed1ab_0
123
+ - markupsafe=3.0.2=py39h9399b63_1
124
+ - mkl=2023.1.0=h213fc3f_46344
125
+ - mpc=1.3.1=h24ddda3_1
126
+ - mpfr=4.2.1=h90cbb55_3
127
+ - mpmath=1.3.0=pyhd8ed1ab_1
128
+ - ncurses=6.5=h2d0b736_3
129
+ - nettle=3.9.1=h7ab15ed_0
130
+ - networkx=3.2.1=pyhd8ed1ab_0
131
+ - numpy=1.26.4=py39h474f0d3_0
132
+ - ocl-icd=2.3.3=hb9d3cd8_0
133
+ - opencl-headers=2025.06.13=h5888daf_0
134
+ - openh264=2.3.1=hcb278e6_2
135
+ - openjpeg=2.5.4=h55fea9a_0
136
+ - openssl=3.6.0=h26f9b46_0
137
+ - p11-kit=0.24.1=hc5aa10d_0
138
+ - packaging=25.0=pyh29332c3_1
139
+ - pandas=2.3.1=py39h1b6b32d_0
140
+ - pcre2=10.46=h1321c63_0
141
+ - pillow=11.3.0=py39h15c0740_0
142
+ - pip=25.2=pyh8b19718_0
143
+ - pixman=0.46.4=h54a6638_1
144
+ - protobuf=6.31.1=py39h2f5525a_0
145
+ - pthread-stubs=0.4=hb9d3cd8_1002
146
+ - pycairo=1.28.0=py39ha09e8ab_0
147
+ - pycparser=2.22=pyh29332c3_1
148
+ - pysocks=1.7.1=pyha55dd90_7
149
+ - python=3.9.23=hc30ae73_0_cpython
150
+ - python-dateutil=2.9.0.post0=pyhe01879c_2
151
+ - python-tzdata=2025.2=pyhd8ed1ab_0
152
+ - python_abi=3.9=8_cp39
153
+ - pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
154
+ - pytorch-cuda=12.4=hc786d27_7
155
+ - pytorch-mutex=1.0=cuda
156
+ - pytz=2025.2=pyhd8ed1ab_0
157
+ - pyyaml=6.0.2=py39h9399b63_2
158
+ - re2=2025.11.05=h5301d42_0
159
+ - readline=8.2=h8c095d6_2
160
+ - requests=2.32.5=pyhd8ed1ab_0
161
+ - scikit-learn=1.6.1=py39h4b7350c_0
162
+ - setuptools=80.9.0=pyhff2d567_0
163
+ - six=1.17.0=pyhe01879c_1
164
+ - svt-av1=1.4.1=hcb278e6_0
165
+ - sympy=1.14.0=pyh2585a3b_105
166
+ - tbb=2021.9.0=hf52228f_0
167
+ - tensorboard=2.20.0=pyhe01879c_0
168
+ - tensorboard-data-server=0.7.0=py39h7170ec2_2
169
+ - threadpoolctl=3.6.0=pyhecae5ae_0
170
+ - tk=8.6.13=noxft_ha0e22de_103
171
+ - torchtriton=3.0.0=py39
172
+ - torchvision=0.19.0=py39_cu124
173
+ - tqdm=4.67.1=pyhd8ed1ab_1
174
+ - typing_extensions=4.14.1=pyhe01879c_0
175
+ - tzdata=2025b=h78e105d_0
176
+ - urllib3=2.5.0=pyhd8ed1ab_0
177
+ - wayland=1.24.0=h3e06ad9_0
178
+ - wayland-protocols=1.46=hd8ed1ab_0
179
+ - werkzeug=3.1.3=pyhd8ed1ab_1
180
+ - wheel=0.45.1=pyhd8ed1ab_1
181
+ - x264=1!164.3095=h166bdaf_2
182
+ - x265=3.5=h924138e_3
183
+ - xorg-libx11=1.8.12=h4f16b4b_0
184
+ - xorg-libxau=1.0.12=hb03c661_1
185
+ - xorg-libxdmcp=1.1.5=hb03c661_1
186
+ - xorg-libxext=1.3.6=hb9d3cd8_0
187
+ - xorg-libxfixes=6.0.2=hb03c661_0
188
+ - xorg-libxrender=0.9.12=hb9d3cd8_0
189
+ - xz=5.8.1=hbcc6ac9_2
190
+ - xz-gpl-tools=5.8.1=hbcc6ac9_2
191
+ - xz-tools=5.8.1=hb9d3cd8_2
192
+ - yaml=0.2.5=h280c20c_3
193
+ - zipp=3.23.0=pyhd8ed1ab_0
194
+ - zlib=1.3.1=hb9d3cd8_2
195
+ - zstandard=0.23.0=py39hd399759_3
196
+ - zstd=1.5.7=hb8e6e7a_2
197
+ - pip:
198
+ - aiohappyeyeballs==2.6.1
199
+ - aiohttp==3.13.2
200
+ - aiosignal==1.4.0
201
+ - anyio==4.12.0
202
+ - async-timeout==5.0.1
203
+ - attrs==25.4.0
204
+ - click==8.1.8
205
+ - datasets==4.4.1
206
+ - descriptastorus==2.8.0
207
+ - dill==0.4.0
208
+ - exceptiongroup==1.3.1
209
+ - frozenlist==1.8.0
210
+ - fsspec==2025.10.0
211
+ - h11==0.16.0
212
+ - hf-xet==1.2.0
213
+ - httpcore==1.0.9
214
+ - httpx==0.28.1
215
+ - huggingface-hub==1.1.7
216
+ - multidict==6.7.0
217
+ - multiprocess==0.70.18
218
+ - pandas-flavor==0.7.0
219
+ - propcache==0.4.1
220
+ - pyarrow==21.0.0
221
+ - rdkit==2025.9.1
222
+ - rdkit-pypi==2022.9.5
223
+ - scipy==1.10.1
224
+ - shellingham==1.5.4
225
+ - typer-slim==0.20.0
226
+ - xarray==2024.7.0
227
+ - xxhash==3.6.0
228
+ - yarl==1.22.0
229
+ prefix: /system/apps/userenv/stopf/grover
evaluate.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from src.eval import compute_roc_auc_from_csv
4
+ from src.commands import predict_from_csv
5
+
6
+ data_path = "./tox21/tox21_test_clean.csv"
7
+ features_path = data_path.replace(".csv", ".npz")
8
+ checkpoint_dir = "checkpoints"
9
+ output_path = "predictions/test_set_preds_best.csv"
10
+
11
+ predict_from_csv(data_path, features_path, checkpoint_dir, output_path)
12
+
13
+ valid_mask = np.load("./tox21/valid_mask_test.npy")
14
+ auc_array, mean_auc = compute_roc_auc_from_csv(output_path, "./tox21/tox21_test.csv", valid_mask)
15
+
16
+ print(auc_array)
17
+ print(mean_auc)
generate_features.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from src.commands import generate_features
4
+
5
+ # paths
6
+ TRAIN_CSV = "./tox21/tox21_train_clean.csv"
7
+ VAL_CSV = "./tox21/tox21_validation_clean.csv"
8
+
9
+ generate_features(EXAMPLES_CSV, EXAMPLES_CSV.replace('.csv', '.npz'))
grover ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3f280d7d3419a781d303b1500c7039e37a1d87a2
hp_search.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import numpy as np
4
+ from datetime import datetime
5
+
6
+ from src.hp_search import generate_random_search
7
+ from src.commands import finetune, predict_from_csv
8
+ from src.eval import compute_roc_auc_from_csv
9
+
10
+ HYPERPARAM_GRID = {
11
+ "batch_size": [32],
12
+ "init_lr": [10],
13
+ "max_lr": [0.001, 0.0005, 0.0001],
14
+ "final_lr": [2, 3, 4, 5, 6, 7, 8, 9, 10],
15
+ "dropout": [0.0, 0.05, 0.1, 0.2],
16
+ "attn_hidden": [128],
17
+ "attn_out": [4, 8],
18
+ "dist_coff": [0.05, 0.1, 0.15],
19
+ "bond_drop_rate": [0.0, 0.2, 0.4, 0.6],
20
+ "ffn_num_layer": [2, 3],
21
+ "ffn_hidden_size": [5, 7, 13],
22
+ }
23
+
24
+ hp_grid = generate_random_search(HYPERPARAM_GRID, num_trials=300, seed=42)
25
+ print("Total number of configs:", len(hp_grid))
26
+
27
+ # general vars
28
+ train_path = "tox21/tox21_train_clean.csv"
29
+ val_path = "tox21/tox21_validation_clean.csv"
30
+ train_features_path = train_path.replace(".csv", ".npz")
31
+ val_features_path = val_path.replace(".csv", ".npz")
32
+ checkpoint_path = "pretrained_models/grover_base.pt"
33
+
34
+ # Tracking best model
35
+ best_mean_auc = -1
36
+ best_config = None
37
+ best_model_path = None
38
+
39
+ # Create directory for logs
40
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
41
+ log_dir = f"hp_search/logs/{timestamp}"
42
+ os.makedirs(log_dir, exist_ok=True)
43
+ overall_log_path = f"{log_dir}/hp_search_results.txt"
44
+ best_log_path = f"{log_dir}/best_result.txt"
45
+
46
+
47
+ # iterate over configs
48
+ for i, args in enumerate(hp_grid):
49
+ save_dir = f"hp_search/trials/Trial_{i+1}"
50
+
51
+ print("\n=========================================")
52
+ print("Training with config:")
53
+ print(args)
54
+ print("Save dir:", save_dir)
55
+ print("=========================================\n")
56
+
57
+ # finetune model
58
+ finetune(train_path, val_path, train_features_path, val_features_path,
59
+ save_dir, checkpoint_path, args)
60
+
61
+ # predict on val set
62
+ finetuned_model_dir = save_dir + "/fold_0/model_0"
63
+ output_path = save_dir + "/predictions.csv"
64
+ predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path)
65
+
66
+ # evaluate model
67
+ preds_path = save_dir + "/predictions.csv"
68
+ labels_path = "tox21/tox21_validation.csv"
69
+ valid_mask = np.load("./tox21/valid_mask_val.npy")
70
+ auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask)
71
+
72
+ # Save all experiment results
73
+ with open(overall_log_path, "a") as f:
74
+ f.write("\n===============================\n")
75
+ f.write(f"Trial Num: {i+1}\n")
76
+ f.write(f"Mean AUC: {mean_auc}\n")
77
+ f.write(f"Config: {args}\n")
78
+ f.write(f"Save dir: {save_dir}\n")
79
+ f.write(f"AUC per task: {auc_per_task}\n")
80
+
81
+ # Check if best model
82
+ if mean_auc > best_mean_auc:
83
+ print("New BEST model found!")
84
+ best_mean_auc = mean_auc
85
+ best_config = args
86
+ best_model_path = save_dir
87
+
88
+ with open(best_log_path, "w") as f:
89
+ f.write("==== BEST MODEL SO FAR ====\n")
90
+ f.write(f"Trial Num: {i+1}\n")
91
+ f.write(f"Mean AUC: {best_mean_auc}\n")
92
+ f.write(f"Config: {best_config}\n")
93
+ f.write(f"Saved at: {best_model_path}\n")
94
+
95
+ print("\n============================")
96
+ print("Hyperparameter Search DONE!")
97
+ print("Trial Num: ", i+1)
98
+ print("Best mean AUC:", best_mean_auc)
99
+ print("Best model saved at:", best_model_path)
100
+ print("Best config:", best_config)
101
+ print("============================\n")
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from rdkit import RDLogger
6
+
7
+ from grover.util.parsing import parse_args, get_newest_train_args
8
+ from grover.util.utils import create_logger
9
+ from grover.task.cross_validate import cross_validate
10
+ from grover.task.fingerprint import generate_fingerprints
11
+ from grover.task.predict import make_predictions, write_prediction
12
+ from grover.task.pretrain import pretrain_model
13
+ from grover.data.torchvocab import MolVocab
14
+
15
+
16
+ def setup(seed):
17
+ # frozen random seed
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ np.random.seed(seed)
21
+ random.seed(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+
24
+
25
+ if __name__ == '__main__':
26
+ # setup random seed
27
+ setup(seed=42)
28
+ # Avoid the pylint warning.
29
+ a = MolVocab
30
+ # supress rdkit logger
31
+ lg = RDLogger.logger()
32
+ lg.setLevel(RDLogger.CRITICAL)
33
+
34
+ # Initialize MolVocab
35
+ mol_vocab = MolVocab
36
+
37
+ args = parse_args()
38
+ if args.parser_name == 'finetune':
39
+ logger = create_logger(name='train', save_dir=args.save_dir, quiet=False)
40
+ cross_validate(args, logger)
41
+ elif args.parser_name == 'pretrain':
42
+ logger = create_logger(name='pretrain', save_dir=args.save_dir)
43
+ pretrain_model(args, logger)
44
+ elif args.parser_name == "eval":
45
+ logger = create_logger(name='eval', save_dir=args.save_dir, quiet=False)
46
+ cross_validate(args, logger)
47
+ elif args.parser_name == 'fingerprint':
48
+ train_args = get_newest_train_args()
49
+ logger = create_logger(name='fingerprint', save_dir=None, quiet=False)
50
+ feas = generate_fingerprints(args, logger)
51
+ np.savez_compressed(args.output_path, fps=feas)
52
+ elif args.parser_name == 'predict':
53
+ train_args = get_newest_train_args()
54
+ avg_preds, test_smiles = make_predictions(args, train_args)
55
+ write_prediction(avg_preds, test_smiles, args)
predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import csv
3
+ import subprocess
4
+ import os
5
+
6
+ from src.preprocess import create_clean_smiles
7
+ from src.commands import generate_features, predict_from_csv
8
+
9
+
10
+
11
+ def predict(smiles_list):
12
+ """
13
+ Predict toxicity targets for a list of SMILES strings.
14
+
15
+ Args:
16
+ smiles_list (list[str]): SMILES strings
17
+
18
+ Returns:
19
+ dict: {smiles: {target_name: prediction_prob}}
20
+ """
21
+ data_path = "tox21/predict_smiles.csv"
22
+ features_path = data_path.replace(".csv", ".npz")
23
+ checkpoint_dir = "checkpoints"
24
+ output_path = "predictions/smiles_predictions.csv"
25
+
26
+ # clean smiles
27
+ clean_smiles, valid_mask = create_clean_smiles(smiles_list)
28
+
29
+ # Mapping from cleaned to original for valid ones
30
+ originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
31
+
32
+ # sanity check (optional but nice to have)
33
+ if len(originals_valid) != len(clean_smiles):
34
+ raise ValueError(
35
+ f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
36
+ )
37
+
38
+ # map cleaned → original
39
+ cleaned_to_original = dict(zip(clean_smiles, originals_valid))
40
+
41
+ # tox21 targets
42
+ TARGET_NAMES = [
43
+ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
44
+ ]
45
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ print(f"Received {len(smiles_list)} SMILES strings")
47
+
48
+ # put smiles into csv
49
+ with open(data_path, "w", newline="") as f:
50
+ writer = csv.writer(f)
51
+ writer.writerow(["smiles"] + TARGET_NAMES) # header
52
+ for smi in clean_smiles:
53
+ writer.writerow([smi] + [""] * len(TARGET_NAMES))
54
+
55
+ # generate features
56
+ generate_features(data_path, features_path)
57
+
58
+ # predict
59
+ predict_from_csv(data_path, features_path, checkpoint_dir, output_path)
60
+
61
+ # create results dictionary from predictions
62
+ predictions = {}
63
+ with open(output_path, "r", newline="") as f:
64
+ reader = csv.DictReader(f)
65
+ rows = list(reader)
66
+
67
+ # Identify the SMILES column even if it is unnamed
68
+ fieldnames = reader.fieldnames
69
+ smiles_col = fieldnames[0] # first column, even if empty string
70
+
71
+ target_names = fieldnames[1:] # all columns except first
72
+
73
+ for row in rows:
74
+ clean_smi = row[smiles_col]
75
+ original_smi = cleaned_to_original.get(clean_smi, clean_smi)
76
+
77
+ pred_dict = {t: float(row[t]) for t in target_names}
78
+ predictions[original_smi] = pred_dict
79
+
80
+ # Add placeholder predictions for invalid SMILES
81
+ for smi, is_valid in zip(smiles_list, valid_mask):
82
+ if not is_valid:
83
+ predictions[smi] = {t: 0.5 for t in TARGET_NAMES}
84
+
85
+
86
+ return predictions
87
+
88
+ preds = predict(["Oc1cc(O)cc(C=Cc2ccc(O)c(O)c2)c1"])
89
+ print(preds)
prepare_data.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from src.preprocess import clean_smiles_in_csv
4
+
5
+ TOX21_TASKS = [
6
+ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
7
+ ]
8
+
9
+ def prepare_data(data_path, save_path_clean_data, save_path_valid_mask):
10
+ valid_mask_train = clean_smiles_in_csv(data_path, save_path_clean_data, "smiles", TOX21_TASKS)
11
+ np.save(save_path_valid_mask, valid_mask_train)
12
+
13
+
14
+ train_path = "./tox21/tox21_train.csv"
15
+ val_path = "./tox21/tox21_validation.csv"
16
+
17
+ train_path_clean = "./tox21/tox21_train_clean.csv"
18
+ val_path_clean = "./tox21/tox21_validation_clean.csv"
19
+
20
+ prepare_data(test_path, test_path_clean, "./tox21/valid_mask_test.npy")
src/commands.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def run(cmd):
4
+ print("\n======================================")
5
+ print("Running command:")
6
+ print(cmd)
7
+ print("======================================\n")
8
+ os.system(cmd)
9
+
10
+ def generate_features(data_path, save_path):
11
+ run(
12
+ f"python scripts/save_features.py "
13
+ f"--data_path {data_path} "
14
+ f"--save_path {save_path} "
15
+ f"--features_generator rdkit_2d_normalized "
16
+ f"--restart"
17
+ )
18
+
19
+
20
+ def finetune(train_path, val_path, train_features_path, val_features_path,
21
+ save_dir, checkpoint_path, args
22
+ ):
23
+ finetune_cmd = (
24
+ f"python main.py finetune "
25
+ f"--data_path {train_path} "
26
+ f"--split_type random "
27
+ f"--split_sizes 1 0 0 "
28
+ f"--separate_val_path {val_path} "
29
+ f"--separate_test_path {val_path} "
30
+ f"--features_path {train_features_path} "
31
+ f"--separate_val_features_path {val_features_path} "
32
+ f"--separate_test_features_path {val_features_path} "
33
+ f"--save_dir {save_dir} "
34
+ f"--checkpoint_path {checkpoint_path} "
35
+ f"--dataset_type classification "
36
+ f"--num_folds 1 "
37
+ f"--ensemble_size 1 "
38
+ f"--no_features_scaling "
39
+ f"--ffn_hidden_size {args['ffn_hidden_size']} "
40
+ f"--ffn_num_layers {args['ffn_num_layer']} "
41
+ f"--batch_size {args['batch_size']} "
42
+ f"--epochs 100 "
43
+ f"--init_lr {args['real_init_lr']} "
44
+ f"--final_lr {args['real_final_lr']} "
45
+ f"--max_lr {args['max_lr']} "
46
+ f"--dropout {args['dropout']} "
47
+ f"--attn_hidden {args['attn_hidden']} "
48
+ f"--attn_out {args['attn_out']} "
49
+ f"--dist_coff {args['dist_coff']} "
50
+ f"--bond_drop_rate {args['bond_drop_rate']} "
51
+ )
52
+ run(finetune_cmd)
53
+
54
+
src/eval.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.metrics import roc_auc_score
4
+
5
+ def compute_roc_auc_from_csv(preds_csv: str, labels_csv: str, valid_mask):
6
+ """
7
+ Compute ROC AUC per class and overall mean, similar to the PyTorch-style function.
8
+ Handles missing labels (NaN) like y_mask.
9
+ """
10
+ preds = pd.read_csv(preds_csv)
11
+ labels = pd.read_csv(labels_csv)
12
+
13
+ smiles_cols = [c for c in preds.columns if "smiles" in c.lower()]
14
+ if smiles_cols:
15
+ print(f"🧪 Dropping SMILES columns: {smiles_cols}")
16
+ preds = preds.drop(columns=smiles_cols, errors="ignore")
17
+ labels = labels.drop(columns=smiles_cols, errors="ignore")
18
+
19
+ shared_cols = [c for c in preds.columns if c in labels.columns]
20
+ preds = preds[shared_cols].apply(pd.to_numeric, errors="coerce")
21
+ labels = labels[shared_cols].apply(pd.to_numeric, errors="coerce")
22
+
23
+ y_pred_clean = preds.to_numpy(dtype=float)
24
+ y_true = labels.to_numpy(dtype=float)
25
+ valid_mask = valid_mask[-y_true.shape[0]:]
26
+ #Re-expand to original size
27
+ y_pred = np.full((len(valid_mask), y_pred_clean.shape[1]), 0.5, dtype=float)
28
+ y_pred[valid_mask] = y_pred_clean
29
+
30
+ y_mask = ~np.isnan(y_true)
31
+
32
+ auc_list = []
33
+ for i in range(y_true.shape[1]):
34
+ mask_i = y_mask[:, i]
35
+ if mask_i.sum() > 0:
36
+ try:
37
+ auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
38
+ except ValueError:
39
+ auc = np.nan
40
+ else:
41
+ auc = np.nan
42
+ auc_list.append(auc)
43
+
44
+ auc_array = np.array(auc_list, dtype=np.float32)
45
+ mean_auc = np.nanmean(auc_array)
46
+
47
+ return auc_array, mean_auc
src/hp_search.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ import random
3
+
4
+ def generate_random_search(grid, num_trials=300, seed=42):
5
+ random.seed(seed)
6
+
7
+ keys = list(grid.keys())
8
+ values = list(grid.values())
9
+ all_combinations = []
10
+ for combo in product(*values):
11
+ params = dict(zip(keys, combo))
12
+ # Convert ratios to actual LR values
13
+ params["real_init_lr"] = params["max_lr"] / params["init_lr"]
14
+ params["real_final_lr"] = params["max_lr"] / params["final_lr"]
15
+ all_combinations.append(params)
16
+
17
+ # sample num_runs out of this
18
+ indices = random.sample(range(len(all_combinations)), num_trials)
19
+ hp_subset = [all_combinations[i] for i in indices]
20
+ return hp_subset
src/preprocess.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem.MolStandardize import rdMolStandardize
3
+ from rdkit import Chem
4
+ import numpy as np
5
+ import pandas as pd
6
+ from datasets import load_dataset
7
+ from typing import List, Optional
8
+
9
+ TOX21_TASKS = [
10
+ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
11
+ ]
12
+
13
+ def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
14
+ """
15
+ Clean and canonicalize SMILES strings while staying in SMILES space.
16
+ Returns (list of cleaned SMILES, mask of valid SMILES).
17
+ """
18
+ clean_smis = []
19
+ valid_mask = []
20
+
21
+ cleaner = rdMolStandardize.CleanupParameters()
22
+ tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
23
+
24
+ for smi in smiles_list:
25
+ try:
26
+ mol = Chem.MolFromSmiles(smi)
27
+ if mol is None:
28
+ valid_mask.append(False)
29
+ continue
30
+
31
+ # Cleanup and tautomer canonicalization
32
+ mol = rdMolStandardize.Cleanup(mol, cleaner)
33
+ mol = tautomer_enumerator.Canonicalize(mol)
34
+
35
+ # -------- Charge filtering (prevents GROVER crash) --------
36
+ allowed_charges = {-1, 0, 1}
37
+ bad_charge = False
38
+ for atom in mol.GetAtoms():
39
+ if atom.GetFormalCharge() not in allowed_charges:
40
+ bad_charge = True
41
+ break
42
+
43
+ if bad_charge:
44
+ valid_mask.append(False)
45
+ continue
46
+ # ----------------------------------------------------------
47
+
48
+ # Canonical SMILES output
49
+ clean_smi = Chem.MolToSmiles(mol, canonical=True)
50
+ clean_smis.append(clean_smi)
51
+ valid_mask.append(True)
52
+
53
+ except Exception as e:
54
+ print(f"Failed to clean {smi}: {e}")
55
+ valid_mask.append(False)
56
+
57
+ return clean_smis, np.array(valid_mask, dtype=bool)
58
+
59
+
60
+ def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: Optional[List[str]] = None):
61
+ """
62
+ Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.
63
+ """
64
+ # Load dataset
65
+ df = pd.read_csv(input_csv)
66
+ if smiles_col not in df.columns:
67
+ raise ValueError(f"'{smiles_col}' column not found in CSV.")
68
+
69
+ # Infer target columns if not specified
70
+ if target_cols is None:
71
+ target_cols = [c for c in df.columns if c != smiles_col]
72
+ keep_cols = target_cols
73
+ # Validate target columns
74
+ missing_targets = [c for c in target_cols if c not in df.columns]
75
+ if missing_targets:
76
+ raise ValueError(f"Missing target columns in CSV: {missing_targets}")
77
+
78
+ # Clean SMILES
79
+ clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())
80
+
81
+ # Keep only valid rows
82
+ df_clean = df.loc[valid_mask, keep_cols].copy()
83
+ df_clean.insert(0, smiles_col, clean_smis) # smiles first column
84
+
85
+ # Save cleaned dataset
86
+ df_clean.to_csv(output_csv, index=False)
87
+ print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
88
+ return valid_mask
train.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from src.commands import finetune, predict_from_csv
5
+ from src.eval import compute_roc_auc_from_csv
6
+
7
+
8
+ def load_config(path="config.json"):
9
+ with open(path, "r") as f:
10
+ config = json.load(f)
11
+ return config
12
+
13
+ config = load_config()
14
+
15
+ print(config)
16
+
17
+
18
+ # Paths to custom split
19
+ train_path = "tox21/tox21_train_clean.csv"
20
+ val_path = "tox21/tox21_validation_clean.csv"
21
+
22
+ train_features_path = train_path.replace(".csv", ".npz")
23
+ val_features_path = val_path.replace(".csv", ".npz")
24
+ checkpoint_path = "pretrained_models/grover_base.pt"
25
+
26
+ # Output directory for finetuned model
27
+ save_dir = "finetune/"
28
+
29
+
30
+ finetune(train_path, val_path, train_features_path, val_features_path,
31
+ save_dir, checkpoint_path, args)
32
+
33
+ # predict on val set
34
+ finetuned_model_dir = save_dir + "/fold_0/model_0"
35
+ output_path = save_dir + "/predictions.csv"
36
+ predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path)
37
+
38
+ # evaluate model
39
+ preds_path = save_dir + "/predictions.csv"
40
+ labels_path = "tox21/tox21_validation.csv"
41
+ valid_mask = np.load("./tox21/valid_mask_val.npy")
42
+ auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask)