morphoclass.training.training_loop_cv module

Functions for model training with cross-validation.

morphoclass.training.training_loop_cv.train_model_cv(model_fn, dataset, dataset_prep_fn, n_epochs=500, batch_size=None, n_splits=5, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, cv_random_state=None)

Train a model with cross-validation and save loss and accuracy.

Parameters
  • model_fn (callable) – Factory function returning a model instance.

  • dataset (pytorch_geometric.data.Data) – The full dataset.

  • dataset_prep_fn (callable) – A function for splitting and pre-processing the dataset. Should have the signature (dataset, train_idx, val_idx) and return a tuple of the form (dataset_train, dataset_val).

  • n_epochs (int) – The number of epochs.

  • batch_size (int) – The batch size.

  • n_splits (int) – The number of cross-validation splits.

  • lr (float) – The value of the learning rate.

  • wd (float) – The value of the weight decay.

  • optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.

  • device (str or torch.device) – The device for training.

  • pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.

  • n_workers_train (int) – The number of workers for training.

  • n_workers_val (int) – The number of workers for the validation.

  • cv_random_state (int) – A random seed for the stratified K-fold split.

Returns

history – History of training and validation losses and accuracies.

Return type

dict