Skip to content

Commit 85ca918

Browse files
committed
vector gating + atom3d
1 parent 5263b49 commit 85ca918

File tree

7 files changed

+1057
-23
lines changed

7 files changed

+1057
-23
lines changed

README.md

+94-6
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
Implementation of equivariant GVP-GNNs as described in [Learning from Protein Structure with Geometric Vector Perceptrons](https://openreview.net/forum?id=1YLJDvSx6J4) by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror.
44

5-
This repository serves two purposes. If you would like to use the GVP architecture for structural biology tasks, we provide building blocks for models and data pipelines. If you are specifically interested in protein design as described in the paper, we provide scripts for training and testing models.
5+
**UPDATE:** Also includes equivariant GNNs with vector gating as described in [Equivariant Graph Neural Networks for 3D Macromolecular Structure](https://arxiv.org/abs/2106.03843) by B Jing, S Eismann, P Soni, and RO Dror.
66

7-
**Note:** This repository is an implementation in PyTorch Geometric emphasizing usability and flexibility. The original code for the paper, in TensorFlow, can be found [here](https://github.com/drorlab/gvp). We thank Pratham Soni for his contributions to the implementation in PyTorch.
7+
Scripts for training/testing/sampling on protein design and training/testing on all [ATOM3D](https://arxiv.org/abs/2012.04035) tasks are provided.
8+
9+
**Note:** This implementation is in PyTorch Geometric. The original TensorFlow code, which is not maintained, can be found [here](https://github.com/drorlab/gvp).
810

911
<p align="center"><img src="schematic.png" width="500"></p>
1012

@@ -20,15 +22,17 @@ This repository serves two purposes. If you would like to use the GVP architectu
2022
* tqdm==4.38.0
2123
* numpy==1.19.4
2224
* sklearn==0.24.1
25+
* atom3d==0.2.1
2326

2427
While we have not tested with other versions, any reasonably recent versions of these requirements should work.
2528

2629
## General usage
2730

2831
We provide classes in three modules:
2932
* `gvp`: core GVP modules and GVP-GNN layers
30-
* `gvp.data`: data pipeline functionality for both general use and protein design
31-
* `gvp.models`: implementations of MQA and CPD models as described in the paper
33+
* `gvp.data`: data pipelines for both general use and protein design
34+
* `gvp.models`: implementations of MQA and CPD models
35+
* `gvp.atom3d`: models and data pipelines for ATOM3D
3236

3337
The core modules in `gvp` are meant to be as general as possible, but you will likely have to modify `gvp.data` and `gvp.models` for your specific application, with the existing classes serving as examples.
3438

@@ -52,6 +56,11 @@ in_dims = scalars_in, vectors_in
5256
out_dims = scalars_out, vectors_out
5357
gvp_ = gvp.GVP(in_dims, out_dims)
5458
```
59+
To use vector gating, pass in `vector_gate=True` and the appropriate activations.
60+
```
61+
gvp_ = gvp.GVP(in_dims, out_dims,
62+
activations=(F.relu, None), vector_gate=True)
63+
```
5564
The classes `gvp.Dropout` and `gvp.LayerNorm` implement vector-channel dropout and layer norm, while using normal dropout and layer norm for scalar channels. Both expect inputs and return outputs of form `(s, V)`, but will also behave like their scalar-valued counterparts if passed a single tensor.
5665
```
5766
dropout = gvp.Dropout(drop_rate=0.1)
@@ -86,7 +95,7 @@ edge_index = torch.randint(0, 5, (2, 10), device=device)
8695
conv = gvp.GVPConv(in_dims, out_dims, edge_dims)
8796
out = conv(nodes, edge_index, edges)
8897
```
89-
The class GVPConvLayer is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed.
98+
The class `GVPConvLayer` is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed.
9099
```
91100
layer = gvp.GVPConvLayer(node_dims, edge_dims)
92101
nodes = layer(nodes, edge_index, edges)
@@ -97,6 +106,8 @@ nodes_static = gvp.randn(n=5, in_dims)
97106
layer = gvp.GVPConvLayer(node_dims, edge_dims, autoregressive=True)
98107
nodes = layer(nodes, edge_index, edges, autoregressive_x=nodes_static)
99108
```
109+
Both `GVPConv` and `GVPConvLayer` accept arguments `activations` and `vector_gate` to use vector gating.
110+
100111
### Loading data
101112

102113
The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures into featurized graphs. Following [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design), we use a JSON/dictionary format to specify backbone structures:
@@ -204,6 +215,76 @@ sample = model.sample(nodes, protein.edge_index, # shape = (n_samples, n_nodes)
204215
```
205216
The output will be an int tensor, with mappings corresponding to those used when training the model.
206217

218+
## ATOM3D
219+
We provide models and dataloaders for all ATOM3D tasks in `gvp.atom3d`, as well as a training and testing script in `run_atom3d.py`. This also supports loading pretrained weights for transfer learning experiments.
220+
221+
### Models / data loaders
222+
The GVP-GNNs for ATOM3D are supplied in `gvp.atom3d` and are named after each task: `gvp.atom3d.MSPModel`, `gvp.atom3d.PPIModel`, etc. All of these extend the base class `gvp.atom3d.BaseModel`. These classes take no arguments at initialization, take in a `torch_geometric.data.Batch` representation of a batch of structures, and return an output corresponding to the task. Details vary based on the exact task---see the docstrings.
223+
```
224+
psr_model = gvp.atom3d.PSRModel()
225+
```
226+
`gvp.atom3d` also includes data loaders to produce `torch_geometric.data.Batch` objects from an underlying `atom3d.datasets.LMDBDataset`. In the case of all tasks except PPI and RES, these are in the form of callable transform objects---`gvp.atom3d.SMPTransform`, `gvp.atom3d.RSRTransform`, etc---which should be passed into the constructor of a `atom3d.datasets.LMDBDataset`:
227+
```
228+
psr_dataset = atom3d.datasets.LMDBDataset(path_to_dataset,
229+
transform=gvp.atom3d.PSRTransform())
230+
```
231+
On the other hand, `gvp.atom3d.PPIDataset` and `gvp.atom3d.RESDataset` take the place of / are wrappers around the `atom3d.datasets.LMDBDataset`:
232+
```
233+
ppi_dataset = gvp.atom3d.PPIDataset(path_to_dataset)
234+
res_dataset = gvp.atom3d.RESDataset(path_to_dataset, path_to_split) # see docstring
235+
```
236+
All datasets must be then wrapped in a `torch_geometric.data.DataLoader`:
237+
```
238+
psr_dataloader = torch_geometric.data.DataLoader(psr_dataset, batch_size=batch_size)
239+
```
240+
The dataloaders can be directly iterated over to yield `torch_geometric.data.Batch` objects, which can then be passed into the models.
241+
```
242+
for batch in psr_dataloader:
243+
pred = psr_model(batch) # pred.shape = (batch_size,)
244+
```
245+
246+
### Training / testing
247+
248+
To run training / testing on ATOM3D, download the datasets as described [here](https://www.atom3d.ai/). Modify the function `get_datasets` in `run_atom3d.py` with the paths to the datasets. Then run:
249+
```
250+
$ python run_atom3d.py -h
251+
252+
usage: run_atom3d.py [-h] [--num-workers N] [--smp-idx IDX]
253+
[--lba-split SPLIT] [--batch SIZE] [--train-time MINUTES]
254+
[--val-time MINUTES] [--epochs N] [--test PATH]
255+
[--lr RATE] [--load PATH]
256+
TASK
257+
258+
positional arguments:
259+
TASK {PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}
260+
261+
optional arguments:
262+
-h, --help show this help message and exit
263+
--num-workers N number of threads for loading data, default=4
264+
--smp-idx IDX label index for SMP, in range 0-19
265+
--lba-split SPLIT identity cutoff for LBA, 30 (default) or 60
266+
--batch SIZE batch size, default=8
267+
--train-time MINUTES maximum time between evaluations on valset,
268+
default=120 minutes
269+
--val-time MINUTES maximum time per evaluation on valset, default=20
270+
minutes
271+
--epochs N training epochs, default=50
272+
--test PATH evaluate a trained model
273+
--lr RATE learning rate
274+
--load PATH initialize first 2 GNN layers with pretrained weights
275+
```
276+
For example:
277+
```
278+
# train a model
279+
python run_atom3d.py PSR
280+
281+
# train a model with pretrained weights
282+
python run_atom3d.py PSR --load PATH
283+
284+
# evaluate a model
285+
python run_atom3d.py PSR --test PATH
286+
```
287+
207288
## Acknowledgements
208289
Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design). We thank Pratham Soni for portions of the implementation in PyTorch.
209290

@@ -217,4 +298,11 @@ Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS
217298
year={2021},
218299
url={https://openreview.net/forum?id=1YLJDvSx6J4}
219300
}
220-
```
301+
302+
@article{jing2021equivariant,
303+
title={Equivariant Graph Neural Networks for 3D Macromolecular Structure},
304+
author={Jing, Bowen and Eismann, Stephan and Soni, Pratham N and Dror, Ron O},
305+
journal={arXiv preprint arXiv:2106.03843},
306+
year={2021}
307+
}
308+
```

gvp/__init__.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch, functools
22
from torch import nn
33
import torch.nn.functional as F
44
from torch_geometric.nn import MessagePassing
@@ -85,18 +85,22 @@ class GVP(nn.Module):
8585
:param out_dims: tuple (n_scalar, n_vector)
8686
:param h_dim: intermediate number of vector channels, optional
8787
:param activations: tuple of functions (scalar_act, vector_act)
88+
:param vector_gate: whether to use vector gating.
89+
(vector_act will be used as sigma^+ in vector gating if `True`)
8890
'''
8991
def __init__(self, in_dims, out_dims, h_dim=None,
90-
activations=(F.relu, torch.sigmoid)):
92+
activations=(F.relu, torch.sigmoid), vector_gate=False):
9193
super(GVP, self).__init__()
9294
self.si, self.vi = in_dims
9395
self.so, self.vo = out_dims
96+
self.vector_gate = vector_gate
9497
if self.vi:
9598
self.h_dim = h_dim or max(self.vi, self.vo)
9699
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
97100
self.ws = nn.Linear(self.h_dim + self.si, self.so)
98101
if self.vo:
99102
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
103+
if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
100104
else:
101105
self.ws = nn.Linear(self.si, self.so)
102106

@@ -119,7 +123,13 @@ def forward(self, x):
119123
if self.vo:
120124
v = self.wv(vh)
121125
v = torch.transpose(v, -1, -2)
122-
if self.vector_act:
126+
if self.vector_gate:
127+
if self.vector_act:
128+
gate = self.wsv(self.vector_act(s))
129+
else:
130+
gate = self.wsv(s)
131+
v = v * torch.sigmoid(gate).unsqueeze(-1)
132+
elif self.vector_act:
123133
v = v * self.vector_act(
124134
_norm_no_nan(v, axis=-1, keepdims=True))
125135
else:
@@ -214,28 +224,35 @@ class GVPConv(MessagePassing):
214224
:param n_layers: number of GVPs in the message function
215225
:param module_list: preconstructed message function, overrides n_layers
216226
:param aggr: should be "add" if some incoming edges are masked, as in
217-
a masked autoregressive decoder architecture
227+
a masked autoregressive decoder architecture, otherwise "mean"
228+
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
229+
:param vector_gate: whether to use vector gating.
230+
(vector_act will be used as sigma^+ in vector gating if `True`)
218231
'''
219232
def __init__(self, in_dims, out_dims, edge_dims,
220-
n_layers=3, module_list=None, aggr="mean"):
233+
n_layers=3, module_list=None, aggr="mean",
234+
activations=(F.relu, torch.sigmoid), vector_gate=False):
221235
super(GVPConv, self).__init__(aggr=aggr)
222236
self.si, self.vi = in_dims
223237
self.so, self.vo = out_dims
224238
self.se, self.ve = edge_dims
225239

240+
GVP_ = functools.partial(GVP,
241+
activations=activations, vector_gate=vector_gate)
242+
226243
module_list = module_list or []
227244
if not module_list:
228245
if n_layers == 1:
229246
module_list.append(
230-
GVP((2*self.si + self.se, 2*self.vi + self.ve),
247+
GVP_((2*self.si + self.se, 2*self.vi + self.ve),
231248
(self.so, self.vo), activations=(None, None)))
232249
else:
233250
module_list.append(
234-
GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
251+
GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
235252
)
236253
for i in range(n_layers - 2):
237-
module_list.append(GVP(out_dims, out_dims))
238-
module_list.append(GVP(out_dims, out_dims,
254+
module_list.append(GVP_(out_dims, out_dims))
255+
module_list.append(GVP_(out_dims, out_dims,
239256
activations=(None, None)))
240257
self.message_func = nn.Sequential(*module_list)
241258

@@ -276,26 +293,33 @@ class GVPConvLayer(nn.Module):
276293
:param autoregressive: if `True`, this `GVPConvLayer` will be used
277294
with a different set of input node embeddings for messages
278295
where src >= dst
296+
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
297+
:param vector_gate: whether to use vector gating.
298+
(vector_act will be used as sigma^+ in vector gating if `True`)
279299
'''
280300
def __init__(self, node_dims, edge_dims,
281301
n_message=3, n_feedforward=2, drop_rate=.1,
282-
autoregressive=False):
302+
autoregressive=False,
303+
activations=(F.relu, torch.sigmoid), vector_gate=False):
283304

284305
super(GVPConvLayer, self).__init__()
285306
self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
286-
aggr="add" if autoregressive else "mean")
307+
aggr="add" if autoregressive else "mean",
308+
activations=activations, vector_gate=vector_gate)
309+
GVP_ = functools.partial(GVP,
310+
activations=activations, vector_gate=vector_gate)
287311
self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
288312
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
289313

290314
ff_func = []
291315
if n_feedforward == 1:
292-
ff_func.append(GVP(node_dims, node_dims, activations=(None, None)))
316+
ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
293317
else:
294318
hid_dims = 4*node_dims[0], 2*node_dims[1]
295-
ff_func.append(GVP(node_dims, hid_dims))
319+
ff_func.append(GVP_(node_dims, hid_dims))
296320
for i in range(n_feedforward-2):
297-
ff_func.append(GVP(hid_dims, hid_dims))
298-
ff_func.append(GVP(hid_dims, node_dims, activations=(None, None)))
321+
ff_func.append(GVP_(hid_dims, hid_dims))
322+
ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
299323
self.ff_func = nn.Sequential(*ff_func)
300324

301325
def forward(self, x, edge_index, edge_attr,

0 commit comments

Comments
 (0)