Source code for emle.models._ani

#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2025
#
# Authors: Lester Hedges   <lester.hedges@gmail.com>
#          Kirill Zinovjev <kzinovjev@gmail.com>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################

"""ANI2xEMLE model implementation."""

__author__ = "Lester Hedges"
__email__ = "lester.hedges@gmail.com"

__all__ = ["ANI2xEMLE"]

import numpy as _np
import torch as _torch
import torchani as _torchani

from torch import Tensor
from typing import Optional, Tuple

from ._emle import EMLE as _EMLE

try:
    import NNPOps as _NNPOps

    _has_nnpops = True
except:
    _has_nnpops = False
    pass


[docs] class ANI2xEMLE(_torch.nn.Module): """ Combined ANI2x and EMLE model. Predicts the in vacuo ANI2x energy along with static and induced EMLE energy components. """ # Class attributes. # A flag for type inference. TorchScript doesn't support inheritance, so # we need to check for an object of type torch.nn.Module, and that it has # the required _is_emle attribute. _is_emle = True def __init__( self, emle_model=None, emle_method="electrostatic", alpha_mode="species", mm_charges=None, qm_charge=0, model_index=None, ani2x_model=None, atomic_numbers=None, device=None, dtype=None, ): """ Constructor. Parameters ---------- emle_model: str Path to a custom EMLE model parameter file. If None, then the default model for the specified 'alpha_mode' will be used. emle_method: str The desired embedding method. Options are: "electrostatic": Full ML electrostatic embedding. "mechanical": ML predicted charges for the core, but zero valence charge. "nonpol": Non-polarisable ML embedding. Here the induced component of the potential is zeroed. "mm": MM charges are used for the core charge and valence charges are set to zero. alpha_mode: str How atomic polarizabilities are calculated. "species": one volume scaling factor is used for each species "reference": scaling factors are obtained with GPR using the values learned for each reference environment mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor List of MM charges for atoms in the QM region in units of mod electron charge. This is required if the 'mm' method is specified. qm_charge: int The charge on the QM region. This can also be passed when calling the forward method. The non-default value will take precendence. model_index: int The index of the ANI2x model to use. If None, then the full 8 model ensemble will be used. ani2x_model: torchani.models.ANI2x, NNPOPS.OptimizedTorchANI An existing ANI2x model to use. If None, a new ANI2x model will be created. If using an OptimizedTorchANI model, please ensure that the ANI2x model from which it derived was created using periodic_table_index=True. atomic_numbers: List[float], Tuple[float], numpy.ndarray, torch.Tensor (N_ATOMS,) Atomic numbers for the QM region. This allows use of optimised AEV symmetry functions from the NNPOps package. Only use this option if you are using a fixed QM region, i.e. the same QM region for each evalulation of the module. device: torch.device The device on which to run the model. dtype: torch.dtype The data type to use for the models floating point tensors. """ # Call the base class constructor. super().__init__() if model_index is not None: if not isinstance(model_index, int): raise TypeError("'model_index' must be of type 'int'") if model_index < 0 or model_index > 7: raise ValueError("'model_index' must be in the range [0, 7]") self._model_index = model_index if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") else: device = _torch.get_default_device() self._device = device if dtype is not None: if not isinstance(dtype, _torch.dtype): raise TypeError("'dtype' must be of type 'torch.dtype'") else: dtype = _torch.get_default_dtype() if atomic_numbers is not None: if isinstance(atomic_numbers, _np.ndarray): atomic_numbers = atomic_numbers.tolist() if isinstance(atomic_numbers, (list, tuple)): if not all(isinstance(i, int) for i in atomic_numbers): raise ValueError("'atomic_numbers' must be a list of integers") else: atomic_numbers = _torch.tensor(atomic_numbers, dtype=_torch.int64) if not isinstance(atomic_numbers, _torch.Tensor): raise TypeError("'atomic_numbers' must be of type 'torch.Tensor'") # Check that they are integers. if atomic_numbers.dtype != _torch.int64: raise ValueError("'atomic_numbers' must be of dtype 'torch.int64'") self._atomic_numbers = atomic_numbers.to(device) else: self._atomic_numbers = None # Create an instance of the EMLE model. self._emle = _EMLE( model=emle_model, method=emle_method, alpha_mode=alpha_mode, atomic_numbers=(atomic_numbers if atomic_numbers is not None else None), mm_charges=mm_charges, qm_charge=qm_charge, device=device, dtype=dtype, create_aev_calculator=False, ) # Initialise the NNOps flag. self._is_nnpops = False if ani2x_model is not None: # Add the base ANI2x model and ensemble. allowed_types = [ _torchani.models.BuiltinModel, _torchani.models.BuiltinEnsemble, ] # Add the optimised model if NNPOps is available. try: allowed_types.append(_NNPOps.OptimizedTorchANI) except: pass if not isinstance(ani2x_model, tuple(allowed_types)): raise TypeError(f"'ani2x_model' must be of type {allowed_types}") if ( isinstance( ani2x_model, (_torchani.models.BuiltinModel, _torchani.models.BuiltinEnsemble), ) and not ani2x_model.periodic_table_index ): raise ValueError( "The ANI2x model must be created with 'periodic_table_index=True'" ) self._ani2x = ani2x_model.to(device) if dtype == _torch.float64: self._ani2x = self._ani2x.double() else: # Create the ANI2x model. self._ani2x = _torchani.models.ANI2x( periodic_table_index=True, model_index=model_index ).to(device) if dtype == _torch.float64: self._ani2x = self._ani2x.double() # Optimise the ANI2x model if atomic_numbers are specified. if _has_nnpops and atomic_numbers is not None: try: atomic_numbers = atomic_numbers.reshape( 1, *atomic_numbers.shape ).to(self._device) self._ani2x = _NNPOps.OptimizedTorchANI( self._ani2x, atomic_numbers ).to(device) # Flag that the model has been optimised with NNPOps. self._is_nnpops = True except Exception as e: raise RuntimeError( "Failed to optimise the ANI2x model with NNPOps." ) from e # Add a hook to the ANI2x model to capture the AEV features. self._add_hook() def _add_hook(self): """ Add a hook to the ANI2x model to capture the AEV features. """ # Assign a tensor attribute that can be used for assigning the AEVs. self._ani2x.aev_computer._aev = _torch.empty(0, device=self._device) # Hook the forward pass of the ANI2x model to get the AEV features. # Note that this currently requires a patched versions of TorchANI and NNPOps. if _has_nnpops and isinstance(self._ani2x, _NNPOps.OptimizedTorchANI): def hook( module, input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]], output: Tuple[Tensor, Tensor], ): module._aev = output[1] else: def hook( module, input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]], output: _torchani.aev.SpeciesAEV, ): module._aev = output[1] # Register the hook. self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook)
[docs] def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion on the model. """ self._emle = self._emle.to(*args, **kwargs) self._ani2x = self._ani2x.to(*args, **kwargs) # Check for a device type in args and update the device attribute. for arg in args: if isinstance(arg, _torch.device): self._device = arg break return self
[docs] def cpu(self, **kwargs): """ Move all model parameters and buffers to CPU memory. """ self._emle = self._emle.cpu(**kwargs) self._ani2x = self._ani2x.cpu(**kwargs) self._device = _torch.device("cpu") return self
[docs] def cuda(self, **kwargs): """ Move all model parameters and buffers to CUDA memory. """ self._emle = self._emle.cuda(**kwargs) self._ani2x = self._ani2x.cuda(**kwargs) self._device = _torch.device("cuda") return self
[docs] def double(self): """ Casts all model parameters and buffers to float64 precision. """ self._emle = self._emle.double() self._ani2x = self._ani2x.double() return self
[docs] def float(self): """ Casts all model parameters and buffers to float32 precision. """ self._emle = self._emle.float() # Using .float() or .to(torch.float32) is broken for ANI2x models. self._ani2x = _torchani.models.ANI2x( periodic_table_index=True, model_index=self._model_index ).to(self._device) # Optimise the ANI2x model if atomic_numbers were specified. if self._atomic_numbers is not None: try: from NNPOps import OptimizedTorchANI as _OptimizedTorchANI species = self._atomic_numbers.reshape(1, *self._atomic_numbers.shape) self._ani2x = _OptimizedTorchANI(self._ani2x, species).to(self._device) except: pass # Re-append the hook. self._add_hook() return self
[docs] def forward( self, atomic_numbers: Tensor, charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, qm_charge: int = 0, ) -> Tensor: """ Compute the the ANI2x and static and induced EMLE energy components. Parameters ---------- atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS) Atomic numbers of QM atoms. A non-existent atom is represented by -1. charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms) MM point charges in atomic units. xyz_qm: torch.Tensor (N_QM_ATOMS, 3) or (BATCH, N_QM_ATOMS, 3) Positions of QM atoms in Angstrom. xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3) Positions of MM atoms in Angstrom. qm_charge: int or torch.Tensor (BATCH,) The charge on the QM region. Returns ------- result: torch.Tensor (3,) or (3, BATCH) The ANI2x and static and induced EMLE energy components in Hartree. """ # Batch the inputs if necessary. if atomic_numbers.ndim == 1: atomic_numbers = atomic_numbers.unsqueeze(0) xyz_qm = xyz_qm.unsqueeze(0) xyz_mm = xyz_mm.unsqueeze(0) charges_mm = charges_mm.unsqueeze(0) elif self._is_nnpops: raise RuntimeError( "Batched inputs are not supported when using NNPOps optimised models." ) # Get the in vacuo energy. E_vac = self._ani2x((atomic_numbers, xyz_qm)).energies # If there are no point charges, return the in vacuo energy and zeros # for the static and induced terms. if xyz_mm.shape[1] == 0: zero = _torch.zeros( atomic_numbers.shape[0], dtype=xyz_qm.dtype, device=xyz_qm.device ) return _torch.stack((E_vac, zero, zero)) # Set the AEVs captured by the forward hook as an attribute of the # EMLE AEVComputer instance. self._emle._emle_base._emle_aev_computer._aev = self._ani2x.aev_computer._aev # Get the EMLE energy components. E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge) # Return the ANI2x and EMLE energy components. return _torch.stack((E_vac, E_emle[0], E_emle[1]))