]> nmode's Git Repositories - Rnaught/blobdiff - R/wp.R
Add input validation to estimators
[Rnaught] / R / wp.R
diff --git a/R/wp.R b/R/wp.R
index 16b4bbb3730df97919b75737f4022f85c4ba7ce1..fbb6ad7956d59c426b56dc5058fcdd93c0e17365 100644 (file)
--- a/R/wp.R
+++ b/R/wp.R
 #' estimate$pmf
 wp <- function(cases, mu = NA, serial = FALSE,
                grid_length = 100, max_shape = 10, max_scale = 10) {
-  if (is.na(mu)) {
+  validate_cases(cases, min_length = 2, min_count = 1)
+  if (!identical(serial, TRUE) && !identical(serial, FALSE)) {
+    stop(
+      paste("The serial distribution flag (`serial`) must be set to",
+        "`TRUE` or `FALSE`."
+      ), call. = FALSE
+    )
+  }
+
+  if (identical(mu, NA)) {
+    if (!is_integer(grid_length) || grid_length < 1) {
+      stop("The grid length must be a positive integer.", call. = FALSE)
+    }
+    if (!is_real(max_shape) || max_shape <= 0) {
+      stop(
+        paste("The largest value of the shape parameter (`max_shape`)",
+          "must be a positive number."
+        ), call. = FALSE
+      )
+    }
+    if (!is_real(max_scale) || max_scale <= 0) {
+      stop(
+        paste("The largest value of the scale parameter (`max_scale`)",
+          "must be a positive number."
+        ), call. = FALSE
+      )
+    }
+
     search <- wp_search(cases, grid_length, max_shape, max_scale)
     r0 <- search$r0
     serial_supp <- search$supp
     serial_pmf <- search$pmf
   } else {
+    if (!is_real(mu) || mu <= 0) {
+      stop("The serial interval (`mu`) must be a positive number or `NA`.",
+        call. = FALSE
+      )
+    }
+
     max_range <- ceiling(qgamma(0.999, shape = 1, scale = mu))
     serial_supp <- seq_len(max_range)
     serial_pmf <- diff(pgamma(0:max_range, shape = 1, scale = mu))