File tree 3 files changed +14
-5
lines changed
3 files changed +14
-5
lines changed Original file line number Diff line number Diff line change 107
107
108
108
# IDE Files
109
109
.vscode /
110
- .idea /
110
+ .idea /
111
+ .dccache
Original file line number Diff line number Diff line change 15
15
16
16
import { SGDRegressor } from './SgdRegressor'
17
17
import { getBackend } from '../tf-singleton'
18
+ import { ModelFitArgs } from '../types'
18
19
19
20
/**
20
21
* LinearRegression implementation using gradient descent
@@ -39,6 +40,8 @@ export interface LinearRegressionParams {
39
40
* **default = true**
40
41
*/
41
42
fitIntercept ?: boolean
43
+ modelFitOptions ?: Partial < ModelFitArgs >
44
+
42
45
}
43
46
44
47
/*
@@ -66,7 +69,7 @@ Next steps:
66
69
* ```
67
70
*/
68
71
export class LinearRegression extends SGDRegressor {
69
- constructor ( { fitIntercept = true } : LinearRegressionParams = { } ) {
72
+ constructor ( { fitIntercept = true , modelFitOptions } : LinearRegressionParams = { } ) {
70
73
let tf = getBackend ( )
71
74
super ( {
72
75
modelCompileArgs : {
@@ -80,7 +83,8 @@ export class LinearRegression extends SGDRegressor {
80
83
verbose : 0 ,
81
84
callbacks : [
82
85
tf . callbacks . earlyStopping ( { monitor : 'mse' , patience : 30 } )
83
- ]
86
+ ] ,
87
+ ...modelFitOptions
84
88
} ,
85
89
denseLayerArgs : {
86
90
units : 1 ,
Original file line number Diff line number Diff line change 15
15
16
16
import { SGDClassifier } from './SgdClassifier'
17
17
import { getBackend } from '../tf-singleton'
18
+ import { ModelFitArgs } from '../types'
18
19
19
20
// First pass at a LogisticRegression implementation using gradient descent
20
21
// Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
@@ -35,6 +36,7 @@ export interface LogisticRegressionParams {
35
36
C ?: number
36
37
/** Whether or not the intercept should be estimator not. **default = true** */
37
38
fitIntercept ?: boolean
39
+ modelFitOptions ?: Partial < ModelFitArgs >
38
40
}
39
41
40
42
/** Builds a linear classification model with associated penalty and regularization
@@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier {
63
65
constructor ( {
64
66
penalty = 'l2' ,
65
67
C = 1 ,
66
- fitIntercept = true
68
+ fitIntercept = true ,
69
+ modelFitOptions
67
70
} : LogisticRegressionParams = { } ) {
68
71
// Assume Binary classification
69
72
// If we call fit, and it isn't binary then update args
@@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier {
80
83
verbose : 0 ,
81
84
callbacks : [
82
85
tf . callbacks . earlyStopping ( { monitor : 'loss' , patience : 50 } )
83
- ]
86
+ ] ,
87
+ ...modelFitOptions
84
88
} ,
85
89
denseLayerArgs : {
86
90
units : 1 ,
You can’t perform that action at this time.
0 commit comments