Skip to content

Commit 4fa3bd3

Browse files
committed
UpperConfidenceBound
1 parent 92b681a commit 4fa3bd3

File tree

5 files changed

+98
-4
lines changed

5 files changed

+98
-4
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ A very light weight Scala machine learning library that provide some basic ML al
8888

8989
- [x] Epsilon Greedy Search [[Code]](src/main/scala/algorithm/optimization/EpsilonGreedy.scala) [[Usage]](src/test/scala/algorithm/optimization/EpsilonGreedyTest.scala)
9090

91+
- [x] Upper Confidence Bound [[Code]](src/main/scala/algorithm/optimization/UpperConfidenceBound.scala) [[Usage]](src/test/scala/algorithm/optimization/UpperConfidenceBoundTest.scala)
92+
9193
### Reinforcement Learning :
9294

9395
- [x] Naive Feedback [[Code]](src/main/scala/algorithm/reinforcement/NaiveFeedback.scala) [[Usage]](src/test/scala/algorithm/reinforcement/NaiveFeedbackTest.scala)

src/main/scala/algorithm/optimization/EpsilonGreedy.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class EpsilonGreedy {
1616
if (scores != null)
1717
currentScores = scores
1818
if (currentScores == null)
19-
currentScores = new Array[Double](size)
19+
currentScores = Array.fill[Double](size)(Double.MinValue)
2020
if (math.random < epsilon) {
2121
val randSelect = (math.random * size).toInt
2222
val value = evaluation(choices(randSelect))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Wei Chen - Upper Confidence Bound
2+
// 2020-03-08
3+
4+
package com.scalaml.algorithm
5+
6+
class UpperConfidenceBound {
7+
var currentStats: Array[(Double, Int)] = null
8+
9+
def select(c: Double): Int = {
10+
val n = currentStats.count(_._2 > 0)
11+
val currentScores = currentStats.map { case (m, kn) =>
12+
m + c * math.sqrt(math.log(n + 1) / (kn + 1e-12))
13+
}
14+
currentScores.indexOf(currentScores.max)
15+
}
16+
17+
def add(i: Int, value: Double) {
18+
val (currentValue, currentCount) = currentStats(i)
19+
val newValue = (currentValue * currentCount + value) / (currentCount + 1)
20+
currentStats(i) = (newValue, currentCount + 1)
21+
}
22+
23+
def search(
24+
evaluation: Array[Double] => Double,
25+
choices: Array[Array[Double]],
26+
scores: Array[(Double, Int)] = null,
27+
c: Double = 1
28+
): Array[Double] = {
29+
val size = choices.size
30+
if (scores != null)
31+
currentStats = scores
32+
if (currentStats == null)
33+
currentStats = Array.fill[(Double, Int)](size)((0, 0))
34+
val currentSelect = select(c)
35+
val value = evaluation(choices(currentSelect))
36+
add(currentSelect, value)
37+
choices(currentStats.indexOf(currentStats.maxBy(_._1)))
38+
}
39+
}

src/test/scala/algorithm/optimization/EpsilonGreedyTest.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ class EpsilonGreedySuite extends AnyFunSuite {
2020
)
2121
val epsilon: Double = 0.1
2222

23-
test("GeneAlgorithm Test : Initial") {
23+
test("EpsilonGreedy Test : Initial") {
2424
assert(eg.currentScores == null)
2525
}
2626

27-
test("GeneAlgorithm Test : Search - Start") {
27+
test("EpsilonGreedy Test : Search - Start") {
2828
for (i <- 0 until 1000)
2929
eg.search(evaluation, choices, null, epsilon)
3030
assert(eg.currentScores.size == choices.size)
@@ -33,7 +33,7 @@ class EpsilonGreedySuite extends AnyFunSuite {
3333
assert((best.head - 0.7).abs < 0.05)
3434
}
3535

36-
test("GeneAlgorithm Test : Search - Continue") {
36+
test("EpsilonGreedy Test : Search - Continue") {
3737
var scores: Array[Double] = Array(0, 0, 1 / 1.3, 0)
3838
for (i <- 0 until 1000) {
3939
eg.search(evaluation, choices, scores, epsilon)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Wei Chen - Upper Confidence Bound Test
2+
// 2020-03-08
3+
4+
import com.scalaml.general.MatrixFunc._
5+
import com.scalaml.algorithm.UpperConfidenceBound
6+
import org.scalatest.funsuite.AnyFunSuite
7+
8+
class UpperConfidenceBoundSuite extends AnyFunSuite {
9+
10+
val ucb = new UpperConfidenceBound()
11+
12+
13+
def evaluation(arr: Array[Double]): Double = 1 / ((arr.head - 0.7).abs + 1)
14+
15+
val choices: Array[Array[Double]] = Array(
16+
Array(0.7),
17+
Array(0.8),
18+
Array(1.0),
19+
Array(0.5)
20+
)
21+
val c: Double = 1
22+
23+
test("UpperConfidenceBound Test : Initial") {
24+
assert(ucb.currentStats == null)
25+
}
26+
27+
test("UpperConfidenceBound Test : Search - Start") {
28+
for (i <- 0 until 100)
29+
ucb.search(evaluation, choices, null, c)
30+
assert(ucb.currentStats.size == choices.size)
31+
32+
val best = ucb.search(evaluation, choices, null, c)
33+
assert((best.head - 0.7).abs < 0.05)
34+
}
35+
36+
test("UpperConfidenceBound Test : Search - Continue") {
37+
var stats: Array[(Double, Int)] = Array(
38+
(0, 0),
39+
(0, 0),
40+
(1 / 1.3, 1),
41+
(0, 0)
42+
)
43+
for (i <- 0 until 100) {
44+
ucb.search(evaluation, choices, stats, c)
45+
stats = ucb.currentStats
46+
}
47+
assert(ucb.currentStats.size == stats.size)
48+
49+
val best = ucb.search(evaluation, choices, stats, c)
50+
assert((best.head - 0.7).abs < 0.05)
51+
}
52+
53+
}

0 commit comments

Comments
 (0)