#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2025
#
# Authors: Lester Hedges <lester.hedges@gmail.com>
# Kirill Zinovjev <kzinovjev@gmail.com>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################
"""EMLE base model implementation."""
__author__ = "Kirill Zinovjev"
__email__ = "kzinovjev@gmail.com"
__all__ = ["EMLEBase"]
import numpy as _np
import torch as _torch
from torch import Tensor
from typing import Tuple
import torchani as _torchani
try:
import NNPOps as _NNPOps
_NNPOps.OptimizedTorchANI = _patches.OptimizedTorchANI
_has_nnpops = True
except:
_has_nnpops = False
[docs]
class EMLEBase(_torch.nn.Module):
"""
Base class for the EMLE model. This is used to compute valence shell
widths, core charges, valence charges, and the A_thole tensor for a batch
of QM systems, which in turn can be used to compute static and induced
electrostatic embedding energies using the EMLE model.
"""
# Store the list of supported species.
_species = [1, 6, 7, 8, 16]
def __init__(
self,
params,
n_ref,
ref_features,
q_core,
emle_aev_computer=None,
species=None,
alpha_mode="species",
device=None,
dtype=None,
):
"""
Constructor.
Parameters
----------
params: dict
EMLE model parameters.
n_ref: torch.Tensor
number of GPR references for each element in species list
ref_features: torch.Tensor
Feature vectors for GPR references.
q_core: torch.Tensor
Core charges for each element in species list.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
one volume scaling factor is used for each species
"reference":
scaling factors are obtained with GPR using the values learned
for each reference environment
emle_aev_computer: EMLEAEVComputer
EMLE AEV computer instance used to compute AEVs (masked and normalized).
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model.
device: torch.device
The device on which to run the model.
dtype: torch.dtype
The data type to use for the models floating point tensors.
"""
# Call the base class constructor.
super().__init__()
# Validate the parameters.
if not isinstance(params, dict):
raise TypeError("'params' must be of type 'dict'")
if not all(
k in params
for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"]
):
raise ValueError(
"'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'"
)
# Validate the number of references.
if not isinstance(n_ref, _torch.Tensor):
raise TypeError("'n_ref' must be of type 'torch.Tensor'")
if len(n_ref.shape) != 1:
raise ValueError("'n_ref' must be a 1D tensor")
if not n_ref.dtype == _torch.int64:
raise ValueError("'n_ref' must have dtype 'torch.int64'")
# Validate the reference features.
if not isinstance(ref_features, _torch.Tensor):
raise TypeError("'ref_features' must be of type 'torch.Tensor'")
if len(ref_features.shape) != 3:
raise ValueError("'ref_features' must be a 3D tensor")
if not ref_features.dtype in (_torch.float64, _torch.float32):
raise ValueError(
"'ref_features' must have dtype 'torch.float64' or 'torch.float32'"
)
# Validate the core charges.
if not isinstance(q_core, _torch.Tensor):
raise TypeError("'q_core' must be of type 'torch.Tensor'")
if len(q_core.shape) != 1:
raise ValueError("'q_core' must be a 1D tensor")
if not q_core.dtype in (_torch.float64, _torch.float32):
raise ValueError(
"'q_core' must have dtype 'torch.float64' or 'torch.float32'"
)
# Validate the alpha mode.
if alpha_mode is None:
alpha_mode = "species"
if not isinstance(alpha_mode, str):
raise TypeError("'alpha_mode' must be of type 'str'")
alpha_mode = alpha_mode.lower().replace(" ", "")
if alpha_mode not in ["species", "reference"]:
raise ValueError("'alpha_mode' must be 'species' or 'reference'")
self._alpha_mode = alpha_mode
# Validate the AEV computer.
if emle_aev_computer is not None:
from ._emle_aev_computer import EMLEAEVComputer
allowed_types = [EMLEAEVComputer, _torchani.AEVComputer]
if _has_nnpops:
allowed_types.append(
_NNPOps.SymmetryFunctions.TorchANISymmetryFunctions
)
if not isinstance(emle_aev_computer, tuple(allowed_types)):
raise TypeError(
"'aev_computer' must be of type 'torchani.AEVComputer' or "
"'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'"
)
self._emle_aev_computer = emle_aev_computer
if device is not None:
if not isinstance(device, _torch.device):
raise TypeError("'device' must be of type 'torch.device'")
else:
device = _torch.get_default_device()
if dtype is not None:
if not isinstance(dtype, _torch.dtype):
raise TypeError("'dtype' must be of type 'torch.dtype'")
else:
dtype = _torch.get_default_dtype()
self._dtype = dtype
# Store model parameters as tensors.
self.a_QEq = _torch.nn.Parameter(params["a_QEq"])
self.a_Thole = _torch.nn.Parameter(params["a_Thole"])
self.ref_values_s = _torch.nn.Parameter(params["ref_values_s"])
self.ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"])
self.k_Z = _torch.nn.Parameter(params["k_Z"])
if self._alpha_mode == "reference":
try:
self.ref_values_sqrtk = _torch.nn.Parameter(params["sqrtk_ref"])
except:
msg = (
"Missing 'sqrtk_ref' key in params. This is required when "
"using 'reference' alpha mode."
)
raise ValueError(msg)
# Validate the species.
if species is None:
# Use the default species.
species = self._species
if isinstance(species, (_np.ndarray, _torch.Tensor)):
species = species.tolist()
if not isinstance(species, (tuple, list)):
raise TypeError(
"'species' must be of type 'list', 'tuple', or 'numpy.ndarray'"
)
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
raise ValueError("All elements of 'species' must be greater than zero")
# Create a map between species and their indices in the model.
species_map = _np.full(max(species) + 2, fill_value=-1, dtype=_np.int64)
for i, s in enumerate(species):
species_map[s] = i
species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device)
# Compute the inverse of the K matrix.
Kinv = self._get_Kinv(ref_features, 1e-3)
# Calculate GPR coefficients for the valence shell widths (s)
# and electronegativities (chi).
ref_mean_s, c_s = self._get_c(n_ref, self.ref_values_s, Kinv)
ref_mean_chi, c_chi = self._get_c(n_ref, self.ref_values_chi, Kinv)
if self._alpha_mode == "species":
ref_mean_sqrtk = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device)
c_sqrtk = _torch.zeros_like(c_s, dtype=dtype, device=device)
else:
ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, self.ref_values_sqrtk, Kinv)
# Store the current device.
self._device = device
# Register constants as buffers.
self.register_buffer("_species_map", species_map)
self.register_buffer("_Kinv", Kinv)
self.register_buffer("_q_core", q_core)
self.register_buffer("_ref_features", ref_features)
self.register_buffer("_n_ref", n_ref)
self.register_buffer("_ref_mean_s", ref_mean_s)
self.register_buffer("_ref_mean_chi", ref_mean_chi)
self.register_buffer("_ref_mean_sqrtk", ref_mean_sqrtk)
self.register_buffer("_c_s", c_s)
self.register_buffer("_c_chi", c_chi)
self.register_buffer("_c_sqrtk", c_sqrtk)
[docs]
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion on the model.
"""
self._emle_aev_computer = self._emle_aev_computer.to(*args, **kwargs)
self._species_map = self._species_map.to(*args, **kwargs)
self._Kinv = self._Kinv.to(*args, **kwargs)
self._q_core = self._q_core.to(*args, **kwargs)
self._ref_features = self._ref_features.to(*args, **kwargs)
self._n_ref = self._n_ref.to(*args, **kwargs)
self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs)
self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs)
self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs)
self._c_s = self._c_s.to(*args, **kwargs)
self._c_chi = self._c_chi.to(*args, **kwargs)
self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs)
self.k_Z = _torch.nn.Parameter(self.k_Z.to(*args, **kwargs))
# Check for a device type in args and update the device attribute.
for arg in args:
if isinstance(arg, _torch.device):
self._device = arg
break
return self
[docs]
def cuda(self, **kwargs):
"""
Move all model parameters and buffers to CUDA memory.
"""
self._emle_aev_computer = self._emle_aev_computer.cuda(**kwargs)
self._species_map = self._species_map.cuda(**kwargs)
self._Kinv = self._Kinv.cuda(**kwargs)
self._q_core = self._q_core.cuda(**kwargs)
self._ref_features = self._ref_features.cuda(**kwargs)
self._n_ref = self._n_ref.cuda(**kwargs)
self._ref_mean_s = self._ref_mean_s.cuda(**kwargs)
self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs)
self._ref_mean_sqrtk = self._ref_mean_sqrtk.cuda(**kwargs)
self._c_s = self._c_s.cuda(**kwargs)
self._c_chi = self._c_chi.cuda(**kwargs)
self._c_sqrtk = self._c_sqrtk.cuda(**kwargs)
self.k_Z = _torch.nn.Parameter(self.k_Z.cuda(**kwargs))
# Update the device attribute.
self._device = self._species_map.device
return self
[docs]
def cpu(self, **kwargs):
"""
Move all model parameters and buffers to CPU memory.
"""
self._emle_aev_computer = self._emle_aev_computer.cpu(**kwargs)
self._species_map = self._species_map.cpu(**kwargs)
self._Kinv = self._Kinv.cpu(**kwargs)
self._q_core = self._q_core.cpu(**kwargs)
self._ref_features = self._ref_features.cpu(**kwargs)
self._n_ref = self._n_ref.cpu(**kwargs)
self._ref_mean_s = self._ref_mean_s.cpu(**kwargs)
self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs)
self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(**kwargs)
self._c_s = self._c_s.cpu(**kwargs)
self._c_chi = self._c_chi.cpu(**kwargs)
self._c_sqrtk = self._c_sqrtk.cpu(**kwargs)
self.k_Z = _torch.nn.Parameter(self.k_Z.cpu(**kwargs))
# Update the device attribute.
self._device = self._species_map.device
return self
[docs]
def double(self):
"""
Casts all floating point model parameters and buffers to float64 precision.
"""
self._emle_aev_computer = self._emle_aev_computer.double()
self._Kinv = self._Kinv.double()
self._q_core = self._q_core.double()
self._ref_features = self._ref_features.double()
self._ref_mean_s = self._ref_mean_s.double()
self._ref_mean_chi = self._ref_mean_chi.double()
self._ref_mean_sqrtk = self._ref_mean_sqrtk.double()
self._c_s = self._c_s.double()
self._c_chi = self._c_chi.double()
self._c_sqrtk = self._c_sqrtk.double()
self.k_Z = _torch.nn.Parameter(self.k_Z.double())
return self
[docs]
def float(self):
"""
Casts all floating point model parameters and buffers to float32 precision.
"""
self._emle_aev_computer = self._emle_aev_computer.float()
self._Kinv = self._Kinv.float()
self._q_core = self._q_core.float()
self._ref_features = self._ref_features.float()
self._ref_mean_s = self._ref_mean_s.float()
self._ref_mean_chi = self._ref_mean_chi.float()
self._ref_mean_sqrtk = self._ref_mean_sqrtk.float()
self._c_s = self._c_s.float()
self._c_chi = self._c_chi.float()
self._c_sqrtk = self._c_sqrtk.float()
self.k_Z = _torch.nn.Parameter(self.k_Z.float())
return self
[docs]
def forward(self, atomic_numbers, xyz_qm, q_total):
"""
Compute the valence widths, core charges, valence charges, and
A_thole tensor for a batch of QM systems.
Parameters
----------
atomic_numbers: torch.Tensor (N_BATCH, N_QM_ATOMS,)
Atomic numbers of QM atoms.
xyz_qm: torch.Tensor (N_BATCH, N_QM_ATOMS, 3)
Positions of QM atoms in Angstrom.
q_total: torch.Tensor (N_BATCH,)
Total charge.
Returns
-------
result: (torch.Tensor (N_BATCH, N_QM_ATOMS,),
torch.Tensor (N_BATCH, N_QM_ATOMS,),
torch.Tensor (N_BATCH, N_QM_ATOMS,),
torch.Tensor (N_BATCH, N_QM_ATOMS * 3, N_QM_ATOMS * 3,))
Valence widths, core charges, valence charges, A_thole tensor
"""
# Mask for padded coordinates.
mask = atomic_numbers > 0
# Convert the atomic numbers to species IDs.
species_id = self._species_map[atomic_numbers]
# Compute the AEVs.
aev = self._emle_aev_computer(species_id, xyz_qm)
# Compute the MBIS valence shell widths.
s = self._gpr(aev, self._ref_mean_s, self._c_s, species_id)
# Compute the electronegativities.
chi = self._gpr(aev, self._ref_mean_chi, self._c_chi, species_id)
# Convert coordinates to Bohr.
ANGSTROM_TO_BOHR = 1.8897261258369282
xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR
r_data = self._get_r_data(xyz_qm_bohr, mask)
q_core = self._q_core[species_id] * mask
q = self._get_q(r_data, s, chi, q_total, mask)
q_val = q - q_core
k = self.k_Z[species_id]
if self._alpha_mode == "reference":
k_scale = (
self._gpr(aev, self._ref_mean_sqrtk, self._c_sqrtk, species_id) ** 2
)
k = k_scale * k
A_thole = self._get_A_thole(r_data, s, q_val, k, self.a_Thole)
return s, q_core, q_val, A_thole
@classmethod
def _get_Kinv(cls, ref_features, sigma):
"""
Internal function to compute the inverse of the K matrix for GPR.
Parameters
----------
ref_features: torch.Tensor (N_Z, MAX_N_REF, N_FEAT)
The basis feature vectors for each species.
sigma: float
The uncertainty of the observations (regularizer).
Returns
-------
result: torch.Tensor (MAX_N_REF, MAX_N_REF)
The inverse of the K matrix.
"""
n = ref_features.shape[1]
K = (ref_features @ ref_features.swapaxes(1, 2)) ** 2
return _torch.linalg.inv(
K + sigma**2 * _torch.eye(n, dtype=ref_features.dtype, device=K.device)
)
@classmethod
def _get_c(cls, n_ref, ref, Kinv):
"""
Internal method to compute the coefficients of the GPR model.
"""
mask = _torch.arange(ref.shape[1], device=n_ref.device) < n_ref[:, None]
ref_mean = _torch.sum(ref * mask, dim=1) / n_ref
ref_shifted = (ref - ref_mean[:, None]) * mask
return ref_mean, (Kinv @ ref_shifted[:, :, None]).squeeze()
def _gpr(self, mol_features, ref_mean, c, zid):
"""
Internal method to predict a property using Gaussian Process Regression.
Parameters
----------
mol_features: torch.Tensor (N_BATCH, N_ATOMS, N_FEAT)
The feature vectors for each atom.
ref_mean: torch.Tensor (N_Z,)
The mean of the reference values for each species.
c: torch.Tensor (N_Z, MAX_N_REF)
The coefficients of the GPR model.
zid: torch.Tensor (N_BATCH, N_ATOMS,)
The species identity value of each atom.
Returns
-------
result: torch.Tensor (N_BATCH, N_ATOMS)
The values of the predicted property for each atom.
"""
result = _torch.zeros(
zid.shape, dtype=mol_features.dtype, device=mol_features.device
)
for i in range(len(self._n_ref)):
n_ref = self._n_ref[i]
ref_features_z = self._ref_features[i, :n_ref]
mol_features_z = mol_features[zid == i]
K_mol_ref2 = (mol_features_z @ ref_features_z.T) ** 2
result[zid == i] = K_mol_ref2 @ c[i, :n_ref] + ref_mean[i]
return result
@classmethod
def _get_r_data(cls, xyz, mask):
"""
Internal method to calculate r_data object.
Parameters
----------
xyz: torch.Tensor (N_BATCH, N_ATOMS, 3)
Atomic positions.
mask: torch.Tensor (N_BATCH, N_ATOMS)
Mask for padded coordinates
Returns
-------
result: r_data object
"""
n_batch, n_atoms_max = xyz.shape[:2]
mask_mat = mask[:, :, None] * mask[:, None, :]
rr_mat = xyz[:, :, None, :] - xyz[:, None, :, :]
r_mat = _torch.where(mask_mat, _torch.cdist(xyz, xyz), 0.0)
r_inv = _torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat)
r_inv1 = r_inv.repeat_interleave(3, dim=2)
r_inv2 = r_inv1.repeat_interleave(3, dim=1)
# Get a stacked matrix of outer products over the rr_mat tensors.
outer = _torch.einsum("bnik,bnij->bnjik", rr_mat, rr_mat).reshape(
(n_batch, n_atoms_max * 3, n_atoms_max * 3)
)
id2 = _torch.tile(
_torch.eye(3, dtype=xyz.dtype, device=xyz.device).T,
(1, n_atoms_max, n_atoms_max),
)
t01 = r_inv
t21 = -id2 * r_inv2**3
t22 = 3 * outer * r_inv2**5
return r_mat, t01, t21, t22
def _get_q(
self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total, mask
):
"""
Internal method that predicts MBIS charges
(Eq. 16 in 10.1021/acs.jctc.2c00914)
Parameters
----------
r_data: r_data object (output of self._get_r_data)
s: torch.Tensor (N_BATCH, N_ATOMS,)
MBIS valence shell widths.
chi: torch.Tensor (N_BATCH, N_ATOMS,)
Electronegativities.
q_total: torch.Tensor (N_BATCH,)
Total charge
mask: torch.Tensor (N_BATCH, N_ATOMS)
Mask for padded coordinates
Returns
-------
result: torch.Tensor (N_BATCH, N_ATOMS,)
Predicted MBIS charges.
"""
A = self._get_A_QEq(r_data, s, mask)
b = _torch.hstack([-chi, q_total[:, None]])
return _torch.linalg.solve(A, b)[:, :-1]
def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, mask):
"""
Internal method, generates A matrix for charge prediction
(Eq. 16 in 10.1021/acs.jctc.2c00914)
Parameters
----------
r_data: r_data object (output of self._get_r_data)
s: torch.Tensor (N_BATCH, N_ATOMS,)
MBIS valence shell widths.
mask: torch.Tensor (N_BATCH, N_ATOMS)
Mask for padded coordinates
Returns
-------
result: torch.Tensor (N_BATCH, N_ATOMS + 1, N_ATOMS + 1)
"""
s_gauss = s * self.a_QEq
s2 = s_gauss**2
s2_mat = s2[:, :, None] + s2[:, None, :]
s_mat = _torch.where(s2_mat > 0, _torch.sqrt(s2_mat + 1e-16), 0)
device = r_data[0].device
dtype = r_data[0].dtype
A = self._get_T0_gaussian(r_data[1], r_data[0], s_mat)
diag_ones = _torch.ones_like(
A.diagonal(dim1=-2, dim2=-1), dtype=dtype, device=device
)
pi = _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device))
new_diag = diag_ones * _torch.where(mask, 1.0 / ((s_gauss + 1e-16) * pi), 0)
diag_mask = _torch.diag_embed(diag_ones)
A = diag_mask * _torch.diag_embed(new_diag) + (1.0 - diag_mask) * A
# Store the dimensions of A.
n_batch, x, y = A.shape
# Create an tensor of ones with one more row and column than A.
B_diag = _torch.ones((n_batch, x + 1), dtype=dtype, device=device)
B = _torch.diag_embed(B_diag)
# Copy A into B.
mask_mat = mask[:, :, None] * mask[:, None, :]
B[:, :x, :y] = _torch.where(mask_mat, A, B[:, :x, :y])
# Set last row and column to 1 (masked)
B[:, -1, :-1] = mask.float()
B[:, :-1, -1] = mask.float()
# Set the final entry on the diagonal to zero.
B[:, -1, -1] = 0.0
return B
@staticmethod
def _get_T0_gaussian(t01, r, s_mat):
"""
Internal method, calculates T0 tensor for Gaussian densities (for QEq).
Parameters
----------
t01: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS)
T0 tensor for QM atoms.
r: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS)
Distance matrix for QM atoms.
s_mat: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS)
Matrix of Gaussian sigmas for QM atoms.
Returns
-------
results: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS)
"""
sqrt2 = _torch.sqrt(_torch.tensor([2.0], dtype=r.dtype, device=r.device))
return t01 * _torch.where(
s_mat > 0, _torch.erf(r / ((s_mat + 1e-16) * sqrt2)), 0.0
)
@classmethod
def _get_A_thole(
cls, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k, a_Thole
):
"""
Internal method, generates A matrix for induced dipoles prediction
(Eq. 20 in 10.1021/acs.jctc.2c00914)
Parameters
----------
r_data: r_data object (output of self._get_r_data)
s: torch.Tensor (N_BATCH, N_ATOMS,)
MBIS valence shell widths.
q_val: torch.Tensor (N_BATCH, N_ATOMS,)
MBIS charges.
k: torch.Tensor (N_BATCH, N_ATOMS,)
Scaling factors for polarizabilities.
a_Thole: float
Thole damping factor
Returns
-------
result: torch.Tensor (N_BATCH, N_ATOMS * 3, N_ATOMS * 3)
The A matrix for induced dipoles prediction.
"""
v = -60 * q_val * s**3
alpha = v * k
alphap = alpha * a_Thole
alphap_mat = alphap[:, :, None] * alphap[:, None, :]
au3 = _torch.where(
alphap_mat > 0, r_data[0] ** 3 / _torch.sqrt(alphap_mat + 1e-16), 0
)
au31 = au3.repeat_interleave(3, dim=2)
au32 = au31.repeat_interleave(3, dim=1)
A = -cls._get_T2_thole(r_data[2], r_data[3], au32)
alpha3 = alpha.repeat_interleave(3, dim=1)
new_diag = _torch.where(alpha3 > 0, 1.0 / (alpha3 + 1e-16), 1.0)
diag_ones = _torch.ones_like(new_diag, dtype=A.dtype, device=A.device)
mask = _torch.diag_embed(diag_ones)
A = mask * _torch.diag_embed(new_diag) + (1.0 - mask) * A
return A
@classmethod
def _get_T2_thole(cls, tr21, tr22, au3):
"""
Internal method, calculates T2 tensor with Thole damping.
Parameters
----------
tr21: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
r_data[2]
tr21: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
r_data[3]
au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
Scaled distance matrix (see _get_A_thole).
Returns
-------
result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
"""
return cls._lambda3(au3) * tr21 + cls._lambda5(au3) * tr22
@staticmethod
def _lambda3(au3):
"""
Internal method, calculates r^3 component of T2 tensor with Thole
damping.
Parameters
----------
au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
Scaled distance matrix (see _get_A_thole).
Returns
-------
result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
"""
return 1 - _torch.exp(-au3)
@staticmethod
def _lambda5(au3):
"""
Internal method, calculates r^5 component of T2 tensor with Thole
damping.
Parameters
----------
au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
Scaled distance matrix (see _get_A_thole).
Returns
-------
result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
"""
return 1 - (1 + au3) * _torch.exp(-au3)
[docs]
@staticmethod
def get_static_energy(
q_core: Tensor,
q_val: Tensor,
charges_mm: Tensor,
mesh_data: Tuple[Tensor, Tensor, Tensor],
) -> Tensor:
"""
Calculate the static electrostatic energy.
Parameters
----------
q_core: torch.Tensor (N_BATCH, N_QM_ATOMS,)
QM core charges.
q_val: torch.Tensor (N_BATCH, N_QM_ATOMS,)
QM valence charges.
charges_mm: torch.Tensor (N_BATCH, N_MM_ATOMS,)
MM charges.
mesh_data: mesh_data object (output of self._get_mesh_data)
Mesh data object.
Returns
-------
result: torch.Tensor (N_BATCH,)
Static electrostatic energy.
"""
vpot_q_core = EMLEBase._get_vpot_q(q_core, mesh_data[0])
vpot_q_val = EMLEBase._get_vpot_q(q_val, mesh_data[1])
vpot_static = vpot_q_core + vpot_q_val
return _torch.sum(vpot_static * charges_mm, dim=1)
[docs]
@staticmethod
def get_induced_energy(
A_thole: Tensor,
charges_mm: Tensor,
s: Tensor,
mesh_data: Tuple[Tensor, Tensor, Tensor],
mask: Tensor,
) -> Tensor:
"""
Calculate the induced electrostatic energy.
Parameters
----------
A_thole: torch.Tensor (N_BATCH, MAX_QM_ATOMS * 3, MAX_QM_ATOMS * 3)
The A matrix for induced dipoles prediction.
charges_mm: torch.Tensor (N_BATCH, MAX_MM_ATOMS,)
MM charges.
s: torch.Tensor (N_BATCH, MAX_QM_ATOMS,)
MBIS valence shell widths.
mesh_data: mesh_data object (output of self._get_mesh_data)
Mesh data object.
mask: torch.Tensor (N_BATCH, MAX_QM_ATOMS)
Mask for padded coordinates.
Returns
-------
result: torch.Tensor (N_BATCH,)
Induced electrostatic energy.
"""
mu_ind = EMLEBase._get_mu_ind(A_thole, mesh_data, charges_mm, s, mask)
vpot_ind = EMLEBase._get_vpot_mu(mu_ind, mesh_data[2])
return _torch.sum(vpot_ind * charges_mm, dim=1) * 0.5
@staticmethod
def _get_mu_ind(
A: Tensor,
mesh_data: Tuple[Tensor, Tensor, Tensor],
q: Tensor,
s: Tensor,
mask: Tensor,
) -> Tensor:
"""
Internal method, calculates induced atomic dipoles
(Eq. 20 in 10.1021/acs.jctc.2c00914)
Parameters
----------
A: torch.Tensor (N_BATCH, MAX_QM_ATOMS * 3, MAX_QM_ATOMS * 3)
The A matrix for induced dipoles prediction.
mesh_data: mesh_data object (output of self._get_mesh_data)
q: torch.Tensor (N_BATCH, MAX_MM_ATOMS,)
MM point charges.
s: torch.Tensor (N_BATCH, N_QM_ATOMS,)
MBIS valence shell widths.
q_val: torch.Tensor (N_BATCH, N_QM_ATOMS,)
MBIS valence charges.
mask: torch.Tensor (N_BATCH, N_QM_ATOMS)
Mask for padded coordinates.
Returns
-------
result: torch.Tensor (N_BATCH, MAX_QM_ATOMS, 3)
Array of induced dipoles
"""
r = 1.0 / mesh_data[0]
f1 = _torch.where(mask, EMLEBase._get_f1_slater(r, s[:, :, None] * 2.0), 0.0)
fields = _torch.sum(
mesh_data[2] * f1[..., None] * q[:, None, :, None], dim=2
).reshape(len(s), -1)
mu_ind = _torch.linalg.solve(A, fields)
return mu_ind.reshape((mu_ind.shape[0], -1, 3))
@staticmethod
def _get_vpot_q(q, T0):
"""
Internal method to calculate the electrostatic potential.
Parameters
----------
q: torch.Tensor (N_BATCH, MAX_QM_ATOMS,)
QM charges (q_core or q_val).
T0: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS)
T0 tensor for QM atoms over MM atom positions.
Returns
-------
result: torch.Tensor (N_BATCH, MAX_MM_ATOMS)
Electrostatic potential over MM atoms.
"""
return _torch.sum(T0 * q[:, :, None], dim=1)
@staticmethod
def _get_vpot_mu(mu: Tensor, T1: Tensor) -> Tensor:
"""
Internal method to calculate the electrostatic potential generated
by atomic dipoles.
Parameters
----------
mu: torch.Tensor (N_BATCH, MAX_QM_ATOMS, 3)
Atomic dipoles.
T1: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS, 3)
T1 tensor for QM atoms over MM atom positions.
Returns
-------
result: torch.Tensor (N_BATCH, MAX_MM_ATOMS)
Electrostatic potential over MM atoms.
"""
return -_torch.einsum("ijkl,ijl->ik", T1, mu)
@staticmethod
def _get_mesh_data(
xyz: Tensor, xyz_mesh: Tensor, s: Tensor, mask: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Internal method, calculates mesh_data object.
Parameters
----------
xyz: torch.Tensor (N_BATCH, MAX_QM_ATOMS, 3)
Atomic positions.
xyz_mesh: torch.Tensor (N_BATCH, MAX_MM_ATOMS, 3)
MM positions.
s: torch.Tensor (N_BATCH, MAX_QM_ATOMS,)
MBIS valence widths.
mask: torch.Tensor (N_BATCH, MAX_QM_ATOMS)
Mask for padded coordinates.
Returns
-------
result: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Tuple of mesh data objects.
"""
rr = xyz_mesh[:, None, :, :] - xyz[:, :, None, :]
r = _torch.linalg.norm(rr, ord=2, dim=3)
# Mask for padded coordinates.
r_inv = _torch.where(mask, 1.0 / r, 0.0)
T0_slater = _torch.where(mask, EMLEBase._get_T0_slater(r, s[:, :, None]), 0.0)
return (
r_inv,
T0_slater,
-rr * r_inv[..., None] ** 3,
)
@staticmethod
def _get_f1_slater(r: Tensor, s: Tensor) -> Tensor:
"""
Internal method, calculates damping factors for Slater densities.
Parameters
----------
r: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS)
Distances from QM to MM atoms.
s: torch.Tensor (N_BATCH, MAX_QM_ATOMS,)
MBIS valence widths.
Returns
-------
result: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS)
"""
return (
EMLEBase._get_T0_slater(r, s) * r
- _torch.exp(-r / s) / s * (0.5 + r / (s * 2)) * r
)
@staticmethod
def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor:
"""
Internal method, calculates T0 tensor for Slater densities.
Parameters
----------
r: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS)
Distances from QM to MM atoms.
s: torch.Tensor (N_BATCH, MAX_QM_ATOMS,)
MBIS valence widths.
Returns
-------
results: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS)
"""
return (1 - (1 + r / (s * 2)) * _torch.exp(-r / s)) / r