Source code for emle.models._emle_aev_computer

#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2025
#
# Authors: Lester Hedges   <lester.hedges@gmail.com>
#          Kirill Zinovjev <kzinovjev@gmail.com>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################

"""EMLE AEVComputer implementation."""

__author__ = "Kirill Zinovjev"
__email__ = "kzinovjev@gmail.com"

import numpy as _np
import torch as _torch
import torchani as _torchani

from torch import Tensor
from typing import Tuple

# Default hyperparameters for AEVComputer. Taken from ANI2x.
_DEFAULT_HYPERS_DICT = {
    "Rcr": 5.1000e00,
    "Rca": 3.5000e00,
    "EtaR": _np.array([1.9700000e01]),
    "ShfR": _np.array(
        [
            8.0000000e-01,
            1.0687500e00,
            1.3375000e00,
            1.6062500e00,
            1.8750000e00,
            2.1437500e00,
            2.4125000e00,
            2.6812500e00,
            2.9500000e00,
            3.2187500e00,
            3.4875000e00,
            3.7562500e00,
            4.0250000e00,
            4.2937500e00,
            4.5625000e00,
            4.8312500e00,
        ]
    ),
    "Zeta": _np.array([1.4100000e01]),
    "ShfZ": _np.array([3.9269908e-01, 1.1780972e00, 1.9634954e00, 2.7488936e00]),
    "EtaA": _np.array([1.2500000e01]),
    "ShfA": _np.array(
        [
            8.0000000e-01,
            1.1375000e00,
            1.4750000e00,
            1.8125000e00,
            2.1500000e00,
            2.4875000e00,
            2.8250000e00,
            3.1625000e00,
        ]
    ),
}


def get_default_hypers(device, dtype):
    """
    Get default hyperparameters for AEVComputer
    """
    hypers = {}
    for key, value in _DEFAULT_HYPERS_DICT.items():
        if isinstance(value, _np.ndarray):
            hypers[key] = _torch.tensor(value, device=device, dtype=dtype)
        else:
            hypers[key] = value
    return hypers


