Spaces:
Sleeping
Sleeping
| """ | |
| The basic building blocks in model. | |
| """ | |
| import math | |
| from argparse import Namespace | |
| from typing import Union | |
| import numpy | |
| import scipy.stats as stats | |
| import torch | |
| from torch import nn as nn | |
| from torch.nn import LayerNorm, functional as F | |
| from grover.util.nn_utils import get_activation_function, select_neighbor_and_aggregate | |
| class SelfAttention(nn.Module): | |
| """ | |
| Self SelfAttention Layer | |
| Given $X\in \mathbb{R}^{n \times in_feature}$, the attention is calculated by: $a=Softmax(W_2tanh(W_1X))$, where | |
| $W_1 \in \mathbb{R}^{hidden \times in_feature}$, $W_2 \in \mathbb{R}^{out_feature \times hidden}$. | |
| The final output is: $out=aX$, which is unrelated with input $n$. | |
| """ | |
| def __init__(self, *, hidden, in_feature, out_feature): | |
| """ | |
| The init function. | |
| :param hidden: the hidden dimension, can be viewed as the number of experts. | |
| :param in_feature: the input feature dimension. | |
| :param out_feature: the output feature dimension. | |
| """ | |
| super(SelfAttention, self).__init__() | |
| self.w1 = torch.nn.Parameter(torch.FloatTensor(hidden, in_feature)) | |
| self.w2 = torch.nn.Parameter(torch.FloatTensor(out_feature, hidden)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| """ | |
| Use xavier_normal method to initialize parameters. | |
| """ | |
| nn.init.xavier_normal_(self.w1) | |
| nn.init.xavier_normal_(self.w2) | |
| def forward(self, X): | |
| """ | |
| The forward function. | |
| :param X: The input feature map. $X \in \mathbb{R}^{n \times in_feature}$. | |
| :return: The final embeddings and attention matrix. | |
| """ | |
| x = torch.tanh(torch.matmul(self.w1, X.transpose(1, 0))) | |
| x = torch.matmul(self.w2, x) | |
| attn = torch.nn.functional.softmax(x, dim=-1) | |
| x = torch.matmul(attn, X) | |
| return x, attn | |
| class Readout(nn.Module): | |
| """The readout function. Convert the node embeddings to the graph embeddings.""" | |
| def __init__(self, | |
| rtype: str = "none", | |
| hidden_size: int = 0, | |
| attn_hidden: int = None, | |
| attn_out: int = None, | |
| ): | |
| """ | |
| The readout function. | |
| :param rtype: readout type, can be "mean" and "self_attention". | |
| :param hidden_size: input hidden size | |
| :param attn_hidden: only valid if rtype == "self_attention". The attention hidden size. | |
| :param attn_out: only valid if rtype == "self_attention". The attention out size. | |
| :param args: legacy use. | |
| """ | |
| super(Readout, self).__init__() | |
| # Cached zeros | |
| self.cached_zero_vector = nn.Parameter(torch.zeros(hidden_size), requires_grad=False) | |
| self.rtype = "mean" | |
| if rtype == "self_attention": | |
| self.attn = SelfAttention(hidden=attn_hidden, | |
| in_feature=hidden_size, | |
| out_feature=attn_out) | |
| self.rtype = "self_attention" | |
| def forward(self, embeddings, scope): | |
| """ | |
| The forward function, given a batch node/edge embedding and a scope list, | |
| produce the graph-level embedding by a scope. | |
| :param embeddings: The embedding matrix, num_atoms or num_bonds \times hidden_size. | |
| :param scope: a list, in which the element is a list [start, range]. `start` is the index | |
| :return: | |
| """ | |
| # Readout | |
| mol_vecs = [] | |
| self.attns = [] | |
| for _, (a_start, a_size) in enumerate(scope): | |
| if a_size == 0: | |
| mol_vecs.append(self.cached_zero_vector) | |
| else: | |
| cur_hiddens = embeddings.narrow(0, a_start, a_size) | |
| if self.rtype == "self_attention": | |
| cur_hiddens, attn = self.attn(cur_hiddens) | |
| cur_hiddens = cur_hiddens.flatten() | |
| # Temporarily disable. Enable it if you want to save attentions. | |
| # self.attns.append(attn.cpu().detach().numpy()) | |
| else: | |
| cur_hiddens = cur_hiddens.sum(dim=0) / a_size | |
| mol_vecs.append(cur_hiddens) | |
| mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) | |
| return mol_vecs | |
| class MPNEncoder(nn.Module): | |
| """A message passing neural network for encoding a molecule.""" | |
| def __init__(self, args: Namespace, | |
| atom_messages: bool, | |
| init_message_dim: int, | |
| attached_fea_fdim: int, | |
| hidden_size: int, | |
| bias: bool, | |
| depth: int, | |
| dropout: float, | |
| undirected: bool, | |
| dense: bool, | |
| aggregate_to_atom: bool, | |
| attach_fea: bool, | |
| input_layer="fc", | |
| dynamic_depth='none' | |
| ): | |
| """ | |
| Initializes the MPNEncoder. | |
| :param args: the arguments. | |
| :param atom_messages: enables atom_messages or not. | |
| :param init_message_dim: the initial input message dimension. | |
| :param attached_fea_fdim: the attached feature dimension. | |
| :param hidden_size: the output message dimension during message passing. | |
| :param bias: the bias in the message passing. | |
| :param depth: the message passing depth. | |
| :param dropout: the dropout rate. | |
| :param undirected: the message passing is undirected or not. | |
| :param dense: enables the dense connections. | |
| :param attach_fea: enables the feature attachment during the message passing process. | |
| :param dynamic_depth: enables the dynamic depth. Possible choices: "none", "uniform" and "truncnorm" | |
| """ | |
| super(MPNEncoder, self).__init__() | |
| self.init_message_dim = init_message_dim | |
| self.attached_fea_fdim = attached_fea_fdim | |
| self.hidden_size = hidden_size | |
| self.bias = bias | |
| self.depth = depth | |
| self.dropout = dropout | |
| self.input_layer = input_layer | |
| self.layers_per_message = 1 | |
| self.undirected = undirected | |
| self.atom_messages = atom_messages | |
| self.dense = dense | |
| self.aggreate_to_atom = aggregate_to_atom | |
| self.attached_fea = attach_fea | |
| self.dynamic_depth = dynamic_depth | |
| # Dropout | |
| self.dropout_layer = nn.Dropout(p=self.dropout) | |
| # Activation | |
| self.act_func = get_activation_function(args.activation) | |
| # Input | |
| if self.input_layer == "fc": | |
| input_dim = self.init_message_dim | |
| self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias) | |
| if self.attached_fea: | |
| w_h_input_size = self.hidden_size + self.attached_fea_fdim | |
| else: | |
| w_h_input_size = self.hidden_size | |
| # Shared weight matrix across depths (default) | |
| self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias) | |
| def forward(self, | |
| init_messages, | |
| init_attached_features, | |
| a2nei, | |
| a2attached, | |
| b2a=None, | |
| b2revb=None, | |
| adjs=None | |
| ) -> torch.FloatTensor: | |
| """ | |
| The forward function. | |
| :param init_messages: initial massages, can be atom features or bond features. | |
| :param init_attached_features: initial attached_features. | |
| :param a2nei: the relation of item to its neighbors. For the atom message passing, a2nei = a2a. For bond | |
| messages a2nei = a2b | |
| :param a2attached: the relation of item to the attached features during message passing. For the atom message | |
| passing, a2attached = a2b. For the bond message passing a2attached = a2a | |
| :param b2a: remove the reversed bond in bond message passing | |
| :param b2revb: remove the revered atom in bond message passing | |
| :return: if aggreate_to_atom or self.atom_messages, return num_atoms x hidden. | |
| Otherwise, return num_bonds x hidden | |
| """ | |
| # Input | |
| if self.input_layer == 'fc': | |
| input = self.W_i(init_messages) # num_bonds x hidden_size # f_bond | |
| message = self.act_func(input) # num_bonds x hidden_size | |
| elif self.input_layer == 'none': | |
| input = init_messages | |
| message = input | |
| attached_fea = init_attached_features # f_atom / f_bond | |
| # dynamic depth | |
| # uniform sampling from depth - 1 to depth + 1 | |
| # only works in training. | |
| if self.training and self.dynamic_depth != "none": | |
| if self.dynamic_depth == "uniform": | |
| # uniform sampling | |
| ndepth = numpy.random.randint(self.depth - 3, self.depth + 3) | |
| else: | |
| # truncnorm | |
| mu = self.depth | |
| sigma = 1 | |
| lower = mu - 3 * sigma | |
| upper = mu + 3 * sigma | |
| X = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma) | |
| ndepth = int(X.rvs(1)) | |
| else: | |
| ndepth = self.depth | |
| # Message passing | |
| for _ in range(ndepth - 1): | |
| if self.undirected: | |
| # two directions should be the same | |
| message = (message + message[b2revb]) / 2 | |
| nei_message = select_neighbor_and_aggregate(message, a2nei) | |
| a_message = nei_message | |
| if self.attached_fea: | |
| attached_nei_fea = select_neighbor_and_aggregate(attached_fea, a2attached) | |
| a_message = torch.cat((nei_message, attached_nei_fea), dim=1) | |
| if not self.atom_messages: | |
| rev_message = message[b2revb] | |
| if self.attached_fea: | |
| atom_rev_message = attached_fea[b2a[b2revb]] | |
| rev_message = torch.cat((rev_message, atom_rev_message), dim=1) | |
| # Except reverse bond its-self(w) ! \sum_{k\in N(u) \ w} | |
| message = a_message[b2a] - rev_message # num_bonds x hidden | |
| else: | |
| message = a_message | |
| message = self.W_h(message) | |
| # BUG here, by default MPNEncoder use the dense connection in the message passing step. | |
| # The correct form should if not self.dense | |
| if self.dense: | |
| message = self.act_func(message) # num_bonds x hidden_size | |
| else: | |
| message = self.act_func(input + message) | |
| message = self.dropout_layer(message) # num_bonds x hidden | |
| output = message | |
| return output # num_atoms x hidden | |
| class PositionwiseFeedForward(nn.Module): | |
| """Implements FFN equation.""" | |
| def __init__(self, d_model, d_ff, activation="PReLU", dropout=0.1, d_out=None): | |
| """Initialization. | |
| :param d_model: the input dimension. | |
| :param d_ff: the hidden dimension. | |
| :param activation: the activation function. | |
| :param dropout: the dropout rate. | |
| :param d_out: the output dimension, the default value is equal to d_model. | |
| """ | |
| super(PositionwiseFeedForward, self).__init__() | |
| if d_out is None: | |
| d_out = d_model | |
| # By default, bias is on. | |
| self.W_1 = nn.Linear(d_model, d_ff) | |
| self.W_2 = nn.Linear(d_ff, d_out) | |
| self.dropout = nn.Dropout(dropout) | |
| self.act_func = get_activation_function(activation) | |
| def forward(self, x): | |
| """ | |
| The forward function | |
| :param x: input tensor. | |
| :return: | |
| """ | |
| return self.W_2(self.dropout(self.act_func(self.W_1(x)))) | |
| class SublayerConnection(nn.Module): | |
| """ | |
| A residual connection followed by a layer norm. | |
| Note for code simplicity the norm is first as opposed to last. | |
| """ | |
| def __init__(self, size, dropout): | |
| """Initialization. | |
| :param size: the input dimension. | |
| :param dropout: the dropout ratio. | |
| """ | |
| super(SublayerConnection, self).__init__() | |
| self.norm = LayerNorm(size, elementwise_affine=True) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, inputs, outputs): | |
| """Apply residual connection to any sublayer with the same size.""" | |
| # return x + self.dropout(self.norm(x)) | |
| if inputs is None: | |
| return self.dropout(self.norm(outputs)) | |
| return inputs + self.dropout(self.norm(outputs)) | |
| class Attention(nn.Module): | |
| """ | |
| Compute 'Scaled Dot Product SelfAttention | |
| """ | |
| def forward(self, query, key, value, mask=None, dropout=None): | |
| """ | |
| :param query: | |
| :param key: | |
| :param value: | |
| :param mask: | |
| :param dropout: | |
| :return: | |
| """ | |
| scores = torch.matmul(query, key.transpose(-2, -1)) \ | |
| / math.sqrt(query.size(-1)) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, -1e9) | |
| p_attn = F.softmax(scores, dim=-1) | |
| if dropout is not None: | |
| p_attn = dropout(p_attn) | |
| return torch.matmul(p_attn, value), p_attn | |
| class MultiHeadedAttention(nn.Module): | |
| """ | |
| The multi-head attention module. Take in model size and number of heads. | |
| """ | |
| def __init__(self, h, d_model, dropout=0.1, bias=False): | |
| """ | |
| :param h: | |
| :param d_model: | |
| :param dropout: | |
| :param bias: | |
| """ | |
| super().__init__() | |
| assert d_model % h == 0 | |
| # We assume d_v always equals d_k | |
| self.d_k = d_model // h | |
| self.h = h # number of heads | |
| self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) # why 3: query, key, value | |
| self.output_linear = nn.Linear(d_model, d_model, bias) | |
| self.attention = Attention() | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, query, key, value, mask=None): | |
| """ | |
| :param query: | |
| :param key: | |
| :param value: | |
| :param mask: | |
| :return: | |
| """ | |
| batch_size = query.size(0) | |
| # 1) Do all the linear projections in batch from d_model => h x d_k | |
| query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) | |
| for l, x in zip(self.linear_layers, (query, key, value))] | |
| # 2) Apply attention on all the projected vectors in batch. | |
| x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout) | |
| # 3) "Concat" using a view and apply a final linear. | |
| x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) | |
| return self.output_linear(x) | |
| class Head(nn.Module): | |
| """ | |
| One head for multi-headed attention. | |
| :return: (query, key, value) | |
| """ | |
| def __init__(self, args, hidden_size, atom_messages=False): | |
| """ | |
| Initialization. | |
| :param args: The argument. | |
| :param hidden_size: the dimension of hidden layer in Head. | |
| :param atom_messages: the MPNEncoder type. | |
| """ | |
| super(Head, self).__init__() | |
| atom_fdim = hidden_size | |
| bond_fdim = hidden_size | |
| hidden_size = hidden_size | |
| self.atom_messages = atom_messages | |
| if self.atom_messages: | |
| init_message_dim = atom_fdim | |
| attached_fea_dim = bond_fdim | |
| else: | |
| init_message_dim = bond_fdim | |
| attached_fea_dim = atom_fdim | |
| # Here we use the message passing network as query, key and value. | |
| self.mpn_q = MPNEncoder(args=args, | |
| atom_messages=atom_messages, | |
| init_message_dim=init_message_dim, | |
| attached_fea_fdim=attached_fea_dim, | |
| hidden_size=hidden_size, | |
| bias=args.bias, | |
| depth=args.depth, | |
| dropout=args.dropout, | |
| undirected=args.undirected, | |
| dense=args.dense, | |
| aggregate_to_atom=False, | |
| attach_fea=False, | |
| input_layer="none", | |
| dynamic_depth="truncnorm") | |
| self.mpn_k = MPNEncoder(args=args, | |
| atom_messages=atom_messages, | |
| init_message_dim=init_message_dim, | |
| attached_fea_fdim=attached_fea_dim, | |
| hidden_size=hidden_size, | |
| bias=args.bias, | |
| depth=args.depth, | |
| dropout=args.dropout, | |
| undirected=args.undirected, | |
| dense=args.dense, | |
| aggregate_to_atom=False, | |
| attach_fea=False, | |
| input_layer="none", | |
| dynamic_depth="truncnorm") | |
| self.mpn_v = MPNEncoder(args=args, | |
| atom_messages=atom_messages, | |
| init_message_dim=init_message_dim, | |
| attached_fea_fdim=attached_fea_dim, | |
| hidden_size=hidden_size, | |
| bias=args.bias, | |
| depth=args.depth, | |
| dropout=args.dropout, | |
| undirected=args.undirected, | |
| dense=args.dense, | |
| aggregate_to_atom=False, | |
| attach_fea=False, | |
| input_layer="none", | |
| dynamic_depth="truncnorm") | |
| def forward(self, f_atoms, f_bonds, a2b, a2a, b2a, b2revb): | |
| """ | |
| The forward function. | |
| :param f_atoms: the atom features, num_atoms * atom_dim | |
| :param f_bonds: the bond features, num_bonds * bond_dim | |
| :param a2b: mapping from atom index to incoming bond indices. | |
| :param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds | |
| :param b2a: mapping from bond index to the index of the atom the bond is coming from. | |
| :param b2revb: mapping from bond index to the index of the reverse bond. | |
| :return: | |
| """ | |
| if self.atom_messages: | |
| init_messages = f_atoms | |
| init_attached_features = f_bonds | |
| a2nei = a2a | |
| a2attached = a2b | |
| b2a = b2a | |
| b2revb = b2revb | |
| else: | |
| init_messages = f_bonds | |
| init_attached_features = f_atoms | |
| a2nei = a2b | |
| a2attached = a2a | |
| b2a = b2a | |
| b2revb = b2revb | |
| q = self.mpn_q(init_messages=init_messages, | |
| init_attached_features=init_attached_features, | |
| a2nei=a2nei, | |
| a2attached=a2attached, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| k = self.mpn_k(init_messages=init_messages, | |
| init_attached_features=init_attached_features, | |
| a2nei=a2nei, | |
| a2attached=a2attached, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| v = self.mpn_v(init_messages=init_messages, | |
| init_attached_features=init_attached_features, | |
| a2nei=a2nei, | |
| a2attached=a2attached, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| return q, k, v | |
| class MTBlock(nn.Module): | |
| """ | |
| The Multi-headed attention block. | |
| """ | |
| def __init__(self, | |
| args, | |
| num_attn_head, | |
| input_dim, | |
| hidden_size, | |
| activation="ReLU", | |
| dropout=0.0, | |
| bias=True, | |
| atom_messages=False, | |
| cuda=True, | |
| res_connection=False): | |
| """ | |
| :param args: the arguments. | |
| :param num_attn_head: the number of attention head. | |
| :param input_dim: the input dimension. | |
| :param hidden_size: the hidden size of the model. | |
| :param activation: the activation function. | |
| :param dropout: the dropout ratio | |
| :param bias: if true: all linear layer contains bias term. | |
| :param atom_messages: the MPNEncoder type | |
| :param cuda: if true, the model run with GPU. | |
| :param res_connection: enables the skip-connection in MTBlock. | |
| """ | |
| super(MTBlock, self).__init__() | |
| # self.args = args | |
| self.atom_messages = atom_messages | |
| self.hidden_size = hidden_size | |
| self.heads = nn.ModuleList() | |
| self.input_dim = input_dim | |
| self.cuda = cuda | |
| self.res_connection = res_connection | |
| self.act_func = get_activation_function(activation) | |
| self.dropout_layer = nn.Dropout(p=dropout) | |
| # Note: elementwise_affine has to be consistent with the pre-training phase | |
| self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=True) | |
| self.W_i = nn.Linear(self.input_dim, self.hidden_size, bias=bias) | |
| self.attn = MultiHeadedAttention(h=num_attn_head, | |
| d_model=self.hidden_size, | |
| bias=bias, | |
| dropout=dropout) | |
| self.W_o = nn.Linear(self.hidden_size * num_attn_head, self.hidden_size, bias=bias) | |
| self.sublayer = SublayerConnection(self.hidden_size, dropout) | |
| for _ in range(num_attn_head): | |
| self.heads.append(Head(args, hidden_size=hidden_size, atom_messages=atom_messages)) | |
| def forward(self, batch, features_batch=None): | |
| """ | |
| :param batch: the graph batch generated by GroverCollator. | |
| :param features_batch: the additional features of molecules. (deprecated) | |
| :return: | |
| """ | |
| f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch | |
| if self.atom_messages: | |
| # Only add linear transformation in the input feature. | |
| if f_atoms.shape[1] != self.hidden_size: | |
| f_atoms = self.W_i(f_atoms) | |
| f_atoms = self.dropout_layer(self.layernorm(self.act_func(f_atoms))) | |
| else: # bond messages | |
| if f_bonds.shape[1] != self.hidden_size: | |
| f_bonds = self.W_i(f_bonds) | |
| f_bonds = self.dropout_layer(self.layernorm(self.act_func(f_bonds))) | |
| queries = [] | |
| keys = [] | |
| values = [] | |
| for head in self.heads: | |
| q, k, v = head(f_atoms, f_bonds, a2b, a2a, b2a, b2revb) | |
| queries.append(q.unsqueeze(1)) | |
| keys.append(k.unsqueeze(1)) | |
| values.append(v.unsqueeze(1)) | |
| queries = torch.cat(queries, dim=1) | |
| keys = torch.cat(keys, dim=1) | |
| values = torch.cat(values, dim=1) | |
| x_out = self.attn(queries, keys, values) # multi-headed attention | |
| x_out = x_out.view(x_out.shape[0], -1) | |
| x_out = self.W_o(x_out) | |
| x_in = None | |
| # support no residual connection in MTBlock. | |
| if self.res_connection: | |
| if self.atom_messages: | |
| x_in = f_atoms | |
| else: | |
| x_in = f_bonds | |
| if self.atom_messages: | |
| f_atoms = self.sublayer(x_in, x_out) | |
| else: | |
| f_bonds = self.sublayer(x_in, x_out) | |
| batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a | |
| features_batch = features_batch | |
| return batch, features_batch | |
| class GTransEncoder(nn.Module): | |
| def __init__(self, | |
| args, | |
| hidden_size, | |
| edge_fdim, | |
| node_fdim, | |
| dropout=0.0, | |
| activation="ReLU", | |
| num_mt_block=1, | |
| num_attn_head=4, | |
| atom_emb_output: Union[bool, str] = False, # options: True, False, None, "atom", "bond", "both" | |
| bias=False, | |
| cuda=True, | |
| res_connection=False): | |
| """ | |
| :param args: the arguments. | |
| :param hidden_size: the hidden size of the model. | |
| :param edge_fdim: the dimension of additional feature for edge/bond. | |
| :param node_fdim: the dimension of additional feature for node/atom. | |
| :param dropout: the dropout ratio | |
| :param activation: the activation function | |
| :param num_mt_block: the number of mt block. | |
| :param num_attn_head: the number of attention head. | |
| :param atom_emb_output: enable the output aggregation after message passing. | |
| atom_messages: True False | |
| -False: no aggregating to atom. output size: (num_atoms, hidden_size) (num_bonds, hidden_size) | |
| -True: aggregating to atom. output size: (num_atoms, hidden_size) (num_atoms, hidden_size) | |
| -None: same as False | |
| -"atom": same as True | |
| -"bond": aggragating to bond. output size: (num_bonds, hidden_size) (num_bonds, hidden_size) | |
| -"both": aggregating to atom&bond. output size: (num_atoms, hidden_size) (num_bonds, hidden_size) | |
| (num_bonds, hidden_size) (num_atoms, hidden_size) | |
| :param bias: enable bias term in all linear layers. | |
| :param cuda: run with cuda. | |
| :param res_connection: enables the skip-connection in MTBlock. | |
| """ | |
| super(GTransEncoder, self).__init__() | |
| # For the compatibility issue. | |
| if atom_emb_output is False: | |
| atom_emb_output = None | |
| if atom_emb_output is True: | |
| atom_emb_output = 'atom' | |
| self.hidden_size = hidden_size | |
| self.dropout = dropout | |
| self.activation = activation | |
| self.cuda = cuda | |
| self.bias = bias | |
| self.res_connection = res_connection | |
| self.edge_blocks = nn.ModuleList() | |
| self.node_blocks = nn.ModuleList() | |
| edge_input_dim = edge_fdim | |
| node_input_dim = node_fdim | |
| edge_input_dim_i = edge_input_dim | |
| node_input_dim_i = node_input_dim | |
| for i in range(num_mt_block): | |
| if i != 0: | |
| edge_input_dim_i = self.hidden_size | |
| node_input_dim_i = self.hidden_size | |
| self.edge_blocks.append(MTBlock(args=args, | |
| num_attn_head=num_attn_head, | |
| input_dim=edge_input_dim_i, | |
| hidden_size=self.hidden_size, | |
| activation=activation, | |
| dropout=dropout, | |
| bias=self.bias, | |
| atom_messages=False, | |
| cuda=cuda)) | |
| self.node_blocks.append(MTBlock(args=args, | |
| num_attn_head=num_attn_head, | |
| input_dim=node_input_dim_i, | |
| hidden_size=self.hidden_size, | |
| activation=activation, | |
| dropout=dropout, | |
| bias=self.bias, | |
| atom_messages=True, | |
| cuda=cuda)) | |
| self.atom_emb_output = atom_emb_output | |
| self.ffn_atom_from_atom = PositionwiseFeedForward(self.hidden_size + node_fdim, | |
| self.hidden_size * 4, | |
| activation=self.activation, | |
| dropout=self.dropout, | |
| d_out=self.hidden_size) | |
| self.ffn_atom_from_bond = PositionwiseFeedForward(self.hidden_size + node_fdim, | |
| self.hidden_size * 4, | |
| activation=self.activation, | |
| dropout=self.dropout, | |
| d_out=self.hidden_size) | |
| self.ffn_bond_from_atom = PositionwiseFeedForward(self.hidden_size + edge_fdim, | |
| self.hidden_size * 4, | |
| activation=self.activation, | |
| dropout=self.dropout, | |
| d_out=self.hidden_size) | |
| self.ffn_bond_from_bond = PositionwiseFeedForward(self.hidden_size + edge_fdim, | |
| self.hidden_size * 4, | |
| activation=self.activation, | |
| dropout=self.dropout, | |
| d_out=self.hidden_size) | |
| self.atom_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout) | |
| self.atom_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout) | |
| self.bond_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout) | |
| self.bond_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout) | |
| self.act_func_node = get_activation_function(self.activation) | |
| self.act_func_edge = get_activation_function(self.activation) | |
| self.dropout_layer = nn.Dropout(p=args.dropout) | |
| def pointwise_feed_forward_to_atom_embedding(self, emb_output, atom_fea, index, ffn_layer): | |
| """ | |
| The point-wise feed forward and long-range residual connection for atom view. | |
| aggregate to atom. | |
| :param emb_output: the output embedding from the previous multi-head attentions. | |
| :param atom_fea: the atom/node feature embedding. | |
| :param index: the index of neighborhood relations. | |
| :param ffn_layer: the feed forward layer | |
| :return: | |
| """ | |
| aggr_output = select_neighbor_and_aggregate(emb_output, index) | |
| aggr_outputx = torch.cat([atom_fea, aggr_output], dim=1) | |
| return ffn_layer(aggr_outputx), aggr_output | |
| def pointwise_feed_forward_to_bond_embedding(self, emb_output, bond_fea, a2nei, b2revb, ffn_layer): | |
| """ | |
| The point-wise feed forward and long-range residual connection for bond view. | |
| aggregate to bond. | |
| :param emb_output: the output embedding from the previous multi-head attentions. | |
| :param bond_fea: the bond/edge feature embedding. | |
| :param index: the index of neighborhood relations. | |
| :param ffn_layer: the feed forward layer | |
| :return: | |
| """ | |
| aggr_output = select_neighbor_and_aggregate(emb_output, a2nei) | |
| # remove rev bond / atom --- need for bond view | |
| aggr_output = self.remove_rev_bond_message(emb_output, aggr_output, b2revb) | |
| aggr_outputx = torch.cat([bond_fea, aggr_output], dim=1) | |
| return ffn_layer(aggr_outputx), aggr_output | |
| def remove_rev_bond_message(orginal_message, aggr_message, b2revb): | |
| """ | |
| :param orginal_message: | |
| :param aggr_message: | |
| :param b2revb: | |
| :return: | |
| """ | |
| rev_message = orginal_message[b2revb] | |
| return aggr_message - rev_message | |
| def atom_bond_transform(self, | |
| to_atom=True, # False: to bond | |
| atomwise_input=None, | |
| bondwise_input=None, | |
| original_f_atoms=None, | |
| original_f_bonds=None, | |
| a2a=None, | |
| a2b=None, | |
| b2a=None, | |
| b2revb=None | |
| ): | |
| """ | |
| Transfer the output of atom/bond multi-head attention to the final atom/bond output. | |
| :param to_atom: if true, the output is atom emebedding, otherwise, the output is bond embedding. | |
| :param atomwise_input: the input embedding of atom/node. | |
| :param bondwise_input: the input embedding of bond/edge. | |
| :param original_f_atoms: the initial atom features. | |
| :param original_f_bonds: the initial bond features. | |
| :param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds | |
| :param a2b: mapping from atom index to incoming bond indices. | |
| :param b2a: mapping from bond index to the index of the atom the bond is coming from. | |
| :param b2revb: mapping from bond index to the index of the reverse bond. | |
| :return: | |
| """ | |
| if to_atom: | |
| # atom input to atom output | |
| atomwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(atomwise_input, original_f_atoms, a2a, | |
| self.ffn_atom_from_atom) | |
| atom_in_atom_out = self.atom_from_atom_sublayer(None, atomwise_input) | |
| # bond to atom | |
| bondwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(bondwise_input, original_f_atoms, a2b, | |
| self.ffn_atom_from_bond) | |
| bond_in_atom_out = self.atom_from_bond_sublayer(None, bondwise_input) | |
| return atom_in_atom_out, bond_in_atom_out | |
| else: # to bond embeddings | |
| # atom input to bond output | |
| atom_list_for_bond = torch.cat([b2a.unsqueeze(dim=1), a2a[b2a]], dim=1) | |
| atomwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(atomwise_input, original_f_bonds, | |
| atom_list_for_bond, | |
| b2a[b2revb], self.ffn_bond_from_atom) | |
| atom_in_bond_out = self.bond_from_atom_sublayer(None, atomwise_input) | |
| # bond input to bond output | |
| bond_list_for_bond = a2b[b2a] | |
| bondwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(bondwise_input, original_f_bonds, | |
| bond_list_for_bond, | |
| b2revb, self.ffn_bond_from_bond) | |
| bond_in_bond_out = self.bond_from_bond_sublayer(None, bondwise_input) | |
| return atom_in_bond_out, bond_in_bond_out | |
| def forward(self, batch, features_batch = None): | |
| f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch | |
| if self.cuda or next(self.parameters()).is_cuda: | |
| f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(), a2b.cuda(), b2a.cuda(), b2revb.cuda() | |
| a2a = a2a.cuda() | |
| node_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a | |
| edge_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a | |
| # opt pointwise_feed_forward | |
| original_f_atoms, original_f_bonds = f_atoms, f_bonds | |
| # Note: features_batch is not used here. | |
| for nb in self.node_blocks: # atom messages. Multi-headed attention | |
| node_batch, features_batch = nb(node_batch, features_batch) | |
| for eb in self.edge_blocks: # bond messages. Multi-headed attention | |
| edge_batch, features_batch = eb(edge_batch, features_batch) | |
| atom_output, _, _, _, _, _, _, _ = node_batch # atom hidden states | |
| _, bond_output, _, _, _, _, _, _ = edge_batch # bond hidden states | |
| if self.atom_emb_output is None: | |
| # output the embedding from multi-head attention directly. | |
| return atom_output, bond_output | |
| if self.atom_emb_output == 'atom': | |
| return self.atom_bond_transform(to_atom=True, # False: to bond | |
| atomwise_input=atom_output, | |
| bondwise_input=bond_output, | |
| original_f_atoms=original_f_atoms, | |
| original_f_bonds=original_f_bonds, | |
| a2a=a2a, | |
| a2b=a2b, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| elif self.atom_emb_output == 'bond': | |
| return self.atom_bond_transform(to_atom=False, # False: to bond | |
| atomwise_input=atom_output, | |
| bondwise_input=bond_output, | |
| original_f_atoms=original_f_atoms, | |
| original_f_bonds=original_f_bonds, | |
| a2a=a2a, | |
| a2b=a2b, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| else: # 'both' | |
| atom_embeddings = self.atom_bond_transform(to_atom=True, # False: to bond | |
| atomwise_input=atom_output, | |
| bondwise_input=bond_output, | |
| original_f_atoms=original_f_atoms, | |
| original_f_bonds=original_f_bonds, | |
| a2a=a2a, | |
| a2b=a2b, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| bond_embeddings = self.atom_bond_transform(to_atom=False, # False: to bond | |
| atomwise_input=atom_output, | |
| bondwise_input=bond_output, | |
| original_f_atoms=original_f_atoms, | |
| original_f_bonds=original_f_bonds, | |
| a2a=a2a, | |
| a2b=a2b, | |
| b2a=b2a, | |
| b2revb=b2revb) | |
| # Notice: need to be consistent with output format of DualMPNN encoder | |
| return ((atom_embeddings[0], bond_embeddings[0]), | |
| (atom_embeddings[1], bond_embeddings[1])) | |