|
1 |
| -library(tidyverse) |
2 |
| -library(httr) |
3 |
| -library(lubridate) |
4 |
| -library(progress) |
5 |
| -library(targets) |
6 |
| -source(here::here("R", "load_all.R")) |
7 |
| - |
8 |
| -options(readr.show_progress = FALSE) |
9 |
| -options(readr.show_col_types = FALSE) |
10 |
| - |
11 |
| -insufficient_data_geos <- c("as", "mp", "vi", "gu") |
12 |
| - |
13 |
| -# Configuration |
14 |
| -config <- list( |
15 |
| - base_url = "https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output", |
16 |
| - forecasters = c("CMU-TimeSeries", "FluSight-baseline", "FluSight-ensemble", "FluSight-base_seasonal", "UMass-flusion"), |
17 |
| - local_storage = "data/forecasts", |
18 |
| - tracking_file = "data/download_tracking.csv" |
19 |
| -) |
20 |
| - |
21 |
| -# Function to ensure directory structure exists |
22 |
| -setup_directories <- function(base_dir) { |
23 |
| - dir.create(file.path(base_dir), recursive = TRUE, showWarnings = FALSE) |
24 |
| - for (forecaster in config$forecasters) { |
25 |
| - dir.create(file.path(base_dir, forecaster), recursive = TRUE, showWarnings = FALSE) |
26 |
| - } |
27 |
| -} |
28 |
| - |
29 |
| -# Function to load tracking data |
30 |
| -load_tracking_data <- function() { |
31 |
| - if (file.exists(config$tracking_file)) { |
32 |
| - read_csv(config$tracking_file) |
33 |
| - } else { |
34 |
| - tibble( |
35 |
| - forecaster = character(), |
36 |
| - filename = character(), |
37 |
| - download_date = character(), |
38 |
| - status = character() |
39 |
| - ) |
40 |
| - } |
41 |
| -} |
42 |
| - |
43 |
| -# Function to generate possible filenames for a date range |
44 |
| -generate_filenames <- function(start_date, end_date, forecaster) { |
45 |
| - dates <- seq(as_date(start_date), as_date(end_date), by = "week") |
46 |
| - filenames <- paste0( |
47 |
| - format(dates, "%Y-%m-%d"), |
48 |
| - "-", |
49 |
| - forecaster, |
50 |
| - ".csv" |
51 |
| - ) |
52 |
| - return(filenames) |
53 |
| -} |
54 |
| - |
55 |
| -# Function to check if file exists on GitHub |
56 |
| -check_github_file <- function(forecaster, filename) { |
57 |
| - url <- paste0(config$base_url, "/", forecaster, "/", filename) |
58 |
| - response <- GET(url) |
59 |
| - return(status_code(response) == 200) |
60 |
| -} |
61 |
| - |
62 |
| -# Function to download a single file |
63 |
| -download_forecast_file <- function(forecaster, filename) { |
64 |
| - url <- paste0(config$base_url, "/", forecaster, "/", filename) |
65 |
| - local_path <- file.path(config$local_storage, forecaster, filename) |
66 |
| - |
67 |
| - tryCatch( |
68 |
| - { |
69 |
| - download.file(url, local_path, mode = "wb", quiet = TRUE) |
70 |
| - return("success") |
71 |
| - }, |
72 |
| - error = function(e) { |
73 |
| - return("failed") |
74 |
| - } |
75 |
| - ) |
76 |
| -} |
77 |
| - |
78 |
| -# Main function to update forecast files |
79 |
| -update_forecast_files <- function(days_back = 30) { |
80 |
| - # Setup |
81 |
| - setup_directories(config$local_storage) |
82 |
| - tracking_data <- load_tracking_data() |
83 |
| - |
84 |
| - # Generate date range |
85 |
| - end_date <- Sys.Date() |
86 |
| - start_date <- get_forecast_reference_date(end_date - days_back) |
87 |
| - |
88 |
| - # Process each forecaster |
89 |
| - new_tracking_records <- list() |
90 |
| - |
91 |
| - pb_forecasters <- progress_bar$new( |
92 |
| - format = "Downloading forecasts from :forecaster [:bar] :percent :eta", |
93 |
| - total = length(config$forecasters), |
94 |
| - clear = FALSE, |
95 |
| - width = 60 |
96 |
| - ) |
97 |
| - |
98 |
| - for (forecaster in config$forecasters) { |
99 |
| - pb_forecasters$tick(tokens = list(forecaster = forecaster)) |
100 |
| - |
101 |
| - # Get potential filenames |
102 |
| - filenames <- generate_filenames(start_date, end_date, forecaster) |
103 |
| - |
104 |
| - # Filter out already downloaded files |
105 |
| - existing_files <- tracking_data %>% |
106 |
| - filter(forecaster == !!forecaster, status == "success") %>% |
107 |
| - pull(filename) |
108 |
| - |
109 |
| - new_files <- setdiff(filenames, existing_files) |
110 |
| - |
111 |
| - if (length(new_files) > 0) { |
112 |
| - # Create nested progress bar for files |
113 |
| - pb_files <- progress_bar$new( |
114 |
| - format = " Downloading files [:bar] :current/:total :filename", |
115 |
| - total = length(new_files) |
116 |
| - ) |
117 |
| - |
118 |
| - for (filename in new_files) { |
119 |
| - pb_files$tick(tokens = list(filename = filename)) |
120 |
| - |
121 |
| - if (check_github_file(forecaster, filename)) { |
122 |
| - status <- download_forecast_file(forecaster, filename) |
123 |
| - |
124 |
| - new_tracking_records[[length(new_tracking_records) + 1]] <- tibble( |
125 |
| - forecaster = forecaster, |
126 |
| - filename = filename, |
127 |
| - download_date = as.character(Sys.time()), |
128 |
| - status = status |
129 |
| - ) |
130 |
| - } |
131 |
| - } |
132 |
| - } |
133 |
| - } |
134 |
| - |
135 |
| - # Update tracking data |
136 |
| - if (length(new_tracking_records) > 0) { |
137 |
| - new_tracking_data <- bind_rows(new_tracking_records) |
138 |
| - tracking_data <- bind_rows(tracking_data, new_tracking_data) |
139 |
| - write_csv(tracking_data, config$tracking_file) |
140 |
| - } |
141 |
| - |
142 |
| - return(tracking_data) |
143 |
| -} |
144 |
| - |
145 |
| -# Function to read all forecast data |
146 |
| -read_all_forecasts <- function() { |
147 |
| - tracking_data <- read_csv(config$tracking_file) |
148 |
| - |
149 |
| - successful_downloads <- tracking_data %>% |
150 |
| - filter(status == "success") |
151 |
| - |
152 |
| - forecast_data <- map(1:nrow(successful_downloads), function(i) { |
153 |
| - row <- successful_downloads[i, ] |
154 |
| - path <- file.path(config$local_storage, row$forecaster, row$filename) |
155 |
| - if (file.exists(path)) { |
156 |
| - read_csv(path, col_types = list( |
157 |
| - reference_date = col_date(format = "%Y-%m-%d"), |
158 |
| - target_end_date = col_date(format = "%Y-%m-%d"), |
159 |
| - target = col_character(), |
160 |
| - location = col_character(), |
161 |
| - horizon = col_integer(), |
162 |
| - output_type = col_character(), |
163 |
| - output_type_id = col_character(), |
164 |
| - value = col_double(), |
165 |
| - forecaster = col_character(), |
166 |
| - forecast_date = col_date(format = "%Y-%m-%d") |
167 |
| - )) %>% |
168 |
| - mutate( |
169 |
| - forecaster = row$forecaster, |
170 |
| - forecast_date = as.Date(str_extract(row$filename, "\\d{4}-\\d{2}-\\d{2}")), |
171 |
| - ) |
172 |
| - } |
173 |
| - }) |
174 |
| - |
175 |
| - bind_rows(forecast_data) %>% |
176 |
| - add_state_info(geo_value_col = "location", old_geo_code = "state_code", new_geo_code = "state_id") %>% |
177 |
| - rename(geo_value = state_id) %>% |
178 |
| - select(-location) %>% |
179 |
| - filter( |
180 |
| - target == "wk inc flu hosp", |
181 |
| - output_type == "quantile", |
182 |
| - ) |
183 |
| -} |
184 |
| - |
185 |
| -score_forecasts <- function(all_forecasts, nhsn_latest_data) { |
186 |
| - predictions_cards <- all_forecasts %>% |
187 |
| - rename(model = forecaster) %>% |
188 |
| - mutate( |
189 |
| - quantile = as.numeric(output_type_id), |
190 |
| - prediction = value |
191 |
| - ) %>% |
192 |
| - select(model, geo_value, forecast_date, target_end_date, quantile, prediction) |
193 |
| - |
194 |
| - truth_data <- nhsn_latest_data %>% |
195 |
| - mutate( |
196 |
| - target_end_date = as.Date(time_value), |
197 |
| - true_value = value |
198 |
| - ) %>% |
199 |
| - select(geo_value, target_end_date, true_value) |
200 |
| - |
201 |
| - evaluate_predictions(predictions_cards = predictions_cards, truth_data = truth_data) %>% |
202 |
| - rename(forecaster = model) |
203 |
| -} |
204 |
| - |
205 |
| -get_latest_data <- function() { |
206 |
| - update_forecast_files(days_back = 120) |
207 |
| - read_all_forecasts() |
208 |
| -} |
209 |
| - |
210 |
| -rlang::list2( |
211 |
| - tar_target( |
212 |
| - nhsn_latest_data, |
213 |
| - command = { |
214 |
| - if (wday(Sys.Date()) < 6 & wday(Sys.Date()) > 3) { |
215 |
| - # download from the preliminary data source from Wednesday to Friday |
216 |
| - most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/mpgq-jmmr.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm") |
217 |
| - } else { |
218 |
| - most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/ua7e-t2fy.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm") |
219 |
| - } |
220 |
| - most_recent_result %>% |
221 |
| - process_nhsn_data() %>% |
222 |
| - filter(disease == "nhsn_flu") %>% |
223 |
| - select(-disease) %>% |
224 |
| - filter(geo_value %nin% insufficient_data_geos) %>% |
225 |
| - mutate( |
226 |
| - source = "nhsn", |
227 |
| - geo_value = ifelse(geo_value == "usa", "us", geo_value), |
228 |
| - time_value = time_value |
229 |
| - ) %>% |
230 |
| - filter(version == max(version)) %>% |
231 |
| - select(-version) %>% |
232 |
| - data_substitutions(disease = "flu") %>% |
233 |
| - as_epi_df(other_keys = "source", as_of = Sys.Date()) |
234 |
| - } |
235 |
| - ), |
236 |
| - tar_target( |
237 |
| - name = nhsn_archive_data, |
238 |
| - command = { |
239 |
| - create_nhsn_data_archive(disease = "nhsn_flu") |
240 |
| - } |
241 |
| - ), |
242 |
| - tar_target(download_forecasts, update_forecast_files(days_back = 120)), |
243 |
| - tar_target(all_forecasts, read_all_forecasts()), |
244 |
| - tar_target(all_scores, score_forecasts(all_forecasts, nhsn_latest_data)) |
245 |
| -) |
0 commit comments