diff options
author | Naeem Model <me@nmode.ca> | 2024-11-02 18:13:28 +0000 |
---|---|---|
committer | Naeem Model <me@nmode.ca> | 2024-11-02 18:13:28 +0000 |
commit | 94b4dcd37e662eb1e525dc241817c8dd5d4681fc (patch) | |
tree | f5ef5b90bf2307dd28ae946413350e34a159b7fa | |
parent | 9fd931aeeba4ab7bdede1a625f64e7024c2b55aa (diff) |
Add input validation to estimators
-rw-r--r-- | R/id.R | 7 | ||||
-rw-r--r-- | R/idea.R | 7 | ||||
-rw-r--r-- | R/seq_bayes.R | 23 | ||||
-rw-r--r-- | R/util.R | 55 | ||||
-rw-r--r-- | R/wp.R | 35 | ||||
-rw-r--r-- | inst/app/scripts/data.R | 8 | ||||
-rw-r--r-- | inst/app/scripts/estimators.R | 72 |
7 files changed, 168 insertions, 39 deletions
@@ -41,5 +41,12 @@ #' # Obtain R0 when the serial distribution has a mean of three days. #' id(cases, mu = 3 / 7) id <- function(cases, mu) { + validate_cases(cases, min_length = 1, min_count = 1) + if (!is_real(mu) || mu <= 0) { + stop("The serial interval (`mu`) must be a number greater than 0.", + call. = FALSE + ) + } + exp(sum((log(cases) * mu) / seq_along(cases)) / length(cases)) } @@ -42,6 +42,13 @@ #' # Obtain R0 when the serial distribution has a mean of three days. #' idea(cases, mu = 3 / 7) idea <- function(cases, mu) { + validate_cases(cases, min_length = 2, min_count = 1) + if (!is_real(mu) || mu <= 0) { + stop("The serial interval (`mu`) must be a number greater than 0.", + call. = FALSE + ) + } + s <- seq_along(cases) / mu x1 <- sum(s) diff --git a/R/seq_bayes.R b/R/seq_bayes.R index d486d2b..ccc9a41 100644 --- a/R/seq_bayes.R +++ b/R/seq_bayes.R @@ -84,10 +84,31 @@ #' # Note that the following always holds: #' estimate == sum(posterior$supp * posterior$pmf) seq_bayes <- function(cases, mu, kappa = 20, post = FALSE) { + validate_cases(cases, min_length = 2, min_count = 0) + if (!is_real(mu) || mu <= 0) { + stop("The serial interval (`mu`) must be a number greater than 0.", + call. = FALSE + ) + } + if (!is_real(kappa) || kappa < 1) { + stop( + paste("The largest value of the uniform prior (`kappa`)", + "must be a number greater than or equal to 1." + ), call. = FALSE + ) + } + if (!identical(post, TRUE) && !identical(post, FALSE)) { + stop("The posterior flag (`post`) must be set to `TRUE` or `FALSE`.", + call. = FALSE + ) + } + if (any(cases == 0)) { times <- which(cases > 0) if (length(times) < 2) { - stop("Vector of case counts must contain at least two positive integers.") + stop("Case counts must contain at least two positive integers.", + call. = FALSE + ) } cases <- cases[times] } else { diff --git a/R/util.R b/R/util.R new file mode 100644 index 0000000..d8b0b59 --- /dev/null +++ b/R/util.R @@ -0,0 +1,55 @@ +#' Case Counts Validation +#' +#' This is an internal function called by the estimators. It validates the +#' supplied case counts by ensuring it is a vector of integers of length at +#' least `min_length` with entries greater than or equal to `min_count`. +#' Execution is halted if these requirements are not satisfied. +#' +#' @param cases The case counts to be validated. +#' @param min_length The minimum length of the vector of case counts. +#' @param min_count The minimum value of the case count vector's entries. +#' +#' @noRd +validate_cases <- function(cases, min_length, min_count) { + if (!is.vector(cases) || !is.numeric(cases) || any(floor(cases) != cases)) { + stop("Case counts must be a vector of integers.", call. = FALSE) + } + if (length(cases) < min_length) { + stop(paste("Case counts must have at least", min_length, "entries."), + call. = FALSE + ) + } + if (any(cases < min_count)) { + stop(paste0("Case counts cannot be less than ", min_count, "."), + call. = FALSE + ) + } +} + +#' Real Number Testing +#' +#' This is an internal function which checks whether the given argument is a +#' real number. +#' +#' @param x The argument to test for being a real number. +#' +#' @return `TRUE` if the argument is a real number, `FALSE` otherwise. +#' +#' @noRd +is_real <- function(x) { + is.vector(x) && is.numeric(x) && length(x) == 1 +} + +#' Integer Testing +#' +#' This is an internal function which checks whether the given argument is an +#' integer. +#' +#' @param n The argument to test for being an integer. +#' +#' @return `TRUE` if the argument is an integer, `FALSE` otherwise. +#' +#' @noRd +is_integer <- function(n) { + is_real(n) && floor(n) == n +} @@ -111,12 +111,45 @@ #' estimate$pmf wp <- function(cases, mu = NA, serial = FALSE, grid_length = 100, max_shape = 10, max_scale = 10) { - if (is.na(mu)) { + validate_cases(cases, min_length = 2, min_count = 1) + if (!identical(serial, TRUE) && !identical(serial, FALSE)) { + stop( + paste("The serial distribution flag (`serial`) must be set to", + "`TRUE` or `FALSE`." + ), call. = FALSE + ) + } + + if (identical(mu, NA)) { + if (!is_integer(grid_length) || grid_length < 1) { + stop("The grid length must be a positive integer.", call. = FALSE) + } + if (!is_real(max_shape) || max_shape <= 0) { + stop( + paste("The largest value of the shape parameter (`max_shape`)", + "must be a positive number." + ), call. = FALSE + ) + } + if (!is_real(max_scale) || max_scale <= 0) { + stop( + paste("The largest value of the scale parameter (`max_scale`)", + "must be a positive number." + ), call. = FALSE + ) + } + search <- wp_search(cases, grid_length, max_shape, max_scale) r0 <- search$r0 serial_supp <- search$supp serial_pmf <- search$pmf } else { + if (!is_real(mu) || mu <= 0) { + stop("The serial interval (`mu`) must be a positive number or `NA`.", + call. = FALSE + ) + } + max_range <- ceiling(qgamma(0.999, shape = 1, scale = mu)) serial_supp <- seq_len(max_range) serial_pmf <- diff(pgamma(0:max_range, shape = 1, scale = mu)) diff --git a/inst/app/scripts/data.R b/inst/app/scripts/data.R index c85e27b..8f8694c 100644 --- a/inst/app/scripts/data.R +++ b/inst/app/scripts/data.R @@ -141,9 +141,7 @@ validate_data <- function(input, output, react_values, data_source) { # corresponding columns in the estimates table. update_estimates_cols(new_rows, react_values) - showNotification("Datasets added successfully.", - duration = 3, id = "notify-success" - ) + showNotification("Datasets added successfully.", duration = 3) } }, error = function(e) { @@ -195,9 +193,7 @@ load_samples <- function(input, output, react_values) { # corresponding columns in the estimates table. update_estimates_cols(new_rows, react_values) - showNotification("Datasets added successfully.", - duration = 3, id = "notify-success" - ) + showNotification("Datasets added successfully.", duration = 3) } }) } diff --git a/inst/app/scripts/estimators.R b/inst/app/scripts/estimators.R index b61f4d4..a86b1d4 100644 --- a/inst/app/scripts/estimators.R +++ b/inst/app/scripts/estimators.R @@ -29,8 +29,8 @@ add_estimator <- function(method, new_estimator, output, react_values) { # Check whether the new estimator is a duplicate, and warn if so. for (i in seq_len(num_estimators)) { if (identical(new_estimator, react_values$estimators[[i]])) { - showNotification("Error: This estimator has already been added.", - duration = 3, id = "notify-error" + showNotification( + "Error: This estimator has already been added.", duration = 3 ) return() } @@ -39,9 +39,7 @@ add_estimator <- function(method, new_estimator, output, react_values) { # Add the new estimator to the list of estimators. react_values$estimators[[num_estimators + 1]] <- new_estimator - showNotification("Estimator added successfully.", - duration = 3, id = "notify-success" - ) + showNotification("Estimator added successfully.", duration = 3) # Evaluate the new estimator on all existing datasets and create a new row in # the estimates table. @@ -95,9 +93,9 @@ add_seq_bayes <- function(input, output, react_values) { kappa <- trimws(input$kappa) kappa <- if (kappa == "") 20 else suppressWarnings(as.numeric(kappa)) - if (is.na(kappa) || kappa <= 0) { + if (is.na(kappa) || kappa < 1) { output$kappa_warn <- renderText( - "The maximum prior must be a positive number." + "The maximum prior must be a number greater than or equal to 1." ) } else if (!is.null(mu)) { output$kappa_warn <- renderText("") @@ -209,32 +207,44 @@ update_estimates_row <- function(estimator, react_values) { eval_estimator <- function(estimator, dataset) { cases <- as.integer(unlist(strsplit(dataset[, 3], ","))) - if (estimator$method == "id") { - mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) - estimate <- round(Rnaught::id(cases, mu), 2) - } else if (estimator$method == "idea") { - mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) - estimate <- round(Rnaught::idea(cases, mu), 2) - } else if (estimator$method == "seq_bayes") { - mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) - estimate <- round(Rnaught::seq_bayes(cases, mu, estimator$kappa), 2) - } else if (estimator$method == "wp") { - if (is.na(estimator$mu)) { - estimate <- Rnaught::wp(cases, serial = TRUE, - grid_length = estimator$grid_length, - max_shape = estimator$max_shape, max_scale = estimator$max_scale - ) - estimated_mu <- round(sum(estimate$supp * estimate$pmf), 2) - estimate <- paste0(round(estimate$r0, 2), " (SI = ", estimated_mu, - " ", tolower(dataset[, 2]), ")" + tryCatch( + { + if (estimator$method == "id") { + mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) + estimate <- round(Rnaught::id(cases, mu), 2) + } else if (estimator$method == "idea") { + mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) + estimate <- round(Rnaught::idea(cases, mu), 2) + } else if (estimator$method == "seq_bayes") { + mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) + estimate <- round(Rnaught::seq_bayes(cases, mu, estimator$kappa), 2) + } else if (estimator$method == "wp") { + if (is.na(estimator$mu)) { + estimate <- Rnaught::wp(cases, serial = TRUE, + grid_length = estimator$grid_length, + max_shape = estimator$max_shape, max_scale = estimator$max_scale + ) + estimated_mu <- round(sum(estimate$supp * estimate$pmf), 2) + estimate <- paste0(round(estimate$r0, 2), " (SI = ", estimated_mu, + " ", tolower(dataset[, 2]), ")" + ) + } else { + mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) + estimate <- round(Rnaught::wp(cases, mu), 2) + } + } + + return(estimate) + }, error = function(e) { + showNotification( + paste0(toString(e), + " [Estimator: ", sub(" .*", "", estimator_name(estimator)), + ", Dataset: ", dataset[, 1], "]" + ), duration = 6 ) - } else { - mu <- convert_mu_units(dataset[, 2], estimator$mu_units, estimator$mu) - estimate <- round(Rnaught::wp(cases, mu), 2) + return("—") } - } - - return(estimate) + ) } # Create the name of an estimator to be added to the first column of the |