morphoclass.models.cnnet module

Classification of persistence images using a CNN.

This module includes a convolutional model and a corresponding trainer class.

class morphoclass.models.cnnet.CNNEmbedder

Bases: torch.nn.modules.module.Module

The embedder part of the CNNet classifier.

forward(data)

Do the forward pass.

Parameters

data (torch_geometric.data.Batch | torch.Tensor) – A batch of the MorphologyDataset dataset.

Returns

x – The embedding of the persistence image. The dimension of the tensor is (n_batch, 3 * (image_size // 4)**2).

Return type

torch.Tensor

training: bool
class morphoclass.models.cnnet.CNNet(n_classes, image_size=100, bn=False)

Bases: torch.nn.modules.module.Module

Convolutional net for classifying persistence images.

The provided persistence images should be square greyscale images and can be obtained from persistence diagrams by applying Gaussian KDE.

Parameters
  • n_classes (int) – The number of classes.

  • image_size (int) – The width or height of the input images. The images are assumed to be square.

  • bn (bool) – If true then 1d batch normalization and a relu are applied to the flattened embeddings before the FC layer.

forward(data)

Do the forward pass.

Parameters

data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.

Returns

logits – The predicted logits of shape (n_images, n_classes)

Return type

torch.Tensor

loss_acc(data)

Get loss and accuracy.

Parameters

data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.

Returns

  • loss (float) – The loss value.

  • acc (float) – The accuracy value.

training: bool