morphoclass.training.training_loop_loo module¶
Functions for model training with the leave-one-out strategy.
-
morphoclass.training.training_loop_loo.train_model_loo(model_fn, dataset, dataset_prep_fn, n_classes, n_epochs=500, batch_size=None, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0)¶ Train a model with leave-one-out and save the training history.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset (pytorch_geometric.data.Data) – The full dataset.
dataset_prep_fn (callable) – A function for splitting and pre-processing the dataset. Should have the signature (dataset, train_idx, val_idx) and return a tuple of the form (dataset_train, dataset_val).
n_classes (int) – The number of classes to predict. Needed ot allocate the prediction array for the history cache.
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
- Returns
history – History of training and validation losses, accuracies, prediction probabilities, and predictions.
- Return type
dict