Skip to content

Commit 364b944

Browse files
committed
Document batched prediction API
1 parent 78b82da commit 364b944

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

README.md

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,22 @@ This is the official codebase of the paper
3535

3636
## Contents
3737

38-
- [Installation](#installation)
39-
- [How to prepare data for FlowDock](#how-to-prepare-data-for-flowdock)
40-
- [How to train FlowDock](#how-to-train-flowdock)
41-
- [How to evaluate FlowDock](#how-to-evaluate-flowdock)
42-
- [How to create comparative plots of evaluation results](#how-to-create-comparative-plots-of-evaluation-results)
43-
- [How to predict new protein-ligand complex structures and their affinities using FlowDock](#how-to-predict-new-protein-ligand-complex-structures-using-flowdock)
44-
- [For developers](#for-developers)
45-
- [Docker](#docker)
46-
- [Acknowledgements](#acknowledgements)
47-
- [License](#license)
48-
- [Citing this work](#citing-this-work)
38+
- [FlowDock](#flowdock)
39+
- [Description](#description)
40+
- [Contents](#contents)
41+
- [Installation](#installation)
42+
- [How to prepare data for `FlowDock`](#how-to-prepare-data-for-flowdock)
43+
- [Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)](#generating-esm2-embeddings-for-each-protein-optional-cached-input-data-available-on-sharepoint)
44+
- [Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)](#predicting-apo-protein-structures-using-esmfold-optional-cached-data-available-on-zenodo)
45+
- [How to train `FlowDock`](#how-to-train-flowdock)
46+
- [How to evaluate `FlowDock`](#how-to-evaluate-flowdock)
47+
- [How to create comparative plots of benchmarking results](#how-to-create-comparative-plots-of-benchmarking-results)
48+
- [How to predict new protein-ligand complex structures and their affinities using `FlowDock`](#how-to-predict-new-protein-ligand-complex-structures-and-their-affinities-using-flowdock)
49+
- [For developers](#for-developers)
50+
- [Docker](#docker)
51+
- [Acknowledgements](#acknowledgements)
52+
- [License](#license)
53+
- [Citing this work](#citing-this-work)
4954

5055
## Installation
5156

@@ -359,6 +364,14 @@ python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights_EMA.
359364

360365
If you do not already have a template protein structure available for your target of interest, set `input_template=null` to instead have the sampling script predict the ESMFold structure of your provided `input_protein` sequence before running the sampling pipeline. For more information regarding the input arguments available for sampling, please refer to the config at `configs/sample.yaml`.
361366

367+
**NOTE:** To optimize prediction runtimes, a `csv_path` can be specified instead of the `input_receptor`, `input_ligand`, and `input_template` CLI arguments to perform *batched* prediction for a collection of protein-ligand sequence pairs, each represented as a CSV row containing column values for `id`, `input_receptor`, `input_ligand`, and `input_template`. Additionally, disabling `visualize_sample_trajectories` may reduce storage requirements when predicting a large batch of inputs.
368+
369+
For instance, one can perform batched prediction as follows:
370+
371+
```bash
372+
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights_EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling csv_path='./data/test_cases/prediction_inputs/flowdock_batched_inputs.csv' out_path='./T1152_batch_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=false auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
373+
```
374+
362375
</details>
363376

364377
## For developers
@@ -395,8 +408,6 @@ Given that this tool has a number of dependencies, it may be easier to run it in
395408

396409
Pull from [Docker Hub](https://hub.docker.com/repository/docker/cford38/flowdock): `docker pull cford38/flowdock:latest`
397410

398-
399-
400411
Alternatively, build the Docker image locally:
401412

402413
```bash
@@ -413,7 +424,6 @@ docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --nam
413424

414425
</details>
415426

416-
417427
## Acknowledgements
418428

419429
`FlowDock` builds upon the source code and data from the following projects:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
id,input_receptor,input_ligand,input_template
2+
T1152_1,data/test_cases/predicted_structures/T1152.pdb,CC(C)C1=CC=C(C=C1)C(=O)O,data/test_cases/predicted_structures/T1152.pdb
3+
T1152_2,data/test_cases/predicted_structures/T1152.pdb,NC(=O)C1=CC=C(C=C1)C(=O)O,data/test_cases/predicted_structures/T1152.pdb
4+
T1152_3,data/test_cases/predicted_structures/T1152.pdb,CC(C)C1=CC=C(C=C1)C(=O)C,data/test_cases/predicted_structures/T1152.pdb
5+
T1152_4,data/test_cases/predicted_structures/T1152.pdb,CC(=O)C1=CC=C(C=C1)C(=O)O,data/test_cases/predicted_structures/T1152.pdb
6+
T1152_5,data/test_cases/predicted_structures/T1152.pdb,NC(C)C1=CC=C(C=C1)C(=O)O,data/test_cases/predicted_structures/T1152.pdb

flowdock/models/flowdock_fm_module.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,12 @@ def predict_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int =
743743
sample_id = batch["sample_id"][0] if "sample_id" in batch else "sample"
744744
input_template = batch["input_template"][0] if "input_template" in batch else None
745745

746+
out_path = (
747+
os.path.join(self.hparams.cfg.out_path, sample_id)
748+
if "sample_id" in batch
749+
else self.hparams.cfg.out_path
750+
)
751+
746752
# generate ESM embeddings for the protein
747753
protein = pdb_filepath_to_protein(rec_path)
748754
sequences = [
@@ -793,7 +799,7 @@ def predict_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int =
793799
ligand_paths,
794800
self.hparams.cfg,
795801
self,
796-
self.hparams.cfg.out_path,
802+
out_path,
797803
separate_pdb=self.hparams.cfg.separate_pdb,
798804
apo_receptor_path=apo_rec_path,
799805
sample_id=sample_id,
@@ -842,6 +848,7 @@ def on_predict_epoch_end(self):
842848
prot_lig_pairs,
843849
os.path.join(
844850
self.hparams.cfg.out_path,
851+
outputs["name"][batch_index],
845852
"predict_epoch_outputs",
846853
f"{outputs['name'][batch_index]}{f'_rank{ranking + 1}' if ranking is not None else ''}_predict_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}.pdb",
847854
),

0 commit comments

Comments
 (0)