diff --git a/R/loo_model_weights.R b/R/loo_model_weights.R index 946dc7c3..423d41a9 100644 --- a/R/loo_model_weights.R +++ b/R/loo_model_weights.R @@ -329,22 +329,16 @@ pseudobma_weights <- return(wts) } - temp <- matrix(NA, BB_n, K) BB_weighting <- dirichlet_rng(BB_n, rep(alpha, N)) - for (bb in 1:BB_n) { - z_bb <- BB_weighting[bb, ] %*% lpd_point * N - uwts <- exp(z_bb - max(z_bb)) - temp[bb, ] <- uwts / sum(uwts) - } - wts <- structure( - colMeans(temp), + z <- BB_weighting %*% lpd_point * N + uwts <- exp(z - matrixStats::rowMaxs(z)) + structure( + colMeans(uwts / rowSums(uwts)), names = paste0("model", 1:K), class = "pseudobma_bb_weights" ) - return(wts) } - #' Generate dirichlet simulations, rewritten version #' @importFrom stats rgamma #' @noRd diff --git a/tests/testthat/_snaps/model_weighting.md b/tests/testthat/_snaps/model_weighting.md index 35a8647a..ba0fad93 100644 --- a/tests/testthat/_snaps/model_weighting.md +++ b/tests/testthat/_snaps/model_weighting.md @@ -42,3 +42,7 @@ model2 1.000 model3 0.000 +# Bayesian bootstrap gives expected result + + c(0.188359351057998, 0.309962881543203, 0.501677767398798) + diff --git a/tests/testthat/test_model_weighting.R b/tests/testthat/test_model_weighting.R index 96c075b8..4e4f94e2 100644 --- a/tests/testthat/test_model_weighting.R +++ b/tests/testthat/test_model_weighting.R @@ -111,6 +111,43 @@ test_that("loo_model_weights (stacking and pseudo-BMA) gives expected result", { expect_identical(w3, w3_b) }) +test_that("Bayesian bootstrap gives expected result", { + lpd_point <- matrix( + c( + -0.2, + -0.8, + -1.1, + -1.4, + -0.3, + -0.5, + -0.6, + -0.7, + -0.1, + -1.0, + -1.2, + -0.4, + -0.9, + -0.5, + -0.8 + ), + ncol = 3, + byrow = TRUE + ) + BB_n <- 25 + alpha <- 0.7 + + set.seed(0) + weights <- pseudobma_weights(lpd_point, BB = TRUE, BB_n = BB_n, alpha = alpha) + + expect_s3_class(weights, "pseudobma_bb_weights") + expect_named(weights, paste0("model", 1:3)) + expect_equal(sum(weights), 1) + expect_snapshot_value( + unname(as.numeric(weights)), + style = "deparse" + ) +}) + test_that("stacking_weights and pseudobma_weights throw correct errors", { xx <- cbind(rnorm(10)) expect_error(stacking_weights(xx), "two models are required")