Spaces:
Sleeping
Sleeping
hiitsmeme
commited on
Commit
·
b25d2b6
1
Parent(s):
9da50ef
initial commit
Browse files- .example.env +1 -0
- .gitignore +8 -0
- LICENSE +400 -0
- LICENSE_GROVER +53 -0
- MODEL_CARD.md +29 -0
- app.py +77 -0
- config/config.json +15 -0
- environment.yaml +229 -0
- evaluate.py +17 -0
- generate_features.py +9 -0
- grover +1 -0
- hp_search.py +101 -0
- main.py +55 -0
- predict.py +89 -0
- prepare_data.py +20 -0
- src/commands.py +54 -0
- src/eval.py +47 -0
- src/hp_search.py +20 -0
- src/preprocess.py +88 -0
- train.py +42 -0
.example.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
TOKEN=example_token
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tox21/*
|
| 2 |
+
predictions/*
|
| 3 |
+
hp_search/trials/*
|
| 4 |
+
hp_search/logs/*
|
| 5 |
+
__pycache__
|
| 6 |
+
.env
|
| 7 |
+
pretrained/*
|
| 8 |
+
grover_base.pt
|
LICENSE
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
=======================================================================
|
| 55 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 56 |
+
License
|
| 57 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 58 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 59 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 60 |
+
License"). To the extent this Public License may be interpreted as a
|
| 61 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
| 62 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
| 63 |
+
such rights in consideration of benefits the Licensor receives from
|
| 64 |
+
making the Licensed Material available under these terms and
|
| 65 |
+
conditions.
|
| 66 |
+
Section 1 -- Definitions.
|
| 67 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 68 |
+
Rights that is derived from or based upon the Licensed Material
|
| 69 |
+
and in which the Licensed Material is translated, altered,
|
| 70 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 71 |
+
permission under the Copyright and Similar Rights held by the
|
| 72 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 73 |
+
Material is a musical work, performance, or sound recording,
|
| 74 |
+
Adapted Material is always produced where the Licensed Material is
|
| 75 |
+
synched in timed relation with a moving image.
|
| 76 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 77 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 78 |
+
accordance with the terms and conditions of this Public License.
|
| 79 |
+
|
| 80 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 81 |
+
closely related to copyright including, without limitation,
|
| 82 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 83 |
+
Rights, without regard to how the rights are labeled or
|
| 84 |
+
categorized. For purposes of this Public License, the rights
|
| 85 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 86 |
+
Rights.
|
| 87 |
+
d. Effective Technological Measures means those measures that, in the
|
| 88 |
+
absence of proper authority, may not be circumvented under laws
|
| 89 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 90 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 91 |
+
agreements.
|
| 92 |
+
|
| 93 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 94 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 95 |
+
that applies to Your use of the Licensed Material.
|
| 96 |
+
|
| 97 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 98 |
+
or other material to which the Licensor applied this Public
|
| 99 |
+
License.
|
| 100 |
+
|
| 101 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 102 |
+
terms and conditions of this Public License, which are limited to
|
| 103 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 104 |
+
Licensed Material and that the Licensor has authority to license.
|
| 105 |
+
|
| 106 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 107 |
+
under this Public License.
|
| 108 |
+
|
| 109 |
+
i. NonCommercial means not primarily intended for or directed towards
|
| 110 |
+
commercial advantage or monetary compensation. For purposes of
|
| 111 |
+
this Public License, the exchange of the Licensed Material for
|
| 112 |
+
other material subject to Copyright and Similar Rights by digital
|
| 113 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 114 |
+
no payment of monetary compensation in connection with the
|
| 115 |
+
exchange.
|
| 116 |
+
|
| 117 |
+
j. Share means to provide material to the public by any means or
|
| 118 |
+
process that requires permission under the Licensed Rights, such
|
| 119 |
+
as reproduction, public display, public performance, distribution,
|
| 120 |
+
dissemination, communication, or importation, and to make material
|
| 121 |
+
available to the public including in ways that members of the
|
| 122 |
+
public may access the material from a place and at a time
|
| 123 |
+
individually chosen by them.
|
| 124 |
+
|
| 125 |
+
k. Sui Generis Database Rights means rights other than copyright
|
| 126 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 127 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 128 |
+
as amended and/or succeeded, as well as other essentially
|
| 129 |
+
equivalent rights anywhere in the world.
|
| 130 |
+
|
| 131 |
+
l. You means the individual or entity exercising the Licensed Rights
|
| 132 |
+
under this Public License. Your has a corresponding meaning.
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
Section 2 -- Scope.
|
| 136 |
+
|
| 137 |
+
a. License grant.
|
| 138 |
+
|
| 139 |
+
1. Subject to the terms and conditions of this Public License,
|
| 140 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 141 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 142 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 143 |
+
|
| 144 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 145 |
+
in part, for NonCommercial purposes only; and
|
| 146 |
+
|
| 147 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 148 |
+
NonCommercial purposes only.
|
| 149 |
+
|
| 150 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 151 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 152 |
+
License does not apply, and You do not need to comply with
|
| 153 |
+
its terms and conditions.
|
| 154 |
+
|
| 155 |
+
3. Term. The term of this Public License is specified in Section
|
| 156 |
+
6(a).
|
| 157 |
+
|
| 158 |
+
4. Media and formats; technical modifications allowed. The
|
| 159 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 160 |
+
all media and formats whether now known or hereafter created,
|
| 161 |
+
and to make technical modifications necessary to do so. The
|
| 162 |
+
Licensor waives and/or agrees not to assert any right or
|
| 163 |
+
authority to forbid You from making technical modifications
|
| 164 |
+
necessary to exercise the Licensed Rights, including
|
| 165 |
+
technical modifications necessary to circumvent Effective
|
| 166 |
+
Technological Measures. For purposes of this Public License,
|
| 167 |
+
simply making modifications authorized by this Section 2(a)
|
| 168 |
+
(4) never produces Adapted Material.
|
| 169 |
+
|
| 170 |
+
5. Downstream recipients.
|
| 171 |
+
|
| 172 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 173 |
+
recipient of the Licensed Material automatically
|
| 174 |
+
receives an offer from the Licensor to exercise the
|
| 175 |
+
Licensed Rights under the terms and conditions of this
|
| 176 |
+
Public License.
|
| 177 |
+
|
| 178 |
+
b. No downstream restrictions. You may not offer or impose
|
| 179 |
+
any additional or different terms or conditions on, or
|
| 180 |
+
apply any Effective Technological Measures to, the
|
| 181 |
+
Licensed Material if doing so restricts exercise of the
|
| 182 |
+
Licensed Rights by any recipient of the Licensed
|
| 183 |
+
Material.
|
| 184 |
+
|
| 185 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 186 |
+
may be construed as permission to assert or imply that You
|
| 187 |
+
are, or that Your use of the Licensed Material is, connected
|
| 188 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 189 |
+
the Licensor or others designated to receive attribution as
|
| 190 |
+
provided in Section 3(a)(1)(A)(i).
|
| 191 |
+
|
| 192 |
+
b. Other rights.
|
| 193 |
+
|
| 194 |
+
1. Moral rights, such as the right of integrity, are not
|
| 195 |
+
licensed under this Public License, nor are publicity,
|
| 196 |
+
privacy, and/or other similar personality rights; however, to
|
| 197 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 198 |
+
assert any such rights held by the Licensor to the limited
|
| 199 |
+
extent necessary to allow You to exercise the Licensed
|
| 200 |
+
Rights, but not otherwise.
|
| 201 |
+
|
| 202 |
+
2. Patent and trademark rights are not licensed under this
|
| 203 |
+
Public License.
|
| 204 |
+
|
| 205 |
+
3. To the extent possible, the Licensor waives any right to
|
| 206 |
+
collect royalties from You for the exercise of the Licensed
|
| 207 |
+
Rights, whether directly or through a collecting society
|
| 208 |
+
under any voluntary or waivable statutory or compulsory
|
| 209 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 210 |
+
reserves any right to collect such royalties, including when
|
| 211 |
+
the Licensed Material is used other than for NonCommercial
|
| 212 |
+
purposes.
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
Section 3 -- License Conditions.
|
| 216 |
+
|
| 217 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 218 |
+
following conditions.
|
| 219 |
+
|
| 220 |
+
a. Attribution.
|
| 221 |
+
|
| 222 |
+
1. If You Share the Licensed Material (including in modified
|
| 223 |
+
form), You must:
|
| 224 |
+
|
| 225 |
+
a. retain the following if it is supplied by the Licensor
|
| 226 |
+
with the Licensed Material:
|
| 227 |
+
|
| 228 |
+
i. identification of the creator(s) of the Licensed
|
| 229 |
+
Material and any others designated to receive
|
| 230 |
+
attribution, in any reasonable manner requested by
|
| 231 |
+
the Licensor (including by pseudonym if
|
| 232 |
+
designated);
|
| 233 |
+
|
| 234 |
+
ii. a copyright notice;
|
| 235 |
+
|
| 236 |
+
iii. a notice that refers to this Public License;
|
| 237 |
+
|
| 238 |
+
iv. a notice that refers to the disclaimer of
|
| 239 |
+
warranties;
|
| 240 |
+
|
| 241 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 242 |
+
extent reasonably practicable;
|
| 243 |
+
|
| 244 |
+
b. indicate if You modified the Licensed Material and
|
| 245 |
+
retain an indication of any previous modifications; and
|
| 246 |
+
|
| 247 |
+
c. indicate the Licensed Material is licensed under this
|
| 248 |
+
Public License, and include the text of, or the URI or
|
| 249 |
+
hyperlink to, this Public License.
|
| 250 |
+
|
| 251 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 252 |
+
reasonable manner based on the medium, means, and context in
|
| 253 |
+
which You Share the Licensed Material. For example, it may be
|
| 254 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 255 |
+
hyperlink to a resource that includes the required
|
| 256 |
+
information.
|
| 257 |
+
|
| 258 |
+
3. If requested by the Licensor, You must remove any of the
|
| 259 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 260 |
+
reasonably practicable.
|
| 261 |
+
|
| 262 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 263 |
+
License You apply must not prevent recipients of the Adapted
|
| 264 |
+
Material from complying with this Public License.
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
Section 4 -- Sui Generis Database Rights.
|
| 268 |
+
|
| 269 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 270 |
+
apply to Your use of the Licensed Material:
|
| 271 |
+
|
| 272 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 273 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 274 |
+
portion of the contents of the database for NonCommercial purposes
|
| 275 |
+
only;
|
| 276 |
+
|
| 277 |
+
b. if You include all or a substantial portion of the database
|
| 278 |
+
contents in a database in which You have Sui Generis Database
|
| 279 |
+
Rights, then the database in which You have Sui Generis Database
|
| 280 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 281 |
+
|
| 282 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 283 |
+
all or a substantial portion of the contents of the database.
|
| 284 |
+
|
| 285 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 286 |
+
replace Your obligations under this Public License where the Licensed
|
| 287 |
+
Rights include other Copyright and Similar Rights.
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 291 |
+
|
| 292 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 293 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 294 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 295 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 296 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 297 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 298 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 299 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 300 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 301 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 302 |
+
|
| 303 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 304 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 305 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 306 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 307 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 308 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 309 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 310 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 311 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 312 |
+
|
| 313 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 314 |
+
above shall be interpreted in a manner that, to the extent
|
| 315 |
+
possible, most closely approximates an absolute disclaimer and
|
| 316 |
+
waiver of all liability.
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
Section 6 -- Term and Termination.
|
| 320 |
+
|
| 321 |
+
a. This Public License applies for the term of the Copyright and
|
| 322 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 323 |
+
this Public License, then Your rights under this Public License
|
| 324 |
+
terminate automatically.
|
| 325 |
+
|
| 326 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 327 |
+
Section 6(a), it reinstates:
|
| 328 |
+
|
| 329 |
+
1. automatically as of the date the violation is cured, provided
|
| 330 |
+
it is cured within 30 days of Your discovery of the
|
| 331 |
+
violation; or
|
| 332 |
+
|
| 333 |
+
2. upon express reinstatement by the Licensor.
|
| 334 |
+
|
| 335 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 336 |
+
right the Licensor may have to seek remedies for Your violations
|
| 337 |
+
of this Public License.
|
| 338 |
+
|
| 339 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 340 |
+
Licensed Material under separate terms or conditions or stop
|
| 341 |
+
distributing the Licensed Material at any time; however, doing so
|
| 342 |
+
will not terminate this Public License.
|
| 343 |
+
|
| 344 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 345 |
+
License.
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
Section 7 -- Other Terms and Conditions.
|
| 349 |
+
|
| 350 |
+
a. The Licensor shall not be bound by any additional or different
|
| 351 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 352 |
+
|
| 353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 354 |
+
Licensed Material not stated herein are separate from and
|
| 355 |
+
independent of the terms and conditions of this Public License.
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
Section 8 -- Interpretation.
|
| 359 |
+
|
| 360 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 361 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 362 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 363 |
+
be made without permission under this Public License.
|
| 364 |
+
|
| 365 |
+
b. To the extent possible, if any provision of this Public License is
|
| 366 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 367 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 368 |
+
cannot be reformed, it shall be severed from this Public License
|
| 369 |
+
without affecting the enforceability of the remaining terms and
|
| 370 |
+
conditions.
|
| 371 |
+
|
| 372 |
+
c. No term or condition of this Public License will be waived and no
|
| 373 |
+
failure to comply consented to unless expressly agreed to by the
|
| 374 |
+
Licensor.
|
| 375 |
+
|
| 376 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 377 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 378 |
+
that apply to the Licensor or You, including from the legal
|
| 379 |
+
processes of any jurisdiction or authority.
|
| 380 |
+
|
| 381 |
+
=======================================================================
|
| 382 |
+
|
| 383 |
+
Creative Commons is not a party to its public
|
| 384 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 385 |
+
its public licenses to material it publishes and in those instances
|
| 386 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 387 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 388 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 389 |
+
material is shared under a Creative Commons public license or as
|
| 390 |
+
otherwise permitted by the Creative Commons policies published at
|
| 391 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 392 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 393 |
+
of Creative Commons without its prior written consent including,
|
| 394 |
+
without limitation, in connection with any unauthorized modifications
|
| 395 |
+
to any of its public licenses or any other arrangements,
|
| 396 |
+
understandings, or agreements concerning use of licensed material. For
|
| 397 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 398 |
+
public licenses.
|
| 399 |
+
|
| 400 |
+
Creative Commons may be contacted at creativecommons.org.
|
LICENSE_GROVER
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 Tencent AI Lab. All rights reserved.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
| 23 |
+
=============================================================================
|
| 24 |
+
|
| 25 |
+
Other dependencies and licenses:
|
| 26 |
+
|
| 27 |
+
-----------------------------------------------------------------------------
|
| 28 |
+
The MIT License (MIT)
|
| 29 |
+
applies to:
|
| 30 |
+
- chemprop (https://github.com/chemprop/chemprop)
|
| 31 |
+
-----------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
MIT License
|
| 34 |
+
|
| 35 |
+
Copyright (c) 2020 Wengong Jin, Kyle Swanson, Kevin Yang, Regina Barzilay, Tommi Jaakkola
|
| 36 |
+
|
| 37 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 38 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 39 |
+
in the Software without restriction, including without limitation the rights
|
| 40 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 41 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 42 |
+
furnished to do so, subject to the following conditions:
|
| 43 |
+
|
| 44 |
+
The above copyright notice and this permission notice shall be included in all
|
| 45 |
+
copies or substantial portions of the Software.
|
| 46 |
+
|
| 47 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 48 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 49 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 50 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 51 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 52 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 53 |
+
SOFTWARE.
|
MODEL_CARD.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model card - tox21_grover_classifier
|
| 2 |
+
### Model details
|
| 3 |
+
- Model name: Grover Tox21 Baseline
|
| 4 |
+
- Developer: Tencent AI Lab (trained by JKU Linz)
|
| 5 |
+
- Paper URL: https://arxiv.org/pdf/2007.02835
|
| 6 |
+
- Model type / architecture:
|
| 7 |
+
- Grover implemented using the code accompanying the paper.
|
| 8 |
+
- The pretrained grover_base model is used as provided.
|
| 9 |
+
- A multitask network is finetuned for all Tox21 targets.
|
| 10 |
+
- Inference: Access via FastAPI endpoint. Upon a Tox21 prediction request, the model generates and returns predictions for all Tox21 targets simultaneously.
|
| 11 |
+
- Model version: v0
|
| 12 |
+
- Model date: 10.12.2025
|
| 13 |
+
- Reproducibility: Code for full training is available and enables refitting of TabPFN
|
| 14 |
+
from scratch.
|
| 15 |
+
- Reproducibility: Code for full training is available and enables retraining of the model from scratch.
|
| 16 |
+
|
| 17 |
+
### Intended use
|
| 18 |
+
This model serves as a baseline benchmark for evaluating and comparing toxicity prediction
|
| 19 |
+
methods across the 12 pathway assays of the Tox21 dataset. It is not intended for clinical
|
| 20 |
+
decision-making without experimental validation.
|
| 21 |
+
|
| 22 |
+
### Metric
|
| 23 |
+
Each Tox21 task is evaluated using the area under the receiver operating characteristic curve (AUC). Overall performance is reported as the mean AUC across all individual tasks.
|
| 24 |
+
|
| 25 |
+
### Training data
|
| 26 |
+
Tox21 training and validation sets.
|
| 27 |
+
|
| 28 |
+
### Evaluation data
|
| 29 |
+
Tox21 test set.
|
app.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is the main entry point for the FastAPI application.
|
| 3 |
+
The app handles the request to predict toxicity for a list of SMILES strings.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ---------------------------------------------------------------------------------------
|
| 7 |
+
# Dependencies and global variable definition
|
| 8 |
+
import os
|
| 9 |
+
from typing import List, Dict, Optional
|
| 10 |
+
from fastapi import FastAPI, Header, HTTPException
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
from predict import predict as predict_func
|
| 14 |
+
|
| 15 |
+
API_KEY = os.getenv("API_KEY") # set via Space Secrets
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------------------
|
| 19 |
+
class Request(BaseModel):
|
| 20 |
+
smiles: List[str] = Field(min_items=1, max_items=1000)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Response(BaseModel):
|
| 24 |
+
predictions: dict
|
| 25 |
+
model_info: Dict[str, str] = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
app = FastAPI(title="toxicity-api")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.get("/")
|
| 32 |
+
def root():
|
| 33 |
+
return {
|
| 34 |
+
"message": "Toxicity Prediction API",
|
| 35 |
+
"endpoints": {
|
| 36 |
+
"/metadata": "GET - API metadata and capabilities",
|
| 37 |
+
"/healthz": "GET - Health check",
|
| 38 |
+
"/predict": "POST - Predict toxicity for SMILES",
|
| 39 |
+
},
|
| 40 |
+
"usage": "Send POST to /predict with {'smiles': ['your_smiles_here']}",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.get("/metadata")
|
| 45 |
+
def metadata():
|
| 46 |
+
return {
|
| 47 |
+
"name": "Tox21 GROVER Classifier",
|
| 48 |
+
"version": "0.1.0",
|
| 49 |
+
"tox_endpoints": [
|
| 50 |
+
"NR-AR",
|
| 51 |
+
"NR-AR-LBD",
|
| 52 |
+
"NR-AhR",
|
| 53 |
+
"NR-Aromatase",
|
| 54 |
+
"NR-ER",
|
| 55 |
+
"NR-ER-LBD",
|
| 56 |
+
"NR-PPAR-gamma",
|
| 57 |
+
"SR-ARE",
|
| 58 |
+
"SR-ATAD5",
|
| 59 |
+
"SR-HSE",
|
| 60 |
+
"SR-MMP",
|
| 61 |
+
"SR-p53",
|
| 62 |
+
],
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@app.get("/healthz")
|
| 67 |
+
def healthz():
|
| 68 |
+
return {"ok": True}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@app.post("/predict", response_model=Response)
|
| 72 |
+
def predict(request: Request):
|
| 73 |
+
predictions = predict_func(request.smiles)
|
| 74 |
+
return {
|
| 75 |
+
"predictions": predictions,
|
| 76 |
+
"model_info": {"name": "Tox21 GROVER classifier", "version": "0.1.0"},
|
| 77 |
+
}
|
config/config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"batch_size": 32,
|
| 3 |
+
"init_lr": 10,
|
| 4 |
+
"max_lr": 0.0001,
|
| 5 |
+
"final_lr": 9,
|
| 6 |
+
"dropout": 0.2,
|
| 7 |
+
"attn_hidden": 128,
|
| 8 |
+
"attn_out": 8,
|
| 9 |
+
"dist_coff": 0.15,
|
| 10 |
+
"bond_drop_rate": 0.2,
|
| 11 |
+
"ffn_num_layer": 2,
|
| 12 |
+
"ffn_hidden_size": 13,
|
| 13 |
+
"real_init_lr": 1e-05,
|
| 14 |
+
"real_final_lr": 1.1111111111111112e-05
|
| 15 |
+
}
|
environment.yaml
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: grover
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
- pytorch
|
| 6 |
+
- https://repo.anaconda.com/pkgs/free
|
| 7 |
+
- rmg
|
| 8 |
+
- rdkit
|
| 9 |
+
dependencies:
|
| 10 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 11 |
+
- _openmp_mutex=4.5=2_gnu
|
| 12 |
+
- absl-py=2.3.1=pyhd8ed1ab_0
|
| 13 |
+
- aom=3.5.0=h27087fc_0
|
| 14 |
+
- blas=1.0=mkl
|
| 15 |
+
- boost=1.74.0=py39h5472131_5
|
| 16 |
+
- boost-cpp=1.74.0=h75c5d50_8
|
| 17 |
+
- brotli-python=1.1.0=py39hf88036b_3
|
| 18 |
+
- bzip2=1.0.8=hda65f42_8
|
| 19 |
+
- c-ares=1.34.5=hb9d3cd8_0
|
| 20 |
+
- ca-certificates=2025.11.12=hbd8a1cb_0
|
| 21 |
+
- cairo=1.18.4=h44eff21_0
|
| 22 |
+
- certifi=2025.8.3=pyhd8ed1ab_0
|
| 23 |
+
- cffi=1.17.1=py39h15c3d72_0
|
| 24 |
+
- charset-normalizer=3.4.3=pyhd8ed1ab_0
|
| 25 |
+
- colorama=0.4.6=pyhd8ed1ab_1
|
| 26 |
+
- cpython=3.9.23=py39hd8ed1ab_0
|
| 27 |
+
- cuda-cudart=12.4.127=he02047a_2
|
| 28 |
+
- cuda-cudart_linux-64=12.4.127=h85509e4_2
|
| 29 |
+
- cuda-cupti=12.4.127=he02047a_2
|
| 30 |
+
- cuda-libraries=12.4.1=ha770c72_1
|
| 31 |
+
- cuda-nvrtc=12.4.127=he02047a_2
|
| 32 |
+
- cuda-nvtx=12.4.127=he02047a_2
|
| 33 |
+
- cuda-opencl=12.4.127=he02047a_1
|
| 34 |
+
- cuda-runtime=12.4.1=ha804496_0
|
| 35 |
+
- cuda-version=12.4=h3060b56_3
|
| 36 |
+
- ffmpeg=5.1.2=gpl_h8dda1f0_106
|
| 37 |
+
- filelock=3.19.1=pyhd8ed1ab_0
|
| 38 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
| 39 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
| 40 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
| 41 |
+
- font-ttf-ubuntu=0.83=h77eed37_3
|
| 42 |
+
- fontconfig=2.15.0=h7e30c49_1
|
| 43 |
+
- fonts-conda-ecosystem=1=0
|
| 44 |
+
- fonts-conda-forge=1=hc364b38_1
|
| 45 |
+
- freetype=2.14.1=ha770c72_0
|
| 46 |
+
- gmp=6.3.0=hac33072_2
|
| 47 |
+
- gmpy2=2.2.1=py39h7196dd7_0
|
| 48 |
+
- gnutls=3.7.9=hb077bed_0
|
| 49 |
+
- grpcio=1.74.1=py39h6482d51_0
|
| 50 |
+
- h2=4.2.0=pyhd8ed1ab_0
|
| 51 |
+
- hpack=4.1.0=pyhd8ed1ab_0
|
| 52 |
+
- hyperframe=6.1.0=pyhd8ed1ab_0
|
| 53 |
+
- icu=70.1=h27087fc_0
|
| 54 |
+
- idna=3.10=pyhd8ed1ab_1
|
| 55 |
+
- importlib-metadata=8.7.0=pyhe01879c_1
|
| 56 |
+
- intel-openmp=2023.0.0=h9e868ea_25371
|
| 57 |
+
- jinja2=3.1.6=pyhd8ed1ab_0
|
| 58 |
+
- joblib=1.5.1=pyhd8ed1ab_0
|
| 59 |
+
- lame=3.100=h166bdaf_1003
|
| 60 |
+
- lcms2=2.17=h717163a_0
|
| 61 |
+
- ld_impl_linux-64=2.45=default_hbd61a6d_104
|
| 62 |
+
- lerc=4.0.0=h0aef613_1
|
| 63 |
+
- libabseil=20250512.1=cxx17_hba17884_0
|
| 64 |
+
- libblas=3.9.0=1_h86c2bf4_netlib
|
| 65 |
+
- libcblas=3.9.0=13_h8e06fc2_netlib
|
| 66 |
+
- libcublas=12.4.5.8=he02047a_2
|
| 67 |
+
- libcufft=11.2.1.3=he02047a_2
|
| 68 |
+
- libcufile=1.9.1.3=he02047a_2
|
| 69 |
+
- libcurand=10.3.5.147=he02047a_2
|
| 70 |
+
- libcusolver=11.6.1.9=he02047a_2
|
| 71 |
+
- libcusparse=12.3.1.170=he02047a_2
|
| 72 |
+
- libdeflate=1.25=h17f619e_0
|
| 73 |
+
- libdrm=2.4.125=hb03c661_1
|
| 74 |
+
- libegl=1.7.0=ha4b6fd6_2
|
| 75 |
+
- libexpat=2.7.3=hecca717_0
|
| 76 |
+
- libffi=3.4.6=h2dba641_1
|
| 77 |
+
- libfreetype=2.14.1=ha770c72_0
|
| 78 |
+
- libfreetype6=2.14.1=h73754d4_0
|
| 79 |
+
- libgcc=15.2.0=he0feb66_13
|
| 80 |
+
- libgcc-ng=15.2.0=h69a702a_13
|
| 81 |
+
- libgfortran=15.2.0=h69a702a_13
|
| 82 |
+
- libgfortran-ng=15.2.0=h69a702a_13
|
| 83 |
+
- libgfortran5=15.2.0=h68bc16d_13
|
| 84 |
+
- libgl=1.7.0=ha4b6fd6_2
|
| 85 |
+
- libglib=2.86.0=h1fed272_0
|
| 86 |
+
- libglvnd=1.7.0=ha4b6fd6_2
|
| 87 |
+
- libglx=1.7.0=ha4b6fd6_2
|
| 88 |
+
- libgomp=15.2.0=he0feb66_13
|
| 89 |
+
- libgrpc=1.74.1=hebd82d6_0
|
| 90 |
+
- libhwloc=2.9.1=hd6dc26d_0
|
| 91 |
+
- libiconv=1.18=h3b78370_2
|
| 92 |
+
- libidn2=2.3.8=hfac485b_1
|
| 93 |
+
- libjpeg-turbo=3.1.2=hb03c661_0
|
| 94 |
+
- liblapack=3.9.0=13_h8876d29_netlib
|
| 95 |
+
- liblzma=5.8.1=hb9d3cd8_2
|
| 96 |
+
- liblzma-devel=5.8.1=hb9d3cd8_2
|
| 97 |
+
- libnpp=12.2.5.30=he02047a_2
|
| 98 |
+
- libnsl=2.0.1=hb9d3cd8_1
|
| 99 |
+
- libnvfatbin=12.4.127=he02047a_2
|
| 100 |
+
- libnvjitlink=12.4.127=he02047a_2
|
| 101 |
+
- libnvjpeg=12.3.1.117=he02047a_2
|
| 102 |
+
- libopus=1.5.2=hd0c01bc_0
|
| 103 |
+
- libpciaccess=0.18=hb9d3cd8_0
|
| 104 |
+
- libpng=1.6.51=h421ea60_0
|
| 105 |
+
- libprotobuf=6.31.1=h49aed37_2
|
| 106 |
+
- libre2-11=2025.11.05=h7b12aa8_0
|
| 107 |
+
- libsqlite=3.51.0=h0c1763c_0
|
| 108 |
+
- libstdcxx=15.2.0=h934c35e_13
|
| 109 |
+
- libstdcxx-ng=15.2.0=hdf11a46_13
|
| 110 |
+
- libtasn1=4.20.0=hb03c661_1
|
| 111 |
+
- libtiff=4.7.1=h9d88235_1
|
| 112 |
+
- libunistring=0.9.10=h7f98852_0
|
| 113 |
+
- libuuid=2.41.2=he9a06e4_0
|
| 114 |
+
- libva=2.22.0=h4f16b4b_2
|
| 115 |
+
- libvpx=1.11.0=h9c3ff4c_3
|
| 116 |
+
- libwebp-base=1.6.0=hd42ef1d_0
|
| 117 |
+
- libxcb=1.17.0=h8a09558_0
|
| 118 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 119 |
+
- libxml2=2.10.3=hca2bb57_4
|
| 120 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
| 121 |
+
- llvm-openmp=15.0.7=h0cdce71_0
|
| 122 |
+
- markdown=3.8.2=pyhd8ed1ab_0
|
| 123 |
+
- markupsafe=3.0.2=py39h9399b63_1
|
| 124 |
+
- mkl=2023.1.0=h213fc3f_46344
|
| 125 |
+
- mpc=1.3.1=h24ddda3_1
|
| 126 |
+
- mpfr=4.2.1=h90cbb55_3
|
| 127 |
+
- mpmath=1.3.0=pyhd8ed1ab_1
|
| 128 |
+
- ncurses=6.5=h2d0b736_3
|
| 129 |
+
- nettle=3.9.1=h7ab15ed_0
|
| 130 |
+
- networkx=3.2.1=pyhd8ed1ab_0
|
| 131 |
+
- numpy=1.26.4=py39h474f0d3_0
|
| 132 |
+
- ocl-icd=2.3.3=hb9d3cd8_0
|
| 133 |
+
- opencl-headers=2025.06.13=h5888daf_0
|
| 134 |
+
- openh264=2.3.1=hcb278e6_2
|
| 135 |
+
- openjpeg=2.5.4=h55fea9a_0
|
| 136 |
+
- openssl=3.6.0=h26f9b46_0
|
| 137 |
+
- p11-kit=0.24.1=hc5aa10d_0
|
| 138 |
+
- packaging=25.0=pyh29332c3_1
|
| 139 |
+
- pandas=2.3.1=py39h1b6b32d_0
|
| 140 |
+
- pcre2=10.46=h1321c63_0
|
| 141 |
+
- pillow=11.3.0=py39h15c0740_0
|
| 142 |
+
- pip=25.2=pyh8b19718_0
|
| 143 |
+
- pixman=0.46.4=h54a6638_1
|
| 144 |
+
- protobuf=6.31.1=py39h2f5525a_0
|
| 145 |
+
- pthread-stubs=0.4=hb9d3cd8_1002
|
| 146 |
+
- pycairo=1.28.0=py39ha09e8ab_0
|
| 147 |
+
- pycparser=2.22=pyh29332c3_1
|
| 148 |
+
- pysocks=1.7.1=pyha55dd90_7
|
| 149 |
+
- python=3.9.23=hc30ae73_0_cpython
|
| 150 |
+
- python-dateutil=2.9.0.post0=pyhe01879c_2
|
| 151 |
+
- python-tzdata=2025.2=pyhd8ed1ab_0
|
| 152 |
+
- python_abi=3.9=8_cp39
|
| 153 |
+
- pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
|
| 154 |
+
- pytorch-cuda=12.4=hc786d27_7
|
| 155 |
+
- pytorch-mutex=1.0=cuda
|
| 156 |
+
- pytz=2025.2=pyhd8ed1ab_0
|
| 157 |
+
- pyyaml=6.0.2=py39h9399b63_2
|
| 158 |
+
- re2=2025.11.05=h5301d42_0
|
| 159 |
+
- readline=8.2=h8c095d6_2
|
| 160 |
+
- requests=2.32.5=pyhd8ed1ab_0
|
| 161 |
+
- scikit-learn=1.6.1=py39h4b7350c_0
|
| 162 |
+
- setuptools=80.9.0=pyhff2d567_0
|
| 163 |
+
- six=1.17.0=pyhe01879c_1
|
| 164 |
+
- svt-av1=1.4.1=hcb278e6_0
|
| 165 |
+
- sympy=1.14.0=pyh2585a3b_105
|
| 166 |
+
- tbb=2021.9.0=hf52228f_0
|
| 167 |
+
- tensorboard=2.20.0=pyhe01879c_0
|
| 168 |
+
- tensorboard-data-server=0.7.0=py39h7170ec2_2
|
| 169 |
+
- threadpoolctl=3.6.0=pyhecae5ae_0
|
| 170 |
+
- tk=8.6.13=noxft_ha0e22de_103
|
| 171 |
+
- torchtriton=3.0.0=py39
|
| 172 |
+
- torchvision=0.19.0=py39_cu124
|
| 173 |
+
- tqdm=4.67.1=pyhd8ed1ab_1
|
| 174 |
+
- typing_extensions=4.14.1=pyhe01879c_0
|
| 175 |
+
- tzdata=2025b=h78e105d_0
|
| 176 |
+
- urllib3=2.5.0=pyhd8ed1ab_0
|
| 177 |
+
- wayland=1.24.0=h3e06ad9_0
|
| 178 |
+
- wayland-protocols=1.46=hd8ed1ab_0
|
| 179 |
+
- werkzeug=3.1.3=pyhd8ed1ab_1
|
| 180 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 181 |
+
- x264=1!164.3095=h166bdaf_2
|
| 182 |
+
- x265=3.5=h924138e_3
|
| 183 |
+
- xorg-libx11=1.8.12=h4f16b4b_0
|
| 184 |
+
- xorg-libxau=1.0.12=hb03c661_1
|
| 185 |
+
- xorg-libxdmcp=1.1.5=hb03c661_1
|
| 186 |
+
- xorg-libxext=1.3.6=hb9d3cd8_0
|
| 187 |
+
- xorg-libxfixes=6.0.2=hb03c661_0
|
| 188 |
+
- xorg-libxrender=0.9.12=hb9d3cd8_0
|
| 189 |
+
- xz=5.8.1=hbcc6ac9_2
|
| 190 |
+
- xz-gpl-tools=5.8.1=hbcc6ac9_2
|
| 191 |
+
- xz-tools=5.8.1=hb9d3cd8_2
|
| 192 |
+
- yaml=0.2.5=h280c20c_3
|
| 193 |
+
- zipp=3.23.0=pyhd8ed1ab_0
|
| 194 |
+
- zlib=1.3.1=hb9d3cd8_2
|
| 195 |
+
- zstandard=0.23.0=py39hd399759_3
|
| 196 |
+
- zstd=1.5.7=hb8e6e7a_2
|
| 197 |
+
- pip:
|
| 198 |
+
- aiohappyeyeballs==2.6.1
|
| 199 |
+
- aiohttp==3.13.2
|
| 200 |
+
- aiosignal==1.4.0
|
| 201 |
+
- anyio==4.12.0
|
| 202 |
+
- async-timeout==5.0.1
|
| 203 |
+
- attrs==25.4.0
|
| 204 |
+
- click==8.1.8
|
| 205 |
+
- datasets==4.4.1
|
| 206 |
+
- descriptastorus==2.8.0
|
| 207 |
+
- dill==0.4.0
|
| 208 |
+
- exceptiongroup==1.3.1
|
| 209 |
+
- frozenlist==1.8.0
|
| 210 |
+
- fsspec==2025.10.0
|
| 211 |
+
- h11==0.16.0
|
| 212 |
+
- hf-xet==1.2.0
|
| 213 |
+
- httpcore==1.0.9
|
| 214 |
+
- httpx==0.28.1
|
| 215 |
+
- huggingface-hub==1.1.7
|
| 216 |
+
- multidict==6.7.0
|
| 217 |
+
- multiprocess==0.70.18
|
| 218 |
+
- pandas-flavor==0.7.0
|
| 219 |
+
- propcache==0.4.1
|
| 220 |
+
- pyarrow==21.0.0
|
| 221 |
+
- rdkit==2025.9.1
|
| 222 |
+
- rdkit-pypi==2022.9.5
|
| 223 |
+
- scipy==1.10.1
|
| 224 |
+
- shellingham==1.5.4
|
| 225 |
+
- typer-slim==0.20.0
|
| 226 |
+
- xarray==2024.7.0
|
| 227 |
+
- xxhash==3.6.0
|
| 228 |
+
- yarl==1.22.0
|
| 229 |
+
prefix: /system/apps/userenv/stopf/grover
|
evaluate.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from src.eval import compute_roc_auc_from_csv
|
| 4 |
+
from src.commands import predict_from_csv
|
| 5 |
+
|
| 6 |
+
data_path = "./tox21/tox21_test_clean.csv"
|
| 7 |
+
features_path = data_path.replace(".csv", ".npz")
|
| 8 |
+
checkpoint_dir = "checkpoints"
|
| 9 |
+
output_path = "predictions/test_set_preds_best.csv"
|
| 10 |
+
|
| 11 |
+
predict_from_csv(data_path, features_path, checkpoint_dir, output_path)
|
| 12 |
+
|
| 13 |
+
valid_mask = np.load("./tox21/valid_mask_test.npy")
|
| 14 |
+
auc_array, mean_auc = compute_roc_auc_from_csv(output_path, "./tox21/tox21_test.csv", valid_mask)
|
| 15 |
+
|
| 16 |
+
print(auc_array)
|
| 17 |
+
print(mean_auc)
|
generate_features.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from src.commands import generate_features
|
| 4 |
+
|
| 5 |
+
# paths
|
| 6 |
+
TRAIN_CSV = "./tox21/tox21_train_clean.csv"
|
| 7 |
+
VAL_CSV = "./tox21/tox21_validation_clean.csv"
|
| 8 |
+
|
| 9 |
+
generate_features(EXAMPLES_CSV, EXAMPLES_CSV.replace('.csv', '.npz'))
|
grover
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 3f280d7d3419a781d303b1500c7039e37a1d87a2
|
hp_search.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
from src.hp_search import generate_random_search
|
| 7 |
+
from src.commands import finetune, predict_from_csv
|
| 8 |
+
from src.eval import compute_roc_auc_from_csv
|
| 9 |
+
|
| 10 |
+
HYPERPARAM_GRID = {
|
| 11 |
+
"batch_size": [32],
|
| 12 |
+
"init_lr": [10],
|
| 13 |
+
"max_lr": [0.001, 0.0005, 0.0001],
|
| 14 |
+
"final_lr": [2, 3, 4, 5, 6, 7, 8, 9, 10],
|
| 15 |
+
"dropout": [0.0, 0.05, 0.1, 0.2],
|
| 16 |
+
"attn_hidden": [128],
|
| 17 |
+
"attn_out": [4, 8],
|
| 18 |
+
"dist_coff": [0.05, 0.1, 0.15],
|
| 19 |
+
"bond_drop_rate": [0.0, 0.2, 0.4, 0.6],
|
| 20 |
+
"ffn_num_layer": [2, 3],
|
| 21 |
+
"ffn_hidden_size": [5, 7, 13],
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
hp_grid = generate_random_search(HYPERPARAM_GRID, num_trials=300, seed=42)
|
| 25 |
+
print("Total number of configs:", len(hp_grid))
|
| 26 |
+
|
| 27 |
+
# general vars
|
| 28 |
+
train_path = "tox21/tox21_train_clean.csv"
|
| 29 |
+
val_path = "tox21/tox21_validation_clean.csv"
|
| 30 |
+
train_features_path = train_path.replace(".csv", ".npz")
|
| 31 |
+
val_features_path = val_path.replace(".csv", ".npz")
|
| 32 |
+
checkpoint_path = "pretrained_models/grover_base.pt"
|
| 33 |
+
|
| 34 |
+
# Tracking best model
|
| 35 |
+
best_mean_auc = -1
|
| 36 |
+
best_config = None
|
| 37 |
+
best_model_path = None
|
| 38 |
+
|
| 39 |
+
# Create directory for logs
|
| 40 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 41 |
+
log_dir = f"hp_search/logs/{timestamp}"
|
| 42 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 43 |
+
overall_log_path = f"{log_dir}/hp_search_results.txt"
|
| 44 |
+
best_log_path = f"{log_dir}/best_result.txt"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# iterate over configs
|
| 48 |
+
for i, args in enumerate(hp_grid):
|
| 49 |
+
save_dir = f"hp_search/trials/Trial_{i+1}"
|
| 50 |
+
|
| 51 |
+
print("\n=========================================")
|
| 52 |
+
print("Training with config:")
|
| 53 |
+
print(args)
|
| 54 |
+
print("Save dir:", save_dir)
|
| 55 |
+
print("=========================================\n")
|
| 56 |
+
|
| 57 |
+
# finetune model
|
| 58 |
+
finetune(train_path, val_path, train_features_path, val_features_path,
|
| 59 |
+
save_dir, checkpoint_path, args)
|
| 60 |
+
|
| 61 |
+
# predict on val set
|
| 62 |
+
finetuned_model_dir = save_dir + "/fold_0/model_0"
|
| 63 |
+
output_path = save_dir + "/predictions.csv"
|
| 64 |
+
predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path)
|
| 65 |
+
|
| 66 |
+
# evaluate model
|
| 67 |
+
preds_path = save_dir + "/predictions.csv"
|
| 68 |
+
labels_path = "tox21/tox21_validation.csv"
|
| 69 |
+
valid_mask = np.load("./tox21/valid_mask_val.npy")
|
| 70 |
+
auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask)
|
| 71 |
+
|
| 72 |
+
# Save all experiment results
|
| 73 |
+
with open(overall_log_path, "a") as f:
|
| 74 |
+
f.write("\n===============================\n")
|
| 75 |
+
f.write(f"Trial Num: {i+1}\n")
|
| 76 |
+
f.write(f"Mean AUC: {mean_auc}\n")
|
| 77 |
+
f.write(f"Config: {args}\n")
|
| 78 |
+
f.write(f"Save dir: {save_dir}\n")
|
| 79 |
+
f.write(f"AUC per task: {auc_per_task}\n")
|
| 80 |
+
|
| 81 |
+
# Check if best model
|
| 82 |
+
if mean_auc > best_mean_auc:
|
| 83 |
+
print("New BEST model found!")
|
| 84 |
+
best_mean_auc = mean_auc
|
| 85 |
+
best_config = args
|
| 86 |
+
best_model_path = save_dir
|
| 87 |
+
|
| 88 |
+
with open(best_log_path, "w") as f:
|
| 89 |
+
f.write("==== BEST MODEL SO FAR ====\n")
|
| 90 |
+
f.write(f"Trial Num: {i+1}\n")
|
| 91 |
+
f.write(f"Mean AUC: {best_mean_auc}\n")
|
| 92 |
+
f.write(f"Config: {best_config}\n")
|
| 93 |
+
f.write(f"Saved at: {best_model_path}\n")
|
| 94 |
+
|
| 95 |
+
print("\n============================")
|
| 96 |
+
print("Hyperparameter Search DONE!")
|
| 97 |
+
print("Trial Num: ", i+1)
|
| 98 |
+
print("Best mean AUC:", best_mean_auc)
|
| 99 |
+
print("Best model saved at:", best_model_path)
|
| 100 |
+
print("Best config:", best_config)
|
| 101 |
+
print("============================\n")
|
main.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from rdkit import RDLogger
|
| 6 |
+
|
| 7 |
+
from grover.util.parsing import parse_args, get_newest_train_args
|
| 8 |
+
from grover.util.utils import create_logger
|
| 9 |
+
from grover.task.cross_validate import cross_validate
|
| 10 |
+
from grover.task.fingerprint import generate_fingerprints
|
| 11 |
+
from grover.task.predict import make_predictions, write_prediction
|
| 12 |
+
from grover.task.pretrain import pretrain_model
|
| 13 |
+
from grover.data.torchvocab import MolVocab
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def setup(seed):
|
| 17 |
+
# frozen random seed
|
| 18 |
+
torch.manual_seed(seed)
|
| 19 |
+
torch.cuda.manual_seed_all(seed)
|
| 20 |
+
np.random.seed(seed)
|
| 21 |
+
random.seed(seed)
|
| 22 |
+
torch.backends.cudnn.deterministic = True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == '__main__':
|
| 26 |
+
# setup random seed
|
| 27 |
+
setup(seed=42)
|
| 28 |
+
# Avoid the pylint warning.
|
| 29 |
+
a = MolVocab
|
| 30 |
+
# supress rdkit logger
|
| 31 |
+
lg = RDLogger.logger()
|
| 32 |
+
lg.setLevel(RDLogger.CRITICAL)
|
| 33 |
+
|
| 34 |
+
# Initialize MolVocab
|
| 35 |
+
mol_vocab = MolVocab
|
| 36 |
+
|
| 37 |
+
args = parse_args()
|
| 38 |
+
if args.parser_name == 'finetune':
|
| 39 |
+
logger = create_logger(name='train', save_dir=args.save_dir, quiet=False)
|
| 40 |
+
cross_validate(args, logger)
|
| 41 |
+
elif args.parser_name == 'pretrain':
|
| 42 |
+
logger = create_logger(name='pretrain', save_dir=args.save_dir)
|
| 43 |
+
pretrain_model(args, logger)
|
| 44 |
+
elif args.parser_name == "eval":
|
| 45 |
+
logger = create_logger(name='eval', save_dir=args.save_dir, quiet=False)
|
| 46 |
+
cross_validate(args, logger)
|
| 47 |
+
elif args.parser_name == 'fingerprint':
|
| 48 |
+
train_args = get_newest_train_args()
|
| 49 |
+
logger = create_logger(name='fingerprint', save_dir=None, quiet=False)
|
| 50 |
+
feas = generate_fingerprints(args, logger)
|
| 51 |
+
np.savez_compressed(args.output_path, fps=feas)
|
| 52 |
+
elif args.parser_name == 'predict':
|
| 53 |
+
train_args = get_newest_train_args()
|
| 54 |
+
avg_preds, test_smiles = make_predictions(args, train_args)
|
| 55 |
+
write_prediction(avg_preds, test_smiles, args)
|
predict.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import csv
|
| 3 |
+
import subprocess
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from src.preprocess import create_clean_smiles
|
| 7 |
+
from src.commands import generate_features, predict_from_csv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def predict(smiles_list):
|
| 12 |
+
"""
|
| 13 |
+
Predict toxicity targets for a list of SMILES strings.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
smiles_list (list[str]): SMILES strings
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
dict: {smiles: {target_name: prediction_prob}}
|
| 20 |
+
"""
|
| 21 |
+
data_path = "tox21/predict_smiles.csv"
|
| 22 |
+
features_path = data_path.replace(".csv", ".npz")
|
| 23 |
+
checkpoint_dir = "checkpoints"
|
| 24 |
+
output_path = "predictions/smiles_predictions.csv"
|
| 25 |
+
|
| 26 |
+
# clean smiles
|
| 27 |
+
clean_smiles, valid_mask = create_clean_smiles(smiles_list)
|
| 28 |
+
|
| 29 |
+
# Mapping from cleaned to original for valid ones
|
| 30 |
+
originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
|
| 31 |
+
|
| 32 |
+
# sanity check (optional but nice to have)
|
| 33 |
+
if len(originals_valid) != len(clean_smiles):
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# map cleaned → original
|
| 39 |
+
cleaned_to_original = dict(zip(clean_smiles, originals_valid))
|
| 40 |
+
|
| 41 |
+
# tox21 targets
|
| 42 |
+
TARGET_NAMES = [
|
| 43 |
+
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
|
| 44 |
+
]
|
| 45 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
print(f"Received {len(smiles_list)} SMILES strings")
|
| 47 |
+
|
| 48 |
+
# put smiles into csv
|
| 49 |
+
with open(data_path, "w", newline="") as f:
|
| 50 |
+
writer = csv.writer(f)
|
| 51 |
+
writer.writerow(["smiles"] + TARGET_NAMES) # header
|
| 52 |
+
for smi in clean_smiles:
|
| 53 |
+
writer.writerow([smi] + [""] * len(TARGET_NAMES))
|
| 54 |
+
|
| 55 |
+
# generate features
|
| 56 |
+
generate_features(data_path, features_path)
|
| 57 |
+
|
| 58 |
+
# predict
|
| 59 |
+
predict_from_csv(data_path, features_path, checkpoint_dir, output_path)
|
| 60 |
+
|
| 61 |
+
# create results dictionary from predictions
|
| 62 |
+
predictions = {}
|
| 63 |
+
with open(output_path, "r", newline="") as f:
|
| 64 |
+
reader = csv.DictReader(f)
|
| 65 |
+
rows = list(reader)
|
| 66 |
+
|
| 67 |
+
# Identify the SMILES column even if it is unnamed
|
| 68 |
+
fieldnames = reader.fieldnames
|
| 69 |
+
smiles_col = fieldnames[0] # first column, even if empty string
|
| 70 |
+
|
| 71 |
+
target_names = fieldnames[1:] # all columns except first
|
| 72 |
+
|
| 73 |
+
for row in rows:
|
| 74 |
+
clean_smi = row[smiles_col]
|
| 75 |
+
original_smi = cleaned_to_original.get(clean_smi, clean_smi)
|
| 76 |
+
|
| 77 |
+
pred_dict = {t: float(row[t]) for t in target_names}
|
| 78 |
+
predictions[original_smi] = pred_dict
|
| 79 |
+
|
| 80 |
+
# Add placeholder predictions for invalid SMILES
|
| 81 |
+
for smi, is_valid in zip(smiles_list, valid_mask):
|
| 82 |
+
if not is_valid:
|
| 83 |
+
predictions[smi] = {t: 0.5 for t in TARGET_NAMES}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
return predictions
|
| 87 |
+
|
| 88 |
+
preds = predict(["Oc1cc(O)cc(C=Cc2ccc(O)c(O)c2)c1"])
|
| 89 |
+
print(preds)
|
prepare_data.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from src.preprocess import clean_smiles_in_csv
|
| 4 |
+
|
| 5 |
+
TOX21_TASKS = [
|
| 6 |
+
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
def prepare_data(data_path, save_path_clean_data, save_path_valid_mask):
|
| 10 |
+
valid_mask_train = clean_smiles_in_csv(data_path, save_path_clean_data, "smiles", TOX21_TASKS)
|
| 11 |
+
np.save(save_path_valid_mask, valid_mask_train)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
train_path = "./tox21/tox21_train.csv"
|
| 15 |
+
val_path = "./tox21/tox21_validation.csv"
|
| 16 |
+
|
| 17 |
+
train_path_clean = "./tox21/tox21_train_clean.csv"
|
| 18 |
+
val_path_clean = "./tox21/tox21_validation_clean.csv"
|
| 19 |
+
|
| 20 |
+
prepare_data(test_path, test_path_clean, "./tox21/valid_mask_test.npy")
|
src/commands.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
def run(cmd):
|
| 4 |
+
print("\n======================================")
|
| 5 |
+
print("Running command:")
|
| 6 |
+
print(cmd)
|
| 7 |
+
print("======================================\n")
|
| 8 |
+
os.system(cmd)
|
| 9 |
+
|
| 10 |
+
def generate_features(data_path, save_path):
|
| 11 |
+
run(
|
| 12 |
+
f"python scripts/save_features.py "
|
| 13 |
+
f"--data_path {data_path} "
|
| 14 |
+
f"--save_path {save_path} "
|
| 15 |
+
f"--features_generator rdkit_2d_normalized "
|
| 16 |
+
f"--restart"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def finetune(train_path, val_path, train_features_path, val_features_path,
|
| 21 |
+
save_dir, checkpoint_path, args
|
| 22 |
+
):
|
| 23 |
+
finetune_cmd = (
|
| 24 |
+
f"python main.py finetune "
|
| 25 |
+
f"--data_path {train_path} "
|
| 26 |
+
f"--split_type random "
|
| 27 |
+
f"--split_sizes 1 0 0 "
|
| 28 |
+
f"--separate_val_path {val_path} "
|
| 29 |
+
f"--separate_test_path {val_path} "
|
| 30 |
+
f"--features_path {train_features_path} "
|
| 31 |
+
f"--separate_val_features_path {val_features_path} "
|
| 32 |
+
f"--separate_test_features_path {val_features_path} "
|
| 33 |
+
f"--save_dir {save_dir} "
|
| 34 |
+
f"--checkpoint_path {checkpoint_path} "
|
| 35 |
+
f"--dataset_type classification "
|
| 36 |
+
f"--num_folds 1 "
|
| 37 |
+
f"--ensemble_size 1 "
|
| 38 |
+
f"--no_features_scaling "
|
| 39 |
+
f"--ffn_hidden_size {args['ffn_hidden_size']} "
|
| 40 |
+
f"--ffn_num_layers {args['ffn_num_layer']} "
|
| 41 |
+
f"--batch_size {args['batch_size']} "
|
| 42 |
+
f"--epochs 100 "
|
| 43 |
+
f"--init_lr {args['real_init_lr']} "
|
| 44 |
+
f"--final_lr {args['real_final_lr']} "
|
| 45 |
+
f"--max_lr {args['max_lr']} "
|
| 46 |
+
f"--dropout {args['dropout']} "
|
| 47 |
+
f"--attn_hidden {args['attn_hidden']} "
|
| 48 |
+
f"--attn_out {args['attn_out']} "
|
| 49 |
+
f"--dist_coff {args['dist_coff']} "
|
| 50 |
+
f"--bond_drop_rate {args['bond_drop_rate']} "
|
| 51 |
+
)
|
| 52 |
+
run(finetune_cmd)
|
| 53 |
+
|
| 54 |
+
|
src/eval.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics import roc_auc_score
|
| 4 |
+
|
| 5 |
+
def compute_roc_auc_from_csv(preds_csv: str, labels_csv: str, valid_mask):
|
| 6 |
+
"""
|
| 7 |
+
Compute ROC AUC per class and overall mean, similar to the PyTorch-style function.
|
| 8 |
+
Handles missing labels (NaN) like y_mask.
|
| 9 |
+
"""
|
| 10 |
+
preds = pd.read_csv(preds_csv)
|
| 11 |
+
labels = pd.read_csv(labels_csv)
|
| 12 |
+
|
| 13 |
+
smiles_cols = [c for c in preds.columns if "smiles" in c.lower()]
|
| 14 |
+
if smiles_cols:
|
| 15 |
+
print(f"🧪 Dropping SMILES columns: {smiles_cols}")
|
| 16 |
+
preds = preds.drop(columns=smiles_cols, errors="ignore")
|
| 17 |
+
labels = labels.drop(columns=smiles_cols, errors="ignore")
|
| 18 |
+
|
| 19 |
+
shared_cols = [c for c in preds.columns if c in labels.columns]
|
| 20 |
+
preds = preds[shared_cols].apply(pd.to_numeric, errors="coerce")
|
| 21 |
+
labels = labels[shared_cols].apply(pd.to_numeric, errors="coerce")
|
| 22 |
+
|
| 23 |
+
y_pred_clean = preds.to_numpy(dtype=float)
|
| 24 |
+
y_true = labels.to_numpy(dtype=float)
|
| 25 |
+
valid_mask = valid_mask[-y_true.shape[0]:]
|
| 26 |
+
#Re-expand to original size
|
| 27 |
+
y_pred = np.full((len(valid_mask), y_pred_clean.shape[1]), 0.5, dtype=float)
|
| 28 |
+
y_pred[valid_mask] = y_pred_clean
|
| 29 |
+
|
| 30 |
+
y_mask = ~np.isnan(y_true)
|
| 31 |
+
|
| 32 |
+
auc_list = []
|
| 33 |
+
for i in range(y_true.shape[1]):
|
| 34 |
+
mask_i = y_mask[:, i]
|
| 35 |
+
if mask_i.sum() > 0:
|
| 36 |
+
try:
|
| 37 |
+
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
|
| 38 |
+
except ValueError:
|
| 39 |
+
auc = np.nan
|
| 40 |
+
else:
|
| 41 |
+
auc = np.nan
|
| 42 |
+
auc_list.append(auc)
|
| 43 |
+
|
| 44 |
+
auc_array = np.array(auc_list, dtype=np.float32)
|
| 45 |
+
mean_auc = np.nanmean(auc_array)
|
| 46 |
+
|
| 47 |
+
return auc_array, mean_auc
|
src/hp_search.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from itertools import product
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
def generate_random_search(grid, num_trials=300, seed=42):
|
| 5 |
+
random.seed(seed)
|
| 6 |
+
|
| 7 |
+
keys = list(grid.keys())
|
| 8 |
+
values = list(grid.values())
|
| 9 |
+
all_combinations = []
|
| 10 |
+
for combo in product(*values):
|
| 11 |
+
params = dict(zip(keys, combo))
|
| 12 |
+
# Convert ratios to actual LR values
|
| 13 |
+
params["real_init_lr"] = params["max_lr"] / params["init_lr"]
|
| 14 |
+
params["real_final_lr"] = params["max_lr"] / params["final_lr"]
|
| 15 |
+
all_combinations.append(params)
|
| 16 |
+
|
| 17 |
+
# sample num_runs out of this
|
| 18 |
+
indices = random.sample(range(len(all_combinations)), num_trials)
|
| 19 |
+
hp_subset = [all_combinations[i] for i in indices]
|
| 20 |
+
return hp_subset
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
| 3 |
+
from rdkit import Chem
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
TOX21_TASKS = [
|
| 10 |
+
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
|
| 14 |
+
"""
|
| 15 |
+
Clean and canonicalize SMILES strings while staying in SMILES space.
|
| 16 |
+
Returns (list of cleaned SMILES, mask of valid SMILES).
|
| 17 |
+
"""
|
| 18 |
+
clean_smis = []
|
| 19 |
+
valid_mask = []
|
| 20 |
+
|
| 21 |
+
cleaner = rdMolStandardize.CleanupParameters()
|
| 22 |
+
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
|
| 23 |
+
|
| 24 |
+
for smi in smiles_list:
|
| 25 |
+
try:
|
| 26 |
+
mol = Chem.MolFromSmiles(smi)
|
| 27 |
+
if mol is None:
|
| 28 |
+
valid_mask.append(False)
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
# Cleanup and tautomer canonicalization
|
| 32 |
+
mol = rdMolStandardize.Cleanup(mol, cleaner)
|
| 33 |
+
mol = tautomer_enumerator.Canonicalize(mol)
|
| 34 |
+
|
| 35 |
+
# -------- Charge filtering (prevents GROVER crash) --------
|
| 36 |
+
allowed_charges = {-1, 0, 1}
|
| 37 |
+
bad_charge = False
|
| 38 |
+
for atom in mol.GetAtoms():
|
| 39 |
+
if atom.GetFormalCharge() not in allowed_charges:
|
| 40 |
+
bad_charge = True
|
| 41 |
+
break
|
| 42 |
+
|
| 43 |
+
if bad_charge:
|
| 44 |
+
valid_mask.append(False)
|
| 45 |
+
continue
|
| 46 |
+
# ----------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
# Canonical SMILES output
|
| 49 |
+
clean_smi = Chem.MolToSmiles(mol, canonical=True)
|
| 50 |
+
clean_smis.append(clean_smi)
|
| 51 |
+
valid_mask.append(True)
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Failed to clean {smi}: {e}")
|
| 55 |
+
valid_mask.append(False)
|
| 56 |
+
|
| 57 |
+
return clean_smis, np.array(valid_mask, dtype=bool)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: Optional[List[str]] = None):
|
| 61 |
+
"""
|
| 62 |
+
Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.
|
| 63 |
+
"""
|
| 64 |
+
# Load dataset
|
| 65 |
+
df = pd.read_csv(input_csv)
|
| 66 |
+
if smiles_col not in df.columns:
|
| 67 |
+
raise ValueError(f"'{smiles_col}' column not found in CSV.")
|
| 68 |
+
|
| 69 |
+
# Infer target columns if not specified
|
| 70 |
+
if target_cols is None:
|
| 71 |
+
target_cols = [c for c in df.columns if c != smiles_col]
|
| 72 |
+
keep_cols = target_cols
|
| 73 |
+
# Validate target columns
|
| 74 |
+
missing_targets = [c for c in target_cols if c not in df.columns]
|
| 75 |
+
if missing_targets:
|
| 76 |
+
raise ValueError(f"Missing target columns in CSV: {missing_targets}")
|
| 77 |
+
|
| 78 |
+
# Clean SMILES
|
| 79 |
+
clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())
|
| 80 |
+
|
| 81 |
+
# Keep only valid rows
|
| 82 |
+
df_clean = df.loc[valid_mask, keep_cols].copy()
|
| 83 |
+
df_clean.insert(0, smiles_col, clean_smis) # smiles first column
|
| 84 |
+
|
| 85 |
+
# Save cleaned dataset
|
| 86 |
+
df_clean.to_csv(output_csv, index=False)
|
| 87 |
+
print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
|
| 88 |
+
return valid_mask
|
train.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from src.commands import finetune, predict_from_csv
|
| 5 |
+
from src.eval import compute_roc_auc_from_csv
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_config(path="config.json"):
|
| 9 |
+
with open(path, "r") as f:
|
| 10 |
+
config = json.load(f)
|
| 11 |
+
return config
|
| 12 |
+
|
| 13 |
+
config = load_config()
|
| 14 |
+
|
| 15 |
+
print(config)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Paths to custom split
|
| 19 |
+
train_path = "tox21/tox21_train_clean.csv"
|
| 20 |
+
val_path = "tox21/tox21_validation_clean.csv"
|
| 21 |
+
|
| 22 |
+
train_features_path = train_path.replace(".csv", ".npz")
|
| 23 |
+
val_features_path = val_path.replace(".csv", ".npz")
|
| 24 |
+
checkpoint_path = "pretrained_models/grover_base.pt"
|
| 25 |
+
|
| 26 |
+
# Output directory for finetuned model
|
| 27 |
+
save_dir = "finetune/"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
finetune(train_path, val_path, train_features_path, val_features_path,
|
| 31 |
+
save_dir, checkpoint_path, args)
|
| 32 |
+
|
| 33 |
+
# predict on val set
|
| 34 |
+
finetuned_model_dir = save_dir + "/fold_0/model_0"
|
| 35 |
+
output_path = save_dir + "/predictions.csv"
|
| 36 |
+
predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path)
|
| 37 |
+
|
| 38 |
+
# evaluate model
|
| 39 |
+
preds_path = save_dir + "/predictions.csv"
|
| 40 |
+
labels_path = "tox21/tox21_validation.csv"
|
| 41 |
+
valid_mask = np.load("./tox21/valid_mask_val.npy")
|
| 42 |
+
auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask)
|