-
-
Notifications
You must be signed in to change notification settings - Fork 405
/
Copy pathtuneThreshold.R
76 lines (70 loc) · 2.79 KB
/
tuneThreshold.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#' @title Tune prediction threshold.
#'
#' @description
#' Optimizes the threshold of predictions based on probabilities.
#' Works for classification and multilabel tasks.
#' Uses [BBmisc::optimizeSubInts] for normal binary class problems and
#' [GenSA::GenSA] for multiclass and multilabel problems.
#'
#' @template arg_pred
#' @param measure ([Measure])\cr
#' Performance measure to optimize.
#' Default is the default measure for the task.
#' @param task ([Task])\cr
#' Learning task. Rarely neeeded,
#' only when required for the performance measure.
#' @param model ([WrappedModel])\cr
#' Fitted model. Rarely neeeded,
#' only when required for the performance measure.
#' @param nsub (`integer(1)`)\cr
#' Passed to [BBmisc::optimizeSubInts] for 2class problems.
#' Default is 20.
#' @param control ([list])\cr
#' Control object for [GenSA::GenSA] when used.
#' Default is empty list.
#' @return ([list]). A named list with with the following components:
#' `th` is the optimal threshold, `perf` the performance value.
#' @family tune
#' @export
tuneThreshold = function(pred, measure, task, model, nsub = 20L, control = list()) {
checkPrediction(pred, task.type = c("classif", "multilabel"), predict.type = "prob")
td = pred$task.desc
ttype = td$type
measure = checkMeasures(measure, td)[[1L]]
if (!missing(task)) {
assertClass(task, classes = "SupervisedTask")
}
if (!missing(model)) {
assertClass(model, classes = "WrappedModel")
}
assertList(control)
probs = getPredictionProbabilities(pred)
# brutally return NA if we find any NA in the predicted probs...
if (anyMissing(probs)) {
return(list(th = NA, pred = pred, th.seq = numeric(0), perf = numeric(0)))
}
cls = pred$task.desc$class.levels
k = length(cls)
fitn = function(x) {
if (ttype == "multilabel" || k > 2) {
names(x) = cls
}
ifelse(measure$minimize, 1, -1) * performance(setThreshold(pred, x), measure, task, model, simpleaggr = TRUE) # always a minimization
}
if (ttype == "multilabel" || k > 2L) {
requirePackages("GenSA", why = "tuneThreshold", default.method = "load")
start = rep(1 / k, k)
ctrl = list(smooth = FALSE, simple.function = TRUE, max.call = 3000L, temperature = 250,
visiting.param = 2.5, acceptance.param = -15)
or = GenSA::GenSA(par = start, fn = fitn, lower = rep(0, k),
upper = rep(1, k), control = ctrl)
th = or$par / sum(or$par)
names(th) = cls
perf = or$value
} else { # classif with k = 2
or = optimizeSubInts(f = fitn, lower = 0, upper = 1, maximum = FALSE, nsub = nsub) # maximum = false, because callback makes it a minimization
th = or[[1]]
perf = ifelse(measure$minimize, 1, -1) * or$objective # flip sign if minimization for negative performance measure was done
}
return(list(th = th, perf = perf))
}