Skip to content

Fix JVPs of select, arctan2, masked_scatter, and bitwise ops#3633

Open
qflen wants to merge 2 commits into
ml-explore:mainfrom
qflen:fix/jvp-packed-tangent-indexing
Open

Fix JVPs of select, arctan2, masked_scatter, and bitwise ops#3633
qflen wants to merge 2 commits into
ml-explore:mainfrom
qflen:fix/jvp-packed-tangent-indexing

Conversation

@qflen
Copy link
Copy Markdown
Contributor

@qflen qflen commented Jun 5, 2026

Fixes #3627 AND #3629

The jvp transform passes tangents packed: tangents[i] is the tangent of input argnums[i], and a jvp returns one tangent per output. Four jvps broke this contract and read out of bounds when only some of their inputs were traced.

jvp bug fix
Select helper called with argnums[i] instead of i, so it read argnums[argnums[i]] and tangents[1]/tangents[2] index by position, use tangents[i]; condition tangent gets the output dtype
ArcTan2 always read tangents[1], OOB for atan2(traced, constant) accumulate one term per traced input
MaskedScatter read tangents[1] for the update even when it was the only traced input use tangents[i]
BitwiseBinary returned one tangent per traced input (vjp semantics), so the transform indexed outputs[1] OOB and corrupted the tangent map return one zero tangent per output; vjp written out instead of delegating to jvp

Before

_, (dout,) = mx.jvp(lambda t: mx.where(t < -1.0, 999.0, t * t), [mx.array(3.0)], [mx.array(1.0)])
# dout = 0.0, silently wrong (vjp and finite differences give 6.0)

The repro in #3629 crashed 10/10 runs with segfaults and non-deterministic shape errors.

After

dout = 6.0, jvp matches vjp and finite differences, and the #3629 repro passes 10/10 runs.

Added a regression test per case and a debug assert in the jvp transform that a primitive returns at most one tangent per output. Also fixed the stale argnums.size() == 2 assert in ArcTan2::vjp (one-sided grads aborted debug builds) and added a where test with a constant condition and both branches traced.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@qflen qflen marked this pull request as draft June 6, 2026 00:12
The jvp transform passes tangents packed: tangents[i] is the tangent of
input argnums[i]. Several primitives indexed tangents by input position
instead, reading out of bounds whenever only a subset of their inputs
was traced:

- Select::jvp also called its helper with argnums[i] where an index
  into argnums was expected, so where(cond, constant, traced) read
  argnums[argnums[i]] and silently returned a zero tangent (ml-explore#3627).
- ArcTan2::jvp unconditionally read tangents[1].
- MaskedScatter::jvp read tangents[1] for the update argument even
  when it was the only traced input.

BitwiseBinary::jvp returned one tangent per traced input (vjp
semantics) instead of one per output, which made the jvp transform
index outputs[1] out of bounds for a single-output primitive and
corrupt the tangent map (ml-explore#3629).
@qflen qflen force-pushed the fix/jvp-packed-tangent-indexing branch from e7798b3 to 2625ee0 Compare June 6, 2026 00:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] JVP silently returns a zero tangent when differentiating through mx.where

1 participant