Skip to content

Commit 480c724

Browse files
committed
Add 0.0.2 notebooks and result plots
1 parent b9e4c25 commit 480c724

6 files changed

+263
-43
lines changed
Loading
Loading

notebooks/dockgen_structure_prediction_results_plotting.ipynb

Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,16 @@
5959
"source": [
6060
"# General variables\n",
6161
"baseline_methods = [\n",
62-
" # \"vina_p2rank\",\n",
62+
" \"vina_p2rank\",\n",
6363
" \"diffdock\",\n",
6464
" \"dynamicbind\",\n",
65-
" # \"rfaa\",\n",
65+
" \"rfaa\",\n",
66+
" \"alphafold3\",\n",
6667
" \"chai-lab\",\n",
6768
" \"neuralplexer\",\n",
6869
" \"flowdock_hp\",\n",
70+
" \"flowdock_aft\",\n",
71+
" \"flowdock_esmfold\",\n",
6972
" \"flowdock\",\n",
7073
"]\n",
7174
"max_num_repeats_per_method = 3\n",
@@ -77,9 +80,14 @@
7780
" \"..\", \"forks\", \"DynamicBind\", \"inference\", \"outputs\", \"results\"\n",
7881
")\n",
7982
"globals()[\"rfaa_output_dir\"] = os.path.join(\"..\", \"forks\", \"RoseTTAFold-All-Atom\", \"inference\")\n",
83+
"globals()[\"alphafold3_output_dir\"] = os.path.join(\"..\", \"forks\", \"alphafold3\", \"inference\")\n",
8084
"globals()[\"chai-lab_output_dir\"] = os.path.join(\"..\", \"forks\", \"chai-lab\", \"inference\")\n",
8185
"globals()[\"neuralplexer_output_dir\"] = os.path.join(\"..\", \"forks\", \"NeuralPLexer\", \"inference\")\n",
8286
"globals()[\"flowdock_hp_output_dir\"] = os.path.join(\"..\", \"forks\", \"FlowDock\", \"hp_inference\")\n",
87+
"globals()[\"flowdock_aft_output_dir\"] = os.path.join(\"..\", \"forks\", \"FlowDock\", \"aft_inference\")\n",
88+
"globals()[\"flowdock_esmfold_output_dir\"] = os.path.join(\n",
89+
" \"..\", \"forks\", \"FlowDock\", \"esmfold_inference\"\n",
90+
")\n",
8391
"globals()[\"flowdock_output_dir\"] = os.path.join(\"..\", \"forks\", \"FlowDock\", \"inference\")\n",
8492
"\n",
8593
"for repeat_index in range(1, max_num_repeats_per_method + 1):\n",
@@ -133,15 +141,29 @@
133141
" \"bust_results.csv\",\n",
134142
" )\n",
135143
"\n",
136-
" # Chai-1 results\n",
144+
" # AlphaFold 3 (Single-Seq) results\n",
145+
" globals()[f\"alphafold3_dockgen_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
146+
" globals()[\"alphafold3_output_dir\"],\n",
147+
" f\"alphafold3_ss_dockgen_outputs_{repeat_index}\",\n",
148+
" \"bust_results.csv\",\n",
149+
" )\n",
150+
" globals()[f\"alphafold3_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
151+
" os.path.join(\n",
152+
" globals()[\"alphafold3_output_dir\"],\n",
153+
" f\"alphafold3_ss_dockgen_outputs_{repeat_index}_relaxed\",\n",
154+
" \"bust_results.csv\",\n",
155+
" )\n",
156+
" )\n",
157+
"\n",
158+
" # Chai-1 (Single-Seq) results\n",
137159
" globals()[f\"chai-lab_dockgen_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
138160
" globals()[\"chai-lab_output_dir\"],\n",
139-
" f\"chai-lab_dockgen_outputs_{repeat_index}\",\n",
161+
" f\"chai-lab_ss_dockgen_outputs_{repeat_index}\",\n",
140162
" \"bust_results.csv\",\n",
141163
" )\n",
142164
" globals()[f\"chai-lab_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
143165
" globals()[\"chai-lab_output_dir\"],\n",
144-
" f\"chai-lab_dockgen_outputs_{repeat_index}_relaxed\",\n",
166+
" f\"chai-lab_ss_dockgen_outputs_{repeat_index}_relaxed\",\n",
145167
" \"bust_results.csv\",\n",
146168
" )\n",
147169
"\n",
@@ -173,6 +195,34 @@
173195
" )\n",
174196
" )\n",
175197
"\n",
198+
" # FlowDock-AFT results\n",
199+
" globals()[f\"flowdock_aft_dockgen_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
200+
" globals()[\"flowdock_aft_output_dir\"],\n",
201+
" f\"flowdock_dockgen_outputs_{repeat_index}\",\n",
202+
" \"bust_results.csv\",\n",
203+
" )\n",
204+
" globals()[f\"flowdock_aft_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
205+
" os.path.join(\n",
206+
" globals()[\"flowdock_aft_output_dir\"],\n",
207+
" f\"flowdock_dockgen_outputs_{repeat_index}_relaxed\",\n",
208+
" \"bust_results.csv\",\n",
209+
" )\n",
210+
" )\n",
211+
"\n",
212+
" # FlowDock-ESMFold results\n",
213+
" globals()[f\"flowdock_esmfold_dockgen_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
214+
" globals()[\"flowdock_esmfold_output_dir\"],\n",
215+
" f\"flowdock_dockgen_outputs_{repeat_index}\",\n",
216+
" \"bust_results.csv\",\n",
217+
" )\n",
218+
" globals()[f\"flowdock_esmfold_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
219+
" os.path.join(\n",
220+
" globals()[\"flowdock_esmfold_output_dir\"],\n",
221+
" f\"flowdock_dockgen_outputs_{repeat_index}_relaxed\",\n",
222+
" \"bust_results.csv\",\n",
223+
" )\n",
224+
" )\n",
225+
"\n",
176226
" # FlowDock results\n",
177227
" globals()[f\"flowdock_dockgen_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
178228
" globals()[\"flowdock_output_dir\"],\n",
@@ -191,20 +241,26 @@
191241
" \"diffdock\": \"DiffDock-L\",\n",
192242
" \"dynamicbind\": \"DynamicBind\",\n",
193243
" \"rfaa\": \"RoseTTAFold-AA\",\n",
194-
" \"chai-lab\": \"Chai-1\",\n",
244+
" \"alphafold3\": \"AF3-Single-Seq\",\n",
245+
" \"chai-lab\": \"Chai-1-Single-Seq\",\n",
195246
" \"neuralplexer\": \"NeuralPLexer\",\n",
196247
" \"flowdock_hp\": \"FlowDock-HP\",\n",
197-
" \"flowdock\": \"FlowDock\",\n",
248+
" \"flowdock_aft\": \"FlowDock-AFT\",\n",
249+
" \"flowdock_esmfold\": \"FlowDock-ESMFold\",\n",
250+
" \"flowdock\": \"FlowDock-AF3\",\n",
198251
"}\n",
199252
"\n",
200253
"method_category_mapping = {\n",
201254
" \"vina_p2rank\": \"Conventional blind\",\n",
202255
" \"diffdock\": \"DL-based blind\",\n",
203256
" \"dynamicbind\": \"DL-based blind\",\n",
204257
" \"rfaa\": \"DL-based blind\",\n",
258+
" \"alphafold3\": \"DL-based blind\",\n",
205259
" \"chai-lab\": \"DL-based blind\",\n",
206260
" \"neuralplexer\": \"DL-based blind\",\n",
207261
" \"flowdock_hp\": \"DL-based blind\",\n",
262+
" \"flowdock_aft\": \"DL-based blind\",\n",
263+
" \"flowdock_esmfold\": \"DL-based blind\",\n",
208264
" \"flowdock\": \"DL-based blind\",\n",
209265
"}\n",
210266
"\n",
@@ -476,7 +532,7 @@
476532
"colors = [\"#FB8072\", \"#BEBADA\"]\n",
477533
"\n",
478534
"bar_width = 0.5\n",
479-
"r1 = [item - 0.25 for item in range(2, 14, 2)]\n",
535+
"r1 = [item - 0.25 for item in range(2, 24, 2)]\n",
480536
"r2 = [x + bar_width for x in r1]\n",
481537
"\n",
482538
"(\n",
@@ -714,15 +770,61 @@
714770
"\n",
715771
"# add labels, titles, ticks, etc.\n",
716772
"axis.set_ylabel(\"Percentage of predictions\")\n",
717-
"axis.set_xlim(1, 13 + 0.1)\n",
773+
"axis.set_xlim(1, 23 + 0.1)\n",
718774
"axis.set_ylim(0, 125)\n",
719775
"\n",
720-
"axis.bar_label(dockgen_rmsd_lt_2_bar, fmt=\"{:,.1f}%\", label_type=\"edge\")\n",
721-
"axis.bar_label(dockgen_rmsd_lt_2_and_pb_valid_bar, fmt=\"{:,.1f}%\", label_type=\"center\", padding=5)\n",
722-
"axis.bar_label(dockgen_relaxed_rmsd_lt_2_bar, fmt=\"{:,.1f}%\", label_type=\"edge\")\n",
723-
"axis.bar_label(\n",
724-
" dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar, fmt=\"{:,.1f}%\", label_type=\"center\", padding=5\n",
776+
"assert len(dockgen_rmsd_lt_2_bar) == len(dockgen_rmsd_lt_2_and_pb_valid_bar), (\n",
777+
" f\"Length of dockgen_rmsd_lt_2_bar ({len(dockgen_rmsd_lt_2_bar)}) \"\n",
778+
" f\"and dockgen_rmsd_lt_2_and_pb_valid_bar ({len(dockgen_rmsd_lt_2_and_pb_valid_bar)}) \"\n",
779+
" \"do not match.\"\n",
780+
")\n",
781+
"assert len(dockgen_relaxed_rmsd_lt_2_bar) == len(dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar), (\n",
782+
" f\"Length of dockgen_relaxed_rmsd_lt_2_bar ({len(dockgen_relaxed_rmsd_lt_2_bar)}) \"\n",
783+
" f\"and dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar ({len(dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar)}) \"\n",
784+
" \"do not match.\"\n",
725785
")\n",
786+
"for bar, pb_valid_bar in zip(dockgen_rmsd_lt_2_bar, dockgen_rmsd_lt_2_and_pb_valid_bar):\n",
787+
" height = bar.get_height()\n",
788+
" pb_valid_height = pb_valid_bar.get_height()\n",
789+
" axis.annotate(\n",
790+
" f\"{height:.1f}\",\n",
791+
" (\n",
792+
" bar.get_x() + bar.get_width() / 2.5,\n",
793+
" max(height + 5, pb_valid_height) + 2,\n",
794+
" ), # Offset to prevent overlap\n",
795+
" ha=\"center\",\n",
796+
" va=\"bottom\",\n",
797+
" fontsize=24,\n",
798+
" )\n",
799+
" axis.annotate(\n",
800+
" f\"{pb_valid_height:.1f}\",\n",
801+
" (pb_valid_bar.get_x() + pb_valid_bar.get_width() / 2.5, max(height, pb_valid_height) + 2),\n",
802+
" ha=\"center\",\n",
803+
" va=\"bottom\",\n",
804+
" fontsize=24,\n",
805+
" )\n",
806+
"for bar, pb_valid_bar in zip(\n",
807+
" dockgen_relaxed_rmsd_lt_2_bar, dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar\n",
808+
"):\n",
809+
" height = bar.get_height()\n",
810+
" pb_valid_height = pb_valid_bar.get_height()\n",
811+
" axis.annotate(\n",
812+
" f\"{height:.1f}\",\n",
813+
" (\n",
814+
" bar.get_x() + bar.get_width() / 1.75,\n",
815+
" max(height + 5, pb_valid_height) + 2,\n",
816+
" ), # Offset to prevent overlap\n",
817+
" ha=\"center\",\n",
818+
" va=\"bottom\",\n",
819+
" fontsize=24,\n",
820+
" )\n",
821+
" axis.annotate(\n",
822+
" f\"{pb_valid_height:.1f}\",\n",
823+
" (pb_valid_bar.get_x() + pb_valid_bar.get_width() / 1.75, max(height, pb_valid_height) + 2),\n",
824+
" ha=\"center\",\n",
825+
" va=\"bottom\",\n",
826+
" fontsize=24,\n",
827+
" )\n",
726828
"\n",
727829
"axis.yaxis.set_major_formatter(mtick.PercentFormatter())\n",
728830
"\n",
@@ -731,20 +833,23 @@
731833
"axis.grid(axis=\"y\", color=\"#EAEFF8\")\n",
732834
"axis.set_axisbelow(True)\n",
733835
"\n",
734-
"axis.set_xticks([2, 4, 6, 7, 8, 10, 12])\n",
836+
"axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 10, 12, 13, 14, 16, 18, 20, 22])\n",
735837
"axis.set_xticks([1 + 0.1], minor=True)\n",
736838
"axis.set_xticklabels(\n",
737839
" [\n",
738-
" # \"P2Rank-Vina\",\n",
739-
" # \"Conventional blind\",\n",
840+
" \"P2Rank-Vina\",\n",
841+
" \"Conventional blind\",\n",
740842
" \"DiffDock-L\",\n",
741843
" \"DynamicBind\",\n",
742-
" # \"RoseTTAFold-AA\",\n",
743-
" \"Chai-1\",\n",
844+
" \"RoseTTAFold-AA\",\n",
845+
" \"AF3-Single-Seq\",\n",
846+
" \"Chai-1-Single-Seq\",\n",
744847
" \"DL-based blind\",\n",
745848
" \"NeuralPLexer\",\n",
746849
" \"FlowDock-HP\",\n",
747-
" \"FlowDock\",\n",
850+
" \"FlowDock-AFT\",\n",
851+
" \"FlowDock-ESMFold\",\n",
852+
" \"FlowDock-AF3\",\n",
748853
" ]\n",
749854
")\n",
750855
"\n",
@@ -756,7 +861,7 @@
756861
"axis.tick_params(axis=\"y\", which=\"major\", left=\"off\", right=\"on\", color=\"#EAEFF8\")\n",
757862
"\n",
758863
"# vertical alignment of xtick labels\n",
759-
"vert_alignments = [0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0]\n",
864+
"vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
760865
"for tick, y in zip(axis.get_xticklabels(), vert_alignments):\n",
761866
" tick.set_y(y)\n",
762867
"\n",
Loading
Loading

0 commit comments

Comments
 (0)