morphoclass.training.training_loop_simple module

Functions for simple model training without cross-validation.

morphoclass.training.training_loop_simple.train_model(model_fn, dataset_train, dataset_val, n_epochs=500, batch_size=None, acc_threshold=0.0, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, silent=False)

Train a model with cross-validation and save loss and accuracy.

Parameters
  • model_fn (callable) – Factory function returning a model instance.

  • dataset_train (pytorch_geometric.data.Data) – The training set.

  • dataset_val (pytorch_geometric.data.Data) – The validation set.

  • n_epochs (int) – The number of epochs.

  • batch_size (int) – The batch size.

  • acc_threshold (float) – Start checkpointing the best performing model once the accuracy surpasses this value. This is useful to avoid too many checkpoints at the beginning when the accuracy improves at almost every step.

  • lr (float) – The value of the learning rate.

  • wd (float) – The value of the weight decay.

  • optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.

  • device (str or torch.device) – The device for training.

  • pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.

  • n_workers_train (int) – The number of workers for training.

  • n_workers_val (int) – The number of workers for the validation.

  • silent (bool) – If true the training progress will be logged to stdout.

Returns

history – History of training and validation losses and accuracies.

Return type

dict