morphoclass.training package¶
Subpackages¶
Submodules¶
- morphoclass.training.cli module
- morphoclass.training.rd_dataset_prep module
- morphoclass.training.tns_utils module
- morphoclass.training.trainers module
- morphoclass.training.training_config module
- morphoclass.training.training_log module
- morphoclass.training.training_loop_cv module
- morphoclass.training.training_loop_loo module
- morphoclass.training.training_loop_regression module
- morphoclass.training.training_loop_simple module
- morphoclass.training.transfer_learning module
Module contents¶
Utilities for setting up and running model training.
-
morphoclass.training.create_k_folds(n_splits, labels, seed)¶ Generate indices for a stratified k-fold split.
For a given seed the generated results will always be the same.
While the validation indices might be sorted, the training indices are randomly permuted. Therefore it is not necessary to shuffle the samples at training time any more.
- Parameters
n_splits (int) – The number of splits in the k-fold
labels (list_like) – The labels upon which to generate splits.
seed (int) – A random seed to ensure determinism.
- Returns
folds – A list of length n_splits containing tuples of the form (train_idx, val_idx).
- Return type
list
-
morphoclass.training.make_transform(dataset, feature_extractor=None, n_features=None, fitted_scaler=None, scaler_cls='FeatureRobustScaler', edge_weight=None)¶ Create a transform for a given dataset.
The transform that is created is of the following form:
MakeCopy(‘y’, ‘y_str’, ‘edge_index’, ‘tmd_neurites’) FeatureExtractor Scaler EdgeWeight (optional)
- Parameters
dataset (morphoclass.data.MorphologyDataset) – A dataset.
feature_extractor – A node feature extractor from morphoclass.transforms.
n_features (int or list or tuple) – The number of features feature_extractor extracts. Must be provided if feature_extractor is not None.
fitted_scaler – The feature scaler. If None then a new scaler will be fitted. Typical use case: use None for the training set to fit a new scaler, then re-used this scaler for the validation set.
scaler_cls (str) – Only used of fitted_scaler is None. In this case a new scaler of this class will be fitted.
edge_weight (int or None) – If not None then an additional edge weight transform will be included in the overall transform.
- Returns
transform – The overall transform.
fitted_scaler – The fitted scaler that is part of the overall transform. Useful when it need to be re-used, for example for a validation set.
-
morphoclass.training.prepare_rd_split(dataset, train_idx, val_idx, feature_extractor=None, n_features=None, scaler_cls='FeatureRobustScaler', edge_weight=400)¶ Split a dataset into train/val sets and set up radial distance features.
- Parameters
dataset (morphoclass.data.MorphologyDataset) – The full dataset with train and validation data.
train_idx (iterable) – The indices of the training-subset.
val_idx (iterable) – The indices of the validation-subset.
feature_extractor – A node feature extractor from morphoclass.transforms.
n_features (int or list or tuple) – The number of features feature_extractor extracts. Must be provided if feature_extractor is not None.
scaler_cls (str) – The name of the scaler class to use. Must be a scaler class available in morphoclass.transforms.scalers.
edge_weight (int or None) – If not None then an additional edge weight transform will be included in the overall transform.
- Returns
dataset_train (morphoclass.data.MorphologyDataset) – The training-subset.
dataset_val (morphoclass.data.MorphologyDataset) – The validation-subset.
-
morphoclass.training.prepare_rd_transforms(dataset_train, dataset_val=None, feature_extractor=None, n_features=None, scaler_cls='FeatureRobustScaler', edge_weight=None)¶ Prepare radial distance transforms for the given datasets.
- Parameters
dataset_train (morphoclass.data.MorphologyDataset) – The training set.
dataset_val (morphoclass.data.MorphologyDataset, optional) – The validation set.
feature_extractor – A node feature extractor from morphoclass.transforms.
n_features (int or list or tuple) – The number of features feature_extractor extracts. Must be provided if feature_extractor is not None.
scaler_cls (str) – The name of the scaler class to use. Must be a scaler class available in morphoclass.transforms.scalers.
edge_weight (int or None) – If not None then an additional edge weight transform will be included in the overall transform.
- Returns
If dataset_val is None the only the dataset_train is returned with the appropriate feature extractor attached to it. Otherwise a tuple with (dataset_train, dataset_val) is returned, with the dataset_val using a feature scaler that was fitted on the training data.
- Return type
datasets
-
morphoclass.training.prepare_smart_split(dataset, train_idx, val_idx, scaler_cls='FeatureRobustScaler', edge_weight=None)¶ Prepare a dataset split given the indices.
The split is prepared by constructing subsets using the train and validation indices, and by setting up transforms that extract the correct features.
We have to distinguish between datasets with IPCs/HPSs and other datasets. The former have to use projections of the coordinates onto the y-axis as features, the latter the usual radial distances.
In other words, L2 and L6 have to use projections, the layers L3, L4, and L5 radial distances.
Note that it is assumed that the apicals are already correctly oriented, otherwise the projection features won’t work correctly. The best way to do this is to include transforms.OrientApicals in the pre_transform of the MorphologyDataset
- Parameters
dataset (morphoclass.data.MorphologyDataset) – The dataset to draw the samples from.
train_idx (list_like) – The indices for the training set.
val_idx (list_like) – The indices for the validation set.
scaler_cls (str, optional) – The name of the scaler class to be used for scaling features.
edge_weight (int, optional) – The scale for the edge weight feature to be extracted.
- Returns
dataset_train (MorphologyDataset) – The training subset of the dataset
dataset_val (MorphologyDataset) – The validation subset of the dataset
-
morphoclass.training.reset_seeds(numpy_seed=0, torch_seed=0)¶ Reset random seeds for numpy and torch.
- Parameters
numpy_seed (int or None) – The random seed for numpy. If None then the seed won’t be set.
torch_seed (int or None) – The random seed for torch. If None then the seed won’t be set.
-
morphoclass.training.train_model(model_fn, dataset_train, dataset_val, n_epochs=500, batch_size=None, acc_threshold=0.0, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, silent=False)¶ Train a model with cross-validation and save loss and accuracy.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset_train (pytorch_geometric.data.Data) – The training set.
dataset_val (pytorch_geometric.data.Data) – The validation set.
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
acc_threshold (float) – Start checkpointing the best performing model once the accuracy surpasses this value. This is useful to avoid too many checkpoints at the beginning when the accuracy improves at almost every step.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
silent (bool) – If true the training progress will be logged to stdout.
- Returns
history – History of training and validation losses and accuracies.
- Return type
dict
-
morphoclass.training.train_model_cv(model_fn, dataset, dataset_prep_fn, n_epochs=500, batch_size=None, n_splits=5, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, cv_random_state=None)¶ Train a model with cross-validation and save loss and accuracy.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset (pytorch_geometric.data.Data) – The full dataset.
dataset_prep_fn (callable) – A function for splitting and pre-processing the dataset. Should have the signature (dataset, train_idx, val_idx) and return a tuple of the form (dataset_train, dataset_val).
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
n_splits (int) – The number of cross-validation splits.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
cv_random_state (int) – A random seed for the stratified K-fold split.
- Returns
history – History of training and validation losses and accuracies.
- Return type
dict
-
morphoclass.training.train_model_loo(model_fn, dataset, dataset_prep_fn, n_classes, n_epochs=500, batch_size=None, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0)¶ Train a model with leave-one-out and save the training history.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset (pytorch_geometric.data.Data) – The full dataset.
dataset_prep_fn (callable) – A function for splitting and pre-processing the dataset. Should have the signature (dataset, train_idx, val_idx) and return a tuple of the form (dataset_train, dataset_val).
n_classes (int) – The number of classes to predict. Needed ot allocate the prediction array for the history cache.
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
- Returns
history – History of training and validation losses, accuracies, prediction probabilities, and predictions.
- Return type
dict
-
morphoclass.training.train_regression_model(model_fn, dataset_train, dataset_val, n_epochs=500, batch_size=None, lr=0.0005, wd=0.005, optimizer_cls=<class 'torch.optim.adam.Adam'>, device=None, pin_memory=False, n_workers_train=0, n_workers_val=0, silent=False)¶ Train a model with leave-one-out and save the training history.
- Parameters
model_fn (callable) – Factory function returning a model instance.
dataset_train (pytorch_geometric.data.Data) – The training set.
dataset_val (pytorch_geometric.data.Data) – The validation set.
n_epochs (int) – The number of epochs.
batch_size (int) – The batch size.
lr (float) – The value of the learning rate.
wd (float) – The value of the weight decay.
optimizer_cls (callable) – Class for instantiating a PyTorch optimizer.
device (str or torch.device) – The device for training.
pin_memory (bool) – Passed through to the MorphologyDataLoader class for the train and validation data loader.
n_workers_train (int) – The number of workers for training.
n_workers_val (int) – The number of workers for the validation.
silent (bool, default False) – If true print training and validation losses to stdout every 50 epochs.
- Returns
model – The trained model.
history (dict) – History of training losses.
-
morphoclass.training.transfer_learning_curves(input_csv, features_dir, image_path, checkpoints_directory, model_class, model_params, dataset_name, feature_extractor_name, optimizer_class, optimizer_params, n_epochs, batch_size, seed, checkpoint_path_pretrained, frozen_backbone=False)¶ Generate TL curves by iterating through splits of different sizes.
-
morphoclass.training.transfer_learning_report(results_file, checkpoint_paths_pretrained, input_csv, features_dir, dataset_name, feature_extractor_name, model_class, model_params, optimizer_class, optimizer_params, n_epochs, batch_size, seed)¶ Generate the TL report.