Skip to content

Commit bab57a2

Browse files
Fix Minibatch alignment in Bayesian Neural Network example + Pre-commit hooks (#719)
* Fix Minibatch alignment in Bayesian Neural Network example * Run: pre-commit run all-files --------- Co-authored-by: Deepak CH <[email protected]>
1 parent 05a928a commit bab57a2

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

Diff for: examples/variational_inference/bayesian_neural_network_advi.ipynb

+11-5
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
},
191191
"outputs": [],
192192
"source": [
193-
"def construct_nn(ann_input, ann_output):\n",
193+
"def construct_nn():\n",
194194
" n_hidden = 5\n",
195195
"\n",
196196
" # Initialize random weights between each layer\n",
@@ -204,9 +204,14 @@
204204
" \"train_cols\": np.arange(X_train.shape[1]),\n",
205205
" \"obs_id\": np.arange(X_train.shape[0]),\n",
206206
" }\n",
207+
"\n",
207208
" with pm.Model(coords=coords) as neural_network:\n",
208-
" ann_input = pm.Data(\"ann_input\", X_train, dims=(\"obs_id\", \"train_cols\"))\n",
209-
" ann_output = pm.Data(\"ann_output\", Y_train, dims=\"obs_id\")\n",
209+
" # Define minibatch variables\n",
210+
" minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n",
211+
"\n",
212+
" # Define data variables using minibatches\n",
213+
" ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n",
214+
" ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n",
210215
"\n",
211216
" # Weights from input to hidden layer\n",
212217
" weights_in_1 = pm.Normal(\n",
@@ -231,13 +236,14 @@
231236
" \"out\",\n",
232237
" act_out,\n",
233238
" observed=ann_output,\n",
234-
" total_size=Y_train.shape[0], # IMPORTANT for minibatches\n",
239+
" total_size=X_train.shape[0], # IMPORTANT for minibatches\n",
235240
" dims=\"obs_id\",\n",
236241
" )\n",
237242
" return neural_network\n",
238243
"\n",
239244
"\n",
240-
"neural_network = construct_nn(X_train, Y_train)"
245+
"# Create the neural network model\n",
246+
"neural_network = construct_nn()"
241247
]
242248
},
243249
{

Diff for: examples/variational_inference/bayesian_neural_network_advi.myst.md

+11-5
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ A neural network is quite simple. The basic unit is a [perceptron](https://en.wi
114114
jupyter:
115115
outputs_hidden: true
116116
---
117-
def construct_nn(ann_input, ann_output):
117+
def construct_nn():
118118
n_hidden = 5
119119
120120
# Initialize random weights between each layer
@@ -128,9 +128,14 @@ def construct_nn(ann_input, ann_output):
128128
"train_cols": np.arange(X_train.shape[1]),
129129
"obs_id": np.arange(X_train.shape[0]),
130130
}
131+
131132
with pm.Model(coords=coords) as neural_network:
132-
ann_input = pm.Data("ann_input", X_train, dims=("obs_id", "train_cols"))
133-
ann_output = pm.Data("ann_output", Y_train, dims="obs_id")
133+
# Define minibatch variables
134+
minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
135+
136+
# Define data variables using minibatches
137+
ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols"))
138+
ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id")
134139
135140
# Weights from input to hidden layer
136141
weights_in_1 = pm.Normal(
@@ -155,13 +160,14 @@ def construct_nn(ann_input, ann_output):
155160
"out",
156161
act_out,
157162
observed=ann_output,
158-
total_size=Y_train.shape[0], # IMPORTANT for minibatches
163+
total_size=X_train.shape[0], # IMPORTANT for minibatches
159164
dims="obs_id",
160165
)
161166
return neural_network
162167
163168
164-
neural_network = construct_nn(X_train, Y_train)
169+
# Create the neural network model
170+
neural_network = construct_nn()
165171
```
166172

167173
That's not so bad. The `Normal` priors help regularize the weights. Usually we would add a constant `b` to the inputs but I omitted it here to keep the code cleaner.

0 commit comments

Comments
 (0)