From dd4e45ad14565609e15c6bd8897a424294a3d6d2 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 19 Jun 2026 08:44:37 -0700 Subject: [PATCH 1/2] Revert "Add checks agains empty or singular blocks" This reverts commit 32dcf7cd5131ab0bbf9a1d9b0f73b19ae4f68272. --- .../dask/electromagnetics/time_domain/simulation.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index 3574a5de0d..ee114837ea 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -157,6 +157,7 @@ def compute_J(self, m, f=None): AdiagTinv, ATinv_df_duT_v[ind], time_mask, + client, ) if client: @@ -166,7 +167,7 @@ def compute_J(self, m, f=None): for block_ind in range(len(blocks)): - if len(blocks[block_ind]) == 0: + if len(block) == 0: continue if client: @@ -337,13 +338,11 @@ def get_field_deriv_block( AdiagTinv, ATinv_df_duT_v, time_mask, + client, ): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ - if len(block) == 0: - return None - Asubdiag = None if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) @@ -376,10 +375,8 @@ def get_field_deriv_block( if len(ATinv_df_duT_v) == 0: ATinv_df_duT_v = np.zeros((field_deriv.shape[0], colm_count), dtype=np.float32) - if len(time_blocks) > 0: - solve = AdiagTinv * np.hstack(time_blocks).reshape( - (ATinv_df_duT_v.shape[0], -1) - ) + if len(time_blocks) > 1: + solve = AdiagTinv * np.hstack(time_blocks) ATinv_df_duT_v[:, np.hstack(colm_indices)] = solve return ATinv_df_duT_v From b0b9318c77a004e9889b45594ddc6f41ca2c1e4e Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 19 Jun 2026 08:46:30 -0700 Subject: [PATCH 2/2] Revert "Compute field derivs as single block. Reduce calls to direct solver" This reverts commit 724c5ff2d33410987308fe37323fe87fd0278990. # Conflicts: # simpeg/dask/electromagnetics/time_domain/simulation.py --- .../time_domain/simulation.py | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/simpeg/dask/electromagnetics/time_domain/simulation.py b/simpeg/dask/electromagnetics/time_domain/simulation.py index ee114837ea..5986f57ffa 100644 --- a/simpeg/dask/electromagnetics/time_domain/simulation.py +++ b/simpeg/dask/electromagnetics/time_domain/simulation.py @@ -99,7 +99,6 @@ def compute_J(self, m, f=None): self.survey.source_list, compute_row_size, thread_count=self.n_threads(client=client, worker=worker), - optimize=False, ) fields_array = f[:, ftype, :] @@ -343,43 +342,43 @@ def get_field_deriv_block( """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ + if len(ATinv_df_duT_v) == 0: + ATinv_df_duT_v = [[] for _ in block] + Asubdiag = None if tInd < self.nT - 1: Asubdiag = self.getAsubdiag(tInd + 1) - time_blocks = [] - colm_indices = [] - colm_count = 0 - for (_, (rx_ind, _, shape)), field_deriv in zip(block, field_derivs): + updated_ATinv_df_duT_v = [] + + for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip( + block, field_derivs, ATinv_df_duT_v + ): # Cut out early data time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind] local_ind = np.arange(rx_ind.shape[0])[time_check] - if len(ATinv_df_duT_v) == 0: + if len(ATinv_chunk) == 0: # last timestep (first to be solved) time_block = field_deriv.toarray()[:, local_ind] + shape = ( + field_deriv.shape[0], + len(rx_ind), + ) + ATinv_chunk = np.zeros(shape, dtype=np.float32) else: time_block = np.asarray( - field_deriv[:, local_ind] - - Asubdiag.T - * ATinv_df_duT_v[:, colm_count : colm_count + rx_ind.shape[0]][ - :, local_ind - ] + field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind] ) - time_blocks.append(time_block) - colm_indices.append(local_ind + colm_count) - colm_count += rx_ind.shape[0] - - if len(ATinv_df_duT_v) == 0: - ATinv_df_duT_v = np.zeros((field_deriv.shape[0], colm_count), dtype=np.float32) + if time_block.ndim == 2 and time_block.shape[1] > 0: + solve = (AdiagTinv * time_block).reshape(time_block.shape) + ATinv_chunk[:, local_ind] = solve - if len(time_blocks) > 1: - solve = AdiagTinv * np.hstack(time_blocks) - ATinv_df_duT_v[:, np.hstack(colm_indices)] = solve + updated_ATinv_df_duT_v.append(ATinv_chunk) - return ATinv_df_duT_v + return updated_ATinv_df_duT_v def block_deriv( @@ -463,14 +462,11 @@ def compute_rows( Compute the rows of the sensitivity matrix for a given source and receiver. """ rows = [] - colm_count = 0 - for address, ind_array in blocks[block_ind]: + for ind, (address, ind_array) in enumerate(blocks[block_ind]): # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v): src = simulation.survey.source_list[address[0]] time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]] - - n_rec = len(ind_array[0]) - local_ind = np.arange(n_rec)[time_check] + local_ind = np.arange(len(ind_array[0]))[time_check] if len(local_ind) < 1: row_block = np.zeros( @@ -482,24 +478,18 @@ def compute_rows( dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( tInd, fields[:, address[0], tInd], - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], + field_derivs[block_ind][ind][:, local_ind], adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, - src, - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], - adjoint=True, + tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) # on nodes of time mesh un_src = fields[:, address[0], tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, - un_src, - field_derivs[block_ind][:, colm_count : colm_count + n_rec][:, local_ind], - adjoint=True, + tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True ) row_block = np.zeros( (len(ind_array[1]), simulation.model.size), dtype=np.float32 @@ -516,8 +506,6 @@ def compute_rows( else: Jmatrix[ind_array[1], :] += row_block - colm_count += n_rec - def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields): """