diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index caf9426..5411377 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -1,6 +1,12 @@ import torch +def _none_to_zeros(values, refs): + return tuple( + ref * 0 if value is None else value for value, ref in zip(values, refs) + ) + + class NumpyDoubleBackwardMixin: """ Adds a Numpy double backward method to any TensorProduct @@ -13,13 +19,15 @@ def double_backward_cpu( ): assert self.torch_op - in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) - in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) - weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) - out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") - in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") - weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") + in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) + in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + weights_torch = torch.tensor(weights, device="cuda", requires_grad=True) + out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) + in1_dgrad_torch = torch.tensor(in1_dgrad, device="cuda", requires_grad=False) + in2_dgrad_torch = torch.tensor(in2_dgrad, device="cuda", requires_grad=False) + weights_dgrad_torch = torch.tensor( + weights_dgrad, device="cuda", requires_grad=False + ) out_torch = self.forward(in1_torch, in2_torch, weights_torch) in1_grad, in2_grad, weights_grad = torch.autograd.grad( @@ -43,6 +51,91 @@ def double_backward_cpu( d.detach().cpu().numpy(), ) + def triple_backward_cpu( + self, + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ): + assert self.torch_op + + in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) + in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + weights_torch = torch.tensor(weights, device="cuda", requires_grad=True) + out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) + in1_dgrad_torch = torch.tensor(in1_dgrad, device="cuda", requires_grad=True) + in2_dgrad_torch = torch.tensor(in2_dgrad, device="cuda", requires_grad=True) + weights_dgrad_torch = torch.tensor( + weights_dgrad, device="cuda", requires_grad=True + ) + out_tgrad_torch = torch.tensor(out_tgrad, device="cuda", requires_grad=False) + in1_tgrad_torch = torch.tensor(in1_tgrad, device="cuda", requires_grad=False) + in2_tgrad_torch = torch.tensor(in2_tgrad, device="cuda", requires_grad=False) + weights_tgrad_torch = torch.tensor( + weights_tgrad, device="cuda", requires_grad=False + ) + + out_torch = self.forward(in1_torch, in2_torch, weights_torch) + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True, + ) + double_grads = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], + create_graph=True, + retain_graph=True, + allow_unused=True, + ) + double_grads = _none_to_zeros( + double_grads, (in1_torch, in2_torch, weights_torch, out_grad_torch) + ) + triple_grads = torch.autograd.grad( + outputs=double_grads, + inputs=[ + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ], + grad_outputs=[ + in1_tgrad_torch, + in2_tgrad_torch, + weights_tgrad_torch, + out_tgrad_torch, + ], + allow_unused=True, + ) + triple_grads = _none_to_zeros( + triple_grads, + ( + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ), + ) + + return tuple(grad.detach().cpu().numpy() for grad in triple_grads) + class NumpyDoubleBackwardMixinConv: """ @@ -54,13 +147,15 @@ def double_backward_cpu( ): assert self.torch_op - in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) - in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) - weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) - out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") - in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") - weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") + in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) + in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + weights_torch = torch.tensor(weights, device="cuda", requires_grad=True) + out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) + in1_dgrad_torch = torch.tensor(in1_dgrad, device="cuda", requires_grad=False) + in2_dgrad_torch = torch.tensor(in2_dgrad, device="cuda", requires_grad=False) + weights_dgrad_torch = torch.tensor( + weights_dgrad, device="cuda", requires_grad=False + ) torch_rows = torch.tensor(graph.rows, device="cuda") torch_cols = torch.tensor(graph.cols, device="cuda") @@ -95,3 +190,100 @@ def double_backward_cpu( c.detach().cpu().numpy(), d.detach().cpu().numpy(), ) + + def triple_backward_cpu( + self, + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + graph, + ): + assert self.torch_op + + in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) + in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + weights_torch = torch.tensor(weights, device="cuda", requires_grad=True) + out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) + in1_dgrad_torch = torch.tensor(in1_dgrad, device="cuda", requires_grad=True) + in2_dgrad_torch = torch.tensor(in2_dgrad, device="cuda", requires_grad=True) + weights_dgrad_torch = torch.tensor( + weights_dgrad, device="cuda", requires_grad=True + ) + out_tgrad_torch = torch.tensor(out_tgrad, device="cuda", requires_grad=False) + in1_tgrad_torch = torch.tensor(in1_tgrad, device="cuda", requires_grad=False) + in2_tgrad_torch = torch.tensor(in2_tgrad, device="cuda", requires_grad=False) + weights_tgrad_torch = torch.tensor( + weights_tgrad, device="cuda", requires_grad=False + ) + + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") + + out_torch = self.forward( + in1_torch, + in2_torch, + weights_torch, + torch_rows, + torch_cols, + torch_transpose_perm, + ) + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True, + ) + double_grads = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], + create_graph=True, + retain_graph=True, + allow_unused=True, + ) + double_grads = _none_to_zeros( + double_grads, (in1_torch, in2_torch, weights_torch, out_grad_torch) + ) + triple_grads = torch.autograd.grad( + outputs=double_grads, + inputs=[ + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ], + grad_outputs=[ + in1_tgrad_torch, + in2_tgrad_torch, + weights_tgrad_torch, + out_tgrad_torch, + ], + allow_unused=True, + ) + triple_grads = _none_to_zeros( + triple_grads, + ( + in1_torch, + in2_torch, + weights_torch, + out_grad_torch, + in1_dgrad_torch, + in2_dgrad_torch, + weights_dgrad_torch, + ), + ) + + return tuple(grad.detach().cpu().numpy() for grad in triple_grads) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index c18b123..de587b7 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -200,6 +200,12 @@ def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G): def register_autograd(): backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward + double_backward_op = torch.ops.libtorch_tp_jit.jit_tp_double_backward + + def zero_if_none(grad_output, like): + if grad_output is None: + return torch.zeros_like(like) + return grad_output def setup_context(ctx, inputs, output): ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_dim = inputs @@ -218,7 +224,7 @@ def setup_context_double_backward(ctx, inputs, output): ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs def double_backward(ctx, E, F, G): - result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + result = double_backward_op( ctx.kernel, ctx.hash, ctx.L1_in, @@ -237,6 +243,121 @@ def double_backward(ctx, E, F, G): setup_context=setup_context_double_backward, ) + def setup_context_triple_backward(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.W_dgrad, + ) = inputs + + def triple_backward(ctx, t_L1_grad, t_L2_grad, t_W_grad, t_L3_dgrad): + t_L1_grad = zero_if_none(t_L1_grad, ctx.L1_in) + t_L2_grad = zero_if_none(t_L2_grad, ctx.L2_in) + t_W_grad = zero_if_none(t_W_grad, ctx.weights) + t_L3_dgrad = zero_if_none(t_L3_dgrad, ctx.L3_grad) + + g1_L1_dgrad, g1_L2_dgrad, g1_W, g1_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.weights, + ctx.L3_grad, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.weights), + ) + g2_L1, g2_L2, g2_W_dgrad, g2_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + ctx.L3_grad, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.W_dgrad), + ) + g3_L1_dgrad, g3_L2, g3_W, g3_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + torch.zeros_like(ctx.L1_dgrad), + torch.zeros_like(ctx.L2_in), + t_W_grad, + ) + g4_L1, g4_L2_dgrad, g4_W, g4_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.weights, + ctx.L3_grad, + torch.zeros_like(ctx.L1_in), + torch.zeros_like(ctx.L2_dgrad), + t_W_grad, + ) + + g5_L1_dgrad, g5_L2, g5_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.weights, + t_L3_dgrad, + ) + g6_L1, g6_L2_dgrad, g6_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.weights, + t_L3_dgrad, + ) + g7_L1, g7_L2, g7_W_dgrad = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + t_L3_dgrad, + ) + + grad_L1 = g2_L1 + g4_L1 + g6_L1 + g7_L1 + grad_L2 = g2_L2 + g3_L2 + g5_L2 + g7_L2 + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_L3_grad = g1_L3_grad + g2_L3_grad + g3_L3_grad + g4_L3_grad + grad_L1_dgrad = g1_L1_dgrad + g3_L1_dgrad + g5_L1_dgrad + grad_L2_dgrad = g1_L2_dgrad + g4_L2_dgrad + g6_L2_dgrad + grad_W_dgrad = g2_W_dgrad + g7_W_dgrad + + return ( + None, + None, + grad_L1, + grad_L2, + grad_W, + grad_L3_grad, + grad_L1_dgrad, + grad_L2_dgrad, + grad_W_dgrad, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_double_backward", + triple_backward, + setup_context=setup_context_triple_backward, + ) + def register_autocast(): global torch diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index c1087a6..c75b739 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -330,6 +330,11 @@ def register_autograd(): backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward + def zero_if_none(grad_output, like): + if grad_output is None: + return torch.zeros_like(like) + return grad_output + def setup_context(ctx, inputs, output): ( ctx.kernel, @@ -413,6 +418,143 @@ def double_backward(ctx, E, F, G): setup_context=setup_context_double_backward, ) + def setup_context_triple_backward(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.W_dgrad, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + + def triple_backward(ctx, t_L1_grad, t_L2_grad, t_W_grad, t_L3_dgrad): + t_L1_grad = zero_if_none(t_L1_grad, ctx.L1_in) + t_L2_grad = zero_if_none(t_L2_grad, ctx.L2_in) + t_W_grad = zero_if_none(t_W_grad, ctx.W) + t_L3_dgrad = zero_if_none(t_L3_dgrad, ctx.grad_output) + + common_args = ( + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + + g1_L1_dgrad, g1_L2_dgrad, g1_W, g1_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_dgrad, + ctx.W, + ctx.grad_output, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.W), + *common_args, + ) + g2_L1, g2_L2, g2_W_dgrad, g2_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + ctx.grad_output, + t_L1_grad, + t_L2_grad, + torch.zeros_like(ctx.W_dgrad), + *common_args, + ) + g3_L1_dgrad, g3_L2, g3_W, g3_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.W, + ctx.grad_output, + torch.zeros_like(ctx.L1_dgrad), + torch.zeros_like(ctx.L2_in), + t_W_grad, + *common_args, + ) + g4_L1, g4_L2_dgrad, g4_W, g4_L3_grad = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.W, + ctx.grad_output, + torch.zeros_like(ctx.L1_in), + torch.zeros_like(ctx.L2_dgrad), + t_W_grad, + *common_args, + ) + + g5_L1_dgrad, g5_L2, g5_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_dgrad, + ctx.L2_in, + ctx.W, + t_L3_dgrad, + *common_args, + ) + g6_L1, g6_L2_dgrad, g6_W = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_dgrad, + ctx.W, + t_L3_dgrad, + *common_args, + ) + g7_L1, g7_L2, g7_W_dgrad = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W_dgrad, + t_L3_dgrad, + *common_args, + ) + + grad_L1 = g2_L1 + g4_L1 + g6_L1 + g7_L1 + grad_L2 = g2_L2 + g3_L2 + g5_L2 + g7_L2 + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_L3_grad = g1_L3_grad + g2_L3_grad + g3_L3_grad + g4_L3_grad + grad_L1_dgrad = g1_L1_dgrad + g3_L1_dgrad + g5_L1_dgrad + grad_L2_dgrad = g1_L2_dgrad + g4_L2_dgrad + g6_L2_dgrad + grad_W_dgrad = g2_W_dgrad + g7_W_dgrad + + return ( + None, + None, + grad_L1, + grad_L2, + grad_W, + grad_L3_grad, + grad_L1_dgrad, + grad_L2_dgrad, + grad_W_dgrad, + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_double_backward", + triple_backward, + setup_context=setup_context_triple_backward, + ) + def register_autocast(): torch.library.register_autocast( diff --git a/openequivariance/openequivariance/benchmark/correctness.py b/openequivariance/openequivariance/benchmark/correctness.py index 45c45c4..fa95f82 100644 --- a/openequivariance/openequivariance/benchmark/correctness.py +++ b/openequivariance/openequivariance/benchmark/correctness.py @@ -13,6 +13,8 @@ get_random_buffers_double_backward, get_random_buffers_forward_conv, get_random_buffers_forward, + get_random_buffers_triple_backward, + get_random_buffers_triple_backward_conv, ) from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.TensorProductBase import TensorProductBase @@ -316,6 +318,175 @@ def correctness_double_backward( return result +def correctness_triple_backward( + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +): + buffers = get_random_buffers_triple_backward( + problem, batch_size=batch_size, prng_seed=prng_seed + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + + reference_implementation = E3NNTensorProduct + + result = {"thresh": correctness_threshold, "batch_size": batch_size} + tensors = [] + for i, impl in enumerate([test_implementation, reference_implementation]): + is_test_impl = i == 0 + tp = instantiate_implementation(impl, problem) + buffers_copy = [buf.copy() for buf in buffers] + ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) = buffers_copy + + weights_reordered = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not problem.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights + ) + weights_tgrad_reordered = tp.reorder_weights_from_e3nn( + weights_tgrad, has_batch_dim=not problem.shared_weights + ) + + if impl == CUETensorProduct and problem.shared_weights: + weights_reordered = weights_reordered[np.newaxis, :] + weights_dgrad_reordered = weights_dgrad_reordered[np.newaxis, :] + weights_tgrad_reordered = weights_tgrad_reordered[np.newaxis, :] + + if is_test_impl: + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ) = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ), + ( + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + ), + ) + ] + + ( + in1_grad, + in2_grad, + weights_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + weights_dgrad_grad, + ) = tp.triple_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad_reordered, + in1_tgrad, + in2_tgrad, + ) + + if is_test_impl: + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ) = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ), + ( + problem.irreps_in1, + problem.irreps_in2, + problem.irreps_out, + problem.irreps_in1, + problem.irreps_in2, + ), + ) + ] + + tensors.append( + ( + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not problem.shared_weights + ), + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + tp.reorder_weights_to_e3nn( + weights_dgrad_grad, has_batch_dim=not problem.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("in1_grad", tensors[0][0], tensors[1][0]), + ("in2_grad", tensors[0][1], tensors[1][1]), + ("weights_grad", tensors[0][2], tensors[1][2]), + ("output_grad", tensors[0][3], tensors[1][3]), + ("in1_double_grad", tensors[0][4], tensors[1][4]), + ("in2_double_grad", tensors[0][5], tensors[1][5]), + ("weights_double_grad", tensors[0][6], tensors[1][6]), + ]: + result[name] = check_similiarity( + name, to_check, ground_truth, correctness_threshold + ) + + return result + + def correctness_forward_conv( conv, graph, @@ -636,3 +807,177 @@ def correctness_double_backward_conv( result[name] = check_similiarity(name, to_check, ground_truth, thresh) return result + + +def correctness_triple_backward_conv( + conv, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, +): + buffers = get_random_buffers_triple_backward_conv( + conv.config, graph.node_count, graph.nnz, prng_seed + ) + + if reference_implementation is None: + from openequivariance._torch.E3NNConv import E3NNConv + + reference_implementation = E3NNConv + + reference_problem = conv.config + if high_precision_ref: + reference_problem = copy.deepcopy(conv.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + + reference_tp = reference_implementation(reference_problem, torch_op=True) + + result = {"thresh": thresh} + tensors = [] + for i, tp in enumerate([conv, reference_tp]): + is_test_impl = i == 0 + buffers_copy = [buf.copy() for buf in buffers] + + if i == 1 and high_precision_ref: + buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] + + ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) = buffers_copy + + weights_reordered = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not conv.config.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not conv.config.shared_weights + ) + weights_tgrad_reordered = tp.reorder_weights_from_e3nn( + weights_tgrad, has_batch_dim=not conv.config.shared_weights + ) + + if is_test_impl: + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ) = [ + transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout) + for arr, irreps in zip( + ( + in1, + in2, + out_grad, + in1_dgrad, + in2_dgrad, + out_tgrad, + in1_tgrad, + in2_tgrad, + ), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] + + ( + in1_grad, + in2_grad, + weights_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + weights_dgrad_grad, + ) = tp.triple_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad_reordered, + in1_tgrad, + in2_tgrad, + graph, + ) + + if is_test_impl: + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ) = [ + transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir") + for arr, irreps in zip( + ( + in1_grad, + in2_grad, + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + ), + ( + tp.config.irreps_in1, + tp.config.irreps_in2, + tp.config.irreps_out, + tp.config.irreps_in1, + tp.config.irreps_in2, + ), + ) + ] + + tensors.append( + ( + in1_grad, + in2_grad, + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not conv.config.shared_weights + ), + out_grad_grad, + in1_dgrad_grad, + in2_dgrad_grad, + tp.reorder_weights_to_e3nn( + weights_dgrad_grad, has_batch_dim=not conv.config.shared_weights + ), + ) + ) + + for name, to_check, ground_truth in [ + ("in1_grad", tensors[0][0], tensors[1][0]), + ("in2_grad", tensors[0][1], tensors[1][1]), + ("weights_grad", tensors[0][2], tensors[1][2]), + ("output_grad", tensors[0][3], tensors[1][3]), + ("in1_double_grad", tensors[0][4], tensors[1][4]), + ("in2_double_grad", tensors[0][5], tensors[1][5]), + ("weights_double_grad", tensors[0][6], tensors[1][6]), + ]: + result[name] = check_similiarity(name, to_check, ground_truth, thresh) + + return result diff --git a/openequivariance/openequivariance/benchmark/test_buffers.py b/openequivariance/openequivariance/benchmark/test_buffers.py index c657d5b..252c206 100644 --- a/openequivariance/openequivariance/benchmark/test_buffers.py +++ b/openequivariance/openequivariance/benchmark/test_buffers.py @@ -127,6 +127,59 @@ def get_random_buffers_double_backward( ) +def get_random_buffers_triple_backward(tpp: TPProblem, batch_size: int, prng_seed: int): + rng = np.random.default_rng(prng_seed) + + in1 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([batch_size, tpp.weight_numel]) + ) + weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_dgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_tgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + + in1_dgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_dgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_tgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + in1_tgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_tgrad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + + return ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) + + def get_random_buffers_forward_conv( tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int ): @@ -225,3 +278,58 @@ def get_random_buffers_double_backward_conv( in2_grad, out_double_grad, ) + + +def get_random_buffers_triple_backward_conv( + tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int +): + rng = np.random.default_rng(prng_seed) + + in1 = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([edge_count, tpp.weight_numel]) + ) + weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_dgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_tgrad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + + in1_dgrad = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_dgrad = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_tgrad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + in1_tgrad = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_tgrad = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + + return ( + in1, + in2, + out_grad, + weights, + weights_dgrad, + in1_dgrad, + in2_dgrad, + out_tgrad, + weights_tgrad, + in1_tgrad, + in2_tgrad, + ) diff --git a/tests/batch_test.py b/tests/batch_test.py index ff1cd1c..7ec6333 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -7,6 +7,7 @@ correctness_backward, correctness_double_backward, correctness_forward, + correctness_triple_backward, ) from openequivariance.benchmark.test_buffers import get_random_buffers_forward from openequivariance.benchmark.problems import ( @@ -27,7 +28,12 @@ def dtype(request): class TPCorrectness: def thresh(self, direction): - return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] + return { + "fwd": 1e-5, + "bwd": 3e-4, + "double_bwd": 3e-4, + "triple_bwd": 5e-4, + }[direction] def check_result(self, result, fieldname): with check: @@ -95,6 +101,31 @@ def test_tp_double_bwd(self, tp_and_problem): self.check_result(result, "in2_grad") self.check_result(result, "weights_grad") + def test_tp_triple_bwd(self, tp_and_problem, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + tp, problem = tp_and_problem + result = correctness_triple_backward( + problem=problem, + test_implementation=tp, + reference_implementation=None, + batch_size=4, + correctness_threshold=self.thresh("triple_bwd"), + prng_seed=12345, + ) + + for fieldname in [ + "in1_grad", + "in2_grad", + "weights_grad", + "output_grad", + "in1_double_grad", + "in2_double_grad", + "weights_double_grad", + ]: + self.check_result(result, fieldname) + class TestProductionModels(TPCorrectness): production_model_tpps = ( @@ -241,6 +272,7 @@ def thresh(self, direction): "fwd": 1e-5, "bwd": 5e-4, # Expect higher errors for shared weights "double_bwd": 5e-4, + "triple_bwd": 5e-4, }[direction] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") diff --git a/tests/conv_test.py b/tests/conv_test.py index 8471e59..446d0f3 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -10,6 +10,7 @@ correctness_backward_conv, correctness_double_backward_conv, correctness_forward_conv, + correctness_triple_backward_conv, ) from itertools import product import torch @@ -47,7 +48,12 @@ def with_jax(request): class ConvCorrectness: def thresh(self, direction): - return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] + return { + "fwd": 3e-4, + "bwd": 3e-4, + "double_bwd": 3e-4, + "triple_bwd": 5e-4, + }[direction] def check_result(self, result, fieldname): with check: @@ -137,6 +143,32 @@ def test_tp_double_bwd(self, conv_object, graph): self.check_result(result, "in2_grad") self.check_result(result, "weights_grad") + def test_tp_triple_bwd(self, conv_object, graph, with_jax): + if with_jax: + pytest.skip("N/A for JAX") + + if conv_object is None: + pytest.skip("'conv_object' fixture returned None, skipping") + + result = correctness_triple_backward_conv( + conv_object, + graph, + thresh=self.thresh("triple_bwd"), + prng_seed=12345, + reference_implementation=None, + ) + + for fieldname in [ + "in1_grad", + "in2_grad", + "weights_grad", + "output_grad", + "in1_double_grad", + "in2_double_grad", + "weights_double_grad", + ]: + self.check_result(result, fieldname) + class TestProductionModels(ConvCorrectness): production_model_tpps = ( @@ -244,6 +276,7 @@ def thresh(self, direction): "fwd": 1e-5, "bwd": 7.5e-2, # Expect higher errors for shared weights "double_bwd": 5e-1, + "triple_bwd": 5e-1, }[direction] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class")