diff --git a/src/ldlt.rs b/src/ldlt.rs index 5f6339c..55458c9 100644 --- a/src/ldlt.rs +++ b/src/ldlt.rs @@ -104,32 +104,53 @@ impl Ldlt { return Err(LaError::Singular { pivot_col: j }); } - // Compute L multipliers below the diagonal in column j. - for i in (j + 1)..D { - let l = rows[i][j] / d; - if !l.is_finite() { - cold_path(); - return Err(LaError::non_finite_cell(i, j)); + if D <= 5 { + // Tiny matrices benchmark better when column normalization stays + // separate from the trailing update. + for i in (j + 1)..D { + let l = rows[i][j] / d; + if !l.is_finite() { + cold_path(); + return Err(LaError::non_finite_cell(i, j)); + } + rows[i][j] = l; } - rows[i][j] = l; - } - // Update the trailing submatrix (lower triangle): A := A - (L_col * d) * L_col^T. - for i in (j + 1)..D { - let l_i = rows[i][j]; - let l_i_d = l_i * d; + for i in (j + 1)..D { + let l_i = rows[i][j]; + let l_i_d = l_i * d; + + for k in (j + 1)..=i { + let l_k = rows[k][j]; + let new_val = (-l_i_d).mul_add(l_k, rows[i][k]); + rows[i][k] = new_val; + } + } + } else { + // Larger fixed dimensions avoid an extra column walk by updating + // each lower-triangular row prefix as soon as its multiplier is finite. + for i in (j + 1)..D { + let l_i = rows[i][j] / d; + if !l_i.is_finite() { + cold_path(); + return Err(LaError::non_finite_cell(i, j)); + } + rows[i][j] = l_i; + + let l_i_d = l_i * d; - for k in (j + 1)..=i { - let l_k = rows[k][j]; - let new_val = (-l_i_d).mul_add(l_k, rows[i][k]); - rows[i][k] = new_val; + for k in (j + 1)..=i { + let l_k = rows[k][j]; + let new_val = (-l_i_d).mul_add(l_k, rows[i][k]); + rows[i][k] = new_val; + } } } } } - let f = f.validate_finite()?; - + // Every computed lower-triangular entry is checked when it becomes a + // pivot or multiplier; the untouched upper triangle remains finite input. Ok(Self { factors: LdltFactors::new_unchecked(f), }) @@ -473,6 +494,29 @@ mod tests { ); } + #[test] + fn nonfinite_l_multiplier_overflow_fused_branch_6d() { + // D > 5 uses the fused LDLT update path. Keep the same overflow shape + // as the 2D test while forcing that branch. + let mut rows = [[0.0; 6]; 6]; + for (i, row) in rows.iter_mut().enumerate() { + row[i] = 1.0; + } + rows[0][0] = 1e-11; + rows[0][5] = 1e300; + rows[5][0] = 1e300; + + let a = Matrix::<6>::try_from_rows(rows).unwrap(); + let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); + assert_eq!( + err, + LaError::NonFinite { + row: Some(5), + col: 0 + } + ); + } + #[test] fn nonfinite_trailing_submatrix_overflow() { // L multiplier is finite (1e200), but the rank-1 update @@ -488,6 +532,28 @@ mod tests { ); } + #[test] + fn nonfinite_trailing_submatrix_overflow_fused_branch_6d() { + // D > 5 uses the fused LDLT update path. The overflowing trailing + // diagonal is detected when it later becomes a pivot. + let mut rows = [[0.0; 6]; 6]; + for (i, row) in rows.iter_mut().enumerate() { + row[i] = 1.0; + } + rows[0][5] = 1e200; + rows[5][0] = 1e200; + + let a = Matrix::<6>::try_from_rows(rows).unwrap(); + let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); + assert_eq!( + err, + LaError::NonFinite { + row: Some(5), + col: 5 + } + ); + } + #[test] fn nonfinite_solve_forward_substitution_overflow() { // SPD matrix with large L multiplier: L[1,0] = 1e153. diff --git a/src/matrix.rs b/src/matrix.rs index 10a5c9b..e7c27e8 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -138,7 +138,7 @@ impl Matrix { /// Mutably borrow raw row-major storage without preserving the finite invariant. /// /// This is reserved for internal factorization temporaries whose results are - /// validated before becoming observable API values. + /// validated or otherwise proven finite before becoming observable API values. #[inline] pub(crate) const fn rows_mut_unchecked(&mut self) -> &mut [[f64; D]; D] { &mut self.rows