From da2647d7211aee74c098999cc05439bcf21a32a9 Mon Sep 17 00:00:00 2001 From: VisruthSK Date: Sat, 6 Jun 2026 16:41:39 -0700 Subject: [PATCH 1/2] Stacking weights test results --- tests/testthat/_snaps/model_weighting.md | 5 +++++ tests/testthat/test_model_weighting.R | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/tests/testthat/_snaps/model_weighting.md b/tests/testthat/_snaps/model_weighting.md index 35a8647a..34da7685 100644 --- a/tests/testthat/_snaps/model_weighting.md +++ b/tests/testthat/_snaps/model_weighting.md @@ -42,3 +42,8 @@ model2 1.000 model3 0.000 +# stacking_weights gives expected result + + c(3.60852095316096e-07, 3.2286826572722e-07, 0.999999316279639 + ) + diff --git a/tests/testthat/test_model_weighting.R b/tests/testthat/test_model_weighting.R index 96c075b8..8089d315 100644 --- a/tests/testthat/test_model_weighting.R +++ b/tests/testthat/test_model_weighting.R @@ -111,6 +111,28 @@ test_that("loo_model_weights (stacking and pseudo-BMA) gives expected result", { expect_identical(w3, w3_b) }) +test_that("stacking_weights gives expected result", { + lpd_point <- matrix( + c( + -0.2, -0.8, -1.1, + -1.4, -0.3, -0.5, + -0.6, -0.7, -0.1, + -1.0, -1.2, -0.4, + -0.9, -0.5, -0.8 + ), + ncol = 3, + byrow = TRUE + ) + + set.seed(0) + actual <- stacking_weights(lpd_point) + + expect_s3_class(actual, "stacking_weights") + expect_named(actual, paste0("model", 1:3)) + expect_equal(sum(actual), 1) + expect_snapshot_value(as.numeric(actual), style = "deparse") +}) + test_that("stacking_weights and pseudobma_weights throw correct errors", { xx <- cbind(rnorm(10)) expect_error(stacking_weights(xx), "two models are required") From 2a09e6db8640fc14af9c64de616d5de0993b7066 Mon Sep 17 00:00:00 2001 From: VisruthSK Date: Sat, 6 Jun 2026 16:47:33 -0700 Subject: [PATCH 2/2] Vectorized log score gradient calc --- R/loo_model_weights.R | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/R/loo_model_weights.R b/R/loo_model_weights.R index 946dc7c3..e426ba3f 100644 --- a/R/loo_model_weights.R +++ b/R/loo_model_weights.R @@ -271,14 +271,16 @@ stacking_weights <- # gradient of the objective function stopifnot(length(w) == K - 1) w_full <- c(w, 1 - sum(w)) - grad <- rep(0, K - 1) # avoid over- and underflows using log weights, rowLogSumExps, # and by subtracting the row maximum of lpd_point mlpd <- matrixStats::rowMaxs(lpd_point) - for (k in 1:(K - 1)) { - grad[k] <- sum((exp(lpd_point[, k] - mlpd) - exp(lpd_point[, K] - mlpd)) / exp(matrixStats::rowLogSumExps(sweep(lpd_point, 2, log(w_full), '+')) - mlpd)) - } - return(-grad) + denom <- exp(matrixStats::rowLogSumExps(sweep(lpd_point, 2, log(w_full), '+')) - mlpd) + -colSums( + ( + exp(lpd_point[, 1:(K - 1), drop = FALSE] - mlpd) - + exp(lpd_point[, K] - mlpd) + ) / denom + ) } ui <- rbind(rep(-1, K - 1), diag(K - 1)) # K-1 simplex constraint matrix