@@ -56,7 +56,11 @@ Pkg.add("RDatasets") # for the demo below
56
56
### Example
57
57
58
58
Following is an introductory example using a default builder and no
59
- standardization of input features.
59
+ standardization of input features ([ notebook/script] ( /examples/iris ) ).
60
+
61
+ For a more advanced illustration, see the [ MNIST dataset
62
+ example] ( https://github.com/FluxML/MLJFlux.jl/blob/dev/examples/mnist ) .
63
+
60
64
61
65
62
66
#### Loading some data and instantiating a model
@@ -66,7 +70,7 @@ using MLJ
66
70
import RDatasets
67
71
iris = RDatasets. dataset (" datasets" , " iris" );
68
72
y, X = unpack (iris, == (:Species ), colname -> true , rng= 123 );
69
- @load NeuralNetworkClassifier
73
+ NeuralNetworkClassifier = @load NeuralNetworkClassifier
70
74
71
75
julia> clf = NeuralNetworkClassifier ()
72
76
NeuralNetworkClassifier (
@@ -94,7 +98,7 @@ fit!(mach)
94
98
julia> training_loss = cross_entropy (predict (mach, X), y) |> mean
95
99
0.89526004f0
96
100
97
- # increase learning rate and add iterations:
101
+ # Increasing learning rate and adding iterations:
98
102
clf. optimiser. eta = clf. optimiser. eta * 2
99
103
clf. epochs = clf. epochs + 5
100
104
@@ -135,7 +139,7 @@ plot(curve.parameter_values,
135
139
136
140
```
137
141
138
- 
142
+ 
139
143
140
144
141
145
### Models
@@ -164,6 +168,21 @@ model type | prediction type | `scitype(X) <: _` | `scitype(y) <: _`
164
168
165
169
> Table 1. Input and output types for MLJFlux models
166
170
171
+ ### Training on a GPU
172
+
173
+ When instantiating a model for training on a GPU, specify
174
+ `acceleration=CUDALibs()`, as in
175
+
176
+ ``` julia
177
+ using MLJ
178
+ ImageClassifier = @load ImageClassifier
179
+ clf = ImageClassifier (epochs= 10 , acceleration= CUDALibs ())
180
+ ```
181
+
182
+ At present, data bound to a MLJ model in an MLJ machine is
183
+ automatically moved on and off the GPU under the hood.
184
+
185
+
167
186
#### Non-tabular input
168
187
169
188
Any `AbstractMatrix{<:AbstractFloat}` object `Xmat` can be forced to
262
281
Note here that `n_in` and `n_out` depend on the size of the data (see
263
282
Table 1).
264
283
284
+ For a concrete image classification example, see
285
+ [examples/mnist](examples/mnist).
286
+
265
287
More generally, defining a new builder means defining a new struct
266
288
sub-typing `MLJFlux.Builder` and defining a new `MLJFlux.build` method
267
289
with one of these signatures:
@@ -279,16 +301,13 @@ following conditions:
279
301
- for any `x <: Vector{<:AbstractFloat}` of length `n_in` (for use
280
302
with one of the first three model types); or
281
303
282
- - for any `x <: Array{<:Float32, 3 }` of size
283
- `(W, H, n_channels )`, where `n_in = (W, H)` and ` n_channels` is
284
- 1 or 3 (for use with `ImageClassifier`)
304
+ - for any `x <: Array{<:Float32, 4 }` of size `(W, H, n_channels,
305
+ batch_size )`, where `(W, H) = n_in`, ` n_channels` is 1 or 3, and
306
+ `batch_size` is any integer (for use with `ImageClassifier`)
285
307
286
308
- The object returned by `chain(x)` must be an `AbstractFloat` vector
287
309
of length `n_out`.
288
310
289
- For an builder example for use with `ImageClassifier` see
290
- [below](an-image-classification-example).
291
-
292
311
293
312
### Loss functions
294
313
@@ -311,6 +330,9 @@ you *should* use MLJ loss functions in MLJ meta-algorithms.
311
330
312
331
### An image classification example
313
332
333
+ An expanded version of this example, with early stopping, is available
334
+ [here](/examples/mnist).
335
+
314
336
We define a builder that builds a chain with six alternating
315
337
convolution and max-pool layers, and a final dense layer, which we
316
338
apply to the MNIST image dataset.
@@ -328,7 +350,7 @@ function flatten(x::AbstractArray)
328
350
end
329
351
330
352
import MLJFlux
331
- mutable struct MyConvBuilder <: MLJFlux.Builder
353
+ mutable struct MyConvBuilder
332
354
filter_size:: Int
333
355
channels1:: Int
334
356
channels2:: Int
@@ -384,7 +406,7 @@ y = coerce(y, Multiclass);
384
406
Instantiating an image classifier model:
385
407
386
408
``` julia
387
- @load ImageClassifier
409
+ ImageClassifier = @load ImageClassifier
388
410
clf = ImageClassifier (builder= MyConvBuilder (3 , 16 , 32 , 32 ),
389
411
epochs= 10 ,
390
412
loss= Flux. crossentropy)
0 commit comments