Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ vignettes/loo2-non-factorizable_cache/*
^vignettes/online-only$

^CRAN-SUBMISSION$
^release-prep\.R$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ tests/testthat/Rplots.pdf

cran-comments.md
CRAN-RELEASE
release-prep.R
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: loo
Type: Package
Title: Efficient Leave-One-Out Cross-Validation and WAIC for Bayesian Models
Version: 2.6.0
Version: 2.6.0.9000
Date: 2023-03-30
Authors@R: c(person("Aki", "Vehtari", email = "Aki.Vehtari@aalto.fi", role = c("aut")),
person("Jonah", "Gabry", email = "jsg2201@columbia.edu", role = c("cre", "aut")),
Expand Down
11 changes: 11 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# loo 2.6.0.9000

### New features

* `E_loo` now allows `type="sd"`.


### Bug fixes

* Fix bug in `E_loo` when `type=variance`.

# loo 2.6.0

### New features
Expand Down
39 changes: 18 additions & 21 deletions R/E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#' specified we are able to compute [Pareto k][pareto-k-diagnostic]
#' diagnostics specific to `E_loo()`.
#' @param type The type of expectation to compute. The options are
#' `"mean"`, `"variance"`, and `"quantile"`.
#' `"mean"`, `"variance"`, `"sd"`, and `"quantile"`.
#' @param probs For computing quantiles, a vector of probabilities.
#' @param ... Arguments passed to individual methods.
#'
Expand Down Expand Up @@ -75,6 +75,7 @@
#'
#' E_loo(yrep, psis_object, type = "mean")
#' E_loo(yrep, psis_object, type = "var")
#' E_loo(yrep, psis_object, type = "sd")
#' E_loo(yrep, psis_object, type = "quantile", probs = 0.5) # median
#' E_loo(yrep, psis_object, type = "quantile", probs = c(0.1, 0.9))
#'
Expand All @@ -94,7 +95,7 @@ E_loo.default <-
function(x,
psis_object,
...,
type = c("mean", "variance", "quantile"),
type = c("mean", "variance", "sd", "quantile"),
probs = NULL,
log_ratios = NULL) {
stopifnot(
Expand All @@ -105,14 +106,9 @@ E_loo.default <-
)
type <- match.arg(type)
E_fun <- .E_fun(type)
r_eff <- NULL
if (type == "variance") {
r_eff <- relative_eff(psis_object)
}

w <- as.vector(weights(psis_object, log = FALSE))
x <- as.vector(x)
out <- E_fun(x, w, probs, r_eff)
out <- E_fun(x, w, probs)

if (is.null(log_ratios)) {
warning("'log_ratios' not specified. Can't compute k-hat diagnostic.",
Expand All @@ -130,7 +126,7 @@ E_loo.matrix <-
function(x,
psis_object,
...,
type = c("mean", "variance", "quantile"),
type = c("mean", "variance", "sd", "quantile"),
probs = NULL,
log_ratios = NULL) {
stopifnot(
Expand All @@ -142,10 +138,7 @@ E_loo.matrix <-
type <- match.arg(type)
E_fun <- .E_fun(type)
fun_val <- numeric(1)
r_eff <- NULL
if (type == "variance") {
r_eff <- relative_eff(psis_object)
} else if (type == "quantile") {
if (type == "quantile") {
stopifnot(
is.numeric(probs),
length(probs) >= 1,
Expand All @@ -156,7 +149,7 @@ E_loo.matrix <-
w <- weights(psis_object, log = FALSE)

out <- vapply(seq_len(ncol(x)), function(i) {
E_fun(x[, i], w[, i], probs = probs, r_eff = r_eff[i])
E_fun(x[, i], w[, i], probs = probs)
}, FUN.VALUE = fun_val)

if (is.null(log_ratios)) {
Expand All @@ -178,11 +171,12 @@ E_loo.matrix <-
#' @return The function for computing the weighted expectation specified by
#' `type`.
#'
.E_fun <- function(type = c("mean", "variance", "quantile")) {
.E_fun <- function(type = c("mean", "variance", "sd", "quantile")) {
switch(
type,
"mean" = .wmean,
"variance" = .wvar,
"sd" = .wsd,
"quantile" = .wquant
)
}
Expand All @@ -199,12 +193,15 @@ E_loo.matrix <-
.wmean <- function(x, w, ...) {
sum(w * x)
}
.wvar <- function(x, w, r_eff = NULL, ...) {
if (is.null(r_eff)) {
r_eff <- 1
}
r <- (x - .wmean(x, w))^2
sum(w^2 * r) / r_eff
.wvar <- function(x, w, ...) {
# The denominator (1- sum(w^2)) is equal to (ESS-1)/ESS, where effective
# sample size ESS is estimated with the generic target quantity invariant
# estimate 1/sum(w^2), see e.g. "Monte Carlo theory, methods and examples"
# by Owen (2013).
(sum(.wmean(x^2, w)) - sum(.wmean(x, w)^2)) / (1 - sum(w^2))
}
.wsd <- function(x, w, ...) {
sqrt(.wvar(x, w))
}
.wquant <- function(x, w, probs, ...) {
if (all(w == w[1])) {
Expand Down
7 changes: 4 additions & 3 deletions man/E_loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_default_var.rds
Binary file not shown.
Binary file added tests/testthat/reference-results/E_loo_matrix_sd.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_matrix_var.rds
Binary file not shown.
44 changes: 35 additions & 9 deletions tests/testthat/test_E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@ log_rats <- -LLmat
# matrix method
E_test_mean <- E_loo(x, psis_mat, type = "mean", log_ratios = log_rats)
E_test_var <- E_loo(x, psis_mat, type = "var", log_ratios = log_rats)
E_test_sd <- E_loo(x, psis_mat, type = "sd", log_ratios = log_rats)
E_test_quant <- E_loo(x, psis_mat, type = "quantile", probs = 0.5, log_ratios = log_rats)
E_test_quant2 <- E_loo(x, psis_mat, type = "quantile", probs = c(0.1, 0.9), log_ratios = log_rats)

# vector method
E_test_mean_vec <- E_loo(x[, 1], psis_vec, type = "mean", log_ratios = log_rats[,1])
E_test_var_vec <- E_loo(x[, 1], psis_vec, type = "var", log_ratios = log_rats[,1])
E_test_sd_vec <- E_loo(x[, 1], psis_vec, type = "sd", log_ratios = log_rats[,1])
E_test_quant_vec <- E_loo(x[, 1], psis_vec, type = "quant", probs = 0.5, log_ratios = log_rats[,1])
E_test_quant_vec2 <- E_loo(x[, 1], psis_vec, type = "quant", probs = c(0.1, 0.5, 0.9), log_ratios = log_rats[,1])

# E_loo_khat
khat <- E_loo_khat(x, psis_mat, log_rats)
khat <- loo:::E_loo_khat.matrix(x, psis_mat, log_rats)

test_that("E_loo return types correct for matrix method", {
expect_type(E_test_mean, "list")
Expand All @@ -43,6 +45,12 @@ test_that("E_loo return types correct for matrix method", {
expect_length(E_test_var$value, ncol(x))
expect_length(E_test_var$pareto_k, ncol(x))

expect_type(E_test_sd, "list")
expect_named(E_test_sd, c("value", "pareto_k"))
expect_length(E_test_sd, 2)
expect_length(E_test_sd$value, ncol(x))
expect_length(E_test_sd$pareto_k, ncol(x))

expect_type(E_test_quant, "list")
expect_named(E_test_quant, c("value", "pareto_k"))
expect_length(E_test_quant, 2)
Expand All @@ -69,6 +77,12 @@ test_that("E_loo return types correct for default/vector method", {
expect_length(E_test_var_vec$value, 1)
expect_length(E_test_var_vec$pareto_k, 1)

expect_type(E_test_sd_vec, "list")
expect_named(E_test_sd_vec, c("value", "pareto_k"))
expect_length(E_test_sd_vec, 2)
expect_length(E_test_sd_vec$value, 1)
expect_length(E_test_sd_vec$pareto_k, 1)

expect_type(E_test_quant_vec, "list")
expect_named(E_test_quant_vec, c("value", "pareto_k"))
expect_length(E_test_quant_vec, 2)
Expand All @@ -83,17 +97,19 @@ test_that("E_loo return types correct for default/vector method", {
})

test_that("E_loo.default equal to reference", {
expect_equal_to_reference(E_test_mean_vec, "reference-results/E_loo_default_mean.rds")
expect_equal_to_reference(E_test_var_vec, "reference-results/E_loo_default_var.rds")
expect_equal_to_reference(E_test_quant_vec, "reference-results/E_loo_default_quantile_50.rds")
expect_equal_to_reference(E_test_quant_vec2, "reference-results/E_loo_default_quantile_10_50_90.rds")
expect_equal_to_reference(E_test_mean_vec, test_path("reference-results/E_loo_default_mean.rds"))
expect_equal_to_reference(E_test_var_vec, test_path("reference-results/E_loo_default_var.rds"))
expect_equal_to_reference(E_test_sd_vec, test_path("reference-results/E_loo_default_sd.rds"))
expect_equal_to_reference(E_test_quant_vec, test_path("reference-results/E_loo_default_quantile_50.rds"))
expect_equal_to_reference(E_test_quant_vec2, test_path("reference-results/E_loo_default_quantile_10_50_90.rds"))
})

test_that("E_loo.matrix equal to reference", {
expect_equal_to_reference(E_test_mean, "reference-results/E_loo_matrix_mean.rds")
expect_equal_to_reference(E_test_var, "reference-results/E_loo_matrix_var.rds")
expect_equal_to_reference(E_test_quant, "reference-results/E_loo_matrix_quantile_50.rds")
expect_equal_to_reference(E_test_quant2, "reference-results/E_loo_matrix_quantile_10_90.rds")
expect_equal_to_reference(E_test_mean, test_path("reference-results/E_loo_matrix_mean.rds"))
expect_equal_to_reference(E_test_var, test_path("reference-results/E_loo_matrix_var.rds"))
expect_equal_to_reference(E_test_sd, test_path("reference-results/E_loo_matrix_sd.rds"))
expect_equal_to_reference(E_test_quant, test_path("reference-results/E_loo_matrix_quantile_50.rds"))
expect_equal_to_reference(E_test_quant2, test_path("reference-results/E_loo_matrix_quantile_10_90.rds"))
})

test_that("E_loo throws correct errors and warnings", {
Expand Down Expand Up @@ -166,3 +182,13 @@ test_that("weighted quantiles work", {
)
})

test_that("weighted variance works", {
x <- rnorm(100)
w <- rep(0.01, 100)
expect_equal(.wvar(x, w), var(x))
expect_equal(.wsd(x, w), sqrt(.wvar(x, w)))

w <- c(rep(0.1, 10), rep(0, 90))
expect_equal(.wvar(x, w), var(x[w > 0]))
})