Skip to content

Commit cd69cb8

Browse files
committed
revamp examples
1 parent 7966999 commit cd69cb8

24 files changed

+7936
-24
lines changed

README.md

+34-12
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ Pkg.add("RDatasets") # for the demo below
5656
### Example
5757

5858
Following is an introductory example using a default builder and no
59-
standardization of input features.
59+
standardization of input features ([notebook/script](/examples/iris)).
60+
61+
For a more advanced illustration, see the [MNIST dataset
62+
example](https://github.com/FluxML/MLJFlux.jl/blob/dev/examples/mnist).
63+
6064

6165

6266
#### Loading some data and instantiating a model
@@ -66,7 +70,7 @@ using MLJ
6670
import RDatasets
6771
iris = RDatasets.dataset("datasets", "iris");
6872
y, X = unpack(iris, ==(:Species), colname -> true, rng=123);
69-
@load NeuralNetworkClassifier
73+
NeuralNetworkClassifier = @load NeuralNetworkClassifier
7074

7175
julia> clf = NeuralNetworkClassifier()
7276
NeuralNetworkClassifier(
@@ -94,7 +98,7 @@ fit!(mach)
9498
julia> training_loss = cross_entropy(predict(mach, X), y) |> mean
9599
0.89526004f0
96100

97-
# increase learning rate and add iterations:
101+
# Increasing learning rate and adding iterations:
98102
clf.optimiser.eta = clf.optimiser.eta * 2
99103
clf.epochs = clf.epochs + 5
100104

@@ -135,7 +139,7 @@ plot(curve.parameter_values,
135139

136140
```
137141
138-
![learning_curve.png](learning_curve.png)
142+
![](examples/iris/iris_history.png)
139143
140144
141145
### Models
@@ -164,6 +168,21 @@ model type | prediction type | `scitype(X) <: _` | `scitype(y) <: _`
164168
165169
> Table 1. Input and output types for MLJFlux models
166170
171+
### Training on a GPU
172+
173+
When instantiating a model for training on a GPU, specify
174+
`acceleration=CUDALibs()`, as in
175+
176+
```julia
177+
using MLJ
178+
ImageClassifier = @load ImageClassifier
179+
clf = ImageClassifier(epochs=10, acceleration=CUDALibs())
180+
```
181+
182+
At present, data bound to a MLJ model in an MLJ machine is
183+
automatically moved on and off the GPU under the hood.
184+
185+
167186
#### Non-tabular input
168187
169188
Any `AbstractMatrix{<:AbstractFloat}` object `Xmat` can be forced to
@@ -262,6 +281,9 @@ end
262281
Note here that `n_in` and `n_out` depend on the size of the data (see
263282
Table 1).
264283
284+
For a concrete image classification example, see
285+
[examples/mnist](examples/mnist).
286+
265287
More generally, defining a new builder means defining a new struct
266288
sub-typing `MLJFlux.Builder` and defining a new `MLJFlux.build` method
267289
with one of these signatures:
@@ -279,16 +301,13 @@ following conditions:
279301
- for any `x <: Vector{<:AbstractFloat}` of length `n_in` (for use
280302
with one of the first three model types); or
281303
282-
- for any `x <: Array{<:Float32, 3}` of size
283-
`(W, H, n_channels)`, where `n_in = (W, H)` and `n_channels` is
284-
1 or 3 (for use with `ImageClassifier`)
304+
- for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels,
305+
batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and
306+
`batch_size` is any integer (for use with `ImageClassifier`)
285307
286308
- The object returned by `chain(x)` must be an `AbstractFloat` vector
287309
of length `n_out`.
288310
289-
For an builder example for use with `ImageClassifier` see
290-
[below](an-image-classification-example).
291-
292311
293312
### Loss functions
294313
@@ -311,6 +330,9 @@ you *should* use MLJ loss functions in MLJ meta-algorithms.
311330
312331
### An image classification example
313332
333+
An expanded version of this example, with early stopping, is available
334+
[here](/examples/mnist).
335+
314336
We define a builder that builds a chain with six alternating
315337
convolution and max-pool layers, and a final dense layer, which we
316338
apply to the MNIST image dataset.
@@ -328,7 +350,7 @@ function flatten(x::AbstractArray)
328350
end
329351

330352
import MLJFlux
331-
mutable struct MyConvBuilder <: MLJFlux.Builder
353+
mutable struct MyConvBuilder
332354
filter_size::Int
333355
channels1::Int
334356
channels2::Int
@@ -384,7 +406,7 @@ y = coerce(y, Multiclass);
384406
Instantiating an image classifier model:
385407
386408
```julia
387-
@load ImageClassifier
409+
ImageClassifier = @load ImageClassifier
388410
clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32),
389411
epochs=10,
390412
loss=Flux.crossentropy)

examples/Project.toml

-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
[deps]
2-
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
3-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4-
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
5-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
6-
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
7-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
8-
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"

0 commit comments

Comments
 (0)