Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions probeflow/gui/dialogs/image_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ def _processing_pixel_sizes_m(self) -> tuple[float | None, float | None]:
return w_m / Nx, h_m / Ny

def _refresh_display_array(self, reset_zoom_if_shape_changed: bool = False):
# Keep the Gaussian-blur σ readout calibrated to the loaded scan.
panel = getattr(self, "_processing_panel", None)
if panel is not None and hasattr(panel, "set_pixel_size_nm"):
psx, _psy = self._processing_pixel_sizes_m()
panel.set_pixel_size_nm(psx * 1e9 if psx else None)
old_shape = self._display_arr.shape if self._display_arr is not None else None
# display array: raw with processing applied (no grain overlay — that's visual only)
if self._raw_arr is not None and self._processing:
Expand Down
87 changes: 78 additions & 9 deletions probeflow/gui/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,38 @@
QSizePolicy, QSlider, QVBoxLayout, QWidget,
)

# The smooth-σ slider is an integer QSlider scaled by this factor so it can
# express sub-pixel σ (e.g. a slider value of 5 → σ = 0.5 px).
_SMOOTH_SIGMA_SCALE = 10
# FWHM of a Gaussian = 2·sqrt(2·ln 2)·σ.
_GAUSSIAN_FWHM_FACTOR = 2.35482004503


def format_gaussian_readout(sigma_px: float, px_nm: float | None) -> str:
"""One-line description of the Gaussian-blur kernel's physical extent.

``sigma_px`` is the standard deviation in pixels. The kernel is truncated at
±4σ to match scipy's ``gaussian_filter(truncate=4.0)`` default, so the
half-width in pixels is ``int(4σ + 0.5)``. When ``px_nm`` (pixel size in nm)
is known the σ, FWHM and kernel extent are also reported in nanometres.
"""
sigma_px = float(sigma_px)
fwhm_px = _GAUSSIAN_FWHM_FACTOR * sigma_px
r_px = int(4.0 * sigma_px + 0.5)
if px_nm and px_nm > 0:
sigma_nm = sigma_px * px_nm
fwhm_nm = fwhm_px * px_nm
r_nm = r_px * px_nm
return (
f"σ {sigma_px:.1f} px · {sigma_nm:.3g} nm "
f"FWHM {fwhm_nm:.3g} nm "
f"kernel ±{r_px} px (±{r_nm:.3g} nm)"
)
return (
f"σ {sigma_px:.1f} px FWHM {fwhm_px:.2f} px kernel ±{r_px} px"
)


class ProcessingControlPanel(QWidget):
"""Internal processing controls shared by Browse and Viewer."""

Expand All @@ -26,6 +58,7 @@ def __init__(self, mode: str, parent=None):
if mode not in ("browse_quick", "viewer_full"):
raise ValueError(f"Unknown processing panel mode: {mode}")
self._mode = mode
self._smooth_px_nm: float | None = None
self._build()

