morphoclass.models.cnnet module¶
Classification of persistence images using a CNN.
This module includes a convolutional model and a corresponding trainer class.
-
class
morphoclass.models.cnnet.CNNEmbedder¶ Bases:
torch.nn.modules.module.ModuleThe embedder part of the CNNet classifier.
-
forward(data)¶ Do the forward pass.
- Parameters
data (torch_geometric.data.Batch | torch.Tensor) – A batch of the MorphologyDataset dataset.
- Returns
x – The embedding of the persistence image. The dimension of the tensor is (n_batch, 3 * (image_size // 4)**2).
- Return type
torch.Tensor
-
training: bool¶
-
-
class
morphoclass.models.cnnet.CNNet(n_classes, image_size=100, bn=False)¶ Bases:
torch.nn.modules.module.ModuleConvolutional net for classifying persistence images.
The provided persistence images should be square greyscale images and can be obtained from persistence diagrams by applying Gaussian KDE.
- Parameters
n_classes (int) – The number of classes.
image_size (int) – The width or height of the input images. The images are assumed to be square.
bn (bool) – If true then 1d batch normalization and a relu are applied to the flattened embeddings before the FC layer.
-
forward(data)¶ Do the forward pass.
- Parameters
data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.
- Returns
logits – The predicted logits of shape (n_images, n_classes)
- Return type
torch.Tensor
-
loss_acc(data)¶ Get loss and accuracy.
- Parameters
data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.
- Returns
loss (float) – The loss value.
acc (float) – The accuracy value.
-
training: bool¶