Skip to content

Commit f53b581

Browse files
wiwaDeepakchowdavarapu
authored andcommitted
Fix Minibatch alignment in Bayesian Neural Network example + Pre-commit hooks (pymc-devs#719)
* Fix Minibatch alignment in Bayesian Neural Network example * Run: pre-commit run all-files --------- Co-authored-by: Deepak CH <[email protected]>
1 parent 92fc463 commit f53b581

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

Diff for: examples/variational_inference/bayesian_neural_network_advi.ipynb

+8-3
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,12 @@
186186
" }\n",
187187
"\n",
188188
" with pm.Model(coords=coords) as neural_network:\n",
189-
" ann_input = pm.Data(\"ann_input\", X_train, mutable=True)\n",
190-
" ann_output = pm.Data(\"ann_output\", Y_train, mutable=True)\n",
189+
" # Define minibatch variables\n",
190+
" minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n",
191+
"\n",
192+
" # Define data variables using minibatches\n",
193+
" ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n",
194+
" ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n",
191195
"\n",
192196
" # Weights from input to hidden layer\n",
193197
" weights_in_1 = pm.Normal(\n",
@@ -212,7 +216,8 @@
212216
" \"out\",\n",
213217
" act_out,\n",
214218
" observed=ann_output,\n",
215-
" total_size=Y_train.shape[0], # IMPORTANT for minibatches\n",
219+
" total_size=X_train.shape[0], # IMPORTANT for minibatches\n",
220+
" dims=\"obs_id\",\n",
216221
" )\n",
217222
" return neural_network\n",
218223
"\n",

Diff for: examples/variational_inference/bayesian_neural_network_advi.myst.md

+8-3
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,12 @@ def construct_nn():
131131
}
132132
133133
with pm.Model(coords=coords) as neural_network:
134-
ann_input = pm.Data("ann_input", X_train, mutable=True)
135-
ann_output = pm.Data("ann_output", Y_train, mutable=True)
134+
# Define minibatch variables
135+
minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
136+
137+
# Define data variables using minibatches
138+
ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols"))
139+
ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id")
136140
137141
# Weights from input to hidden layer
138142
weights_in_1 = pm.Normal(
@@ -157,7 +161,8 @@ def construct_nn():
157161
"out",
158162
act_out,
159163
observed=ann_output,
160-
total_size=Y_train.shape[0], # IMPORTANT for minibatches
164+
total_size=X_train.shape[0], # IMPORTANT for minibatches
165+
dims="obs_id",
161166
)
162167
return neural_network
163168

0 commit comments

Comments
 (0)