def _build(self):
Expand Down Expand Up @@ -57,7 +90,10 @@ def _combo_row(label: str, items: list[str],
return cb

def _sub_slider(label: str, mn: int, mx: int, init: int,
fmt="{v}") -> tuple[QWidget, QSlider, QLabel]:
fmt="{v}", scale: int = 1) -> tuple[QWidget, QSlider, QLabel]:
# ``scale`` > 1 makes the integer slider represent a fractional value:
# the displayed/logical value is ``slider_value / scale`` (e.g. scale=10
# gives 0.1 steps). scale=1 is the original integer behaviour.
w = QWidget()
rl = QHBoxLayout(w)
rl.setContentsMargins(0, 0, 0, 0)
Expand All @@ -67,11 +103,15 @@ def _sub_slider(label: str, mn: int, mx: int, init: int,
sl = QSlider(Qt.Horizontal)
sl.setRange(mn, mx)
sl.setValue(init)
val_lbl = QLabel(fmt.format(v=init))

def _disp(v: int) -> str:
return fmt.format(v=(v / scale if scale != 1 else v))

val_lbl = QLabel(_disp(init))
val_lbl.setFont(ui_font(8))
val_lbl.setFixedWidth(28)
val_lbl.setFixedWidth(28 if scale == 1 else 44)
sl.valueChanged.connect(
lambda v, vl=val_lbl, f=fmt: vl.setText(f.format(v=v)))
lambda v, vl=val_lbl: vl.setText(_disp(v)))
rl.addWidget(lbl)
rl.addWidget(sl, 1)
rl.addWidget(val_lbl)
Expand Down Expand Up @@ -239,14 +279,27 @@ def _col_lbl(text: str, target):
self._smooth_combo = _combo_row("Smooth:", ["None", "Gaussian"], L, 54)
self._smooth_combo.setToolTip(
"Gaussian blur to suppress noise. Larger sigma (px) smooths more, but "
"also blurs genuine fine features."
"also blurs genuine fine features. The kernel spans ±4σ; σ may be "
"sub-pixel (down to 0.2 px) for gentle denoising."
)
self._smooth_sigma_w, self._smooth_sigma_sl, _ = _sub_slider(
"sigma:", 1, 20, 1, "{v}px")
"sigma:", 2, 200, _SMOOTH_SIGMA_SCALE, "{v:.1f}px",
scale=_SMOOTH_SIGMA_SCALE)
L.addWidget(self._smooth_sigma_w)
self._smooth_sigma_w.setVisible(False)
# Physical readout: σ/FWHM/kernel extent in nm (when calibrated) or px.
self._smooth_readout_lbl = QLabel()
self._smooth_readout_lbl.setFont(ui_font(7))
self._smooth_readout_lbl.setWordWrap(True)
self._smooth_readout_lbl.setVisible(False)
L.addWidget(self._smooth_readout_lbl)
self._smooth_sigma_sl.valueChanged.connect(
lambda _v: self._update_smooth_readout())
self._smooth_combo.currentIndexChanged.connect(
lambda i: self._smooth_sigma_w.setVisible(i != 0))
lambda i: (self._smooth_sigma_w.setVisible(i != 0),
self._smooth_readout_lbl.setVisible(i != 0),
self._update_smooth_readout()))
self._update_smooth_readout()

self._highpass_combo = _combo_row("Hi-pass:", ["None", "Gaussian"], L, 54)
self._highpass_sigma_w, self._highpass_sigma_sl, _ = _sub_slider(
Expand Down Expand Up @@ -314,7 +367,10 @@ def state(self) -> dict:
"remove_bad_lines_max_adjacent_bad_lines": int(
self._bad_line_adjacent_spin.value()
),
"smooth_sigma": self._smooth_sigma_sl.value() if smooth_i != 0 else None,
"smooth_sigma": (
self._smooth_sigma_sl.value() / _SMOOTH_SIGMA_SCALE
if smooth_i != 0 else None
),
"highpass_sigma": self._highpass_sigma_sl.value() if highpass_i != 0 else None,
"edge_method": edge_map[edge_i],
"edge_sigma": self._edge_sigma_sl.value(),
Expand Down Expand Up @@ -348,7 +404,8 @@ def set_state(self, state: dict | None) -> None:
sigma = state.get("smooth_sigma")
if sigma:
self._smooth_combo.setCurrentIndex(1)
self._smooth_sigma_sl.setValue(int(sigma))
self._smooth_sigma_sl.setValue(
int(round(float(sigma) * _SMOOTH_SIGMA_SCALE)))
else:
self._smooth_combo.setCurrentIndex(0)

Expand All @@ -365,6 +422,18 @@ def set_state(self, state: dict | None) -> None:
"sobel": 4, "scharr": 5}.get(edge, 0))
self._edge_sigma_sl.setValue(int(state.get("edge_sigma", 1)))

def set_pixel_size_nm(self, px_nm: float | None) -> None:
"""Tell the panel the loaded scan's pixel size (nm) for the σ readout."""
self._smooth_px_nm = float(px_nm) if px_nm else None
self._update_smooth_readout()

def _update_smooth_readout(self) -> None:
lbl = getattr(self, "_smooth_readout_lbl", None)
if lbl is None: # browse_quick panel has no smooth control
return
sigma_px = self._smooth_sigma_sl.value() / _SMOOTH_SIGMA_SCALE
lbl.setText(format_gaussian_readout(sigma_px, self._smooth_px_nm))

