Skip to content

Commit 624b06b

Browse files
committed
Add MinMax and Standard Scalers with test
1 parent 97f2ac5 commit 624b06b

File tree

5 files changed

+424
-0
lines changed

5 files changed

+424
-0
lines changed
+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/**
2+
* @license
3+
* Copyright 2021, JsData. All rights reserved.
4+
*
5+
* This source code is licensed under the MIT license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ==========================================================================
14+
*/
15+
16+
import { Tensor, tensor1d, tensor2d } from "@tensorflow/tfjs-node"
17+
import { DataFrame, Series } from "danfojs-node"
18+
import { is1DArray } from "../../utils"
19+
20+
/**
21+
* Transform features by scaling each feature to a given range.
22+
* This estimator scales and translates each feature individually such
23+
* that it is in the given range on the training set, e.g. between the maximum and minimum value.
24+
*/
25+
export default class MinMaxScaler {
26+
private $max: Tensor
27+
private $min: Tensor
28+
29+
constructor() {
30+
this.$max = tensor1d([])
31+
this.$min = tensor1d([])
32+
}
33+
34+
private $getTensor(data: number[] | number[][] | Tensor | DataFrame | Series) {
35+
let $tensorArray;
36+
37+
if (data instanceof Array) {
38+
if (is1DArray(data)) {
39+
$tensorArray = tensor1d(data as number[])
40+
} else {
41+
$tensorArray = tensor2d(data)
42+
}
43+
} else if (data instanceof DataFrame) {
44+
$tensorArray = tensor2d(data.values as number[][])
45+
} else if (data instanceof Series) {
46+
$tensorArray = tensor1d(data.values as number[])
47+
} else if (data instanceof Tensor) {
48+
$tensorArray = data
49+
} else {
50+
throw new Error("ParamError: data must be one of Array, DataFrame or Series")
51+
}
52+
return $tensorArray
53+
}
54+
55+
/**
56+
* Fits a MinMaxScaler to the data
57+
* @param data Array, Tensor, DataFrame or Series object
58+
* @returns MinMaxScaler
59+
* @example
60+
* const scaler = new MinMaxScaler()
61+
* scaler.fit([1, 2, 3, 4, 5])
62+
* // MinMaxScaler {
63+
* // $max: [5],
64+
* // $min: [1]
65+
* // }
66+
*
67+
*/
68+
public fit(data: number[] | number[][] | Tensor | DataFrame | Series) {
69+
const tensorArray = this.$getTensor(data)
70+
this.$max = tensorArray.max(0)
71+
this.$min = tensorArray.min(0)
72+
return this
73+
}
74+
75+
/**
76+
* Transform the data using the fitted scaler
77+
* @param data Array, Tensor, DataFrame or Series object
78+
* @returns Array, Tensor, DataFrame or Series object
79+
* @example
80+
* const scaler = new MinMaxScaler()
81+
* scaler.fit([1, 2, 3, 4, 5])
82+
* scaler.transform([1, 2, 3, 4, 5])
83+
* // [0, 0.25, 0.5, 0.75, 1]
84+
* */
85+
public transform(data: number[] | number[][] | Tensor | DataFrame | Series) {
86+
const tensorArray = this.$getTensor(data)
87+
const outputData = tensorArray
88+
.sub(this.$min)
89+
.div(this.$max.sub(this.$min))
90+
91+
if (Array.isArray(data)) {
92+
return outputData.arraySync()
93+
94+
} else if (data instanceof Series) {
95+
return new Series(outputData, {
96+
index: data.index,
97+
});
98+
99+
} else if (data instanceof DataFrame) {
100+
return new DataFrame(outputData, {
101+
index: data.index,
102+
columns: data.columns,
103+
});
104+
} else {
105+
return outputData
106+
}
107+
}
108+
109+
/**
110+
* Fit the data and transform it
111+
* @param data Array, Tensor, DataFrame or Series object
112+
* @returns Array, Tensor, DataFrame or Series object
113+
* @example
114+
* const scaler = new MinMaxScaler()
115+
* scaler.fitTransform([1, 2, 3, 4, 5])
116+
* // [0, 0.25, 0.5, 0.75, 1]
117+
* */
118+
public fitTransform(data: number[] | number[][] | Tensor | DataFrame | Series) {
119+
this.fit(data)
120+
return this.transform(data)
121+
}
122+
123+
/**
124+
* Inverse transform the data using the fitted scaler
125+
* @param data Array, Tensor, DataFrame or Series object
126+
* @returns Array, Tensor, DataFrame or Series object
127+
* @example
128+
* const scaler = new MinMaxScaler()
129+
* scaler.fit([1, 2, 3, 4, 5])
130+
* scaler.inverseTransform([0, 0.25, 0.5, 0.75, 1])
131+
* // [1, 2, 3, 4, 5]
132+
* */
133+
public inverseTransform(data: number[] | number[][] | Tensor | DataFrame | Series) {
134+
const tensorArray = this.$getTensor(data)
135+
const outputData = tensorArray
136+
.mul(this.$max.sub(this.$min))
137+
.add(this.$min)
138+
139+
if (Array.isArray(data)) {
140+
return outputData.arraySync()
141+
142+
} else if (data instanceof Series) {
143+
return new Series(outputData, {
144+
index: data.index,
145+
});
146+
147+
} else if (data instanceof DataFrame) {
148+
return new DataFrame(outputData, {
149+
index: data.index,
150+
columns: data.columns,
151+
});
152+
} else {
153+
return outputData
154+
}
155+
}
156+
157+
}
158+
159+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/**
2+
* @license
3+
* Copyright 2021, JsData. All rights reserved.
4+
*
5+
* This source code is licensed under the MIT license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ==========================================================================
14+
*/
15+
16+
import { tensor1d, Tensor, tensor2d, moments } from "@tensorflow/tfjs-node"
17+
import { DataFrame, Series } from "danfojs-node"
18+
import { is1DArray } from "../../utils"
19+
20+
/**
21+
* Standardize features by removing the mean and scaling to unit variance.
22+
* The standard score of a sample x is calculated as: `z = (x - u) / s`,
23+
* where `u` is the mean of the training samples, and `s` is the standard deviation of the training samples.
24+
*/
25+
export default class StandardScaler {
26+
private $std: Tensor
27+
private $mean: Tensor
28+
29+
constructor() {
30+
this.$std = tensor1d([])
31+
this.$mean = tensor1d([])
32+
}
33+
34+
private $getTensor(data: number[] | number[][] | Tensor | DataFrame | Series) {
35+
let $tensorArray;
36+
37+
if (data instanceof Array) {
38+
if (is1DArray(data)) {
39+
$tensorArray = tensor1d(data as number[])
40+
} else {
41+
$tensorArray = tensor2d(data)
42+
}
43+
} else if (data instanceof DataFrame) {
44+
$tensorArray = tensor2d(data.values as number[][])
45+
} else if (data instanceof Series) {
46+
$tensorArray = tensor1d(data.values as number[])
47+
} else if (data instanceof Tensor) {
48+
$tensorArray = data
49+
} else {
50+
throw new Error("ParamError: data must be one of Array, DataFrame or Series")
51+
}
52+
return $tensorArray
53+
}
54+
/**
55+
* Fit a StandardScaler to the data.
56+
* @param data Array, Tensor, DataFrame or Series object
57+
* @returns StandardScaler
58+
* @example
59+
* const scaler = new StandardScaler()
60+
* scaler.fit([1, 2, 3, 4, 5])
61+
*/
62+
public fit(data: number[] | number[][] | Tensor | DataFrame | Series) {
63+
const tensorArray = this.$getTensor(data)
64+
this.$std = moments(tensorArray, 0).variance.sqrt();
65+
this.$mean = tensorArray.mean(0);
66+
return this
67+
}
68+
69+
/**
70+
* Transform the data using the fitted scaler
71+
* @param data Array, Tensor, DataFrame or Series object
72+
* @returns Array, Tensor, DataFrame or Series object
73+
* @example
74+
* const scaler = new StandardScaler()
75+
* scaler.fit([1, 2, 3, 4, 5])
76+
* scaler.transform([1, 2, 3, 4, 5])
77+
* // [0.0, 0.0, 0.0, 0.0, 0.0]
78+
* */
79+
public transform(data: number[] | number[][] | Tensor | DataFrame | Series) {
80+
const tensorArray = this.$getTensor(data)
81+
const outputData = tensorArray.sub(this.$mean).div(this.$std)
82+
83+
if (Array.isArray(data)) {
84+
return outputData.arraySync()
85+
86+
} else if (data instanceof Series) {
87+
return new Series(outputData, {
88+
index: data.index,
89+
});
90+
91+
} else if (data instanceof DataFrame) {
92+
return new DataFrame(outputData, {
93+
index: data.index,
94+
columns: data.columns,
95+
});
96+
} else {
97+
return outputData
98+
}
99+
}
100+
101+
/**
102+
* Fit and transform the data using the fitted scaler
103+
* @param data Array, Tensor, DataFrame or Series object
104+
* @returns Array, Tensor, DataFrame or Series object
105+
* @example
106+
* const scaler = new StandardScaler()
107+
* scaler.fit([1, 2, 3, 4, 5])
108+
* scaler.fitTransform([1, 2, 3, 4, 5])
109+
* // [0.0, 0.0, 0.0, 0.0, 0.0]
110+
* */
111+
public fitTransform(data: number[] | number[][] | Tensor | DataFrame | Series) {
112+
this.fit(data)
113+
return this.transform(data)
114+
}
115+
116+
/**
117+
* Inverse transform the data using the fitted scaler
118+
* @param data Array, Tensor, DataFrame or Series object
119+
* @returns Array, Tensor, DataFrame or Series object
120+
* @example
121+
* const scaler = new StandardScaler()
122+
* scaler.fit([1, 2, 3, 4, 5])
123+
* scaler.transform([1, 2, 3, 4, 5])
124+
* // [0.0, 0.0, 0.0, 0.0, 0.0]
125+
* scaler.inverseTransform([0.0, 0.0, 0.0, 0.0, 0.0])
126+
* // [1, 2, 3, 4, 5]
127+
* */
128+
public inverseTransform(data: number[] | number[][] | Tensor | DataFrame | Series) {
129+
const tensorArray = this.$getTensor(data)
130+
const outputData = tensorArray.mul(this.$std).add(this.$mean)
131+
132+
if (Array.isArray(data)) {
133+
return outputData.arraySync()
134+
135+
} else if (data instanceof Series) {
136+
return new Series(outputData, {
137+
index: data.index,
138+
});
139+
140+
} else if (data instanceof DataFrame) {
141+
return new DataFrame(outputData, {
142+
index: data.index,
143+
columns: data.columns,
144+
});
145+
} else {
146+
return outputData
147+
}
148+
}
149+
}
150+
151+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import { assert } from "chai";
2+
import { MinMaxScaler } from "../../../dist";
3+
import { Series, DataFrame } from "danfojs-node"
4+
5+
describe("MinMaxscaler", function () {
6+
7+
it("Standardize values in a DataFrame using a MinMaxScaler", function () {
8+
const data = [[-1, 2], [-0.5, 6], [0, 10], [1, 18]];
9+
const scaler = new MinMaxScaler();
10+
11+
const expected = [[0, 0], [0.25, 0.25], [0.5, 0.5], [1, 1]];
12+
const transformedData = [[1.5, 0.]];
13+
14+
scaler.fit(new DataFrame(data));
15+
const resultDf = scaler.transform(new DataFrame(data)) as DataFrame;
16+
assert.deepEqual(resultDf.values, expected);
17+
assert.deepEqual(scaler.transform([[2, 2]]) as any, transformedData);
18+
});
19+
it("fitTransform using a MinMaxScaler", function () {
20+
const data = [[-1, 2], [-0.5, 6], [0, 10], [1, 18]];
21+
const scaler = new MinMaxScaler();
22+
const resultDf = scaler.fitTransform(new DataFrame(data)) as DataFrame;
23+
24+
const expected = [[0, 0], [0.25, 0.25], [0.5, 0.5], [1, 1]];
25+
assert.deepEqual(resultDf.values, expected);
26+
});
27+
it("InverseTransform with MinMaxScaler", function () {
28+
const scaler = new MinMaxScaler();
29+
scaler.fit([1, 2, 3, 4, 5])
30+
const resultTransform = scaler.transform([1, 2, 3, 4, 5])
31+
const resultInverse = scaler.inverseTransform([0, 0.25, 0.5, 0.75, 1])
32+
33+
assert.deepEqual(resultTransform, [0, 0.25, 0.5, 0.75, 1]);
34+
assert.deepEqual([1, 2, 3, 4, 5], resultInverse);
35+
});
36+
it("Index and columns are kept after transformation", function () {
37+
const data = [[-1, 2], [-0.5, 6], [0, 10], [1, 18]];
38+
const df = new DataFrame(data, { index: [1, 2, 3, 4], columns: ["a", "b"] });
39+
40+
const scaler = new MinMaxScaler();
41+
scaler.fit(df);
42+
const resultDf = scaler.transform(df) as DataFrame
43+
44+
assert.deepEqual(resultDf.index, [1, 2, 3, 4]);
45+
assert.deepEqual(resultDf.columns, ["a", "b"]);
46+
});
47+
it("Standardize values in a Series using a MinMaxScaler", function () {
48+
const data = [-1, 2, -0.5, 60, 101, 18];
49+
const scaler = new MinMaxScaler();
50+
const result = [0, 0.029411764815449715, 0.0049019609577953815, 0.5980392098426819, 1, 0.18627451360225677];
51+
const transformedData = [0.029411764815449715, 0.029411764815449715];
52+
scaler.fit(new Series(data))
53+
assert.deepEqual((scaler.transform(new Series(data)) as Series).values, result);
54+
assert.deepEqual(scaler.transform([2, 2]), transformedData);
55+
});
56+
});

0 commit comments

Comments
 (0)