emle.train

The train module contains functionality for training EMLE models. The EMLETrainer class is used internally by the emle-train script but can also be used directly in Python.

class emle.train.EMLETrainer(emle_base=<class 'emle.models._emle_base.EMLEBase'>, qeq_loss=<class 'emle.train._loss.QEqLoss'>, thole_loss=<class 'emle.train._loss.TholeLoss'>, log_level=None, log_file=None)[source]
train(z, xyz, s, q_core, q_val, alpha, train_mask, alpha_mode='reference', sigma=0.001, ivm_thr=0.05, epochs=100, lr_qeq=0.05, lr_thole=0.05, lr_sqrtk=0.05, print_every=10, computer_n_species=None, computer_zid_map=None, model_filename='emle_model.mat', plot_data_filename=None, device=device(type='cuda'), dtype=torch.float64)[source]

Train an EMLE model.

Parameters:
  • z (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS)) – Atomic numbers.

  • xyz (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOM, 3) – Atomic coordinates.

  • s (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS)) – Atomic widths.

  • q_core (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS)) – Atomic core charges.

  • q_val (array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS)) – Atomic valence charges.

  • alpha (array or tensor or list of tensor/arrays of shape (N_BATCH, 3, 3)) – Atomic polarizabilities.

  • train_mask (torch.Tensor(N_BATCH,)) – Mask for training samples.

  • alpha_mode ('species' or 'reference') – Mode for polarizability model.

  • sigma (float) – GPR sigma value.

  • ivm_thr (float) – IVM threshold.

  • epochs (int) – Number of training epochs.

  • lr_qeq (float) – Learning rate for QEq model.

  • lr_thole (float) – Learning rate for Thole model.

  • lr_sqrtk (float) – Learning rate for sqrtk.

  • print_every (int) – How often to print training progress.

  • computer_n_species (int) – Number of species supported by calculator (for ani2x backend)

  • computer_zid_map (dict ({emle_zid: calculator_zid})) – Map between EMLE and calculator zid values (for ANI2x backend).

  • model_filename (str or None) – Filename to save the trained model. If None, the model is not saved.

  • plot_data_filename (str or None) – Filename to write plotting data. If None, data is not written.

  • device (torch.device) – Device to use for training.

  • dtype (torch.dtype) – Data type to use for training. Default is torch.float64.

Returns:

Trained EMLE model.

Return type:

dict