Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 87 additions & 11 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
self._lambdas = lambdas
self._rest2_scale_factors = rest2_scale_factors
self._states = _np.array(range(len(lambdas)))
self._old_states = _np.array(range(len(lambdas)))
self._openmm_states = [None] * len(lambdas)
self._gcmc_samplers = [None] * len(lambdas)
self._gcmc_states = [None] * len(lambdas)
Expand Down Expand Up @@ -150,7 +149,6 @@ def __getstate__(self):
"_lambdas": self._lambdas,
"_rest2_scale_factors": self._rest2_scale_factors,
"_states": self._states,
"_old_states": self._old_states,
"_openmm_states": self._openmm_states,
# Don't pickle the GCMC samplers since they need to be recreated.
"_gcmc_samplers": len(self._gcmc_samplers) * [None],
Expand Down Expand Up @@ -511,9 +509,14 @@ def set_states(self, states):
"""
self._states = states

def mix_states(self):
def mix_states(self, old_states):
"""
Mix the states of the dynamics objects.

Parameters
----------
old_states : numpy.ndarray
The state indices from before the last replica mix.
"""
# Mix the states.
for i, state in enumerate(self._states):
Expand Down Expand Up @@ -541,11 +544,7 @@ def mix_states(self):
self._gcmc_samplers[i].pop()

# Update the swap matrix.
old_state = self._old_states[i]
self._num_swaps[old_state, state] += 1

# Store the current states.
self._old_states = self._states.copy()
self._num_swaps[old_states[i], state] += 1

def get_proposed(self):
"""
Expand Down Expand Up @@ -716,6 +715,11 @@ def __init__(self, system, config):
# Store the name of the replica exchange swap acceptance matrix.
self._repex_matrix = self._config.output_directory / "repex_matrix.txt"

# Sentinel file written only after a fully successful run (dynamics +
# trajectory consolidation + backup cleanup). Used to distinguish
# "truly complete" from "complete dynamics but killed during cleanup".
self._done_file = self._config.output_directory / "simulation.done"

# Flag that we haven't equilibrated.
self._is_equilibration = False

Expand Down Expand Up @@ -756,6 +760,11 @@ def __init__(self, system, config):
}
)

# On a fresh (non-restart) run, remove any leftover sentinel so that
# a repeated run with --overwrite doesn't immediately exit as complete.
if not self._is_restart and self._done_file.exists():
self._done_file.unlink()

# Create the dynamics cache.
if not self._is_restart:
xml_filenames = (
Expand All @@ -777,10 +786,33 @@ def __init__(self, system, config):
else:
_logger.debug("Restarting from file")

# Check to see if the simulation is already complete.
time = self._system[0].time()

# Check to see if the simulation is already complete.
if self._done_file.exists():
# The runtime may have been extended beyond the previous run.
# If so, clear the sentinel and continue.
if time < self._config.runtime - self._config.timestep:
_logger.info(
"Runtime has been extended. Clearing completion sentinel."
)
self._done_file.unlink()
else:
_logger.success("Simulation already complete. Exiting.")
_sys.exit(0)

if time > self._config.runtime - self._config.timestep:
_logger.success("Simulation already complete. Exiting.")
# Dynamics finished but the process was killed before cleanup
# completed (e.g. during DCD consolidation or backup removal).
# Consolidate any remaining trajectory chunks and tidy up.
_logger.warning(
"Simulation dynamics are complete but post-run cleanup was "
"not finished. Completing cleanup now."
)
self._consolidate_trajectories()
self._cleanup()
self._done_file.touch()
_logger.success("Cleanup complete. Exiting.")
_sys.exit(0)
else:
_logger.info(
Expand Down Expand Up @@ -1209,6 +1241,7 @@ def run(self):

# Mix the replicas.
_logger.info("Mixing replicas")
old_states = self._dynamics_cache.get_states()
self._dynamics_cache.set_states(
self._mix_replicas(
self._config.num_lambda,
Expand All @@ -1217,7 +1250,7 @@ def run(self):
self._dynamics_cache.get_accepted(),
)
)
self._dynamics_cache.mix_states()
self._dynamics_cache.mix_states(old_states)

# Snapshot the pre-run state for crash recovery.
if self._config.auto_fix_minimise:
Expand Down Expand Up @@ -1300,6 +1333,10 @@ def run(self):
# Delete all backup files from the working directory.
self._cleanup()

# Write the sentinel file to signal that the run completed fully,
# including trajectory consolidation and cleanup.
self._done_file.touch()

def _run_block(
self,
index,
Expand Down Expand Up @@ -1872,6 +1909,45 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
except Exception as e:
return index, e

def _consolidate_trajectories(self):
"""
Consolidate any remaining trajectory chunk files into the final DCD.

Called when a restart detects that dynamics completed but the process
was killed before post-run cleanup finished. Safe to call when some
replicas are already fully consolidated (no chunks left) — those are
skipped automatically.
"""
from glob import glob as _glob_local
from pathlib import Path as _Path_local
from shutil import copyfile as _copyfile_local

if not self._config.save_trajectories:
return

for i in range(len(self._lambda_values)):
traj_filename = self._filenames[i]["trajectory"]
chunk_pattern = f"{self._filenames[i]['trajectory_chunk']}*"
traj_chunks = sorted(_glob_local(chunk_pattern))

# On a restart, prepend an existing final DCD as .prev so frames
# from a previous (possibly partial) consolidation are preserved.
path = _Path_local(traj_filename)
if path.exists() and path.stat().st_size > 0:
prev = f"{traj_filename}.prev"
_copyfile_local(traj_filename, prev)
traj_chunks = [prev] + traj_chunks

if not traj_chunks:
continue

topology0 = self._filenames["topology0"]
mols = _sr.load([topology0] + traj_chunks)
_sr.save(mols.trajectory(), traj_filename, format=["DCD"])

for chunk in traj_chunks:
_Path_local(chunk).unlink()

@staticmethod
@_njit
def _mix_replicas(num_replicas, energy_matrix, proposed, accepted):
Expand Down
Loading