Inference Mode
Inference mode is a mode allows us to use the trained model to make predictions. Given that a key usage of the inference mode will be molecule simulation, more efficient schemes for calculating interacting pairs are needed.
Neighborlists
Currently, there are two neighborlist strategies implemented within modelforge for inference, the brute force neighbolist and Verlet neighborlist (implemented within a single class NeighborlistForInference). Both neighborlists support periodic and not periodic orthogonal boxes.
The neighborlist strategy can be toggled during potential setup via the inference_neighborlist_strategy parameter passed to the NeuralNetworkPotentialFactory. The default is the Verlet neighborlist (“verlet_nsq”); brute can be set via “brute_nsq”. This can also be set via set at run time in the potential via Potential.set_neighborlist_strategy(strategy, skin).
Brute force neighborlist
The brute force neighborlist calculates the pairs within the interaction cutoff by considering all possible pairs each time called, via an order N^2 operation. Typically this approach should only be used for very system sizes, given the scaling; furthermore the N^2 approach used to generate this list utilizes a large amount of memory as the system size grows.
Verlet neighborlist
The Verlet neighborlist operates under the assumption that under short time windows, the local environment around a given particle does not change significantly. As such, information about this local environment can be reused between subsequent steps, eliminating the need for a costly build step.
To do this, the local environment of a given particle is identified and saved in a list (e.g., we can call this the verlet list), using the criteria pair distance < cutoff + skin. The skin is a user modifiable distance that captures a region of space beyond the interaction cutoff. In the current implementation, this verlet list is generated using the same order N^2 approach as the brute for scheme. Again, because positions are correlated with time, we typically can avoid performing another order N^2 calculation for several timesteps. Steps in between rebuilds scale as order N*M, where M is the average number of neighbors (which is typically much less than N). In our implementation, the verlet list is automatically regenerated when any given particle moves more than skin/2 (since the last build), to ensure that interactions are not missed.
Larger values of skin result in longer time periods between rebuilds, but also typically increase the number of calculations that need to be perform at each timestep (as M will typically be larger). As such, this value can have a significant impact on performance of this calculation.
Note: Since this utilizes an N^2 computation within Torch, the memory footprint may be problematic as system size grows. A cell list based approach will be implemented in the future.
Load inference potential from training checkpoint
To use the trained model for inference, the checkpoint file generated during training must be loaded. The checkpoint file contains the model’s weights, optimizer state, and other training-related information. The load_inference_model_from_checkpoint function provides a convenient way to load the checkpoint file and generate an inference model.
from modelforge.potential.models import load_inference_model_from_checkpoint
inference_model = load_inference_model_from_checkpoint(checkpoint_file)
Note, prior to merging PR #299, checkpoint files and state_dicts did not save the only_unique_pairs bool parameter, needed to properly generate neighbor information. As such, if you are using a checkpoint file generated prior to this PR, you will need to set this parameter manually. This can be done by passing the only_unique_pairs parameter to the load_inference_model_from_checkpoint function. For example, for ANI2x models, where this should be True (other currently implemented potentials require False):
from modelforge.potential.models import load_inference_model_from_checkpoint
inference_model = load_inference_model_from_checkpoint(checkpoint_file, only_unique_pairs=True)
To modify state dictionary files, this can be done easily via the modify_state_dict function in the file modify_state_dict.py in the scripts directory. This will generate a new copy of the state dictionary file with the appropriate only_unique_pairs parameter set.
Loading a checkpoint from weights and biases
Checkpoint files can be loaded directly from wandb using the load_from_wandb function as part of the NeuralNetworkPotentialFactory. This can be done by passing the wandb run id and appropriate version number. Note this will require authentication with wandb for users part of the project. The following code snippet demonstrates how to load a model from wandb.
from modelforge.potential.potential import NeuralNetworkPotentialFactory
nn_potential = NeuralNetworkPotentialFactory().load_from_wandb(
run_path="modelforge_nnps/test_ANI2x_on_dataset/model-qloqn6gk",
version="v0",
local_cache_dir=f"{prep_temp_dir}/test_wandb",
)
Using a model for inference in ASE
To use the trained model for inference in ASE, we need to load up a modelforge potential, and use ASE calculator wrapper in modelforge.ase.calculator to put this in a form that will work with the ASE calculator. The follow code snippet demonstrates loading a state dict, generating a potential, setting up the ASE calculator wrapper, and using this to calculate the energy and forces for a given system.
from modelforge.utils.io import get_path_string
from modelforge.ase.tests import data
from modelforge.potential.potential import load_inference_model_from_checkpoint
# checkpoint file is saved in tests/data
checkpoint_file_path = get_path_string(data) + "/model.ckpt"
potential = load_inference_model_from_checkpoint(checkpoint_file_path, jit=False)
# to use the potential wtih ASE we can need to import the ModelForgeCalculator class,
# which wraps the potential in an ASE-compatible calculator interface
from modelforge.ase.calculator import ModelForgeCalculator
# let us use one of ase's built in molecules
from ase.build import molecule
atoms = molecule("H2O")
atoms.calc = ModelForgeCalculator(potential)
# extract the energy and forces
pe = atoms.get_potential_energy()
forces = atoms.get_forces()
print("potential energy: ", pe)
print("forces: ", forces)
Additional examples are included the examples folder within the modelforge.ase subpackage, which demonstrate how to use the ASE calculator wrapper with modelforge potentials for inference.
Examples
Using a model for inference in OpenMM
To use the trained model for inference in OpenMM, we can use a similar approach as used for ASE, but instead of the ASE calculator wrapper, we will use a wrapper for OpenMM that interfaces with OpenMM’s PythonForce.
from modelforge.utils.io import get_path_string
from modelforge.openmm.examples import data
from modelforge.potential.potential import load_inference_model_from_checkpoint
# checkpoint file is saved in tests/data
checkpoint_file_path = get_path_string(data) + "/model.ckpt"
potential = load_inference_model_from_checkpoint(checkpoint_file_path, jit=False)
# helper functions to load up a water topology and positions for use in OpenMM.
from modelforge.openmm.examples.openmm_water_topology import openmm_water_topology
water, positions = openmm_water_topology()
atomic_numbers = [atom.element.atomic_number for atom in water.atoms()]
# initialize the compute object that will be used in the OpenMM PythonForce.
from modelforge.openmm.potential import generate_compute
comp = generate_compute(potential=potential, atomic_numbers=atomic_numbers)
# set up the PythonForce with the compute object. This will allow us to use the potential for inference within OpenMM.
from openmm import PythonForce
system_force = PythonForce(comp)
# OpenMM simulation setup
import openmm
from openmm.unit import (
kelvin,
picosecond,
femtosecond,
nanometer,
kilojoules_per_mole,
)
system = openmm.System()
for atom in water.atoms():
system.addParticle(atom.element.mass)
system.addForce(system_force)
import sys
from openmm import LangevinMiddleIntegrator
from openmm.app import Simulation, StateDataReporter
# Create an integrator with a time step of 1 fs
temperature = 298.15 * kelvin
frictionCoeff = 1 / picosecond
timeStep = 0.01 * femtosecond
integrator = LangevinMiddleIntegrator(temperature, frictionCoeff, timeStep)
# Create a simulation and set the initial positions and velocities
simulation = Simulation(water, system, integrator)
simulation.context.setPositions(positions)
reporter = StateDataReporter(
file=sys.stdout,
reportInterval=1,
step=True,
time=True,
potentialEnergy=True,
)
simulation.reporters.append(reporter)
simulation.step(10)