def bad_line_method(self) -> str | None:
return self.state().get("remove_bad_lines")

Expand Down
166 changes: 137 additions & 29 deletions probeflow/processing/bad_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass

import numpy as np
from scipy.ndimage import uniform_filter1d

from ._image_utils import _nonnegative_finite

Expand Down Expand Up @@ -38,6 +39,31 @@ class BadLineCorrectionInfo:
max_adjacent_bad_lines: int = 1


# Maximum sub-threshold gap (px) bridged inside a MAD segment so per-pixel
# noise holes don't shatter a coherent defect into fragments. Kept small and
# fixed (a noise-hole scale), independent of the minimum feature length.
_NOISE_GAP_TOL_PX = 4

# Along-row smoothing window (px) applied to the residual before MAD detection.
# Acts as a matched filter for *extended* scan-line defects: a coherent line is
# preserved while per-pixel noise averages down (SNR gain ~ sqrt(window)), so a
# faint-but-long bad line clears the threshold without texture flooding in.
_MAD_SMOOTH_WINDOW_PX = 11


def _smooth_rows_nanaware(arr: np.ndarray, window: int) -> np.ndarray:
"""Moving average along each row (axis 1), ignoring non-finite pixels."""
if window <= 1:
return arr
finite = np.isfinite(arr)
vals = np.where(finite, arr, 0.0)
num = uniform_filter1d(vals, size=window, axis=1, mode="nearest")
den = uniform_filter1d(finite.astype(np.float64), size=window, axis=1, mode="nearest")
out = np.full_like(arr, np.nan, dtype=np.float64)
np.divide(num, den, out=out, where=den > 1e-6)
return out


