morphoclass.training.trainers module¶
A collection of model trainers.
-
class
morphoclass.training.trainers.ConcateCNNetTrainer(concatecnnet, dataset, images, labels, optimizer)¶ Bases:
objectTrainer class for the ConcateCNNet network.
- Parameters
concatecnnet (morphoclass.models.ConcateCNNet) – An instance of the concate-cnnet model.
dataset (morphoclass.data.MorphologyDataset) – The dataset with all morphologies (train and validation), the exact splits are specified in the train_split method.
images (iterable) – The dataset with all persistence images (train and validation), the exact splits are specified in the train_split method.
labels – All data labels.
optimizer (torch.optim.Optimizer) – an instance of a torch optimizer.
-
train_split(train_idx, val_idx, batch_size=32, n_epochs=500, verbose=False)¶ Train the concate-cnnet on a given split.
- Parameters
train_idx (iterable of int) – The train set indices.
val_idx (iterable of int) – The validation set indices.
batch_size (int, default 32) – The batch size.
n_epochs (int, default 500) – The number of epochs.
verbose (bool, default False) – If true print the training progress statistics on the stdout.
- Returns
probabilities – The history of predictions per epoch. Will have shape (n_epochs, n_val_samples, n_classes).
- Return type
np.ndarray
-
class
morphoclass.training.trainers.ConcateNetTrainer(concatenet, dataset, diagrams, labels, optimizer)¶ Bases:
objectTrainer class for the ConcateNet network.
- Parameters
concatenet (morphoclass.models.ConcateNet) – An instance of the concate-net model.
dataset (morphoclass.data.MorphologyDataset) – The dataset with all morphologies (train and validation), the exact splits are specified in the train_split method.
diagrams – All persistence diagrams (train and validation), the exact splits are specified in the train_split method.
labels – All data labels.
optimizer (torch.optim.Optimizer) – an instance of a torch optimizer.
-
static
perslay_collate_fn(samples)¶ Batch together persistence diagrams and their labels.
- Parameters
samples (iterable of tuple) – The sample for batching. Each element in the iterable is a tuple with a persistence diagram tensor and a label.
- Returns
diagram_batch (torch.Tensor) – A batch of persistence diagrams.
y_batch (torch.Tensor) – A batch of labels.
point_index (torch.Tensor) – Segmentation map for samples in the batch.
-
train_split(train_idx, val_idx, batch_size=32, n_epochs=500, verbose=False)¶ Train the concate-net on a given split.
- Parameters
train_idx (iterable of int) – The train set indices.
val_idx (iterable of int) – The validation set indices.
batch_size (int, default 32) – The batch size.
n_epochs (int, default 500) – The number of epochs.
verbose (bool, default False) – If true print the training progress statistics on the stdout.
- Returns
probabilities – The history of predictions per epoch. Will have shape (n_epochs, n_val_samples, n_classes).
- Return type
np.ndarray
-
class
morphoclass.training.trainers.Trainer(net: nn.Module, dataset: Dataset, optimizer: Optimizer, loader_class: type[DataLoader])¶ Bases:
objectA trainer for morphology classifiers.
- Parameters
net – A morphoclass classifier instance.
dataset – A morphology dataset.
optimizer – An optimizer instance.
loader_class – The morphology dataset loader class.
-
static
acc(labels: torch.Tensor, probas: torch.Tensor) → float¶ Compute the accuracy score given labels and probabilities.
-
data_loader(idx: Sequence[int] | None = None, batch_size: int = 1, shuffle: bool = False) → DataLoader¶ Construct a data loader for a given data subset.
-
get_latent_features(idx: Sequence[int] | None = None, batch_size: int = 1) → torch.Tensor¶ Compute the forward pass and collect the latent features.
- Parameters
idx – A sequence of indices specifying a data subset. If none then the whole dataset will be used.
batch_size – The batch size to use for the forward pass.
- Returns
A tensor of shape
(n_samples, *n_features)that contain the activations of the feature extractor part of the model.- Return type
torch.Tensor
-
predict(idx: Sequence[int] | None = None, batch_size: int = 1) → tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶ Run inference on a data subset, get losses, probabilities, and labels.
- Parameters
idx – An index subset of the dataset. If not given then the entire dataset is used.
batch_size – The batch size at which to process data. The results don’t depend on the batch size. However, bigger batch size mean faster inference, but a too big batch size might exhaust the memory.
- Returns
losses (torch.Tensor) – The non-reduced loss per sample.
logits (torch.Tensor) – The predicted logits.
labels (torch.Tensor) – All sample labels.
-
train(n_epochs: int, batch_size: int, train_idx: Sequence[int], val_idx: Sequence[int] | None = None, load_best: bool = False, progress_bar: Callable[[Iterable], Iterable] = <built-in function iter>) → dict¶ Train and evaluate on dataset subsets specified by indices.
- Parameters
train_idx (array of int) – The indices of the training samples.
val_idx (array of int) – The indices of the evaluation samples.
batch_size (int) – The batch size.
n_epochs (int) – The number of epochs.
load_best (bool) – If true the model with the best validation accuracy will be restored at the end of training. Only possible if val_idx is not None.
save_latent_features (bool) – If true the latent features of the feature extractor part of the model will be saved under history[“latent_features”].
progress_bar (callable) – A callable that wraps an iterable over the epoch numbers and creates a progress bar.
- Returns
history – Dictionary with predictions, probabilities, training and validation losses and accuracies.
- Return type
dict