Skip to content

Commit 0e7d90f

Browse files
dshemetovdajmcdon
andcommitted
fix: update for compatibility with epiprocess==0.9.0
Co-authored-by: Daniel McDonald <[email protected]>
1 parent cd12775 commit 0e7d90f

27 files changed

+198
-191
lines changed

.Rbuildignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
^DEVELOPMENT\.md$
2020
^doc$
2121
^Meta$
22-
^.lintr$
22+
^.lintr$
23+
^.venv$

DESCRIPTION

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.20
3+
Version: 0.0.21
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -23,8 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
2323
https://cmu-delphi.github.io/epipredict
2424
BugReports: https://github.com/cmu-delphi/epipredict/issues/
2525
Depends:
26-
epiprocess (>= 0.8.0),
27-
epiprocess (< 0.9.0),
26+
epiprocess (>= 0.9.0),
2827
parsnip (>= 1.0.0),
2928
R (>= 3.5.0)
3029
Imports:

R/autoplot.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ autoplot.epi_workflow <- function(
131131
if (length(extra_keys) == 0L) extra_keys <- NULL
132132
edf <- as_epi_df(edf,
133133
as_of = object$fit$meta$as_of,
134-
additional_metadata = list(other_keys = extra_keys)
134+
other_keys = extra_keys %||% character()
135135
)
136136
if (is.null(predictions)) {
137137
return(autoplot(

R/cdc_baseline_forecaster.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
3030
#' select(-pop, -death_rate) %>%
3131
#' group_by(geo_value) %>%
32-
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
32+
#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>%
3333
#' ungroup() %>%
3434
#' filter(weekdays(time_value) == "Saturday")
3535
#'
36-
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
36+
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum")
3737
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
3838
#'
3939
#' if (require(ggplot2)) {
@@ -47,7 +47,7 @@
4747
#' geom_line(aes(y = .pred), color = "orange") +
4848
#' geom_line(
4949
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
50-
#' aes(x = time_value, y = deaths)
50+
#' aes(x = time_value, y = deaths_7dsum)
5151
#' ) +
5252
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
5353
#' labs(x = "Date", y = "Weekly deaths") +

R/epi_recipe.R

+7-5
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ epi_recipe.epi_df <-
9595
keys <- key_colnames(x) # we know x is an epi_df
9696

9797
var_info <- tibble(variable = vars)
98-
key_roles <- c("geo_value", "time_value", rep("key", length(keys) - 2))
98+
key_roles <- c("geo_value", rep("key", length(keys) - 2), "time_value")
9999

100100
## Check and add roles when available
101101
if (!is.null(roles)) {
@@ -499,8 +499,11 @@ prep.epi_recipe <- function(
499499
if (!is_epi_df(training)) {
500500
# tidymodels killed our class
501501
# for now, we only allow step_epi_* to alter the metadata
502-
training <- dplyr::dplyr_reconstruct(
503-
as_epi_df(training), before_template
502+
metadata <- attr(before_template, "metadata")
503+
training <- as_epi_df(
504+
training,
505+
as_of = metadata$as_of,
506+
other_keys = metadata$other_keys %||% character()
504507
)
505508
}
506509
training <- dplyr::relocate(training, all_of(key_colnames(training)))
@@ -579,8 +582,7 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
579582
new_data <- as_epi_df(
580583
new_data,
581584
as_of = meta$as_of,
582-
# avoid NULL if meta is from saved older epi_df:
583-
additional_metadata = meta$additional_metadata %||% list()
585+
other_keys = meta$other_keys %||% character()
584586
)
585587
}
586588
new_data

R/epi_workflow.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ is_epi_workflow <- function(x) {
9898
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) {
9999
object$fit$meta <- list(
100100
max_time_value = max(data$time_value),
101-
as_of = attributes(data)$metadata$as_of
101+
as_of = attr(data, "metadata")$as_of,
102+
other_keys = attr(data, "metadata")$other_keys
102103
)
103104
object$original_data <- data
104105

R/flusight_hub_formatter.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ abbr_to_location <- function(abbr) {
6767
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
6868
#' select(-pop, -death_rate) %>%
6969
#' group_by(geo_value) %>%
70-
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
70+
#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>%
7171
#' ungroup() %>%
7272
#' filter(weekdays(time_value) == "Saturday")
7373
#'
74-
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
74+
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum")
7575
#' flusight_hub_formatter(cdc)
7676
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
7777
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))

R/key_colnames.R

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
#' @export
22
key_colnames.recipe <- function(x, ...) {
3-
possible_keys <- c("geo_value", "time_value", "key")
4-
keys <- x$var_info$variable[x$var_info$role %in% possible_keys]
5-
keys[order(match(keys, possible_keys))] %||% character(0L)
3+
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
4+
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
5+
keys <- x$var_info$variable[x$var_info$role %in% "key"]
6+
c(geo_key, keys, time_key) %||% character(0L)
67
}
78

89
#' @export
910
key_colnames.epi_workflow <- function(x, ...) {
1011
# safer to look at the mold than the preprocessor
1112
mold <- hardhat::extract_mold(x)
12-
possible_keys <- c("geo_value", "time_value", "key")
1313
molded_names <- names(mold$extras$roles)
14-
keys <- map(mold$extras$roles[molded_names %in% possible_keys], names)
15-
keys <- unname(unlist(keys))
16-
keys[order(match(keys, possible_keys))] %||% character(0L)
14+
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
15+
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
16+
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
17+
c(geo_key, keys, time_key) %||% character(0L)
1718
}
1819

1920
kill_time_value <- function(v) {

R/step_epi_slide.R

+64-57
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@
1919
#' argument must be named `.x`. A common, though very difficult to debug
2020
#' error is using something like `function(x) mean`. This will not work
2121
#' because it returns the function mean, rather than `mean(x)`
22-
#' @param before,after the size of the sliding window on the left and the right
23-
#' of the center. Usually non-negative integers for data indexed by date, but
24-
#' more restrictive in other cases (see [epiprocess::epi_slide()] for details).
25-
#' @param f_name a character string of at most 20 characters that describes
26-
#' the function. This will be combined with `prefix` and the columns in `...`
27-
#' to name the result using `{prefix}{f_name}_{column}`. By default it will be determined
28-
#' automatically using `clean_f_name()`.
22+
#' @param .window_size the size of the sliding window, required. Usually a
23+
#' non-negative integer will suffice (e.g. for data indexed by date, but more
24+
#' restrictive in other time_type cases (see [epiprocess::epi_slide()] for
25+
#' details). For example, set to 7 for a 7-day window.
26+
#' @param .align a character string indicating how the window should be aligned.
27+
#' By default, this is "right", meaning the slide_window will be anchored with
28+
#' its right end point on the reference date. (see [epiprocess::epi_slide()]
29+
#' for details).
30+
#' @param f_name a character string of at most 20 characters that describes the
31+
#' function. This will be combined with `prefix` and the columns in `...` to
32+
#' name the result using `{prefix}{f_name}_{column}`. By default it will be
33+
#' determined automatically using `clean_f_name()`.
2934
#'
3035
#' @template step-return
3136
#'
@@ -37,53 +42,55 @@
3742
#' rec <- epi_recipe(jhu) %>%
3843
#' step_epi_slide(case_rate, death_rate,
3944
#' .f = \(x) mean(x, na.rm = TRUE),
40-
#' before = 6L
45+
#' .window_size = 7L
4146
#' )
4247
#' bake(prep(rec, jhu), new_data = NULL)
43-
step_epi_slide <-
44-
function(recipe,
45-
...,
46-
.f,
47-
before = 0L,
48-
after = 0L,
49-
role = "predictor",
50-
prefix = "epi_slide_",
51-
f_name = clean_f_name(.f),
52-
skip = FALSE,
53-
id = rand_id("epi_slide")) {
54-
if (!is_epi_recipe(recipe)) {
55-
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
56-
}
57-
.f <- validate_slide_fun(.f)
58-
epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type)
59-
epiprocess:::validate_slide_window_arg(after, attributes(recipe$template)$metadata$time_type)
60-
arg_is_chr_scalar(role, prefix, id)
61-
arg_is_lgl_scalar(skip)
48+
step_epi_slide <- function(recipe,
49+
...,
50+
.f,
51+
.window_size = NULL,
52+
.align = c("right", "center", "left"),
53+
role = "predictor",
54+
prefix = "epi_slide_",
55+
f_name = clean_f_name(.f),
56+
skip = FALSE,
57+
id = rand_id("epi_slide")) {
58+
if (!is_epi_recipe(recipe)) {
59+
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
60+
}
61+
.f <- validate_slide_fun(.f)
62+
if (is.null(.window_size)) {
63+
cli_abort("step_epi_slide: `.window_size` must be specified.")
64+
}
65+
epiprocess:::validate_slide_window_arg(.window_size, attributes(recipe$template)$metadata$time_type)
66+
.align <- rlang::arg_match(.align)
67+
arg_is_chr_scalar(role, prefix, id)
68+
arg_is_lgl_scalar(skip)
6269

63-
recipes::add_step(
64-
recipe,
65-
step_epi_slide_new(
66-
terms = enquos(...),
67-
before = before,
68-
after = after,
69-
.f = .f,
70-
f_name = f_name,
71-
role = role,
72-
trained = FALSE,
73-
prefix = prefix,
74-
keys = key_colnames(recipe),
75-
columns = NULL,
76-
skip = skip,
77-
id = id
78-
)
70+
recipes::add_step(
71+
recipe,
72+
step_epi_slide_new(
73+
terms = enquos(...),
74+
.window_size = .window_size,
75+
.align = .align,
76+
.f = .f,
77+
f_name = f_name,
78+
role = role,
79+
trained = FALSE,
80+
prefix = prefix,
81+
keys = key_colnames(recipe),
82+
columns = NULL,
83+
skip = skip,
84+
id = id
7985
)
80-
}
86+
)
87+
}
8188

8289

8390
step_epi_slide_new <-
8491
function(terms,
85-
before,
86-
after,
92+
.window_size,
93+
.align,
8794
.f,
8895
f_name,
8996
role,
@@ -96,8 +103,8 @@ step_epi_slide_new <-
96103
recipes::step(
97104
subclass = "epi_slide",
98105
terms = terms,
99-
before = before,
100-
after = after,
106+
.window_size = .window_size,
107+
.align = .align,
101108
.f = .f,
102109
f_name = f_name,
103110
role = role,
@@ -119,8 +126,8 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) {
119126

120127
step_epi_slide_new(
121128
terms = x$terms,
122-
before = x$before,
123-
after = x$after,
129+
.window_size = x$.window_size,
130+
.align = x$.align,
124131
.f = x$.f,
125132
f_name = x$f_name,
126133
role = x$role,
@@ -165,8 +172,8 @@ bake.step_epi_slide <- function(object, new_data, ...) {
165172
# }
166173
epi_slide_wrapper(
167174
new_data,
168-
object$before,
169-
object$after,
175+
object$.window_size,
176+
object$.align,
170177
object$columns,
171178
c(object$.f),
172179
object$f_name,
@@ -190,7 +197,7 @@ bake.step_epi_slide <- function(object, new_data, ...) {
190197
#' @importFrom dplyr bind_cols group_by ungroup
191198
#' @importFrom epiprocess epi_slide
192199
#' @keywords internal
193-
epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, group_keys, name_prefix) {
200+
epi_slide_wrapper <- function(new_data, .window_size, .align, columns, fns, fn_names, group_keys, name_prefix) {
194201
cols_fns <- tidyr::crossing(col_name = columns, fn_name = fn_names, fn = fns)
195202
# Iterate over the rows of cols_fns. For each row number, we will output a
196203
# transformed column. The first result returns all the original columns along
@@ -204,10 +211,10 @@ epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, g
204211
result <- new_data %>%
205212
group_by(across(all_of(group_keys))) %>%
206213
epi_slide(
207-
before = before,
208-
after = after,
209-
new_col_name = result_name,
210-
f = function(slice, geo_key, ref_time_value) {
214+
.window_size = .window_size,
215+
.align = .align,
216+
.new_col_name = result_name,
217+
.f = function(slice, geo_key, ref_time_value) {
211218
fn(slice[[col_name]])
212219
}
213220
) %>%

R/utils-misc.R

+7-9
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,26 @@ check_pname <- function(res, preds, object, newname = NULL) {
3333

3434

3535
grab_forged_keys <- function(forged, workflow, new_data) {
36-
keys <- c("geo_value", "time_value", "key")
3736
forged_roles <- names(forged$extras$roles)
38-
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% keys])
37+
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")])
3938
# 1. these are the keys in the test data after prep/bake
4039
new_keys <- names(extras)
4140
# 2. these are the keys in the training data
4241
old_keys <- key_colnames(workflow)
4342
# 3. these are the keys in the test data as input
44-
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, keys[1:2]))
43+
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value")))
4544
if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
4645
cli::cli_warn(c(
4746
"Not all epi keys that were present in the training data are available",
4847
"in `new_data`. Predictions will have only the available keys."
4948
))
5049
}
5150
if (is_epi_df(new_data)) {
52-
extras <- as_epi_df(extras)
53-
attr(extras, "metadata") <- attr(new_data, "metadata")
54-
} else if (all(keys[1:2] %in% new_keys)) {
55-
l <- list()
56-
if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)])
57-
extras <- as_epi_df(extras, additional_metadata = l)
51+
meta <- attr(new_data, "metadata")
52+
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character())
53+
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
54+
if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
55+
extras <- as_epi_df(extras, other_keys = other_keys %||% character())
5856
}
5957
extras
6058
}

data-raw/grad_employ_subset.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,6 @@ ncol(gemploy)
101101
grad_employ_subset <- gemploy %>%
102102
as_epi_df(
103103
as_of = "2022-07-19",
104-
additional_metadata = list(other_keys = c("age_group", "edu_qual"))
104+
other_keys = c("age_group", "edu_qual")
105105
)
106106
usethis::use_data(grad_employ_subset, overwrite = TRUE)

data/grad_employ_subset.rda

1.03 KB
Binary file not shown.

man/autoplot-epipred.Rd

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/cdc_baseline_forecaster.Rd

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)