morphoclass.training.training_loop_cv module¶
Functions for model training with cross-validation.
-
morphoclass.training.training_loop_cv.train_model_cv(model_fn, dataset, dataset_prep_fn, n_epochs=500, batch_size=None, n_splits=5, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, cv_random_state=None)¶ Train a model with cross-validation and save loss and accuracy.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset (pytorch_geometric.data.Data) – The full dataset.
dataset_prep_fn (callable) – A function for splitting and pre-processing the dataset. Should have the signature (dataset, train_idx, val_idx) and return a tuple of the form (dataset_train, dataset_val).
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
n_splits (int) – The number of cross-validation splits.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
cv_random_state (int) – A random seed for the stratified K-fold split.
- Returns
history – History of training and validation losses and accuracies.
- Return type
dict