diff --git a/NAMESPACE b/NAMESPACE index 01be9360..83719051 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,6 +11,7 @@ S3method(dim,waic) S3method(loo,"function") S3method(loo,array) S3method(loo,matrix) +S3method(loo_compare,default) S3method(loo_model_weights,default) S3method(plot,loo) S3method(plot,psis) @@ -24,6 +25,7 @@ S3method(print,psis) S3method(print,psis_loo) S3method(print,stacking_weights) S3method(print,waic) +S3method(print_dims,kfold) S3method(print_dims,psis) S3method(print_dims,psis_loo) S3method(print_dims,waic) @@ -43,7 +45,14 @@ export(compare) export(example_loglik_array) export(example_loglik_matrix) export(extract_log_lik) +export(find_model_names) export(gpdfit) +export(is.kfold) +export(is.loo) +export(is.psis) +export(is.psis_loo) +export(is.waic) +export(kfold) export(kfold_split_grouped) export(kfold_split_random) export(kfold_split_stratified) @@ -51,6 +60,7 @@ export(loo) export(loo.array) export(loo.function) export(loo.matrix) +export(loo_compare) export(loo_i) export(loo_model_weights) export(loo_model_weights.default) diff --git a/NEWS.md b/NEWS.md index b41787fb..8a223357 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,12 +2,12 @@ * New vignette on LOO for non-factorizable joint Gaussian models. (#75) -* When comparing more than two models there is now also -an `se_diff` column in the results. (#78) +* New `se_diff` column in model comparison results. (#78) -* Fix for `psis()` when `log_ratios` are very small. (#74) +* Improved behavior of `psis()` when `log_ratios` are very small. (#74) -* Allow `r_eff=NA` to suppress warning when specifying `r_eff` is not applicable (i.e., draws not from MCMC). (#72) +* Allow `r_eff=NA` to suppress warning when specifying `r_eff` is not applicable +(i.e., draws not from MCMC). (#72) * Update effective sample size calculations to match RStan's version. (#85) diff --git a/R/compare.R b/R/compare.R index 5923b11f..755462bc 100644 --- a/R/compare.R +++ b/R/compare.R @@ -60,6 +60,7 @@ #' } #' compare <- function(..., x = list()) { + # .Deprecated("loo_compare") dots <- list(...) if (length(dots)) { if (length(x)) { @@ -87,6 +88,7 @@ compare <- function(..., x = list()) { loo1 <- dots[[1]] loo2 <- dots[[2]] comp <- compare_two_models(loo1, loo2) + class(comp) <- c(class(comp), "old_compare.loo") return(comp) } else { Ns <- sapply(dots, function(x) nrow(x$pointwise)) @@ -115,29 +117,11 @@ compare <- function(..., x = list()) { se_diff <- apply(diffs, 2, se_elpd_diff) comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) rownames(comp) <- rnms - class(comp) <- c("compare.loo", class(comp)) + class(comp) <- c("compare.loo", class(comp), "old_compare.loo") comp } } -#' @rdname compare -#' @export -#' @param digits For the print method only, the number of digits to use when -#' printing. -#' @param simplify For the print method only, should only the essential columns -#' of the summary matrix be printed when comparing more than two models? The -#' entire matrix is always returned, but by default only the most important -#' columns are printed. -print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { - xcopy <- x - if (NCOL(xcopy) >= 2 && simplify) { - patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" - xcopy <- xcopy[, grepl(patts, colnames(xcopy))] - } - print(.fr(xcopy, digits), quote = FALSE) - invisible(x) -} - # internal ---------------------------------------------------------------- @@ -154,14 +138,3 @@ compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), chec comp <- c(elpd_diff = sum(diffs), se = se_elpd_diff(diffs)) structure(comp, class = "compare.loo") } - -elpd_diffs <- function(loo_a, loo_b) { - pt_a <- loo_a$pointwise - pt_b <- loo_b$pointwise - elpd <- grep("^elpd", colnames(pt_a)) - pt_b[, elpd] - pt_a[, elpd] -} -se_elpd_diff <- function(diffs) { - N <- length(diffs) - sqrt(N) * sd(diffs) -} diff --git a/R/helpers.R b/R/helpers.R index 829e7e72..ff9775ad 100644 --- a/R/helpers.R +++ b/R/helpers.R @@ -37,21 +37,6 @@ table_of_estimates <- function(x) { } -# checking classes -------------------------------------------------------- -is.psis <- function(x) { - inherits(x, "psis") && is.list(x) -} -is.loo <- function(x) { - inherits(x, "loo") -} -is.psis_loo <- function(x) { - inherits(x, "psis_loo") && is.loo(x) -} -is.waic <- function(x) { - inherits(x, "waic") && is.loo(x) -} - - # validating and reshaping arrays/matrices ------------------------------- #' Check for NAs and non-finite values in log-lik (or log-ratios) diff --git a/R/kfold-generic.R b/R/kfold-generic.R new file mode 100644 index 00000000..8e7e2f63 --- /dev/null +++ b/R/kfold-generic.R @@ -0,0 +1,34 @@ +#' Generic function for K-fold cross-validation for developers +#' +#' For developers of modeling packages, \pkg{loo} includes a generic function +#' \code{kfold} so that methods may be defined for K-fold CV without name +#' conflicts between packages. See, e.g., the \code{kfold.stanreg} method in +#' \pkg{rstanarm} and the \code{kfold.brmsfit} method in \pkg{brms}. +#' +#' @name kfold-generic +#' @param x A fitted model object. +#' @param ... Arguments to pass to specific methods. +#' +#' @return For developers defining a \code{kfold} method for a class +#' \code{"foo"}, the \code{kfold.foo} function should return a list with class +#' \code{c("kfold", "loo")} with at least the elements +#' \itemize{ +#' \item \code{"estimates"}: a 1x2 matrix with column names "Estimate" and "SE" +#' containing the ELPD estimate and its standard error. +#' \item \code{"pointwise"}: an Nx1 matrix with column name "elpd_kfold" containing +#' the pointwise contributions for each data point. +#' } +#' +NULL + +#' @rdname kfold-generic +#' @export +kfold <- function(x, ...) { + UseMethod("kfold") +} + +#' @rdname kfold-generic +#' @export +is.kfold <- function(x) { + inherits(x, "kfold") && is.loo(x) +} diff --git a/R/kfold-helpers.R b/R/kfold-helpers.R index 03805609..a1cbafbb 100644 --- a/R/kfold-helpers.R +++ b/R/kfold-helpers.R @@ -1,16 +1,15 @@ #' Helper functions for K-fold cross-validation #' -#' These functions can be used to generate indexes for use with K-fold -#' cross-validation. +#' @description These functions can be used to generate indexes for use with +#' K-fold cross-validation. See the \strong{Details} section for explanations. #' #' @name kfold-helpers #' @param K The number of folds to use. #' @param N The number of observations in the data. #' @param x A discrete variable of length \code{N} with at least \code{K} levels #' (unique values). Will be coerced to \code{\link{factor}}. -#' . -#' @return An integer vector of length \code{N} where each element is an index -#' in \code{1:K}. +#' +#' @return An integer vector of length \code{N} where each element is an index in \code{1:K}. #' #' @details #' \code{kfold_split_random} splits the data into \code{K} groups diff --git a/R/loo.R b/R/loo.R index 4555c069..e36168a4 100644 --- a/R/loo.R +++ b/R/loo.R @@ -397,6 +397,18 @@ dim.psis_loo <- function(x) { } +#' @rdname loo +#' @export +is.loo <- function(x) { + inherits(x, "loo") +} + +#' @rdname loo +#' @export +is.psis_loo <- function(x) { + inherits(x, "psis_loo") && is.loo(x) +} + # internal ---------------------------------------------------------------- diff --git a/R/loo_compare.R b/R/loo_compare.R new file mode 100644 index 00000000..78e184bc --- /dev/null +++ b/R/loo_compare.R @@ -0,0 +1,178 @@ +#' Model comparison +#' +#' Compare fitted models on LOO or WAIC. +#' +#' @export +#' @param x An object of class \code{"loo"} or a list of such objects. +#' @param ... Additional objects of class \code{"loo"}. +#' +#' @return A matrix with class \code{"compare.loo"} that has its own +#' print method. See the \strong{Details} section for more . +#' +#' @details +#' When comparing two fitted models, we can estimate the difference in their +#' expected predictive accuracy by the difference in \code{elpd_loo} or +#' \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the +#' deviance scale). +#' +#' When using \code{loo_compare()}, the returned matrix will have one row per +#' model and several columns of estimates. The values in the \code{elpd_diff} +#' and \code{se_diff} columns of the returned matrix are computed by making +#' pairwise comparisons between each model and the model with the largest ELPD +#' (the model in the first row). For this reason the \code{elpd_diff} column +#' will always have the value \code{0} in the first row (i.e., the difference +#' between the preferred model and itself) and negative values in subsequent +#' rows for the remaining models. +#' +#' To compute the standard error of the difference in ELPD --- which should +#' not be expected to equal the difference of the standard errors --- we use a +#' paired estimate to take advantage of the fact that the same set of \eqn{N} +#' data points was used to fit both models. These calculations should be most +#' useful when \eqn{N} is large, because then non-normality of the +#' distribution is not such an issue when estimating the uncertainty in these +#' sums. These standard errors, for all their flaws, should give a better +#' sense of uncertainty than what is obtained using the current standard +#' approach of comparing differences of deviances to a Chi-squared +#' distribution, a practice derived for Gaussian linear models or +#' asymptotically, and which only applies to nested models in any case. +#' +#' @template loo-and-psis-references +#' +#' @examples +#' \dontrun{ +#' loo1 <- loo(log_lik1) +#' loo2 <- loo(log_lik2) +#' print(loo_compare(loo1, loo2), digits = 3) +#' print(loo_compare(x = list(loo1, loo2))) +#' +#' waic1 <- waic(log_lik1) +#' waic2 <- waic(log_lik2) +#' loo_compare(waic1, waic2) +#' } +#' +loo_compare <- function(x, ...) { + UseMethod("loo_compare") +} + +#' @rdname loo_compare +#' @export +loo_compare.default <- function(x, ...) { + if (is.loo(x)) { + dots <- list(...) + loos <- c(list(x), dots) + } else { + if (!is.list(x) || !length(x)) { + stop("'x' must be a list if not a 'loo' object.") + } + if (length(list(...))) { + stop("If 'x' is a list then '...' should not be specified.") + } + loos <- x + } + + if (!all(sapply(loos, is.loo))) { + stop("All inputs should have class 'loo'.") + } + if (length(loos) <= 1L) { + stop("'loo_compare' requires at least two models.") + } + + Ns <- sapply(loos, function(x) nrow(x$pointwise)) + if (!all(Ns == Ns[1L])) { + stop("Not all models have the same number of data points.") + } + + tmp <- sapply(loos, function(x) { + est <- x$estimates + setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) ) + }) + + colnames(tmp) <- find_model_names(loos) + rnms <- rownames(tmp) + comp <- tmp + ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE) + comp <- t(comp)[ord, ] + patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$") + col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))), + use.names = FALSE) + comp <- comp[, col_ord] + + # compute elpd_diff and se_elpd_diff relative to best model + rnms <- rownames(comp) + diffs <- mapply(FUN = elpd_diffs, loos[ord[1]], loos[ord]) + elpd_diff <- apply(diffs, 2, sum) + se_diff <- apply(diffs, 2, se_elpd_diff) + comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) + rownames(comp) <- rnms + + class(comp) <- c("compare.loo", class(comp)) + return(comp) +} + +#' @rdname loo_compare +#' @export +#' @param digits For the print method only, the number of digits to use when +#' printing. +#' @param simplify For the print method only, should only the essential columns +#' of the summary matrix be printed? The entire matrix is always returned, but +#' by default only the most important columns are printed. +print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { + xcopy <- x + if (inherits(xcopy, "old_compare.loo")) { + if (NCOL(xcopy) >= 2 && simplify) { + patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" + xcopy <- xcopy[, grepl(patts, colnames(xcopy))] + } + } else if (NCOL(xcopy) >= 2 && simplify) { + xcopy <- xcopy[, c("elpd_diff", "se_diff")] + } + print(.fr(xcopy, digits), quote = FALSE) + invisible(x) +} + + + +# internal ---------------------------------------------------------------- +elpd_diffs <- function(loo_a, loo_b) { + pt_a <- loo_a$pointwise + pt_b <- loo_b$pointwise + elpd <- grep("^elpd", colnames(pt_a)) + pt_b[, elpd] - pt_a[, elpd] +} +se_elpd_diff <- function(diffs) { + N <- length(diffs) + sqrt(N) * sd(diffs) +} + + + +#' Find the model names associated with loo objects +#' +#' @export +#' @keywords internal +#' @param x List of loo objects. +#' @return Character vector of model names the same length as x. +#' +find_model_names <- function(x) { + stopifnot(is.list(x)) + out_names <- character(length(x)) + + names1 <- names(x) + names2 <- lapply(x, "attr", "model_name", exact = TRUE) + names3 <- lapply(x, "[[", "model_name") + names4 <- paste0("model", seq_along(x)) + + for (j in seq_along(x)) { + if (isTRUE(nzchar(names1[j]))) { + out_names[j] <- names1[j] + } else if (length(names2[[j]])) { + out_names[j] <- names2[[j]] + } else if (length(names3[[j]])) { + out_names[j] <- names3[[j]] + } else { + out_names[j] <- names4[j] + } + } + + return(out_names) +} diff --git a/R/print.R b/R/print.R index 044a29da..d55a5299 100644 --- a/R/print.R +++ b/R/print.R @@ -104,6 +104,15 @@ print_dims.waic <- function(x, ...) { ) } +#' @rdname print_dims +#' @export +print_dims.kfold <- function(x, ...) { + K <- attr(x, "K", exact = TRUE) + if (!is.null(K)) { + cat("Based on", paste0(K, "-fold"), "cross-validation\n") + } +} + print_mcse_summary <- function(x, digits) { mcse_val <- mcse_loo(x) diff --git a/R/psis.R b/R/psis.R index b91ac0aa..ef767b6f 100644 --- a/R/psis.R +++ b/R/psis.R @@ -170,6 +170,13 @@ dim.psis <- function(x) { attr(x, "dims") } +#' @rdname psis +#' @export +#' @param x For \code{is.psis}, an object to check. +is.psis <- function(x) { + inherits(x, "psis") && is.list(x) +} + # internal ---------------------------------------------------------------- diff --git a/R/waic.R b/R/waic.R index cf57f6aa..478801d4 100644 --- a/R/waic.R +++ b/R/waic.R @@ -121,6 +121,12 @@ dim.waic <- function(x) { attr(x, "dims") } +#' @rdname waic +#' @export +is.waic <- function(x) { + inherits(x, "waic") && is.loo(x) +} + # internal ---------------------------------------------------------------- diff --git a/man/compare.Rd b/man/compare.Rd index 6acb10f0..cb30a7c2 100644 --- a/man/compare.Rd +++ b/man/compare.Rd @@ -2,12 +2,9 @@ % Please edit documentation in R/compare.R \name{compare} \alias{compare} -\alias{print.compare.loo} \title{Model comparison} \usage{ compare(..., x = list()) - -\method{print}{compare.loo}(x, ..., digits = 1, simplify = TRUE) } \arguments{ \item{...}{At least two objects returned by \code{\link{loo}} (or @@ -16,14 +13,6 @@ compare(..., x = list()) \item{x}{A list of at least two objects returned by \code{\link{loo}} (or \code{\link{waic}}). This argument can be used as an alternative to specifying the objects in \code{...}.} - -\item{digits}{For the print method only, the number of digits to use when -printing.} - -\item{simplify}{For the print method only, should only the essential columns -of the summary matrix be printed when comparing more than two models? The -entire matrix is always returned, but by default only the most important -columns are printed.} } \value{ A vector or matrix with class \code{'compare.loo'} that has its own diff --git a/man/find_model_names.Rd b/man/find_model_names.Rd new file mode 100644 index 00000000..575665c7 --- /dev/null +++ b/man/find_model_names.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/loo_compare.R +\name{find_model_names} +\alias{find_model_names} +\title{Find the model names associated with loo objects} +\usage{ +find_model_names(x) +} +\arguments{ +\item{x}{List of loo objects.} +} +\value{ +Character vector of model names the same length as x. +} +\description{ +Find the model names associated with loo objects +} +\keyword{internal} diff --git a/man/kfold-generic.Rd b/man/kfold-generic.Rd new file mode 100644 index 00000000..ab81e55d --- /dev/null +++ b/man/kfold-generic.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/kfold-generic.R +\name{kfold-generic} +\alias{kfold-generic} +\alias{kfold} +\alias{is.kfold} +\title{Generic function for K-fold cross-validation for developers} +\usage{ +kfold(x, ...) + +is.kfold(x) +} +\arguments{ +\item{x}{A fitted model object.} + +\item{...}{Arguments to pass to specific methods.} +} +\value{ +For developers defining a \code{kfold} method for a class + \code{"foo"}, the \code{kfold.foo} function should return a list with class + \code{c("kfold", "loo")} with at least the elements + \itemize{ + \item \code{"estimates"}: a 1x2 matrix with column names "Estimate" and "SE" + containing the ELPD estimate and its standard error. + \item \code{"pointwise"}: an Nx1 matrix with column name "elpd_kfold" containing + the pointwise contributions for each data point. + } +} +\description{ +For developers of modeling packages, \pkg{loo} includes a generic function +\code{kfold} so that methods may be defined for K-fold CV without name +conflicts between packages. See, e.g., the \code{kfold.stanreg} method in +\pkg{rstanarm} and the \code{kfold.brmsfit} method in \pkg{brms}. +} diff --git a/man/kfold-helpers.Rd b/man/kfold-helpers.Rd index 8061624c..df214636 100644 --- a/man/kfold-helpers.Rd +++ b/man/kfold-helpers.Rd @@ -19,16 +19,14 @@ kfold_split_grouped(K = 10, x = NULL) \item{N}{The number of observations in the data.} \item{x}{A discrete variable of length \code{N} with at least \code{K} levels -(unique values). Will be coerced to \code{\link{factor}}. -.} +(unique values). Will be coerced to \code{\link{factor}}.} } \value{ -An integer vector of length \code{N} where each element is an index - in \code{1:K}. +An integer vector of length \code{N} where each element is an index in \code{1:K}. } \description{ -These functions can be used to generate indexes for use with K-fold -cross-validation. +These functions can be used to generate indexes for use with + K-fold cross-validation. See the \strong{Details} section for explanations. } \details{ \code{kfold_split_random} splits the data into \code{K} groups diff --git a/man/loo.Rd b/man/loo.Rd index 3593883d..b68b6377 100644 --- a/man/loo.Rd +++ b/man/loo.Rd @@ -6,6 +6,8 @@ \alias{loo.matrix} \alias{loo.function} \alias{loo_i} +\alias{is.loo} +\alias{is.psis_loo} \title{Efficient approximate leave-one-out cross-validation (LOO)} \usage{ loo(x, ...) @@ -20,6 +22,10 @@ loo(x, ...) save_psis = FALSE, cores = getOption("mc.cores", 1)) loo_i(i, llfun, ..., data = NULL, draws = NULL, r_eff = NULL) + +is.loo(x) + +is.psis_loo(x) } \arguments{ \item{x}{A log-likelihood array, matrix, or function. See the \strong{Methods diff --git a/man/loo_compare.Rd b/man/loo_compare.Rd new file mode 100644 index 00000000..2528b6bd --- /dev/null +++ b/man/loo_compare.Rd @@ -0,0 +1,84 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/loo_compare.R +\name{loo_compare} +\alias{loo_compare} +\alias{loo_compare.default} +\alias{print.compare.loo} +\title{Model comparison} +\usage{ +loo_compare(x, ...) + +\method{loo_compare}{default}(x, ...) + +\method{print}{compare.loo}(x, ..., digits = 1, simplify = TRUE) +} +\arguments{ +\item{x}{An object of class \code{"loo"} or a list of such objects.} + +\item{...}{Additional objects of class \code{"loo"}.} + +\item{digits}{For the print method only, the number of digits to use when +printing.} + +\item{simplify}{For the print method only, should only the essential columns +of the summary matrix be printed? The entire matrix is always returned, but +by default only the most important columns are printed.} +} +\value{ +A matrix with class \code{"compare.loo"} that has its own + print method. See the \strong{Details} section for more . +} +\description{ +Compare fitted models on LOO or WAIC. +} +\details{ +When comparing two fitted models, we can estimate the difference in their + expected predictive accuracy by the difference in \code{elpd_loo} or + \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the + deviance scale). + + When using \code{loo_compare()}, the returned matrix will have one row per + model and several columns of estimates. The values in the \code{elpd_diff} + and \code{se_diff} columns of the returned matrix are computed by making + pairwise comparisons between each model and the model with the largest ELPD + (the model in the first row). For this reason the \code{elpd_diff} column + will always have the value \code{0} in the first row (i.e., the difference + between the preferred model and itself) and negative values in subsequent + rows for the remaining models. + + To compute the standard error of the difference in ELPD --- which should + not be expected to equal the difference of the standard errors --- we use a + paired estimate to take advantage of the fact that the same set of \eqn{N} + data points was used to fit both models. These calculations should be most + useful when \eqn{N} is large, because then non-normality of the + distribution is not such an issue when estimating the uncertainty in these + sums. These standard errors, for all their flaws, should give a better + sense of uncertainty than what is obtained using the current standard + approach of comparing differences of deviances to a Chi-squared + distribution, a practice derived for Gaussian linear models or + asymptotically, and which only applies to nested models in any case. +} +\examples{ +\dontrun{ +loo1 <- loo(log_lik1) +loo2 <- loo(log_lik2) +print(loo_compare(loo1, loo2), digits = 3) +print(loo_compare(x = list(loo1, loo2))) + +waic1 <- waic(log_lik1) +waic2 <- waic(log_lik2) +loo_compare(waic1, waic2) +} + +} +\references{ +Vehtari, A., Gelman, A., and Gabry, J. (2017a). Practical + Bayesian model evaluation using leave-one-out cross-validation and WAIC. + \emph{Statistics and Computing}. 27(5), 1413--1432. + doi:10.1007/s11222-016-9696-4. + (\href{http://link.springer.com/article/10.1007\%2Fs11222-016-9696-4}{published + version}, \href{http://arxiv.org/abs/1507.04544}{arXiv preprint}). + +Vehtari, A., Gelman, A., and Gabry, J. (2017b). Pareto smoothed + importance sampling. arXiv preprint: \url{http://arxiv.org/abs/1507.02646/} +} diff --git a/man/print_dims.Rd b/man/print_dims.Rd index 3f275498..2bab964e 100644 --- a/man/print_dims.Rd +++ b/man/print_dims.Rd @@ -5,6 +5,7 @@ \alias{print_dims.psis} \alias{print_dims.psis_loo} \alias{print_dims.waic} +\alias{print_dims.kfold} \title{Print dimensions of log-likelihood or log-weights matrix} \usage{ print_dims(x, ...) @@ -14,6 +15,8 @@ print_dims(x, ...) \method{print_dims}{psis_loo}(x, ...) \method{print_dims}{waic}(x, ...) + +\method{print_dims}{kfold}(x, ...) } \arguments{ \item{x}{The object returned by \code{\link{psis}}, \code{\link{loo}}, or diff --git a/man/psis.Rd b/man/psis.Rd index 2755a1c6..1a9d8cfc 100644 --- a/man/psis.Rd +++ b/man/psis.Rd @@ -6,6 +6,7 @@ \alias{psis.matrix} \alias{psis.default} \alias{weights.psis} +\alias{is.psis} \title{Pareto smoothed importance sampling (PSIS)} \usage{ psis(log_ratios, ...) @@ -19,6 +20,8 @@ psis(log_ratios, ...) \method{psis}{default}(log_ratios, ..., r_eff = NULL) \method{weights}{psis}(object, ..., log = TRUE, normalize = TRUE) + +is.psis(x) } \arguments{ \item{log_ratios}{An array, matrix, or vector of importance ratios on the log @@ -56,6 +59,8 @@ the log scale? Defaults to \code{TRUE}.} \item{normalize}{For the \code{weights} method, should the weights be normalized? Defaults to \code{TRUE}.} + +\item{x}{For \code{is.psis}, an object to check.} } \value{ The \code{psis} methods return an object of class \code{"psis"}, diff --git a/man/waic.Rd b/man/waic.Rd index 3cd302ce..09b1b796 100644 --- a/man/waic.Rd +++ b/man/waic.Rd @@ -5,6 +5,7 @@ \alias{waic.array} \alias{waic.matrix} \alias{waic.function} +\alias{is.waic} \title{Widely applicable information criterion (WAIC)} \usage{ waic(x, ...) @@ -14,6 +15,8 @@ waic(x, ...) \method{waic}{matrix}(x, ...) \method{waic}{function}(x, ..., data = NULL, draws = NULL) + +is.waic(x) } \arguments{ \item{x}{A log-likelihood array, matrix, or function. See the \strong{Methods diff --git a/tests/testthat/compare_three_models.rds b/tests/testthat/compare_three_models.rds index 76d81dd4..a7abb8a5 100644 Binary files a/tests/testthat/compare_three_models.rds and b/tests/testthat/compare_three_models.rds differ diff --git a/tests/testthat/compare_two_models.rds b/tests/testthat/compare_two_models.rds index 48ddbdbf..7261c8fb 100644 Binary files a/tests/testthat/compare_two_models.rds and b/tests/testthat/compare_two_models.rds differ diff --git a/tests/testthat/loo_compare_three_models.rds b/tests/testthat/loo_compare_three_models.rds new file mode 100644 index 00000000..2f60b185 Binary files /dev/null and b/tests/testthat/loo_compare_three_models.rds differ diff --git a/tests/testthat/loo_compare_two_models.rds b/tests/testthat/loo_compare_two_models.rds new file mode 100644 index 00000000..546024d2 Binary files /dev/null and b/tests/testthat/loo_compare_two_models.rds differ diff --git a/tests/testthat/test_compare.R b/tests/testthat/test_compare.R index 8eade396..0e6efa2d 100644 --- a/tests/testthat/test_compare.R +++ b/tests/testthat/test_compare.R @@ -2,7 +2,7 @@ library(loo) set.seed(123) SW <- suppressWarnings -context("compare") +context("compare models") LLarr <- example_loglik_array() LLarr2 <- array(rnorm(prod(dim(LLarr)), c(LLarr), 0.5), dim = dim(LLarr)) @@ -10,43 +10,96 @@ LLarr3 <- array(rnorm(prod(dim(LLarr)), c(LLarr), 1), dim = dim(LLarr)) w1 <- SW(waic(LLarr)) w2 <- SW(waic(LLarr2)) -test_that("compare throws appropriate errors", { + +test_that("loo_compare throws appropriate errors", { w3 <- SW(waic(LLarr[,, -1])) w4 <- SW(waic(LLarr[,, -(1:2)])) - expect_error(loo::compare(w1, w2, x = list(w1, w2)), - regexp = "If 'x' is specified then '...' should not be specified") - expect_error(loo::compare(w1, list(1,2,3)), + expect_error(loo_compare(w1, w2, x = list(w1, w2)), + regexp = "If 'x' is a list then '...' should not be specified") + expect_error(loo_compare(w1, list(1,2,3)), regexp = "class 'loo'") - expect_error(loo::compare(w1), + expect_error(loo_compare(w1), regexp = "requires at least two models") - expect_error(loo::compare(x = list(w1)), + expect_error(loo_compare(x = list(w1)), regexp = "requires at least two models") - expect_error(loo::compare(w1, w3), + expect_error(loo_compare(w1, w3), regexp = "same number of data points") - expect_error(loo::compare(w1, w2, w3), + expect_error(loo_compare(w1, w2, w3), regexp = "same number of data points") - expect_silent(loo::compare(w1, w2)) - expect_silent(loo::compare(w1, w1, w2)) }) + + +comp_colnames <- c( + "elpd_diff", "se_diff", "elpd_waic", "se_elpd_waic", + "p_waic", "se_p_waic", "waic", "se_waic" +) + +test_that("loo_compare returns expected results (2 models)", { + comp1 <- loo_compare(w1, w1) + expect_s3_class(comp1, "compare.loo") + expect_equal(colnames(comp1), comp_colnames) + expect_equal(rownames(comp1), c("model1", "model2")) + expect_output(print(comp1), "elpd_diff") + expect_equivalent(comp1[1:2,1], c(0, 0)) + expect_equivalent(comp1[1:2,2], c(0, 0)) + + comp2 <- loo_compare(w1, w2) + expect_s3_class(comp2, "compare.loo") + expect_equal_to_reference(comp2, "loo_compare_two_models.rds") + expect_equal(colnames(comp2), comp_colnames) + + # specifying objects via ... and via arg x gives equal results + expect_equal(comp2, loo_compare(x = list(w1, w2))) +}) + + +test_that("loo_compare returns expected result (3 models)", { + w3 <- SW(waic(LLarr3)) + comp1 <- loo_compare(w1, w2, w3) + + expect_equal(colnames(comp1), comp_colnames) + expect_equal(rownames(comp1), c("model1", "model2", "model3")) + expect_equal(comp1[1,1], 0) + expect_s3_class(comp1, "compare.loo") + expect_s3_class(comp1, "matrix") + expect_equal_to_reference(comp1, "loo_compare_three_models.rds") + + # specifying objects via '...' gives equivalent results (equal + # except rownames) to using 'x' argument + expect_equivalent(comp1, loo_compare(x = list(w1, w2, w3))) +}) + +# Tests for deprecated compare() ------------------------------------------ + +# test_that("compare throws deprecation warnings", { +# expect_warning(loo::compare(w1, w2), "Deprecated") +# expect_warning(loo::compare(w1, w1, w2), "Deprecated") +# }) + test_that("compare returns expected result (2 models)", { - comp1 <- compare(w1, w1) + comp1 <- loo::compare(w1, w1) + # comp1 <- expect_warning(loo::compare(w1, w1), "Deprecated") expect_output(print(comp1), "elpd_diff") expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0)) - comp2 <- compare(w1, w2) + comp2 <- loo::compare(w1, w2) + # comp2 <- expect_warning(loo::compare(w1, w2), "Deprecated") expect_equal_to_reference(comp2, "compare_two_models.rds") expect_named(comp2, c("elpd_diff", "se")) expect_s3_class(comp2, "compare.loo") # specifying objects via ... and via arg x gives equal results - expect_equal(comp2, compare(x = list(w1, w2))) + comp_via_list <- loo::compare(x = list(w1, w2)) + # comp_via_list <- expect_warning(loo::compare(x = list(w1, w2)), "Deprecated") + expect_equal(comp2, comp_via_list) }) test_that("compare returns expected result (3 models)", { w3 <- SW(waic(LLarr3)) - comp1 <- compare(w1, w2, w3) + comp1 <- loo::compare(w1, w2, w3) + # comp1 <- expect_warning(loo::compare(w1, w2, w3), "Deprecated") expect_equal( colnames(comp1), @@ -62,5 +115,7 @@ test_that("compare returns expected result (3 models)", { # specifying objects via '...' gives equivalent results (equal # except rownames) to using 'x' argument - expect_equivalent(comp1, compare(x = list(w1, w2, w3))) + comp_via_list <- loo::compare(x = list(w1, w2, w3)) + # comp_via_list <- expect_warning(loo::compare(x = list(w1, w2, w3)), "Deprecated") + expect_equivalent(comp1, comp_via_list) })