morphoclass.training.trainers module

A collection of model trainers.

class morphoclass.training.trainers.ConcateCNNetTrainer(concatecnnet, dataset, images, labels, optimizer)

Bases: object

Trainer class for the ConcateCNNet network.

Parameters
  • concatecnnet (morphoclass.models.ConcateCNNet) – An instance of the concate-cnnet model.

  • dataset (morphoclass.data.MorphologyDataset) – The dataset with all morphologies (train and validation), the exact splits are specified in the train_split method.

  • images (iterable) – The dataset with all persistence images (train and validation), the exact splits are specified in the train_split method.

  • labels – All data labels.

  • optimizer (torch.optim.Optimizer) – an instance of a torch optimizer.

train_split(train_idx, val_idx, batch_size=32, n_epochs=500, verbose=False)

Train the concate-cnnet on a given split.

Parameters
  • train_idx (iterable of int) – The train set indices.

  • val_idx (iterable of int) – The validation set indices.

  • batch_size (int, default 32) – The batch size.

  • n_epochs (int, default 500) – The number of epochs.

  • verbose (bool, default False) – If true print the training progress statistics on the stdout.

Returns

probabilities – The history of predictions per epoch. Will have shape (n_epochs, n_val_samples, n_classes).

Return type

np.ndarray

class morphoclass.training.trainers.ConcateNetTrainer(concatenet, dataset, diagrams, labels, optimizer)

Bases: object

Trainer class for the ConcateNet network.

Parameters
  • concatenet (morphoclass.models.ConcateNet) – An instance of the concate-net model.

  • dataset (morphoclass.data.MorphologyDataset) – The dataset with all morphologies (train and validation), the exact splits are specified in the train_split method.

  • diagrams – All persistence diagrams (train and validation), the exact splits are specified in the train_split method.

  • labels – All data labels.

  • optimizer (torch.optim.Optimizer) – an instance of a torch optimizer.

static perslay_collate_fn(samples)

Batch together persistence diagrams and their labels.

Parameters

samples (iterable of tuple) – The sample for batching. Each element in the iterable is a tuple with a persistence diagram tensor and a label.

Returns

  • diagram_batch (torch.Tensor) – A batch of persistence diagrams.

  • y_batch (torch.Tensor) – A batch of labels.

  • point_index (torch.Tensor) – Segmentation map for samples in the batch.

train_split(train_idx, val_idx, batch_size=32, n_epochs=500, verbose=False)

Train the concate-net on a given split.

Parameters
  • train_idx (iterable of int) – The train set indices.

  • val_idx (iterable of int) – The validation set indices.

  • batch_size (int, default 32) – The batch size.

  • n_epochs (int, default 500) – The number of epochs.

  • verbose (bool, default False) – If true print the training progress statistics on the stdout.

Returns

probabilities – The history of predictions per epoch. Will have shape (n_epochs, n_val_samples, n_classes).

Return type

np.ndarray

class morphoclass.training.trainers.Trainer(net: nn.Module, dataset: Dataset, optimizer: Optimizer, loader_class: type[DataLoader])

Bases: object

A trainer for morphology classifiers.

Parameters
  • net – A morphoclass classifier instance.

  • dataset – A morphology dataset.

  • optimizer – An optimizer instance.

  • loader_class – The morphology dataset loader class.

static acc(labels: torch.Tensor, probas: torch.Tensor)float

Compute the accuracy score given labels and probabilities.

data_loader(idx: Sequence[int] | None = None, batch_size: int = 1, shuffle: bool = False)DataLoader

Construct a data loader for a given data subset.

get_latent_features(idx: Sequence[int] | None = None, batch_size: int = 1)torch.Tensor

Compute the forward pass and collect the latent features.

Parameters
  • idx – A sequence of indices specifying a data subset. If none then the whole dataset will be used.

  • batch_size – The batch size to use for the forward pass.

Returns

A tensor of shape (n_samples, *n_features) that contain the activations of the feature extractor part of the model.

Return type

torch.Tensor

predict(idx: Sequence[int] | None = None, batch_size: int = 1)tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Run inference on a data subset, get losses, probabilities, and labels.

Parameters
  • idx – An index subset of the dataset. If not given then the entire dataset is used.

  • batch_size – The batch size at which to process data. The results don’t depend on the batch size. However, bigger batch size mean faster inference, but a too big batch size might exhaust the memory.

Returns

  • losses (torch.Tensor) – The non-reduced loss per sample.

  • logits (torch.Tensor) – The predicted logits.

  • labels (torch.Tensor) – All sample labels.

train(n_epochs: int, batch_size: int, train_idx: Sequence[int], val_idx: Sequence[int] | None = None, load_best: bool = False, progress_bar: Callable[[Iterable], Iterable] = <built-in function iter>)dict

Train and evaluate on dataset subsets specified by indices.

Parameters
  • train_idx (array of int) – The indices of the training samples.

  • val_idx (array of int) – The indices of the evaluation samples.

  • batch_size (int) – The batch size.

  • n_epochs (int) – The number of epochs.

  • load_best (bool) – If true the model with the best validation accuracy will be restored at the end of training. Only possible if val_idx is not None.

  • save_latent_features (bool) – If true the latent features of the feature extractor part of the model will be saved under history[“latent_features”].

  • progress_bar (callable) – A callable that wraps an iterable over the epoch numbers and creates a progress bar.

Returns

history – Dictionary with predictions, probabilities, training and validation losses and accuracies.

Return type

dict