-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathGradientBoostTest.scala
53 lines (45 loc) · 1.89 KB
/
GradientBoostTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// Wei Chen - Gradient Boost Test
// 2022-12-27
import com.scalaml.TestData._
import com.scalaml.general.MatrixFunc._
import com.scalaml.algorithm._
import org.scalatest.funsuite.AnyFunSuite
class GradientBoostSuite extends AnyFunSuite {
val gb = new GradientBoost()
test("GradientBoost Test : Clear") {
assert(gb.clear())
}
test("GradientBoost Test : Linear Data") {
assert(gb.clear())
assert(gb.config(Map[String, Any]()))
assert(gb.train(LABELED_LINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = gb.predict(UNLABELED_LINEAR_DATA)
val nResult = result.map(v => if (v > 0) 1.0 else -1.0)
assert(arraysimilar(nResult, LABEL_LINEAR_DATA.map(_.toDouble), 0.9))
}
test("GradientBoost Test : Nonlinear Data, 1 Model - WRONG") {
assert(gb.clear())
assert(gb.config(Map[String, Any]()))
assert(gb.train(LABELED_NONLINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = gb.predict(UNLABELED_NONLINEAR_DATA)
assert(!arraysimilar(result, LABEL_NONLINEAR_DATA.map(_.toDouble), 0.45))
}
test("GradientBoost Test : Nonlinear Data, 4 Models - WRONG") {
val regressors: Any = Array(
new MultipleLinearRegression,
new MultivariateLinearRegression,
new StochasticGradientDecent,
new RegressionTree
)
assert(gb.clear())
assert(gb.config(Map("regressors" -> regressors)))
assert(gb.train(LABELED_NONLINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = gb.predict(UNLABELED_NONLINEAR_DATA)
assert(!arraysimilar(result, LABEL_NONLINEAR_DATA.map(_.toDouble), 0.45))
}
test("GradientBoost Test : Invalid Config & Data") {
assert(gb.clear())
assert(!gb.config(Map("regressors" -> "test")))
assert(!gb.train(Array((1, Array(1, 2)), (1, Array()))))
}
}