Source code for emle.train._trainer

#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2025
#
# Authors: Lester Hedges   <lester.hedges@gmail.com>
#          Kirill Zinovjev <kzinovjev@gmail.com>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################

from loguru import logger as _logger
import os as _os
import sys as _sys
import torch as _torch

from ..models import EMLEAEVComputer as _EMLEAEVComputer
from ..models import EMLEBase as _EMLEBase
from ._gpr import GPR as _GPR
from ._ivm import IVM as _IVM
from ._loss import QEqLoss as _QEqLoss
from ._loss import TholeLoss as _TholeLoss
from ._utils import pad_to_max as _pad_to_max
from ._utils import mean_by_z as _mean_by_z


[docs] class EMLETrainer: def __init__( self, emle_base=_EMLEBase, qeq_loss=_QEqLoss, thole_loss=_TholeLoss, log_level=None, log_file=None, ): if emle_base is not _EMLEBase: raise TypeError("emle_base must be a reference to EMLEBase") self._emle_base = emle_base if qeq_loss is not _QEqLoss: raise TypeError("qeq_loss must be a reference to QEqLoss") self._qeq_loss = qeq_loss if thole_loss is not _TholeLoss: raise TypeError("thole_loss must be a reference to TholeLoss") self._thole_loss = thole_loss # First handle the logger. if log_level is None: log_level = "INFO" else: if not isinstance(log_level, str): raise TypeError("'log_level' must be of type 'str'") # Delete whitespace and convert to upper case. log_level = log_level.upper().replace(" ", "") # Validate the log level. if not log_level in _logger._core.levels.keys(): raise ValueError( f"Unsupported logging level '{log_level}'. Options are: {', '.join(_logger._core.levels.keys())}" ) self._log_level = log_level # Validate the log file. if log_file is not None: if not isinstance(log_file, str): raise TypeError("'log_file' must be of type 'str'") # Try to create the directory. dirname = _os.path.dirname(log_file) if dirname != "": try: _os.makedirs(dirname, exist_ok=True) except: raise IOError( f"Unable to create directory for log file: {log_file}" ) self._log_file = _os.path.abspath(log_file) else: self._log_file = _sys.stdout # Update the logger. _logger.remove() _logger.add(self._log_file, level=self._log_level) @staticmethod def _get_zid_mapping(species): """ Generate the species ID mapping. Parameters ---------- species: torch.Tensor(N_SPECIES) Species IDs. Returns ------- mapping: torch.Tensor Species ID mapping. """ zid_mapping = ( _torch.ones(max(species) + 1, dtype=_torch.int, device=species.device) * -1 ) for i, z in enumerate(species): zid_mapping[z] = i return zid_mapping @staticmethod def _write_model_to_file(emle_model, model_filename): """ Write the trained model to a file. Parameters ---------- emle_model: dict Trained EMLE model. model_filename: str Filename to save the trained model. """ import scipy.io # Deatch the tensors, convert to numpy arrays and save the model. emle_model = { k: v.cpu().detach().numpy() if isinstance(v, _torch.Tensor) else v for k, v in emle_model.items() if v is not None } scipy.io.savemat(model_filename, emle_model) @staticmethod def _train_s(s, zid, aev_mols, aev_ivm_allz, sigma): """ Train the s model. Parameters ---------- s: torch.Tensor(N_BATCH, N_ATOMS) Atomic widths. zid: torch.Tensor(N_BATCH, N_ATOMS) Species IDs. aev_mols: torch.Tensor(N_BATCH, N_ATOMS, N_AEV) Atomic environment vectors. aev_ivm_allz: torch.Tensor(N_BATCH, N_ATOMS, N_AEV) Atomic environment vectors for all species. sigma: float GPR sigma value. Returns ------- torch.Tensor(N_BATCH, N_ATOMS) Atomic widths. """ n_ref = _torch.tensor([_.shape[0] for _ in aev_ivm_allz], device=s.device) K_ref_ref_padded, K_mols_ref = _GPR.get_gpr_kernels( aev_mols, zid, aev_ivm_allz, n_ref ) ref_values_s = _GPR.fit_atomic_sparse_gpr( s, K_mols_ref, K_ref_ref_padded, zid, sigma, n_ref ) return _pad_to_max(ref_values_s) @staticmethod def _train_model( loss_class, opt_param_names, lr, epochs, emle_base, print_every=10, *args, **kwargs, ): """ Train a model. Parameters ---------- loss_class: class Loss class. opt_param_names: list of str List of parameter names to optimize. lr: float Learning rate. epochs: int Number of training epochs. emle_base: EMLEBase EMLEBase instance. print_every: int How often to print training progress Returns ------- model Trained model. """ def _train_loop( loss_instance, optimizer, epochs, print_every=10, *args, **kwargs ): """ Perform the training loop. Parameters ---------- loss_instance: nn.Module Loss instance. optimizer: torch.optim.Optimizer Optimizer. epochs: int Number of training epochs. print_every: int How often to print training progress args: list Positional arguments to pass to the forward method. kwargs: dict Keyword arguments to pass to the forward method. Returns ------- loss Forward loss. """ for epoch in range(epochs): loss_instance.train() optimizer.zero_grad() loss, rmse, max_error = loss_instance(*args, **kwargs) loss.backward(retain_graph=True) optimizer.step() if (epoch + 1) % print_every == 0: _logger.info( f"Epoch {epoch+1}: Loss ={loss.item():9.4f} " f"RMSE ={rmse.item():9.4f} " f"Max Error ={max_error.item():9.4f}" ) return loss model = loss_class(emle_base) opt_parameters = [ param for name, param in model.named_parameters() if name.split(".")[1] in opt_param_names ] optimizer = _torch.optim.Adam(opt_parameters, lr=lr) _train_loop(model, optimizer, epochs, print_every, *args, **kwargs) return model
[docs] def train( self, z, xyz, s, q_core, q_val, alpha, train_mask, alpha_mode="reference", sigma=1e-3, ivm_thr=0.05, epochs=100, lr_qeq=0.05, lr_thole=0.05, lr_sqrtk=0.05, print_every=10, computer_n_species=None, computer_zid_map=None, model_filename="emle_model.mat", plot_data_filename=None, device=_torch.device("cuda"), dtype=_torch.float64, ): """ Train an EMLE model. Parameters ---------- z: numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS) Atomic numbers. xyz: numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOM, 3 Atomic coordinates. s: numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS) Atomic widths. q_core: numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS) Atomic core charges. q_val: array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS) Atomic valence charges. alpha: array or tensor or list of tensor/arrays of shape (N_BATCH, 3, 3) Atomic polarizabilities. train_mask: torch.Tensor(N_BATCH,) Mask for training samples. alpha_mode: 'species' or 'reference' Mode for polarizability model. sigma: float GPR sigma value. ivm_thr: float IVM threshold. epochs: int Number of training epochs. lr_qeq: float Learning rate for QEq model. lr_thole: float Learning rate for Thole model. lr_sqrtk: float Learning rate for sqrtk. print_every: int How often to print training progress. computer_n_species: int Number of species supported by calculator (for ani2x backend) computer_zid_map: dict ({emle_zid: calculator_zid}) Map between EMLE and calculator zid values (for ANI2x backend). model_filename: str or None Filename to save the trained model. If None, the model is not saved. plot_data_filename: str or None Filename to write plotting data. If None, data is not written. device: torch.device Device to use for training. dtype: torch.dtype Data type to use for training. Default is torch.float64. Returns ------- dict Trained EMLE model. """ # Check input data. assert ( len(z) == len(xyz) == len(s) == len(q_core) == len(q_val) == len(alpha) ), "z, xyz, s, q_core, q, and alpha must have the same number of samples" # Checks for alpha_mode. if not isinstance(alpha_mode, str): raise TypeError("'alpha_mode' must be of type 'str'") alpha_mode = alpha_mode.lower().replace(" ", "") if alpha_mode not in ["species", "reference"]: raise ValueError("'alpha_mode' must be 'species' or 'reference'") if train_mask is None: train_mask = _torch.ones(len(z), dtype=_torch.bool) # Prepare batch data. q_val = _pad_to_max(q_val) q_core = _pad_to_max(q_core) q = q_core + q_val q_mol = _torch.sum(q, dim=1) z = _pad_to_max(z) xyz = _pad_to_max(xyz) q_core_train = q_core[train_mask] q_mol_train = q_mol[train_mask] q_train = q[train_mask] z_train = z[train_mask] xyz_train = xyz[train_mask] s_train = _pad_to_max(s)[train_mask] alpha_train = _pad_to_max(alpha)[train_mask] species = _torch.unique(_torch.tensor(z_train[z_train > 0], device=device)) # Place on the correct device and set the data type. q_mol = q_mol.to(device=device, dtype=dtype) q_mol_train = q_mol_train.to(device=device, dtype=dtype) z_train = z_train.to(device=device, dtype=_torch.int64) xyz_train = xyz_train.to(device=device, dtype=dtype) s_train = s_train.to(device=device, dtype=dtype) q_core_train = q_core_train.to(device=device, dtype=dtype) q_train = q_train.to(device=device, dtype=dtype) alpha_train = alpha_train.to(device=device, dtype=dtype) species = species.to(device=device, dtype=_torch.int64) # Get zid mapping. zid_mapping = self._get_zid_mapping(species) zid_train = zid_mapping[z_train] if computer_n_species is None: computer_n_species = len(species) # Calculate AEVs. emle_aev_computer = _EMLEAEVComputer( num_species=computer_n_species, zid_map=computer_zid_map, dtype=dtype, device=device, ) aev_mols = emle_aev_computer(zid_train, xyz_train) aev_mask = _torch.sum(aev_mols.reshape(-1, aev_mols.shape[-1]) ** 2, dim=0) > 0 aev_mols = aev_mols[:, :, aev_mask] emle_aev_computer = _EMLEAEVComputer( num_species=computer_n_species, zid_map=computer_zid_map, mask=aev_mask, dtype=dtype, device=device, ) # "Fit" q_core (just take averages over the entire training set). q_core_z = _mean_by_z(q_core_train, zid_train) _logger.info("Performing IVM...") # Create an array of (molecule_id, atom_id) pairs (as in the full # dataset) for the training set. This is needed to be able to locate # atoms/molecules in the original dataset that were picked by IVM. n_mols, max_atoms = q_train.shape atom_ids = _torch.stack( _torch.meshgrid(_torch.arange(n_mols), _torch.arange(max_atoms)), dim=-1 ).to(device) # Perform IVM. ivm_mol_atom_ids_padded, aev_ivm_allz = _IVM.perform_ivm( aev_mols, z_train, atom_ids, species, ivm_thr, sigma ) ref_features = _pad_to_max(aev_ivm_allz) ref_mask = ivm_mol_atom_ids_padded[:, :, 0] > -1 n_ref = _torch.sum(ref_mask, dim=1) _logger.info("IVM done. Number of reference environments selected:") for atom_z, n in zip(species, n_ref): _logger.info(f"{atom_z:2d}: {n:5d}") # Fit s (pure GPR, no fancy optimization needed). ref_values_s = self._train_s(s_train, zid_train, aev_mols, aev_ivm_allz, sigma) # Good for debugging # _torch.autograd.set_detect_anomaly(True) # Initial guess for the model parameters. params = { "a_QEq": _torch.tensor([1.0]).to(device=device, dtype=dtype), "a_Thole": _torch.tensor([2.0]).to(device=device, dtype=dtype), "ref_values_s": ref_values_s.to(device=device, dtype=dtype), "ref_values_chi": _torch.zeros( *ref_values_s.shape, dtype=ref_values_s.dtype, device=device, ), "k_Z": 0.5 * _torch.ones(len(species), dtype=dtype, device=_torch.device(device)), "sqrtk_ref": ( _torch.ones( *ref_values_s.shape, dtype=ref_values_s.dtype, device=_torch.device(device), ) if alpha_mode == "reference" else None ), } # Create the EMLE base instance. emle_base = self._emle_base( params=params, n_ref=n_ref, ref_features=ref_features, q_core=q_core_z, emle_aev_computer=emle_aev_computer, species=species, alpha_mode=alpha_mode, device=_torch.device(device), dtype=dtype, ) # Fit chi, a_QEq (QEq over chi predicted with GPR). _logger.info("Fitting a_QEq and chi values...") self._train_model( loss_class=self._qeq_loss, opt_param_names=["a_QEq", "ref_values_chi"], lr=lr_qeq, epochs=epochs, print_every=print_every, emle_base=emle_base, atomic_numbers=z_train, xyz=xyz_train, q_mol=q_mol_train, q_target=q_train, ) # Update GPR constants for chi # (now inconsistent since not updated after the last epoch) self._qeq_loss._update_chi_gpr(emle_base) _logger.debug(f"Optimized a_QEq: {emle_base.a_QEq.data.item()}") # Fit a_Thole, k_Z (uses volumes predicted by QEq model). _logger.info("Fitting a_Thole and k_Z values...") self._train_model( loss_class=self._thole_loss, opt_param_names=["a_Thole", "k_Z"], lr=lr_thole, epochs=epochs, print_every=print_every, emle_base=emle_base, atomic_numbers=z_train, xyz=xyz_train, q_mol=q_mol_train, alpha_mol_target=alpha_train, ) _logger.debug(f"Optimized a_Thole: {emle_base.a_Thole.data.item()}") # Fit sqrtk_ref ( alpha = sqrtk ** 2 * k_Z * v). if alpha_mode == "reference": _logger.info("Fitting ref_values_sqrtk values...") self._train_model( loss_class=self._thole_loss, opt_param_names=["ref_values_sqrtk"], lr=lr_sqrtk, epochs=epochs, print_every=print_every, emle_base=emle_base, atomic_numbers=z_train, xyz=xyz_train, q_mol=q_mol_train, alpha_mol_target=alpha_train, opt_sqrtk=True, l2_reg=20.0, ) # Update GPR constants for sqrtk # (now inconsistent since not updated after the last epoch) self._thole_loss._update_sqrtk_gpr(emle_base) # Create the final model. emle_model = { "q_core": q_core_z, "a_QEq": emle_base.a_QEq, "a_Thole": emle_base.a_Thole, "s_ref": emle_base.ref_values_s, "chi_ref": emle_base.ref_values_chi, "k_Z": emle_base.k_Z, "sqrtk_ref": ( emle_base.ref_values_sqrtk if alpha_mode == "reference" else None ), "species": species, "alpha_mode": alpha_mode, "n_ref": n_ref, "ref_aev": ref_features, "aev_mask": aev_mask, "zid_map": emle_aev_computer._zid_map, "computer_n_species": computer_n_species, } if model_filename is not None: self._write_model_to_file(emle_model, model_filename) if plot_data_filename is None: return emle_base emle_base._alpha_mode = "species" s_pred, q_core_pred, q_val_pred, A_thole = emle_base( z.to(device=device, dtype=_torch.int64), xyz.to(device=device, dtype=dtype), q_mol, ) z_mask = _torch.tensor(z > 0, device=device) plot_data = { "s_emle": s_pred, "q_core_emle": q_core_pred, "q_val_emle": q_val_pred, "alpha_species": self._thole_loss._get_alpha_mol(A_thole, z_mask), "z": z, "s_qm": s, "q_core_qm": q_core, "q_val_qm": q_val, "alpha_qm": alpha, } if alpha_mode == "reference": emle_base._alpha_mode = "reference" *_, A_thole = emle_base( z.to(device=device, dtype=_torch.int64), xyz.to(device=device, dtype=dtype), q_mol, ) plot_data["alpha_reference"] = self._thole_loss._get_alpha_mol( A_thole, z_mask ) self._write_model_to_file(plot_data, plot_data_filename) return emle_base