[docs] class EMLEAEVComputer(_torch.nn.Module): """ Wrapper for AEVCalculator from torchani (not a subclass to make sure it works with TorchScript) """ def __init__( self, num_species=7, hypers=None, mask=None, is_external=False, zid_map=None, device=None, dtype=None, ): """ Constructor. Parameters ---------- num_species: int Number of supported species. hypers: dict Hyperparameters for the wrapped AEVComputer. mask: torch.BoolTensor Mask for the features returned from wrapped AEVComputer. is_external: bool Whether the features are calculated externally. zid_map: dict or torch.tensor Map from zid provided here to the ones passed to AEVComputer. device: torch.device The device on which to run the model. dtype: torch.dtype The data type to use for the models floating point tensors. """ super().__init__() # Validate the input. if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") else: device = _torch.get_default_device() self._device = device if dtype is not None: if not isinstance(dtype, _torch.dtype): raise TypeError("'dtype' must be of type 'torch.dtype'") else: dtype = _torch.get_default_dtype() if mask is not None: if not isinstance(mask, _torch.Tensor): raise TypeError("'mask' must be of type 'torch.Tensor'") if len(mask.shape) != 1: raise ValueError("'mask' must be a 1D tensor") if not mask.dtype == _torch.bool: raise ValueError("'mask' must have dtype 'torch.bool'") self._mask = mask if not isinstance(is_external, bool): raise TypeError("'is_external' must be of type 'bool'") self._is_external = is_external # Initalise an empty AEV tensor to use to store the AEVs in parent models. # If AEVs are computed externally, then this tensor will be set by the # parent. self._aev = _torch.empty(0, dtype=dtype, device=device) if not isinstance(num_species, int): raise TypeError("'num_species' must be of type 'int'") if num_species < 1: raise ValueError("'num_species' must be greater than 0") if hypers is not None: if not isinstance(hypers, dict): raise TypeError("'hypers' must be of type 'dict' or None") # Create the AEV computer. if not self._is_external: hypers = hypers or get_default_hypers(device, dtype) self._aev_computer = _torchani.AEVComputer( hypers["Rcr"], hypers["Rca"], hypers["EtaR"], hypers["ShfR"], hypers["EtaA"], hypers["Zeta"], hypers["ShfA"], hypers["ShfZ"], num_species=num_species, ).to(device=device, dtype=dtype) # Create a dummy function to use in forward. else: self._aev_computer = self._dummy_aev_computer if zid_map is None: zid_map = {i: i for i in range(num_species)} if isinstance(zid_map, dict): self._zid_map = -_torch.ones( num_species + 1, dtype=_torch.int, device=device ) for self_atom_zid, aev_atom_zid in zid_map.items(): self._zid_map[self_atom_zid] = aev_atom_zid elif isinstance(zid_map, _torch.Tensor): self._zid_map = zid_map elif isinstance(zid_map, (list, tuple, _np.ndarray)): self._zid_map = _torch.tensor(zid_map, dtype=_torch.int64, device=device) else: raise ValueError("zid_map must be a dict, torch.Tensor, list or tuple")
[docs] def forward(self, zid, xyz): """ Evaluate the AEVs. Parameters ---------- zid: torch.Tensor (N_BATCH, MAX_N_ATOMS) The species indices. xyz: torch.Tensor (N_BATCH, MAX_N_ATOMS, 3) The atomic coordinates. Returns ------- aevs: torch.Tensor (N_BATCH, MAX_N_ATOMS, N_AEV_COMPONENTS) The atomic environment vectors. """ if not self._is_external: zid_aev = self._zid_map[zid] aev = self._aev_computer((zid_aev, xyz))[1] else: aev = self._aev norm = _torch.linalg.norm(aev, dim=2, keepdim=True) aev = self._apply_mask(_torch.where(zid[:, :, None] > -1, aev / norm, 0.0)) return aev
@staticmethod def _dummy_aev_computer(input: Tuple[Tensor, Tensor]) -> Tensor: """ Dummy function to use in forward if AEVs are computed externally. Parameters ---------- zid: torch.Tensor (N_BATCH, MAX_N_ATOMS) The species indices. xyz: torch.Tensor (N_BATCH, MAX_N_ATOMS, 3) The atomic coordinates. Returns ------- aevs: torch.Tensor (N_BATCH, MAX_N_ATOMS, N_AEV_COMPONENTS) The atomic environment vectors. """ return _torch.empty(0, dtype=_torch.float32) def _apply_mask(self, aev): """ Apply the mask to the AEVs. Parameters ---------- aev: torch.Tensor The AEVs to mask. Returns ------- aev: torch.Tensor The masked AEVs. """ return aev[:, :, self._mask] if self._mask is not None else aev
[docs] def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion on the model. """ if not self._is_external: self._aev_computer = self._aev_computer.to(*args, **kwargs) if self._mask is not None: self._mask = self._mask.to(*args, **kwargs) self._zid_map = self._zid_map.to(*args, **kwargs) # Check for a device type in args and update the device attribute. for arg in args: if isinstance(arg, _torch.device): self._device = arg break return self
[docs] def cuda(self, **kwargs): """ Move all model parameters and buffers to CUDA memory. """ if not self._is_external: self._aev_computer = self._aev_computer.cuda(**kwargs) if self._mask is not None: self._mask = self._mask.cuda(**kwargs) self._zid_map = self._zid_map.cuda(**kwargs) return self
[docs] def cpu(self, **kwargs): """ Move all model parameters and buffers to CPU memory. """ if not self._is_external: self._aev_computer = self._aev_computer.cpu(**kwargs) if self._mask is not None: self._mask = self._mask.cpu(**kwargs) self._zid_map = self._zid_map.cpu(**kwargs) return self
[docs] def double(self): """ Casts all floating point model parameters and buffers to float64 precision. """ if not self._is_external: self._aev_computer = self._aev_computer.double() return self
[docs] def float(self): """ Casts all floating point model parameters and buffers to float32 precision. """ if not self._is_external: self._aev_computer = self._aev_computer.float() return self