-
-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathTask.R
1477 lines (1299 loc) · 55.1 KB
/
Task.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#' @title Task Class
#'
#' @include mlr_reflections.R
#' @include warn_deprecated.R
#'
#' @description
#' This is the abstract base class for [TaskSupervised] and [TaskUnsupervised].
#' [TaskClassif] and [TaskRegr] inherit from [TaskSupervised].
#' More supervised tasks are implemented in \CRANpkg{mlr3proba}, unsupervised cluster tasks
#' in package \CRANpkg{mlr3cluster}.
#'
#' Tasks serve two purposes:
#'
#' 1. Tasks wrap a [DataBackend], an object to transparently interface different data storage types.
#' 2. Tasks store meta-information, such as the role of the individual columns in the [DataBackend].
#' For example, for a classification task a single column must be marked as target column, and others as features.
#'
#' Predefined (toy) tasks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_tasks],
#' e.g. [`penguins`][mlr_tasks_penguins] or [`california_housing`][mlr_tasks_california_housing].
#' More toy tasks can be found in the dictionary after loading \CRANpkg{mlr3data}.
#'
#' @template param_id
#' @template param_backend
#' @template param_task_type
#' @template param_rows
#' @template param_cols
#' @template param_data_format
#' @template param_label
#' @template param_extra_args
#'
#' @section S3 methods:
#' * `as.data.table(t)`\cr
#' [Task] -> [data.table::data.table()]\cr
#' Returns the complete data as [data.table::data.table()].
#' * `head(t)`\cr
#' Calls [head()] on the task's data.
#' * `summary(t)`\cr
#' Calls [summary()] on the task's data.
#'
#'
#' @section Task mutators:
#' The following methods change the task in-place:
#' * Any modification of the lists `$col_roles` or `$row_roles`.
#' This provides a different "view" on the data without altering the data itself.
#' This may affects, e.g., `$data`, `$nrow`, `$ncol`, `n_features`, `row_ids`, and `$feature_names`.
#' Altering `$col_roles` may affect, e.g., `$data`, `$ncol`, `$n_features`, and `$feature_names`.
#' Altering `$row_roles` may affect, e.g., `$data`, `$nrow`, and `$row_ids`.
#' * Modification of column or row roles via `$set_col_roles()` or `$set_row_roles()`, respectively.
#' They are an alternative to directly accessing `$col_roles` or `$row_roles`, with the same side effects.
#' * `$select()` and `$filter()` subset the set of active features or rows in `$col_roles` or `$row_roles`, respectively.
#' * `$cbind()` and `$rbind()` change the task in-place by binding new columns or rows to the data.
#' * `$rename()` changes column names.
#' * `$set_levels()` and `$droplevels()` update the field `$col_info()` to automatically repair factor levels while querying data with `$data()`.
#'
#' @template seealso_task
#' @concept Task
#' @export
#' @examples
#' # We use the inherited class TaskClassif here,
#' # because the base class `Task` is not intended for direct use
#' task = TaskClassif$new("penguings", palmerpenguins::penguins, target = "species")
#'
#' task$nrow
#' task$ncol
#' task$feature_names
#' task$formula()
#'
#' # de-select "year"
#' task$select(setdiff(task$feature_names, "year"))
#'
#' task$feature_names
#'
#' # Add new column "foo"
#' task$cbind(data.frame(foo = 1:344))
#' head(task)
Task = R6Class("Task",
public = list(
#' @template field_label
label = NA_character_,
#' @template field_task_type
task_type = NULL,
#' @field backend ([DataBackend])\cr
#' Abstract interface to the data of the task.
backend = NULL,
#' @field col_info ([data.table::data.table()])\cr
#' Table with with 4 columns, mainly for internal purposes:
#' - `"id"` (`character()`) stores the name of the column.
#' - `"type"` (`character()`) holds the storage type of the variable, e.g. `integer`, `numeric` or `character`.
#' See [mlr_reflections$task_feature_types][mlr_reflections] for a complete list of allowed types.
#' - `"levels"` (`list()`) stores a vector of distinct values (levels) for ordered and unordered factor variables.
#' - `"label"` (`character()`) stores a vector of prettier, formated column names.
#' - `"fix_factor_levels"` (`logical()`) stores flags which determine if the levels of the respective variable
#' need to be reordered after querying the data from the [DataBackend].
#'
#' Note that all columns of the [DataBackend], also columns which are not selected or have any role, are listed
#' in this table.
col_info = NULL,
#' @template field_man
man = NA_character_,
#' @field extra_args (named `list()`)\cr
#' Additional arguments set during construction.
#' Required for [convert_task()].
extra_args = NULL,
#' @field mlr3_version (`package_version`)\cr
#' Package version of `mlr3` used to create the task.
mlr3_version = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' Note that this object is typically constructed via a derived classes, e.g. [TaskClassif] or [TaskRegr].
initialize = function(id, task_type, backend, label = NA_character_, extra_args = list()) {
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
if (!inherits(backend, "DataBackend")) {
self$backend = as_data_backend(backend)
} else {
self$backend = assert_backend(backend)
}
cn = self$backend$colnames
rn = self$backend$rownames
assert_names(cn, "unique", .var.name = "column names")
if (any(grepl("%", cn, fixed = TRUE))) {
stopf("Column names may not contain special character '%%'")
}
self$col_info = col_info(self$backend)
self$col_info$label = NA_character_
self$col_info$fix_factor_levels = FALSE
assert_subset(self$col_info$type, mlr_reflections$task_feature_types, .var.name = "feature types")
pmap(self$col_info,
function(id, levels, ...) {
assert_character(levels, any.missing = FALSE, min.len = 1L, null.ok = TRUE,
.var.name = sprintf("levels of '%s'", id))
}
)
cn = self$col_info$id # note: this sorts the columns!
private$.row_roles = list(use = rn)
private$.col_roles = named_list(mlr_reflections$task_col_roles[[task_type]], character())
private$.col_roles$feature = setdiff(cn, self$backend$primary_key)
self$extra_args = assert_list(extra_args, names = "unique")
self$mlr3_version = mlr_reflections$package_version
},
#' @description
#' Deprecated.
#'
#' @param ratio (`numeric(1)`)\cr
#' The proportion of datapoints to use as validation data.
#' @param ids (`integer()`)\cr
#' The row ids to use as validation data.
#' @param remove (`logical(1)`)\cr
#' If `TRUE` (default), the `row_ids` are removed from the primary task's active `"use"` rows, ensuring a
#' disjoint split between the train and validation data.
#'
#' @return Modified `Self`.
divide = function(ratio = NULL, ids = NULL, remove = TRUE) {
.Deprecated("field $internal_valid_task")
assert_flag(remove)
private$.hash = NULL
if (!xor(is.null(ratio), is.null(ids))) {
stopf("Provide a ratio or ids to create a validation task, but not both (Task '%s').", self$id)
}
valid_ids = if (!is.null(ratio)) {
assert_numeric(ratio, lower = 0, upper = 1, any.missing = FALSE)
partition(self, ratio = 1 - ratio)$test
} else {
assert_row_ids(ids, null.ok = FALSE)
}
prev_internal_valid = private$.internal_valid_task
if (!is.null(prev_internal_valid)) {
lg$debug("Task %s already had an internal validation task that is being overwritten.", self$id)
# in case something goes wrong
on.exit({private$.internal_valid_task = prev_internal_valid}, add = TRUE)
private$.internal_valid_task = NULL
}
private$.internal_valid_task = self$clone(deep = TRUE)
private$.internal_valid_task$row_roles$use = valid_ids
if (remove) {
self$row_roles$use = setdiff(self$row_roles$use, valid_ids)
}
on.exit({}, add = FALSE)
invisible(self)
},
#' @description
#' Opens the corresponding help page referenced by field `$man`.
help = function() {
open_help(self$man)
},
#' @description
#' Helper for print outputs.
#' @param ... (ignored).
format = function(...) {
sprintf("<%s:%s>", class(self)[1L], self$id)
},
#' @description
#' Printer.
#' @param ... (ignored).
print = function(...) {
catf("%s (%i x %i)%s", format(self), self$nrow, self$ncol,
if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
roles = private$.col_roles
roles = roles[lengths(roles) > 0L]
# print additional columns as specified in reflections
before = mlr_reflections$task_print_col_roles$before
iwalk(before[before %chin% names(roles)], function(role, str) {
catn(str_indent(sprintf("* %s:", str), roles[[role]]))
})
catf(str_indent("* Target:", self$target_names))
catf(str_indent("* Properties:", self$properties))
types = self$feature_types
if (nrow(types)) {
id = type = NULL
catf("* Features (%i):", nrow(types))
types = types[, list(N = .N, feats = str_collapse(id, n = 100L)), by = "type"][, "type" := translate_types(type)]
setorderv(types, "N", order = -1L)
pmap(types, function(type, N, feats) {
catn(str_indent(sprintf(" - %s (%i):", type, N), feats, exdent = 4L))
})
}
# print additional columns are specified in reflections
after = mlr_reflections$task_print_col_roles$after
iwalk(after[after %chin% names(roles)], function(role, str) {
catn(str_indent(sprintf("* %s:", str), roles[[role]]))
})
if (!is.null(private$.internal_valid_task)) {
catf(str_indent("* Validation Task:", sprintf("(%ix%i)", private$.internal_valid_task$nrow, private$.internal_valid_task$ncol)))
}
if (!is.null(self$characteristics)) {
catf(str_indent("* Characteristics: ", as_short_string(self$characteristics)))
}
},
#' @description
#' Returns a slice of the data from the [DataBackend] as a `data.table`.
#' Rows default to observations with role `"use"`, and columns default to features with roles `"target"` or `"feature"`.
#' Rows must be a subset of `$row_ids`.
#' If `rows` or `cols` are specified which do not exist in the [DataBackend], an exception is raised.
#'
#' Rows and columns are returned in the order specified via the arguments `rows` and `cols`.
#' If `rows` is `NULL`, rows are returned in the order of `task$row_ids`.
#' If `cols` is `NULL`, the column order defaults to `c(task$target_names, task$feature_names)`.
#' Note that it is recommended to **not** rely on the order of columns, and instead always address columns with their respective column name.
#'
#' @param ordered (`logical(1)`)\cr
#' If `TRUE`, data is ordered according to the columns with column role `"order"`.
#'
#' @return Depending on the [DataBackend], but usually a [data.table::data.table()].
data = function(rows = NULL, cols = NULL, data_format, ordered = FALSE) {
assert_has_backend(self)
assert_flag(ordered)
if (!missing(data_format)) warn_deprecated("Task$data argument 'data_format'")
row_roles = private$.row_roles
col_roles = private$.col_roles
if (is.null(rows)) {
rows = row_roles$use
} else {
assert_subset(rows, self$row_roles$use)
if (is.double(rows)) {
rows = as.integer(rows)
}
}
if (is.null(cols)) {
query_cols = cols = c(col_roles$target, col_roles$feature)
} else {
assert_subset(cols, self$col_info$id)
query_cols = cols
}
reorder_rows = length(col_roles$order) > 0L && ordered
if (reorder_rows) {
query_cols = union(query_cols, col_roles$order)
}
data = self$backend$data(rows = rows, cols = query_cols)
if (length(query_cols) && nrow(data) != length(rows)) {
stopf("DataBackend did not return the queried rows correctly: %i requested, %i received.
The resampling was probably instantiated on a different task.", length(rows), nrow(data))
}
if (length(rows) && ncol(data) != length(query_cols)) {
stopf("DataBackend did not return the queried cols correctly: %i requested, %i received", length(cols), ncol(data))
}
.__i__ = self$col_info[["fix_factor_levels"]]
if (any(.__i__)) {
fix_factors = self$col_info[.__i__, c("id", "levels"), with = FALSE]
if (nrow(fix_factors)) {
# ordering is slow
if (nrow(fix_factors) > 1L) fix_factors = fix_factors[list(names(data)), on = "id", nomatch = NULL]
data = fix_factor_levels(data, levels = set_names(fix_factors$levels, fix_factors$id))
}
}
if (reorder_rows) {
setorderv(data, col_roles$order)[]
data = remove_named(data, setdiff(col_roles$order, cols))
}
return(data)
},
#' @description
#' Constructs a [formula()], e.g. `[target] ~ [feature_1] + [feature_2] + ... + [feature_k]`,
#' using the features provided in argument `rhs` (defaults to all columns with role `"feature"`, symbolized by `"."`).
#'
#' Note that it is currently not possible to change the formula.
#' However, \CRANpkg{mlr3pipelines} provides a pipe operator interfacing [stats::model.matrix()] for this purpose: `"modelmatrix"`.
#'
#' @param rhs (`character(1)`)\cr
#' Right hand side of the formula. Defaults to `"."` (all features of the task).
#' @return [formula()].
formula = function(rhs = ".") {
formulate(self$target_names, rhs)
},
#' @description
#' Get the first `n` observations with role `"use"` of all columns with role `"target"` or `"feature"`.
#'
#' @param n (`integer(1)`).
#' @return [data.table::data.table()] with `n` rows.
head = function(n = 6L) {
assert_number(n, na.ok = FALSE)
ids = head(private$.row_roles$use, n)
self$data(rows = ids)
},
#' @description
#' Returns the distinct values for columns referenced in `cols` with storage type "factor" or "ordered".
#' Argument `cols` defaults to all such columns with role `"target"` or `"feature"`.
#'
#' Note that this function ignores the row roles, it returns all levels available in the [DataBackend].
#' To update the stored level information, e.g. after subsetting a task with `$filter()`, call `$droplevels()`.
#'
#' @return named `list()`.
levels = function(cols = NULL) {
if (is.null(cols)) {
cols = unlist(private$.col_roles[c("target", "feature")], use.names = FALSE)
cols = self$col_info[get("id") %chin% cols & get("type") %chin% c("factor", "ordered"), "id", with = FALSE][[1L]]
} else {
assert_subset(cols, self$col_info$id)
}
set_names(
fget(self$col_info, cols, "levels", "id"),
cols
)
},
#' @description
#' Returns the number of missing observations for columns referenced in `cols`.
#' Considers only active rows with row role `"use"`.
#' Argument `cols` defaults to all columns with role "target" or "feature".
#'
#' @return Named `integer()`.
missings = function(cols = NULL) {
assert_has_backend(self)
if (is.null(cols)) {
cols = unlist(private$.col_roles[c("target", "feature")], use.names = FALSE)
} else {
assert_subset(cols, self$col_info$id)
}
self$backend$missings(self$row_ids, cols = cols)
},
#' @description
#' Subsets the task, keeping only the rows specified via row ids `rows`.
#'
#' This operation mutates the task in-place.
#' See the section on task mutators for more information.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
filter = function(rows) {
assert_has_backend(self)
rows = assert_row_ids(rows)
private$.row_roles$use = assert_subset(rows, self$row_ids_backend)
private$.row_hash = NULL
private$.hash = NULL
invisible(self)
},
#' @description
#' Subsets the task, keeping only the features specified via column names `cols`.
#' Note that you cannot deselect the target column, for obvious reasons.
#'
#' This operation mutates the task in-place.
#' See the section on task mutators for more information.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
select = function(cols) {
assert_has_backend(self)
assert_character(cols)
assert_subset(cols, private$.col_roles$feature)
private$.hash = NULL
private$.col_hashes = NULL
private$.col_roles$feature = intersect(private$.col_roles$feature, cols)
invisible(self)
},
#' @description
#' Adds additional rows to the [DataBackend] stored in `$backend`.
#' New row ids are automatically created, unless `data` has a column whose name matches
#' the primary key of the [DataBackend] (`task$backend$primary_key`).
#' In case of name clashes of row ids, rows in `data` have higher precedence
#' and virtually overwrite the rows in the [DataBackend].
#'
#' All columns with the roles `"target"`, `"feature"`, `"weight"`, `"group"`, `"stratum"`,
#' and `"order"` must be present in `data`.
#' Columns only present in `data` but not in the [DataBackend] of `task` will be discarded.
#'
#' This operation mutates the task in-place.
#' See the section on task mutators for more information.
#'
#' @param data (`data.frame()`).
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
rbind = function(data) {
assert_has_backend(self)
pk = self$backend$primary_key
rn = self$backend$rownames
pk_in_backend = TRUE
type_check = TRUE
if (is.data.frame(data)) {
pk_in_backend = pk %chin% names(data)
type_check = FALSE # done by auto-converter
keep_cols = intersect(names(data), self$col_info$id)
if (length(keep_cols) == pk_in_backend || nrow(data) == 0L) {
return(invisible(self))
}
if (!pk_in_backend) {
start = if (length(rn)) max(rn) + 1L else 1L
pk = seq(from = start, to = start + nrow(data) - 1L)
}
ci = self$col_info[list(keep_cols), on = "id"]
data = do.call(data.table, Map(auto_convert,
value = as.list(data)[ci$id],
id = ci$id, type = ci$type, levels = ci$levels))
data = as_data_backend(data, primary_key = pk)
} else {
assert_backend(data)
if (data$ncol <= 1L || data$nrow == 0L) {
return(invisible(self))
}
}
if (pk_in_backend && any(data$rownames %in% self$backend$rownames)) {
stopf("Cannot rbind data to task '%s', duplicated row ids", self$id)
}
# columns with these roles must be present in data
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order", "offset")
mandatory_cols = unlist(private$.col_roles[mandatory_roles], use.names = FALSE)
missing_cols = setdiff(mandatory_cols, data$colnames)
if (length(missing_cols)) {
stopf("Cannot rbind data to task '%s', missing the following mandatory columns: %s", self$id, str_collapse(missing_cols))
}
# merge col infos
tab = merge(self$col_info, col_info(data), by = "id",
all.x = TRUE, all.y = FALSE, suffixes = c("", "_y"), sort = TRUE)
# type check
if (type_check) {
type = type_y = NULL
ii = head(tab[type != type_y, which = TRUE], 1L)
if (length(ii)) {
stopf("Cannot rbind to task: Types do not match for column: %s (%s != %s)", tab$id[ii], tab$type[ii], tab$type_y[ii])
}
}
# merge factor levels
ii = tab[type %chin% c("factor", "ordered"), which = TRUE]
for (i in ii) {
x = tab[["levels"]][[i]]
y = tab[["levels_y"]][[i]]
if (any(y %nin% x)) {
set(tab, i = i, j = "levels", value = list(union(x, y)))
set(tab, i = i, j = "fix_factor_levels", value = TRUE)
}
}
tab[, c("type_y", "levels_y") := list(NULL, NULL)]
# everything looks good, modify task
private$.hash = NULL
self$backend = DataBackendRbind$new(self$backend, data)
self$col_info = tab[]
private$.row_roles$use = c(private$.row_roles$use, data$rownames)
invisible(self)
},
#' @description
#'
#' Adds additional columns to the [DataBackend] stored in `$backend`.
#'
#' The row ids must be provided as column in `data` (with column name matching the primary key name of the [DataBackend]).
#' If this column is missing, it is assumed that the rows are exactly in the order of `$row_ids`.
#' In case of name clashes of column names in `data` and [DataBackend], columns in `data` have higher precedence
#' and virtually overwrite the columns in the [DataBackend].
#'
#' This operation mutates the task in-place.
#' See the section on task mutators for more information.
#' @param data (`data.frame()`).
cbind = function(data) {
assert_has_backend(self)
pk = self$backend$primary_key
if (is.data.frame(data)) {
# binding data with 0 rows is explicitly allowed
if (ncol(data) == 0L) {
return(invisible(self))
}
row_ids = if (pk %nin% names(data)) {
data[[pk]] = self$row_ids
}
data = as_data_backend(data, primary_key = pk)
} else {
assert_backend(data)
if (data$ncol <= 1L) {
return(invisible(self))
}
assert_set_equal(self$row_ids, data$rownames)
}
# update col_info for existing columns
ci = col_info(data)
self$col_info = ujoin(self$col_info, ci, key = "id")
# add rows to col_info for new columns
self$col_info = rbindlist(list(
self$col_info,
insert_named(ci[!list(self$col_info), on = "id"], list(label = NA_character_, fix_factor_levels = FALSE))
), use.names = TRUE)
setkeyv(self$col_info, "id")
# add new features
private$.hash = NULL
private$.col_hashes = NULL
col_roles = private$.col_roles
private$.col_roles$feature = union(col_roles$feature, setdiff(data$colnames, c(pk, col_roles$target)))
# update backend
self$backend = DataBackendCbind$new(self$backend, data)
invisible(self)
},
#' @description
#' Renames columns by mapping column names in `old` to new column names in `new` (element-wise).
#'
#' This operation mutates the task in-place.
#' See the section on task mutators for more information.
#'
#' @param old (`character()`)\cr
#' Old names.
#'
#' @param new (`character()`)\cr
#' New names.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
rename = function(old, new) {
assert_has_backend(self)
private$.hash = NULL
private$.col_hashes = NULL
self$backend = DataBackendRename$new(self$backend, old, new)
setkeyv(self$col_info[old, ("id") := new, on = "id"], "id")
private$.col_roles = map(private$.col_roles, map_values, old = old, new = new)
invisible(self)
},
#' @description
#' Modifies the roles in `$row_roles` **in-place**.
#'
#' @param rows (`integer()`)\cr
#' Row ids for which to change the roles for.
#' @param roles (`character()`)\cr
#' Exclusively set rows to the specified `roles` (remove from other roles).
#' @param add_to (`character()`)\cr
#' Add rows with row ids `rows` to roles specified in `add_to`.
#' Rows keep their previous roles.
#' @param remove_from (`character()`)\cr
#' Remove rows with row ids `rows` from roles specified in `remove_from`.
#' Other row roles are preserved.
#'
#' @details
#' Roles are first set exclusively (argument `roles`), then added (argument `add_to`) and finally
#' removed (argument `remove_from`) from different roles.
#' Duplicated row ids are explicitly allowed, so you can add replicate an observation by repeating its
#' `row_id`.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
set_row_roles = function(rows, roles = NULL, add_to = NULL, remove_from = NULL) {
assert_has_backend(self)
assert_subset(rows, self$backend$rownames)
private$.row_hash = NULL
private$.hash = NULL
private$.row_roles = task_set_roles(private$.row_roles, rows, roles, add_to, remove_from, allow_duplicated = TRUE)
invisible(self)
},
#' @description
#' Modifies the roles in `$col_roles` **in-place**.
#' See `$col_roles` for a list of possible roles.
#'
#' @param cols (`character()`)\cr
#' Column names for which to change the roles for.
#' @param roles (`character()`)\cr
#' Exclusively set columns to the specified `roles` (remove from other roles).
#' @param add_to (`character()`)\cr
#' Add columns with column names `cols` to roles specified in `add_to`.
#' Columns keep their previous roles.
#' @param remove_from (`character()`)\cr
#' Remove columns with columns names `cols` from roles specified in `remove_from`.
#' Other column roles are preserved.
#'
#' @details
#' Roles are first set exclusively (argument `roles`), then added (argument `add_to`) and finally removed (argument `remove_from`) from different roles.
#' Duplicated columns are removed from the same role.
#' For tasks that only allow one target, the target column cannot be set with `$set_col_roles()`.
#' Use the `$col_roles` field to swap the target column.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
set_col_roles = function(cols, roles = NULL, add_to = NULL, remove_from = NULL) {
assert_has_backend(self)
assert_subset(cols, self$col_info$id)
private$.hash = NULL
private$.col_hashes = NULL
new_roles = task_set_roles(private$.col_roles, cols, roles, add_to, remove_from)
private$.col_roles = task_check_col_roles(self, new_roles)
invisible(self)
},
#' @description
#' Set levels for columns of type `factor` and `ordered` in field `col_info`.
#' You can add, remove or reorder the levels, affecting the data returned by
#' `$data()` and `$levels()`.
#' If you just want to remove unused levels, use `$droplevels()` instead.
#'
#' Note that factor levels which are present in the data but not listed in the task as
#' valid levels are converted to missing values.
#'
#' @param levels (named `list()` of `character()`)\cr
#' List of character vectors of new levels, named by column names.
#'
#' @return Modified `self`.
set_levels = function(levels) {
assert_list(levels, types = "character", names = "unique", any.missing = FALSE)
assert_subset(names(levels), self$col_info$id)
tab = enframe(lapply(levels, unname), name = "id", value = "levels")
tab$fix_factor_levels = TRUE
private$.hash = NULL
self$col_info = ujoin(self$col_info, tab, key = "id")
invisible(self)
},
#' @description
#' Updates the cache of stored factor levels, removing all levels not present in the current set of active rows.
#' `cols` defaults to all columns with storage type "factor" or "ordered".
#' @return Modified `self`.
droplevels = function(cols = NULL) {
assert_has_backend(self)
tab = self$col_info[get("type") %chin% c("factor", "ordered"), c("id", "levels", "fix_factor_levels"), with = FALSE]
if (!is.null(cols)) {
tab = tab[list(cols), on = "id", nomatch = NULL]
}
# update levels
# note that we assume that new_levels is a subset of levels!
new_levels = NULL
tab$new_levels = self$backend$distinct(rows = self$row_ids, cols = tab$id)
tab = tab[lengths(levels) > lengths(new_levels)]
tab[, c("levels", "fix_factor_levels") := list(Map(intersect, levels, new_levels), TRUE)]
private$.hash = NULL
self$col_info = ujoin(self$col_info, remove_named(tab, "new_levels"), key = "id")
invisible(self)
},
#' @description
#' Cuts numeric variables into new factors columns which are added to the task with role
#' `"stratum"`.
#' This ensures that all training and test splits contain observations from all bins.
#' The columns are named `"..stratum_[col_name]"`.
#'
#' @param cols (`character()`)\cr
#' Names of columns to operate on.
#' @param bins (`integer()`)\cr
#' Number of bins to cut into (passed to [cut()] as `breaks`).
#' Replicated to have the same length as `cols`.
#' @return self (invisibly).
add_strata = function(cols, bins = 3L) {
assert_names(cols, "unique", subset.of = self$backend$colnames)
bins = assert_integerish(bins, any.missing = FALSE, coerce = TRUE)
col_types = fget(self$col_info, i = cols, j = "type", key = "id")
ii = wf(col_types %nin% c("integer", "numeric"))
if (length(ii)) {
stopf("For `add_strata`, all columns must be numeric, but '%s' is not", cols[ii])
}
strata = pmap_dtc(list(self$data(cols = cols), bins), cut, include.lowest = TRUE)
setnames(strata, sprintf("..stratum_%s", cols))
self$cbind(strata)
self$set_col_roles(names(strata), roles = "stratum")
}
),
active = list(
#' @template field_id
id = function(rhs) {
if (missing(rhs)) {
return(private$.id)
}
private$.hash = NULL
private$.id = assert_string(rhs, min.chars = 1L)
},
#' @field internal_valid_task (`Task` or `integer()` or `NULL`)\cr
#' Optional validation task that can, e.g., be used for early stopping with learners such as XGBoost.
#' See also the `$validate` field of [`Learner`].
#' If integers are assigned they are removed from the primary task and an internal validation task
#' with those ids is created from the primary task using only those ids.
#' When assigning a new task, it is always cloned.
internal_valid_task = function(rhs) {
if (missing(rhs)) {
return(invisible(private$.internal_valid_task))
}
private$.hash = NULL
if (is.null(rhs)) {
private$.internal_valid_task = NULL
return(invisible(private$.internal_valid_task))
}
private$.hash = NULL
if (test_integerish(rhs)) {
train_ids = setdiff(self$row_ids, rhs)
rhs = self$clone(deep = TRUE)$filter(rhs)
rhs$internal_valid_task = NULL
self$row_roles$use = train_ids
} else {
if (!is.null(rhs$internal_valid_task)) { # avoid recursive structures
stopf("Trying to assign task '%s' as a validation task, remove its validation task first.", rhs$id)
}
assert_task(rhs, task_type = self$task_type)
rhs = rhs$clone(deep = TRUE)
}
ci1 = self$col_info
ci2 = rhs$col_info
# don't do this too strictly, some column roles might just be important during training (weights)
cols = unlist(self$col_roles[c("target", "feature")], use.names = FALSE)
walk(cols, function(.col) {
if (.col %nin% ci2$id) {
stopf("Primary task has column '%s' which is not present in the validation task.", .col)
}
if (ci1[get("id") == .col, "type"]$type != ci2[get("id") == .col, "type"]$type) {
stopf("The type of column '%s' from the validation task differs from the type in the primary task.", .col)
}
})
private$.internal_valid_task = rhs
if (private$.internal_valid_task$nrow == 0) {
warningf("Internal validation task has 0 observations.")
}
invisible(private$.internal_valid_task)
},
#' @field hash (`character(1)`)\cr
#' Hash (unique identifier) for this object.
#' The hash is calculated based on the complete task object and `$row_ids`.
#' If an internal validation task is set, the hash is recalculated.
hash = function(rhs) {
if (is.null(private$.hash)) {
private$.hash = task_hash(self, self$row_ids, ignore_internal_valid_task = FALSE)
}
private$.hash
},
#' @field row_hash (`character(1)`)\cr
#' Hash (unique identifier) calculated based on the row ids.
row_hash = function(rhs) {
assert_ro_binding(rhs)
if (is.null(private$.row_hash)) {
private$.row_hash = calculate_hash(self$row_ids)
}
private$.row_hash
},
#' @field row_ids (positive `integer()`)\cr
#' Returns the row ids of the [DataBackend] for observations with role "use".
row_ids = function(rhs) {
assert_ro_binding(rhs)
private$.row_roles$use
},
#' @field row_names ([data.table::data.table()])\cr
#' Returns a table with two columns:
#'
#' * `"row_id"` (`integer()`), and
#' * `"row_name"` (`character()`).
row_names = function(rhs) {
assert_ro_binding(rhs)
nn = private$.col_roles$name
if (length(nn) == 0L) {
return(NULL)
}
setnames(self$backend$data(rows = self$row_ids, cols = c(self$backend$primary_key, nn)),
c("row_id", "row_name"))
},
#' @field feature_names (`character()`)\cr
#' Returns all column names with `role == "feature"`.
#'
#' Note that this vector determines the default order of columns for `task$data(cols = NULL, ...)`.
#' However, it is recommended to **not** rely on the order of columns, but instead always
#' address columns by their name. The default order is not well defined after some
#' operations, e.g. after `task$cbind()` or after processing via \CRANpkg{mlr3pipelines}.
feature_names = function(rhs) {
assert_ro_binding(rhs)
private$.col_roles$feature
},
#' @field target_names (`character()`)\cr
#' Returns all column names with role "target".
target_names = function(rhs) {
assert_ro_binding(rhs)
private$.col_roles$target
},
#' @field properties (`character()`)\cr
#' Set of task properties.
#' Possible properties are are stored in [mlr_reflections$task_properties][mlr_reflections].
#' The following properties are currently standardized and understood by tasks in \CRANpkg{mlr3}:
#'
#' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`).
#' * `"groups"`: The task comes with grouping/blocking information (role `"group"`).
#' * `"weights"`: The task comes with observation weights (role `"weight"`).
#' * `"offset"`: The task includes one or more offset columns specifying fixed adjustments for model training and possibly for prediction (role `"offset"`).
#' * `"ordered"`: The task has columns which define the row order (role `"order"`).
#'
#' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly.
properties = function(rhs) {
if (missing(rhs)) {
col_roles = private$.col_roles
c(character(),
private$.properties,
if (length(col_roles$group)) "groups" else NULL,
if (length(col_roles$stratum)) "strata" else NULL,
if (length(col_roles$weight)) "weights" else NULL,
if (length(col_roles$offset)) "offset" else NULL,
if (length(col_roles$order)) "ordered" else NULL
)
} else {
private$.properties = assert_set(rhs, .var.name = "properties")
}
},
#' @field row_roles (named `list()`)\cr
#' Each row (observation) can have an arbitrary number of roles in the learning task:
#'
#' - `"use"`: Use in train / predict / resampling.
#'
#' `row_roles` is a named list whose elements are named by row role and each element is an `integer()` vector of row ids.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
row_roles = function(rhs) {
if (missing(rhs)) {
return(private$.row_roles)
}
assert_has_backend(self)
assert_list(rhs, .var.name = "row_roles")
if ("test" %chin% names(rhs) || "holdout" %chin% names(rhs)) {
stopf("Setting row roles 'test'/'holdout' is no longer possible.")
}
assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_row_roles, .var.name = "names of row_roles")
rhs = map(rhs, assert_row_ids, .var.name = "elements of row_roles")
private$.row_hash = NULL
private$.hash = NULL
private$.row_roles = rhs
},
#' @field col_roles (named `list()`)\cr
#' Each column can be in one or more of the following groups to fulfill different roles:
#'
#' * `"feature"`: Regular feature used in the model fitting process.
#' * `"target"`: Target variable. Most tasks only accept a single target column.
#' * `"name"`: Row names / observation labels. To be used in plots. Can be queried with `$row_names`.
#' Not more than a single column can be associated with this role.
#' * `"order"`: Data returned by `$data()` is ordered by this column (or these columns).
#' Columns must be sortable with [order()].
#' * `"group"`: During resampling, observations with the same value of the variable with role "group" are marked as "belonging together".
#' For each resampling iteration, observations of the same group will be exclusively assigned to be either in the training set or in the test set.
#' Not more than a single column can be associated with this role.
#' * `"stratum"`: Stratification variables. Multiple discrete columns may have this role.
#' * `"weight"`: Observation weights. Not more than one numeric column may have this role.
#' * `"offset"`: Numeric columns used to specify fixed adjustments for model training.
#' Some models use offsets to simply shift predictions, while others incorporate them to boost predictions from a baseline model.
#' For learners supporting offsets in multiclass settings, an offset column must be provided for each target class.
#' These columns must follow the naming convention `"offset_{target_class_name}"`.
#' For an example of a learner that supports offsets, see `LearnerClassifXgboost` of \CRANpkg{mlr3learners}.
#'
#' `col_roles` is a named list whose elements are named by column role and each element is a `character()` vector of column names.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
#' The method `$set_col_roles` provides a convenient alternative to assign columns to roles.
col_roles = function(rhs) {
if (missing(rhs)) {
return(private$.col_roles)
}
assert_has_backend(self)
qassertr(rhs, "S[1,]", .var.name = "col_roles")
assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_col_roles[[self$task_type]], .var.name = "names of col_roles")
assert_subset(unlist(rhs, use.names = FALSE), setdiff(self$col_info$id, self$backend$primary_key), .var.name = "elements of col_roles")
private$.hash = NULL
private$.col_hashes = NULL
private$.col_roles = task_check_col_roles(self, rhs)
},
#' @field nrow (`integer(1)`)\cr
#' Returns the total number of rows with role "use".
nrow = function(rhs) {
assert_ro_binding(rhs)
length(private$.row_roles$use)
},
#' @field ncol (`integer(1)`)\cr
#' Returns the total number of columns with role "target" or "feature".
ncol = function(rhs) {
assert_ro_binding(rhs)