diff --git a/.Rbuildignore b/.Rbuildignore index ec53fb2..19c1c66 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -32,3 +32,5 @@ ^vignettes/.+\.pdf$ ^vignettes/.+\.sty$ ^vignettes/.+\.tex$ +^.*\.Rproj$ +^\.Rproj\.user$ diff --git a/DESCRIPTION b/DESCRIPTION index fb3c4f3..cb39104 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -3,8 +3,8 @@ Type: Package Title: Tidy, Type-Safe 'prediction()' Methods Description: A one-function package containing 'prediction()', a type-safe alternative to 'predict()' that always returns a data frame. The 'summary()' method provides a data frame with average predictions, possibly over counterfactual versions of the data (a la the 'margins' command in 'Stata'). Marginal effect estimation is provided by the related package, 'margins' . The package currently supports common model types (e.g., "lm", "glm") from the 'stats' package, as well as numerous other model classes from other add-on packages. See the README or main package documentation page for a complete listing. License: MIT + file LICENSE -Version: 0.3.14 -Date: 2019-06-16 +Version: 0.3.15 +Date: 2019-08-08 Authors@R: c(person("Thomas J.", "Leeper", role = c("aut", "cre"), email = "thosjleeper@gmail.com", @@ -13,7 +13,10 @@ Authors@R: c(person("Thomas J.", "Leeper", email = "carlganz@ucla.edu"), person("Vincent", "Arel-Bundock", role = "ctb", email = "vincent.arel-bundock@umontreal.ca", - comment = c(ORCID = "0000-0003-2042-7063")) + comment = c(ORCID = "0000-0003-2042-7063")), + person("Tomasz", "\u017b\u00F3\u0142tak", role = "ctb", + email = "tomek@zozlak.org", + comment = c(ORCID = "0000-0003-1354-4472")) ) URL: https://github.com/leeper/prediction BugReports: https://github.com/leeper/prediction/issues diff --git a/NAMESPACE b/NAMESPACE index 3daf9dc..21a792a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -9,7 +9,9 @@ S3method(find_data,hxlr) S3method(find_data,lm) S3method(find_data,mca) S3method(find_data,merMod) +S3method(find_data,survey.design) S3method(find_data,svyglm) +S3method(find_data,svyrep.design) S3method(find_data,train) S3method(find_data,vgam) S3method(find_data,vglm) diff --git a/NEWS.md b/NEWS.md index 2bc7259..eb2881f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# prediction 0.3.15 + +* `prediction.svyglm` handles survey design objects as `data` argument. +* `prediction.svyglm` handles `data` with NAs. +* `build_datalist` preserves levels of factors that are mentioned in `at` argument. + # prediction 0.3.13 * Fixed a bug in `prediction_glm` with the `data` argument (Issue #32). diff --git a/R/build_datalist.R b/R/build_datalist.R index 93198dc..7ee6282 100644 --- a/R/build_datalist.R +++ b/R/build_datalist.R @@ -19,26 +19,25 @@ #' @seealso \code{\link{find_data}}, \code{\link{mean_or_mode}}, \code{\link{seq_range}} #' @importFrom data.table rbindlist #' @export -build_datalist <- +build_datalist <- function(data, - at = NULL, + at = NULL, as.data.frame = FALSE, ...){ - + # check for `at` specification and `as.data.frame` arguments if (!is.null(at) && length(at) > 0) { # check `at` specification against data check_at(data, at) - + # setup list of data.frames based on at data_out <- set_data_to_at(data, at = at) at_specification <- cbind(index = seq_len(nrow(data_out[["at"]])), data_out[["at"]]) data_out <- data_out[["data"]] - if (isTRUE(as.data.frame)) { data_out <- data.table::rbindlist(data_out) } - + } else if (isTRUE(as.data.frame)) { # if `at` empty and `as.data.frame = TRUE`, simply return original data data_out <- data @@ -54,10 +53,10 @@ function(data, check_at <- function(data, at) { # check names of `at` check_at_names(names(data), at) - + # check factor levels specified in `at` check_factor_levels(data, at) - + # check values of numeric values are interpolations check_values(data, at) } @@ -71,7 +70,7 @@ check_factor_levels <- function(data, at) { levels(factor(v)) } else { NULL - } + } }) levels <- levels[!sapply(levels, is.null)] at <- at[names(at) %in% names(levels)] @@ -79,8 +78,8 @@ check_factor_levels <- function(data, at) { atvals <- as.character(at[[i]]) x <- atvals %in% levels[[names(at)[i]]] if (!all(x)) { - stop(paste0("Illegal factor levels for variable '", names(at)[i], "': ", - paste0(shQuote(atvals[!x]), collapse = ", ")), + stop(paste0("Illegal factor levels for variable '", names(at)[i], "': ", + paste0(shQuote(atvals[!x]), collapse = ", ")), call. = FALSE) } } @@ -90,7 +89,7 @@ check_factor_levels <- function(data, at) { check_values <- function(data, at) { # drop variables not in `at` dat <- data[, names(at), drop = FALSE] - + # drop non-numeric variables from `dat` and `at` not_numeric <- !sapply(dat, class) %in% c("character", "factor", "ordered", "logical") at <- at[names(at) %in% names(dat)[not_numeric]] @@ -100,7 +99,7 @@ check_values <- function(data, at) { # calculate variable ranges limits <- do.call(rbind, lapply(dat, range, na.rm = TRUE)) rownames(limits) <- names(dat) - + # check ranges for (i in seq_along(at)) { out <- (at[[i]] < limits[names(at)[i],1]) | (at[[i]] > limits[names(at)[i],2]) @@ -136,9 +135,13 @@ set_data_to_at <- function(data, at = NULL) { } else { expanded <- expand.grid(at, KEEP.OUT.ATTRS = FALSE) } - e <- split(expanded, unique(expanded)) + for (i in intersect(names(data)[sapply(data, is.factor)], names(expanded))) { + expanded[, i] <- factor(expanded[[i]], levels(data[[i]])) + } + e <- split(expanded, unique(expanded), drop = TRUE) data_out <- lapply(e, function(atvals) { dat <- data + dat <- `[<-`(dat, , names(atvals), value = atvals) structure(dat, at = as.list(atvals)) }) diff --git a/R/find_data.R b/R/find_data.R index a55da6c..528f0ff 100644 --- a/R/find_data.R +++ b/R/find_data.R @@ -10,7 +10,7 @@ #' require("datasets") #' x <- lm(mpg ~ cyl * hp + wt, data = head(mtcars)) #' find_data(x) -#' +#' #' @seealso \code{\link{prediction}}, \code{\link{build_datalist}}, \code{\link{mean_or_mode}}, \code{\link{seq_range}} #' @export find_data <- function(model, ...) { @@ -107,9 +107,34 @@ find_data.merMod <- function(model, env = parent.frame(), ...) { #' @export find_data.svyglm <- function(model, ...) { data <- model[["data"]] + # handle subset + if (!is.null(model[["call"]][["subset"]])) { + subs <- try(eval(model[["call"]][["subset"]], data), silent = TRUE) + if (inherits(subs, "try-error")) { + subs <- TRUE + warning("'find_data()' cannot locate variable(s) used in 'subset'") + } + data <- data[subs, , drop = FALSE] + } + # handle na.action + if (!is.null(model[["na.action"]])) { + data <- data[-model[["na.action"]], , drop = FALSE] + } data } +#' @rdname find_data +#' @export +find_data.survey.design <- function(model, ...) { + model[["variables"]] +} + +#' @rdname find_data +#' @export +find_data.svyrep.design <- function(model, ...) { + model[["variables"]] +} + #' @rdname find_data #' @export find_data.train <- function(model, ...) { diff --git a/R/prediction_svyglm.R b/R/prediction_svyglm.R index 3e017e1..8a4e601 100644 --- a/R/prediction_svyglm.R +++ b/R/prediction_svyglm.R @@ -1,22 +1,25 @@ #' @rdname prediction #' @export -prediction.svyglm <- -function(model, - data = find_data(model, parent.frame()), - at = NULL, - type = c("response", "link"), +prediction.svyglm <- +function(model, + data = find_data(model, parent.frame()), + at = NULL, + type = c("response", "link"), calculate_se = TRUE, ...) { - + type <- match.arg(type) - + # extract predicted values data <- data if (missing(data) || is.null(data)) { pred <- predict(model, type = type, se.fit = TRUE, ...) - pred <- data.frame(fitted = unclass(pred), + pred <- data.frame(fitted = unclass(pred), se.fitted = sqrt(unname(attributes(pred)[["var"]]))) } else { + if (inherits(data, c("survey.design", "svyrep.design"))) { + data <- find_data(data) + } # setup data if (is.null(at)) { out <- data @@ -26,14 +29,18 @@ function(model, } # calculate predictions tmp <- predict(model, newdata = out, type = type, se.fit = TRUE, ...) - pred <- make_data_frame(out, fitted = unclass(tmp), se.fitted = sqrt(unname(attributes(tmp)[["var"]]))) + se.fitted <- fitted <- rep(NA_real_, nrow(out)) + noNAs <- rownames(out) %in% names(tmp) + se.fitted[noNAs] <- sqrt(unname(attributes(tmp)[["var"]])) + fitted[noNAs] <- unclass(tmp) + pred <- make_data_frame(out, fitted = fitted, se.fitted = se.fitted) } - + # variance(s) of average predictions vc <- NA_real_ - + # output - structure(pred, + structure(pred, class = c("prediction", "data.frame"), at = if (is.null(at)) at else at_specification, type = type, diff --git a/man/find_data.Rd b/man/find_data.Rd index cc6cb50..3496b90 100644 --- a/man/find_data.Rd +++ b/man/find_data.Rd @@ -12,6 +12,8 @@ \alias{find_data.mca} \alias{find_data.merMod} \alias{find_data.svyglm} +\alias{find_data.survey.design} +\alias{find_data.svyrep.design} \alias{find_data.train} \alias{find_data.vgam} \alias{find_data.vglm} @@ -39,6 +41,10 @@ find_data(model, ...) \method{find_data}{svyglm}(model, ...) +\method{find_data}{survey.design}(model, ...) + +\method{find_data}{svyrep.design}(model, ...) + \method{find_data}{train}(model, ...) \method{find_data}{vgam}(model, env = parent.frame(), ...) diff --git a/tests/testthat/tests-methods.R b/tests/testthat/tests-methods.R index 446f3b5..8e5d3fa 100644 --- a/tests/testthat/tests-methods.R +++ b/tests/testthat/tests-methods.R @@ -481,7 +481,7 @@ if (require("sampleSelection", quietly = TRUE)) { test_that("Test prediction() for 'selection'", { data("Mroz87", package = "sampleSelection") Mroz87$kids <- (Mroz87$kids5 + Mroz87$kids618 > 0) - m <- sampleSelection::heckit(lfp ~ age + I( age^2 ) + faminc + kids + educ, + m <- sampleSelection::heckit(lfp ~ age + I( age^2 ) + faminc + kids + educ, wage ~ exper + I( exper^2 ) + educ + city, Mroz87) p <- prediction(m) expect_true(inherits(p, "prediction"), label = "'prediction' class is correct") @@ -570,16 +570,22 @@ if (require("survey", quietly = TRUE)) { p <- prediction(m) expect_true(inherits(p, "prediction"), label = "'prediction' class is correct") expect_true(all(c("fitted", "se.fitted") %in% names(p)), label = "'fitted' and 'se.fitted' columns returned") + dstrat2 <- subset(dstrat, yr.rnd == "No") + dstrat2$variables$enroll[10] = NA + p2 <- prediction(m, dstrat2) + expect_true(inherits(p2, "prediction"), label = "'prediction' class is correct") + expect_true(all(c("fitted", "se.fitted") %in% names(p2)), label = "'fitted' and 'se.fitted' columns returned") + expect_true(is.na(p2$fitted[10]) & is.na(p2$se.fitted[10]), label = "NAs in data handled correctly") }) } if (require("survival", quietly = TRUE)) { test_that("Test prediction() for 'coxph'", { - test1 <- list(time=c(4,3,1,1,2,2,3), - status=c(1,1,1,0,1,1,0), - x=c(0,2,1,1,1,0,0), - sex=c(0,0,0,0,1,1,1)) - m <- survival::coxph(survival::Surv(time, status) ~ x + survival::strata(sex), test1) + test1 <- list(time=c(4,3,1,1,2,2,3), + status=c(1,1,1,0,1,1,0), + x=c(0,2,1,1,1,0,0), + sex=c(0,0,0,0,1,1,1)) + m <- survival::coxph(survival::Surv(time, status) ~ x + survival::strata(sex), test1) p <- prediction(m) expect_true(inherits(p, "prediction"), label = "'prediction' class is correct") expect_true(all(c("fitted", "se.fitted") %in% names(p)), label = "'fitted' and 'se.fitted' columns returned")