morphoclass.training.training_loop_regression module

Functions for running training of a regression model.

morphoclass.training.training_loop_regression.get_mse_loss(model, loader, device)

Compute the mean squared error loss.

Parameters
  • model – Model for the forward pass

  • loader – A data loader for all data for the loss.

  • device – A torch device.

Returns

The mean squared error loss on all data in the loader.

Return type

float

morphoclass.training.training_loop_regression.train_regression_model(model_fn, dataset_train, dataset_val, n_epochs=500, batch_size=None, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, silent=False)

Train a model with leave-one-out and save the training history.

Parameters
  • model_fn (callable) – Factory function returning a model instance.

  • dataset_train (pytorch_geometric.data.Data) – The training set.

  • dataset_val (pytorch_geometric.data.Data) – The validation set.

  • n_epochs (int) – The number of epochs.

  • batch_size (int) – The batch size.

  • lr (float) – The value of the learning rate.

  • wd (float) – The value of the weight decay.

  • optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.

  • device (str or torch.device) – The device for training.

  • pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.

  • n_workers_train (int) – The number of workers for training.

  • n_workers_val (int) – The number of workers for the validation.

  • silent (bool, default False) – If true print training and validation losses to stdout every 50 epochs.

Returns

  • model – The trained model.

  • history (dict) – History of training losses.