Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,40 @@ def save_openmm_state(self, index):
.getState(getPositions=True, getVelocities=True)
)

# Store the state.
self._openmm_states[index] = state
# Store positions, velocities, and box vectors as compact numpy arrays
# rather than the OpenMM State object, which serialises to XML when
# pickled and is orders of magnitude larger.
self._openmm_states[index] = {
"positions": state.getPositions(asNumpy=True),
"velocities": state.getVelocities(asNumpy=True),
"box": state.getPeriodicBoxVectors(asNumpy=True),
}

@staticmethod
def _apply_openmm_state(context, state):
"""
Apply a saved OpenMM state to a context.

Parameters
----------

context: openmm.Context
The OpenMM context to update.

state: dict or openmm.State
The state to apply. Dicts (new format) contain "positions",
"velocities", and "box" numpy arrays. A bare openmm.State is
accepted for backwards compatibility with old checkpoint files.
"""
if isinstance(state, dict):
context.setPositions(state["positions"])
context.setVelocities(state["velocities"])
if state["box"] is not None:
context.setPeriodicBoxVectors(*state["box"])
else:
# Legacy openmm.State from checkpoint files written before this
# format change.
context.setState(state)

def save_gcmc_state(self, index):
"""
Expand Down Expand Up @@ -520,7 +552,9 @@ def mix_states(self):
# The state has changed.
if i != state:
_logger.debug(f"Replica {i} seeded from state {state}")
self._dynamics[i].context().setState(self._openmm_states[state])
self._apply_openmm_state(
self._dynamics[i].context(), self._openmm_states[state]
)

# Swap the water state in the GCMCSamplers.
if self._gcmc_samplers[i] is not None:
Expand Down Expand Up @@ -821,7 +855,9 @@ def __init__(self, system, config):
# Reset the OpenMM state, applying the last replica exchange
# mixing so the correct post-mix state is restored.
state = self._dynamics_cache._states[i]
dynamics.context().setState(self._dynamics_cache._openmm_states[state])
DynamicsCache._apply_openmm_state(
dynamics.context(), self._dynamics_cache._openmm_states[state]
)

# Reset the GCMC water state and restore statistics.
if gcmc_sampler is not None:
Expand Down Expand Up @@ -1222,9 +1258,11 @@ def run(self):
# Snapshot the pre-run state for crash recovery.
if self._config.auto_fix_minimise:
for i, state in enumerate(self._dynamics_cache.get_states()):
self._dynamics_cache._dynamics[
i
]._d._pre_run_state = self._dynamics_cache._openmm_states[state]
self._dynamics_cache._dynamics[i]._d._pre_run_state = (
self._dynamics_cache._dynamics[i]
.context()
.getState(getPositions=True, getVelocities=True)
)

# This is a checkpoint cycle.
if is_checkpoint:
Expand Down Expand Up @@ -1734,14 +1772,18 @@ def _compute_energies(self, index):
# Loop over the states.
for i in range(self._config.num_lambda):
# Set the state.
dynamics.context().setState(self._dynamics_cache._openmm_states[i])
DynamicsCache._apply_openmm_state(
dynamics.context(), self._dynamics_cache._openmm_states[i]
)
dynamics._d._clear_state()

# Compute and store the energy for this state.
energies[i] = dynamics.current_potential_energy().value()

# Reset the state.
dynamics.context().setState(self._dynamics_cache._openmm_states[index])
DynamicsCache._apply_openmm_state(
dynamics.context(), self._dynamics_cache._openmm_states[index]
)

return index, energies

Expand Down
Loading