emle.train¶
The train module contains functionality for training
EMLE models. The EMLETrainer
class is used internally by the emle-train
script but can also be used directly in Python.
- class emle.train.EMLETrainer(emle_base=<class 'emle.models._emle_base.EMLEBase'>, qeq_loss=<class 'emle.train._loss.QEqLoss'>, thole_loss=<class 'emle.train._loss.TholeLoss'>, log_level=None, log_file=None)[source]¶
- train(z, xyz, s, q_core, q_val, alpha, train_mask, alpha_mode='reference', sigma=0.001, ivm_thr=0.05, epochs=100, lr_qeq=0.05, lr_thole=0.05, lr_sqrtk=0.05, print_every=10, computer_n_species=None, computer_zid_map=None, model_filename='emle_model.mat', plot_data_filename=None, device=device(type='cuda'), dtype=torch.float64)[source]¶
Train an EMLE model.
- Parameters:
z (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS)) – Atomic numbers.
xyz (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOM, 3) – Atomic coordinates.
s (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS)) – Atomic widths.
q_core (numpy.array, List[numpy.array], torch.Tensor, List[torch.Tensor] (N_BATCH, N_ATOMS)) – Atomic core charges.
q_val (array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS)) – Atomic valence charges.
alpha (array or tensor or list of tensor/arrays of shape (N_BATCH, 3, 3)) – Atomic polarizabilities.
train_mask (torch.Tensor(N_BATCH,)) – Mask for training samples.
alpha_mode ('species' or 'reference') – Mode for polarizability model.
sigma (float) – GPR sigma value.
ivm_thr (float) – IVM threshold.
epochs (int) – Number of training epochs.
lr_qeq (float) – Learning rate for QEq model.
lr_thole (float) – Learning rate for Thole model.
lr_sqrtk (float) – Learning rate for sqrtk.
print_every (int) – How often to print training progress.
computer_n_species (int) – Number of species supported by calculator (for ani2x backend)
computer_zid_map (dict ({emle_zid: calculator_zid})) – Map between EMLE and calculator zid values (for ANI2x backend).
model_filename (str or None) – Filename to save the trained model. If None, the model is not saved.
plot_data_filename (str or None) – Filename to write plotting data. If None, data is not written.
device (torch.device) – Device to use for training.
dtype (torch.dtype) – Data type to use for training. Default is torch.float64.
- Returns:
Trained EMLE model.
- Return type:
dict