Fix JVPs of select, arctan2, masked_scatter, and bitwise ops#3633
Open
qflen wants to merge 2 commits into
Open
Fix JVPs of select, arctan2, masked_scatter, and bitwise ops#3633qflen wants to merge 2 commits into
qflen wants to merge 2 commits into
Conversation
The jvp transform passes tangents packed: tangents[i] is the tangent of input argnums[i]. Several primitives indexed tangents by input position instead, reading out of bounds whenever only a subset of their inputs was traced: - Select::jvp also called its helper with argnums[i] where an index into argnums was expected, so where(cond, constant, traced) read argnums[argnums[i]] and silently returned a zero tangent (ml-explore#3627). - ArcTan2::jvp unconditionally read tangents[1]. - MaskedScatter::jvp read tangents[1] for the update argument even when it was the only traced input. BitwiseBinary::jvp returned one tangent per traced input (vjp semantics) instead of one per output, which made the jvp transform index outputs[1] out of bounds for a single-output primitive and corrupt the tangent map (ml-explore#3629).
e7798b3 to
2625ee0
Compare
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #3627 AND #3629
The jvp transform passes tangents packed:
tangents[i]is the tangent of inputargnums[i], and a jvp returns one tangent per output. Four jvps broke this contract and read out of bounds when only some of their inputs were traced.Selectargnums[i]instead ofi, so it readargnums[argnums[i]]andtangents[1]/tangents[2]tangents[i]; condition tangent gets the output dtypeArcTan2tangents[1], OOB foratan2(traced, constant)MaskedScattertangents[1]for the update even when it was the only traced inputtangents[i]BitwiseBinaryoutputs[1]OOB and corrupted the tangent mapBefore
The repro in #3629 crashed 10/10 runs with segfaults and non-deterministic shape errors.
After
dout = 6.0, jvp matches vjp and finite differences, and the #3629 repro passes 10/10 runs.Added a regression test per case and a debug assert in the jvp transform that a primitive returns at most one tangent per output. Also fixed the stale
argnums.size() == 2assert inArcTan2::vjp(one-sided grads aborted debug builds) and added awheretest with a constant condition and both branches traced.Checklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes