Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 84 additions & 18 deletions src/ldlt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,53 @@ impl<const D: usize> Ldlt<D> {
return Err(LaError::Singular { pivot_col: j });
}

// Compute L multipliers below the diagonal in column j.
for i in (j + 1)..D {
let l = rows[i][j] / d;
if !l.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, j));
if D <= 5 {
// Tiny matrices benchmark better when column normalization stays
// separate from the trailing update.
for i in (j + 1)..D {
let l = rows[i][j] / d;
if !l.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, j));
}
rows[i][j] = l;
}
rows[i][j] = l;
}

// Update the trailing submatrix (lower triangle): A := A - (L_col * d) * L_col^T.
for i in (j + 1)..D {
let l_i = rows[i][j];
let l_i_d = l_i * d;
for i in (j + 1)..D {
let l_i = rows[i][j];
let l_i_d = l_i * d;

for k in (j + 1)..=i {
let l_k = rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
rows[i][k] = new_val;
}
}
} else {
// Larger fixed dimensions avoid an extra column walk by updating
// each lower-triangular row prefix as soon as its multiplier is finite.
for i in (j + 1)..D {
let l_i = rows[i][j] / d;
if !l_i.is_finite() {
cold_path();
return Err(LaError::non_finite_cell(i, j));
}
rows[i][j] = l_i;

let l_i_d = l_i * d;

for k in (j + 1)..=i {
let l_k = rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
rows[i][k] = new_val;
for k in (j + 1)..=i {
let l_k = rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
rows[i][k] = new_val;
}
}
}
}
}

let f = f.validate_finite()?;

// Every computed lower-triangular entry is checked when it becomes a
// pivot or multiplier; the untouched upper triangle remains finite input.
Ok(Self {
factors: LdltFactors::new_unchecked(f),
})
Expand Down Expand Up @@ -473,6 +494,29 @@ mod tests {
);
}

#[test]
fn nonfinite_l_multiplier_overflow_fused_branch_6d() {
// D > 5 uses the fused LDLT update path. Keep the same overflow shape
// as the 2D test while forcing that branch.
let mut rows = [[0.0; 6]; 6];
for (i, row) in rows.iter_mut().enumerate() {
row[i] = 1.0;
}
rows[0][0] = 1e-11;
rows[0][5] = 1e300;
rows[5][0] = 1e300;

let a = Matrix::<6>::try_from_rows(rows).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(5),
col: 0
}
);
}

#[test]
fn nonfinite_trailing_submatrix_overflow() {
// L multiplier is finite (1e200), but the rank-1 update
Expand All @@ -488,6 +532,28 @@ mod tests {
);
}

#[test]
fn nonfinite_trailing_submatrix_overflow_fused_branch_6d() {
// D > 5 uses the fused LDLT update path. The overflowing trailing
// diagonal is detected when it later becomes a pivot.
let mut rows = [[0.0; 6]; 6];
for (i, row) in rows.iter_mut().enumerate() {
row[i] = 1.0;
}
rows[0][5] = 1e200;
rows[5][0] = 1e200;

let a = Matrix::<6>::try_from_rows(rows).unwrap();
let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
assert_eq!(
err,
LaError::NonFinite {
row: Some(5),
col: 5
}
);
}

#[test]
fn nonfinite_solve_forward_substitution_overflow() {
// SPD matrix with large L multiplier: L[1,0] = 1e153.
Expand Down
2 changes: 1 addition & 1 deletion src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl<const D: usize> Matrix<D> {
/// Mutably borrow raw row-major storage without preserving the finite invariant.
///
/// This is reserved for internal factorization temporaries whose results are
/// validated before becoming observable API values.
/// validated or otherwise proven finite before becoming observable API values.
#[inline]
pub(crate) const fn rows_mut_unchecked(&mut self) -> &mut [[f64; D]; D] {
&mut self.rows
Expand Down
Loading