Skip to content

Commit 2ddcad9

Browse files
committed
feat: custom modelfitargs for linear models
1 parent 8506540 commit 2ddcad9

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,5 @@ dist
107107

108108
# IDE Files
109109
.vscode/
110-
.idea/
110+
.idea/
111+
.dccache

src/linear_model/LinearRegression.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import { SGDRegressor } from './SgdRegressor'
1717
import { getBackend } from '../tf-singleton'
18+
import { ModelFitArgs } from '../types'
1819

1920
/**
2021
* LinearRegression implementation using gradient descent
@@ -39,6 +40,8 @@ export interface LinearRegressionParams {
3940
* **default = true**
4041
*/
4142
fitIntercept?: boolean
43+
modelFitOptions?: Partial<ModelFitArgs>
44+
4245
}
4346

4447
/*
@@ -66,7 +69,7 @@ Next steps:
6669
* ```
6770
*/
6871
export class LinearRegression extends SGDRegressor {
69-
constructor({ fitIntercept = true }: LinearRegressionParams = {}) {
72+
constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) {
7073
let tf = getBackend()
7174
super({
7275
modelCompileArgs: {
@@ -80,7 +83,8 @@ export class LinearRegression extends SGDRegressor {
8083
verbose: 0,
8184
callbacks: [
8285
tf.callbacks.earlyStopping({ monitor: 'mse', patience: 30 })
83-
]
86+
],
87+
...modelFitOptions
8488
},
8589
denseLayerArgs: {
8690
units: 1,

src/linear_model/LogisticRegression.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import { SGDClassifier } from './SgdClassifier'
1717
import { getBackend } from '../tf-singleton'
18+
import { ModelFitArgs } from '../types'
1819

1920
// First pass at a LogisticRegression implementation using gradient descent
2021
// Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
@@ -35,6 +36,7 @@ export interface LogisticRegressionParams {
3536
C?: number
3637
/** Whether or not the intercept should be estimator not. **default = true** */
3738
fitIntercept?: boolean
39+
modelFitOptions?: Partial<ModelFitArgs>
3840
}
3941

4042
/** Builds a linear classification model with associated penalty and regularization
@@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier {
6365
constructor({
6466
penalty = 'l2',
6567
C = 1,
66-
fitIntercept = true
68+
fitIntercept = true,
69+
modelFitOptions
6770
}: LogisticRegressionParams = {}) {
6871
// Assume Binary classification
6972
// If we call fit, and it isn't binary then update args
@@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier {
8083
verbose: 0,
8184
callbacks: [
8285
tf.callbacks.earlyStopping({ monitor: 'loss', patience: 50 })
83-
]
86+
],
87+
...modelFitOptions
8488
},
8589
denseLayerArgs: {
8690
units: 1,

0 commit comments

Comments
 (0)