Source code for emle.models._mace

#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2025
#
# Authors: Lester Hedges   <lester.hedges@gmail.com>
#          Kirill Zinovjev <kzinovjev@gmail.com>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################

"""MACEEMLE model implementation."""

__author__ = "Joao Morado"
__email__ = "joaomorado@gmail.com>"

__all__ = ["MACEEMLE"]

import os as _os
import torch as _torch

from typing import List

from ._emle import EMLE as _EMLE
from ._emle import _has_nnpops
from ._utils import _get_neighbor_pairs

from torch import Tensor

try:
    from mace.calculators.foundations_models import mace_off as _mace_off

    _has_mace = True
except:
    _has_mace = False

try:
    from e3nn.util import jit as _e3nn_jit

    _has_e3nn = True
except:
    _has_e3nn = False


[docs] class MACEEMLE(_torch.nn.Module): """ Combined MACE and EMLE model. Predicts the in vacuo MACE energy along with static and induced EMLE energy components. """ # Class attributes. # A flag for type inference. TorchScript doesn't support inheritance, so # we need to check for an object of type torch.nn.Module, and that it has # the required _is_emle attribute. _is_emle = True def __init__( self, emle_model=None, emle_method="electrostatic", alpha_mode="species", mm_charges=None, qm_charge=0, mace_model=None, atomic_numbers=None, device=None, dtype=None, ): """ Constructor. Parameters ---------- emle_model: str Path to a custom EMLE model parameter file. If None, then the default model for the specified 'alpha_mode' will be used. emle_method: str The desired embedding method. Options are: "electrostatic": Full ML electrostatic embedding. "mechanical": ML predicted charges for the core, but zero valence charge. "nonpol": Non-polarisable ML embedding. Here the induced component of the potential is zeroed. "mm": MM charges are used for the core charge and valence charges are set to zero. alpha_mode: str How atomic polarizabilities are calculated. "species": one volume scaling factor is used for each species "reference": scaling factors are obtained with GPR using the values learned for each reference environmentw mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor List of MM charges for atoms in the QM region in units of mod electron charge. This is required if the 'mm' method is specified. qm_charge: int The charge on the QM region. This can also be passed when calling the forward method. The non-default value will take precendence. mace_model: str Name of the MACE-OFF23 models to use. Available models are 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'. To use a locally trained MACE model, provide the path to the model file. If None, the MACE-OFF23(S) model will be used by default. atomic_numbers: List[int], Tuple[int], numpy.ndarray, torch.Tensor (N_ATOMS,) List of atomic numbers to use in the MACE model. device: torch.device The device on which to run the model. dtype: torch.dtype The data type to use for the models floating point tensors. """ # Call the base class constructor. super().__init__() if not _has_mace: raise ImportError( 'mace is required to use the MACEEMLE model. Install it with "pip install mace-torch"' ) if not _has_e3nn: raise ImportError("e3nn is required to compile the MACEmodel.") if not _has_nnpops: raise ImportError("NNPOps is required to use the MACEEMLE model.") if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") else: device = _torch.get_default_device() if dtype is not None: if not isinstance(dtype, _torch.dtype): raise TypeError("'dtype' must be of type 'torch.dtype'") else: self._dtype = _torch.get_default_dtype() if atomic_numbers is not None: if isinstance(atomic_numbers, _np.ndarray): atomic_numbers = atomic_numbers.tolist() if isinstance(atomic_numbers, (list, tuple)): if not all(isinstance(i, int) for i in atomic_numbers): raise ValueError("'atomic_numbers' must be a list of integers") else: atomic_numbers = _torch.tensor( atomic_numbers, dtype=_torch.int64, device=device ) if not isinstance(atomic_numbers, _torch.Tensor): raise TypeError("'atomic_numbers' must be of type 'torch.Tensor'") # Check that they are integers. if atomic_numbers.dtype != _torch.int64: raise ValueError("'atomic_numbers' must be of dtype 'torch.int64'") self.register_buffer("_atomic_numbers", atomic_numbers.to(device)) else: self.register_buffer( "_atomic_numbers", _torch.tensor( [], dtype=_torch.int64, device=device, requires_grad=False ), ) # Create an instance of the EMLE model. self._emle = _EMLE( model=emle_model, method=emle_method, alpha_mode=alpha_mode, atomic_numbers=(atomic_numbers if atomic_numbers is not None else None), mm_charges=mm_charges, qm_charge=qm_charge, device=device, dtype=dtype, create_aev_calculator=True, ) # Load the MACE model. if mace_model is not None: if not isinstance(mace_model, str): raise TypeError("'mace_model' must be of type 'str'") # Convert to lower case and remove whitespace. formatted_mace_model = mace_model.lower().replace(" ", "") if formatted_mace_model.startswith("mace-off23"): size = formatted_mace_model.split("-")[-1] if not size in ["small", "medium", "large"]: raise ValueError( f"Unsupported MACE model: '{mace_model}'. Available MACE-OFF23 models are " "'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'" ) source_model = _mace_off( model=size, device=device, return_raw_model=True ) else: # Assuming that the model is a local model. if _os.path.exists(mace_model): source_model = _torch.load(mace_model, map_location=device) else: raise FileNotFoundError(f"MACE model file not found: {mace_model}") else: # If no MACE model is provided, use the default MACE-OFF23(S) model. source_model = _mace_off( model="small", device=device, return_raw_model=True ) from mace.tools.scripts_utils import extract_config_mace_model # Extract the config from the model. config = extract_config_mace_model(source_model) # Create the target model. target_model = source_model.__class__(**config).to(device) # Load the state dict. target_model.load_state_dict(source_model.state_dict()) # Compile the model. self._mace = _e3nn_jit.compile(target_model).to(self._dtype) # Create the z_table of the MACE model. self._z_table = [int(z.item()) for z in self._mace.atomic_numbers] if len(self._atomic_numbers) > 0: # Get the node attributes. node_attrs = self._get_node_attrs(self._atomic_numbers) self.register_buffer("_node_attrs", node_attrs.to(self._dtype)) self.register_buffer( "_ptr", _torch.tensor( [0, node_attrs.shape[0]], dtype=_torch.long, requires_grad=False ), ) self.register_buffer( "_batch", _torch.zeros( node_attrs.shape[0], dtype=_torch.long, requires_grad=False ), ) else: # Initialise the node attributes. self.register_buffer("_node_attrs", _torch.tensor([], dtype=self._dtype)) self.register_buffer( "_ptr", _torch.tensor([], dtype=_torch.long, requires_grad=False) ) self.register_buffer( "_batch", _torch.tensor([], dtype=_torch.long, requires_grad=False) ) # No PBCs for now. self.register_buffer( "_pbc", _torch.tensor( [False, False, False], dtype=_torch.bool, requires_grad=False ), ) self.register_buffer( "_cell", _torch.zeros((3, 3), dtype=self._dtype, requires_grad=False) ) # Set the _get_neighbor_pairs method on the instance. self._get_neighbor_pairs = _get_neighbor_pairs @staticmethod def _to_one_hot(indices: _torch.Tensor, num_classes: int) -> _torch.Tensor: """ Convert a tensor of indices to one-hot encoding. Parameters ---------- indices: torch.Tensor Tensor of indices. num_classes: int Number of classes of atomic numbers. Returns ------- oh: torch.Tensor One-hot encoding of the indices. """ shape = indices.shape[:-1] + (num_classes,) oh = _torch.zeros(shape, device=indices.device).view(shape) return oh.scatter_(dim=-1, index=indices, value=1) @staticmethod def _atomic_numbers_to_indices( atomic_numbers: _torch.Tensor, z_table: List[int] ) -> _torch.Tensor: """ Get the indices of the atomic numbers in the z_table. Parameters ---------- atomic_numbers: torch.Tensor (N_ATOMS,) Atomic numbers of QM atoms. z_table: List[int] List of atomic numbers in the MACE model. Returns ------- indices: torch.Tensor (N_ATOMS, 1) Indices of the atomic numbers in the z_table. """ return _torch.tensor( [z_table.index(z) for z in atomic_numbers], dtype=_torch.long ).unsqueeze(-1) def _get_node_attrs(self, atomic_numbers: _torch.Tensor) -> _torch.Tensor: """ Internal method to get the node attributes for the MACE model. Parameters ---------- atomic_numbers: torch.Tensor (N_ATOMS,) Atomic numbers of QM atoms. Returns ------- node_attrs: torch.Tensor (N_ATOMS, N_FEATURES) Node attributes for the MACE model. """ ids = self._atomic_numbers_to_indices(atomic_numbers, z_table=self._z_table) return self._to_one_hot(ids, num_classes=len(self._z_table))
[docs] def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion on the model. """ self._emle = self._emle.to(*args, **kwargs) self._mace = self._mace.to(*args, **kwargs) return self
[docs] def cpu(self, **kwargs): """ Move all model parameters and buffers to CPU memory. """ self._emle = self._emle.cpu(**kwargs) self._mace = self._mace.cpu(**kwargs) if self._atomic_numbers is not None: self._atomic_numbers = self._atomic_numbers.cpu(**kwargs) return self
[docs] def cuda(self, **kwargs): """ Move all model parameters and buffers to CUDA memory. """ self._emle = self._emle.cuda(**kwargs) self._mace = self._mace.cuda(**kwargs) if self._atomic_numbers is not None: self._atomic_numbers = self._atomic_numbers.cuda(**kwargs) return self
[docs] def double(self): """ Cast all floating point model parameters and buffers to float64 precision. """ self._emle = self._emle.double() self._mace = self._mace.double() return self
[docs] def float(self): """ Cast all floating point model parameters and buffers to float32 precision. """ self._emle = self._emle.float() self._mace = self._mace.float() return self
[docs] def forward( self, atomic_numbers: Tensor, charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, qm_charge: int = 0, ) -> Tensor: """ Compute the the MACE and static and induced EMLE energy components. Parameters ---------- atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS) Atomic numbers of QM atoms. charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms) MM point charges in atomic units. xyz_qm: torch.Tensor (N_QM_ATOMS, 3), or (BATCH, N_QM_ATOMS, 3) Positions of QM atoms in Angstrom. xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3) Positions of MM atoms in Angstrom. qm_charge: int The charge on the QM region. Returns ------- result: torch.Tensor (3,) The ANI2x and static and induced EMLE energy components in Hartree. """ # Get the device. device = xyz_qm.device # Batch the inputs if necessary. if atomic_numbers.ndim == 1: atomic_numbers = atomic_numbers.unsqueeze(0) xyz_qm = xyz_qm.unsqueeze(0) xyz_mm = xyz_mm.unsqueeze(0) charges_mm = charges_mm.unsqueeze(0) # Store the number of batches. num_batches = atomic_numbers.shape[0] # Create a list to store the results. results_E_vac = _torch.empty(num_batches, dtype=self._dtype, device=device) results_E_emle_static = _torch.empty( num_batches, dtype=self._dtype, device=device ) results_E_emle_induced = _torch.empty( num_batches, dtype=self._dtype, device=device ) # Loop over the batches. for i in range(num_batches): # Get the edge index and shifts for this configuration. edge_index, shifts = self._get_neighbor_pairs( xyz_qm[i], None, self._mace.r_max, self._dtype, device ) if not _torch.equal(atomic_numbers[i], self._atomic_numbers): # Update the node attributes if the atomic numbers have changed. self._node_attrs = ( self._get_node_attrs(atomic_numbers[i]).to(self._dtype).to(device) ) self._ptr = _torch.tensor( [0, self._node_attrs.shape[0]], dtype=_torch.long, requires_grad=False, ).to(device) self._batch = _torch.zeros( self._node_attrs.shape[0], dtype=_torch.long ).to(device) self._atomic_numbers = atomic_numbers[i] # Create the input dictionary input_dict = { "ptr": self._ptr, "node_attrs": self._node_attrs, "batch": self._batch, "pbc": self._pbc, "positions": xyz_qm[i].to(self._dtype), "edge_index": edge_index, "shifts": shifts, "cell": self._cell, } # Get the in vacuo energy. EV_TO_HARTREE = 0.0367492929 E_vac = self._mace(input_dict, compute_force=False)["interaction_energy"] assert ( E_vac is not None ), "The model did not return any energy. Please check the input." results_E_vac[i] = E_vac[0] * EV_TO_HARTREE # If there are no point charges, return the in vacuo energy and zeros # for the static and induced terms. if len(xyz_mm[i]) == 0: zero = _torch.tensor(0.0, dtype=xyz_qm.dtype, device=device) results_E_emle_static[i] = zero results_E_emle_induced[i] = zero else: # Get the EMLE energy components. E_emle = self._emle( atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge ) results_E_emle_static[i] = E_emle[0][0] results_E_emle_induced[i] = E_emle[1][0] # Return the MACE and EMLE energy components. return _torch.stack( [results_E_vac, results_E_emle_static, results_E_emle_induced] )