@@ -19,21 +19,22 @@ import {
19
19
Tensor1D ,
20
20
Tensor2D ,
21
21
tensor1d ,
22
- tensor2d ,
22
+ tensor2d
23
23
} from '@tensorflow/tfjs-core'
24
24
import {
25
25
layers ,
26
26
sequential ,
27
27
Sequential ,
28
28
ModelFitArgs ,
29
- ModelCompileArgs ,
29
+ ModelCompileArgs
30
30
} from '@tensorflow/tfjs-layers'
31
31
import { DenseLayerArgs } from '@tensorflow/tfjs-layers/dist/layers/core'
32
32
import {
33
33
convertToNumericTensor1D_2D ,
34
- convertToNumericTensor2D ,
34
+ convertToNumericTensor2D
35
35
} from '../utils'
36
36
import { Scikit2D , ScikitVecOrMatrix } from '../types'
37
+ import { PredictorMixin } from '../mixins'
37
38
/**
38
39
* SGD is a thin Wrapper around Tensorflow's model api with a single dense layer.
39
40
* With this base class and different error functions / regularizers we can
@@ -71,12 +72,12 @@ export interface SGDParams {
71
72
modelFitArgs : ModelFitArgs
72
73
73
74
/**
74
- * The arguments for a single dense layer in tensorflow. This also defaults to
75
+ * The arguments for a single dense layer in tensorflow. This also defaults to
75
76
* different settings based on the regressor / classifier. An example dense layer
76
77
* might look like.
77
78
* const model = sequential()
78
79
model.add(
79
- layers.dense({ inputShape: [100],
80
+ layers.dense({ inputShape: [100],
80
81
units: 1,
81
82
useBias: true,
82
83
})
@@ -85,13 +86,14 @@ export interface SGDParams {
85
86
denseLayerArgs : DenseLayerArgs
86
87
}
87
88
88
- export class SGD {
89
+ export class SGD extends PredictorMixin {
89
90
model : Sequential
90
91
modelFitArgs : ModelFitArgs
91
92
modelCompileArgs : ModelCompileArgs
92
93
denseLayerArgs : DenseLayerArgs
93
94
94
95
constructor ( params : SGDParams ) {
96
+ super ( )
95
97
this . model = sequential ( )
96
98
this . modelFitArgs = params . modelFitArgs
97
99
this . modelCompileArgs = params . modelCompileArgs
@@ -187,7 +189,7 @@ export class SGD {
187
189
let myIntercept = tensor1d ( [ params . intercept_ ] , 'float32' )
188
190
this . initializeModel ( myCoef . shape , myIntercept . shape , [
189
191
myCoef ,
190
- myIntercept ,
192
+ myIntercept
191
193
] )
192
194
return this
193
195
}
@@ -203,7 +205,7 @@ export class SGD {
203
205
*
204
206
* lr = new LinearRegression()
205
207
* lr.getParams()
206
- * // =>
208
+ * // =>
207
209
{
208
210
modelCompileArgs: {
209
211
optimizer: train.adam(0.1),
@@ -227,7 +229,7 @@ export class SGD {
227
229
return {
228
230
modelFitArgs : this . modelFitArgs ,
229
231
modelCompileArgs : this . modelCompileArgs ,
230
- denseLayerArgs : this . denseLayerArgs ,
232
+ denseLayerArgs : this . denseLayerArgs
231
233
}
232
234
}
233
235
@@ -303,7 +305,7 @@ export class SGD {
303
305
* await lr.fit(X, [1,2,3]);
304
306
* lr.coef_
305
307
* // => tensor1d([[ 1.2, 3.3, 1.1, 0.2 ]])
306
- *
308
+ *
307
309
* await lr.fit(X, [ [1,2], [3,4], [5,6] ]);
308
310
* lr.coef_
309
311
* // => tensor2d([ [1.2, 3.3], [3.4, 5.6], [4.5, 6.7] ])
0 commit comments