Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ S3method(dim,kfold)
S3method(dim,loo)
S3method(dim,psis_loo)
S3method(dim,waic)
S3method(elpd,array)
S3method(elpd,matrix)
S3method(importance_sampling,array)
S3method(importance_sampling,default)
S3method(importance_sampling,matrix)
Expand Down Expand Up @@ -73,6 +75,7 @@ export(.ndraws)
export(.thin_draws)
export(E_loo)
export(compare)
export(elpd)
export(example_loglik_array)
export(example_loglik_matrix)
export(extract_log_lik)
Expand Down
70 changes: 70 additions & 0 deletions R/elpd.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' Generic (expected) log-predictive density
#'
#' The `elpd()` methods for arrays, and matrices can compute the expected log pointwise predictive density for a new dataset or the log pointwise predictive density of the observed data (an overestimate of the elpd).
#'
#' @export
#' @inheritParams loo
#'
#' @details The `elpd()` function is an S3 generic and methods are provided for
#' 3-D pointwise log-likelihood arrays, pointwise log-likelihood matrices.
#'
#'
#' @examples
#' ### Array methods for calculating the lpd of the observed data (using example objects included with loo package)
#' LLarr <- example_loglik_array()
#' elpd(LLarr)
#'
#'
#'
elpd <- function(x, ...) {
UseMethod("elpd")
}

#' @export
#' @templateVar fn elpd
#' @template array
#'
elpd.array <- function(x, ...) {

ll <- llarray_to_matrix(x)
elpd.matrix(ll)
}

#' @export
#' @templateVar fn elpd
#' @template matrix
#'
elpd.matrix <-
function(x, ...) {
pointwise <- pointwise_elpd_calcs(x)
elpd_object(pointwise, dim(x))
}


pointwise_elpd_calcs <- function(ll){
elpd <- colLogSumExps(ll) - log(nrow(ll))
ic <- -2 * elpd
cbind(elpd, ic)
}

elpd_object <- function(pointwise, dims) {
if (!is.matrix(pointwise)) stop("Internal error ('pointwise' must be a matrix)")

cols_to_summarize <- colnames(pointwise)
estimates <- table_of_estimates(pointwise[, cols_to_summarize, drop=FALSE])

out <- nlist(estimates, pointwise)

structure(
out,
dims = dims,
class = c("elpd_generic", "loo")
)
}
print_dims.elpd_generic <- function(x, ...) {
cat(
"Computed from",
paste(dim(x), collapse = " by "),
"log-likelihood matrix using the generic elpd function\n"
)
}
50 changes: 50 additions & 0 deletions man/elpd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added tests/testthat/reference-results/elpd.rds
Binary file not shown.
32 changes: 29 additions & 3 deletions tests/testthat/test_loo_and_waic.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ library(loo)
options(mc.cores = 1)
set.seed(123)

context("loo and waic")
context("loo, waic and elpd")

LLarr <- example_loglik_array()
LLmat <- example_loglik_matrix()
Expand All @@ -13,6 +13,7 @@ r_eff_mat <- relative_eff(exp(LLmat), chain_id = chain_id)

loo1 <- suppressWarnings(loo(LLarr, r_eff = r_eff_arr))
waic1 <- suppressWarnings(waic(LLarr))
elpd1 <- suppressWarnings(elpd(LLarr))

test_that("using loo.cores is deprecated", {
options(mc.cores = NULL)
Expand All @@ -22,9 +23,10 @@ test_that("using loo.cores is deprecated", {
options(mc.cores = 1)
})

test_that("loo and waic results haven't changed", {
test_that("loo, waic and elpd results haven't changed", {
expect_equal_to_reference(loo1, "reference-results/loo.rds")
expect_equal_to_reference(waic1, "reference-results/waic.rds")
expect_equal_to_reference(elpd1, "reference-results/elpd.rds")
})

test_that("loo with cores=1 and cores=2 gives same results", {
Expand Down Expand Up @@ -86,6 +88,24 @@ test_that("loo returns object with correct structure", {
expect_equal(dim(loo1), dim(LLmat))
})


test_that("elpd returns object with correct structure", {
expect_true(is.loo(elpd1))
expect_named(
elpd1,
c(
"estimates",
"pointwise"
)
)
est_names <- dimnames(elpd1$estimates)
expect_equal(est_names[[1]], c("elpd", "ic"))
expect_equal(est_names[[2]], c("Estimate", "SE"))
expect_equal(colnames(elpd1$pointwise), est_names[[1]])
expect_equal(dim(elpd1), dim(LLmat))
})


test_that("two pareto k values are equal", {
expect_identical(loo1$pointwise[,"influence_pareto_k"], loo1$diagnostics$pareto_k)
})
Expand All @@ -111,9 +131,15 @@ test_that("waic.array and waic.matrix give same result", {
expect_identical(waic1, waic2)
})

test_that("loo and waic error with vector input", {
test_that("elpd.array and elpd.matrix give same result", {
elpd2 <- suppressWarnings(elpd(LLmat))
expect_identical(elpd1, elpd2)
})

test_that("loo, waic, and elpd error with vector input", {
expect_error(loo(LLvec), regexp = "no applicable method")
expect_error(waic(LLvec), regexp = "no applicable method")
expect_error(elpd(LLvec), regexp = "no applicable method")
})


Expand Down
Loading