-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathGradientBoost.scala
41 lines (34 loc) · 1.33 KB
/
GradientBoost.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// Wei Chen - Gradient Boost
// 2022-12-27
package com.scalaml.algorithm
import com.scalaml.general.MatrixFunc._
class GradientBoost() extends Regression {
val algoname: String = "GradientBoost"
val version: String = "0.1"
var regressors = Array[Regression]()
override def clear(): Boolean = {
regressors = Array[Regression]()
true
}
override def config(paras: Map[String, Any]): Boolean = try {
regressors = paras.getOrElse("REGRESSORS", paras.getOrElse("regressors", Array(new StochasticGradientDecent): Any)).asInstanceOf[Array[Regression]]
true
} catch { case e: Exception =>
Console.err.println(e)
false
}
override def train(data: Array[(Double, Array[Double])]): Boolean = {
var check = regressors.size > 0
var residue = Array.fill(data.size)(0.0)
for (regressor <- regressors) {
val tmpdata = data.zip(residue).map { case (d, r) => (d._1 + r, d._2) }
check &= regressor.train(tmpdata)
residue = arrayminus(data.map(_._1), regressor.predict(data.map(_._2)))
}
check
}
override def predict(data: Array[Array[Double]]): Array[Double] = {
val results = regressors.map(regressor => regressor.predict(data))
matrixaccumulate(results).map(_ / regressors.size)
}
}