morphoclass.training.training_loop_regression module¶
Functions for running training of a regression model.
-
morphoclass.training.training_loop_regression.get_mse_loss(model, loader, device)¶ Compute the mean squared error loss.
- Parameters
model – Model for the forward pass
loader – A data loader for all data for the loss.
device – A torch device.
- Returns
The mean squared error loss on all data in the loader.
- Return type
float
-
morphoclass.training.training_loop_regression.train_regression_model(model_fn, dataset_train, dataset_val, n_epochs=500, batch_size=None, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, silent=False)¶ Train a model with leave-one-out and save the training history.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset_train (pytorch_geometric.data.Data) – The training set.
dataset_val (pytorch_geometric.data.Data) – The validation set.
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
silent (bool, default False) – If true print training and validation losses to stdout every 50 epochs.
- Returns
model – The trained model.
history (dict) – History of training losses.