def _normalise_bad_segment_method(method: str) -> str:
method = str(method or "step").lower().replace("-", "_")
aliases = {
Expand Down Expand Up @@ -110,30 +136,89 @@ def _contiguous_true_runs(mask: np.ndarray) -> list[tuple[int, int]]:
return runs


def _close_small_gaps(mask: np.ndarray, gap_tol: int) -> np.ndarray:
"""Bridge ``False`` gaps no longer than ``gap_tol`` between two ``True`` runs.

Per-pixel noise punches sub-threshold holes through an otherwise coherent
bright/dark segment; closing those holes lets the segment be recovered as a
single run instead of many short fragments.
"""
if gap_tol <= 0:
return mask
out = np.asarray(mask, dtype=bool).copy()
runs = _contiguous_true_runs(out)
for (_s0, e0), (s1, _e1) in zip(runs, runs[1:]):
if s1 - e0 <= gap_tol:
out[e0:s1] = True
return out


def _split_segments_by_adjacent_limit(
segments: list[BadSegment] | tuple[BadSegment, ...],
max_adjacent_bad_lines: int,
) -> tuple[list[BadSegment], list[BadSegment]]:
"""Return (accepted, skipped) after applying adjacent-line safety limit."""
"""Return (accepted, skipped) after applying the adjacent-line safety limit.

The limit guards against repairing a vertical *block* of bad pixels that has
no clean neighbour row to copy from. Crucially it is **column-aware**: two
segments on adjacent lines only belong to the same stack when their column
ranges overlap. Segments that merely share a line index with unrelated bad
lines elsewhere in the image are independently repairable and are kept. A
stack taller than ``max_adjacent_bad_lines`` consecutive overlapping lines is
skipped; everything else is repaired.
"""
max_adjacent_bad_lines = max(1, int(max_adjacent_bad_lines))
by_line: dict[int, list[BadSegment]] = {}
for seg in segments:
by_line.setdefault(int(seg.line_index), []).append(seg)
segs = list(segments)
n = len(segs)
if n == 0:
return [], []

parent = list(range(n))

def find(x: int) -> int:
root = x
while parent[root] != root:
root = parent[root]
while parent[x] != root:
parent[x], x = root, parent[x]
return root

def union(a: int, b: int) -> None:
ra, rb = find(a), find(b)
if ra != rb:
parent[ra] = rb

by_line: dict[int, list[int]] = {}
for i, seg in enumerate(segs):
by_line.setdefault(int(seg.line_index), []).append(i)

def _overlaps(a: int, b: int) -> bool:
return (int(segs[a].start_col) < int(segs[b].end_col)
and int(segs[b].start_col) < int(segs[a].end_col))

# Link a segment to any column-overlapping segment on the next line; the
# connected components are the vertical stacks of contiguous bad pixels.
for i, seg in enumerate(segs):
for j in by_line.get(int(seg.line_index) + 1, ()):
if _overlaps(i, j):
union(i, j)

components: dict[int, list[int]] = {}
for i in range(n):
components.setdefault(find(i), []).append(i)

accepted: list[BadSegment] = []
skipped: list[BadSegment] = []
lines = sorted(by_line)
i = 0
while i < len(lines):
group = [lines[i]]
i += 1
while i < len(lines) and lines[i] == group[-1] + 1:
group.append(lines[i])
i += 1
group_segments = [seg for line in group for seg in by_line[line]]
if len(group) > max_adjacent_bad_lines:
skipped.extend(group_segments)
else:
accepted.extend(group_segments)
for members in components.values():
lines = {int(segs[i].line_index) for i in members}
# Members are linked only through ±1-line overlaps, so the line set is a
# consecutive run; its span is the stack thickness.
stack_height = max(lines) - min(lines) + 1
target = accepted if stack_height <= max_adjacent_bad_lines else skipped
target.extend(segs[i] for i in members)

accepted.sort(key=lambda s: (int(s.line_index), int(s.start_col)))
skipped.sort(key=lambda s: (int(s.line_index), int(s.start_col)))
return accepted, skipped


Expand Down Expand Up @@ -203,27 +288,50 @@ def detect_bad_scanline_segments(
k += 2
return segments

scale = _robust_scale(residuals)
# The residual is already referenced to the *neighbour rows*, so good pixels
# sit at ~0 and a bright/dark defect is a direct deviation from zero. We do
# NOT subtract each row's own median (the previous behaviour): once a defect
# covers more than ~half the row it becomes the row median, nets its own
# pixels to ~0, and vanishes — so lowering the threshold only surfaced noise.
#
# Scan-line defects are *extended*, so we smooth the residual along each row
# first (a matched filter): this preserves a coherent line's amplitude while
# averaging per-pixel noise down, letting a faint-but-long line clear the
# threshold without texture flooding in. ``step`` remains the detector for
# sharp, short scars.
window = max(1, min(_MAD_SMOOTH_WINDOW_PX, Nx))
smoothed = _smooth_rows_nanaware(residuals, window)
scale = _robust_scale(smoothed)
cutoff = threshold * scale
if not np.isfinite(cutoff):
return []

sign = 1.0 if polarity == "bright" else -1.0
segments = []
for row in range(Ny):
residual = residuals[row]
finite = np.isfinite(residual)
sm = smoothed[row]
finite = np.isfinite(sm)
if not finite.any():
continue
centre = float(np.median(residual[finite]))
if polarity == "bright":
bad = finite & ((residual - centre) > cutoff)
else:
bad = finite & ((residual - centre) < -cutoff)
bad = finite & ((sign * sm) > cutoff)
# Bridge short noise-sized holes (independent of ``min_length`` so
# raising the minimum feature length to reject texture does not start
# merging unrelated noise spikes into spurious long runs).
bad = _close_small_gaps(bad, _NOISE_GAP_TOL_PX) & finite
for start, end in _contiguous_true_runs(bad):
length = end - start
if not (min_length <= length <= max_len):
# No upper-length cap: a long (even full-width) coherent defect is
# exactly the target, and the neighbour-referenced baseline stays
# valid no matter how much of the row is affected.
if end - start < min_length:
continue
seg = sm[start:end]
seg = seg[np.isfinite(seg)]
if seg.size == 0:
continue
local = np.abs(residual[start:end] - centre)
score = float(np.nanmedian(local) / max(scale, 1e-15))
seg_median = float(np.median(seg))
if (sign * seg_median) <= cutoff:
continue # not a defect on the median
score = float(abs(seg_median) / max(scale, 1e-15))
segments.append(BadSegment(row, start, end, score, method))
return segments

Expand Down
Loading
Loading