morphoclass.models package

Module contents

Various models for morphology type classification of neurons.

class morphoclass.models.BidirectionalNet(num_classes, num_nodes_features)

Bases: torch.nn.modules.module.Module

Model for classifying morphologies of pyramidal neurons.

This is the architecture that performed best in the TensorFlow implementation. It consists of two graph convolutions layers with each computing two convolutions: one on the directed adjacency matrix, and one with the adjacency matrix with the reversed direction. The results of both convolutions are concatenated and passed to the next layer. After the two parallel graph convolutions follows a global average pooling layer and a fully-connected layer. Finally, a softmax layer is used for prediction.

Parameters
  • num_classes (int) – The number of output classes.

  • num_nodes_features (int) – The number of input node features.

accuracy(data)

Run the forward pass and compute the accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

acc – The accuracy on the current data batch.

Return type

float

forward(data)

Compute the forward pass.

Parameters

data (torch_geometric.data.data.Data) – A batch of input data.

Returns

The log softmax of the predictions.

Return type

log_softmax

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.CNNEmbedder

Bases: torch.nn.modules.module.Module

The embedder part of the CNNet classifier.

forward(data)

Do the forward pass.

Parameters

data (torch_geometric.data.Batch | torch.Tensor) – A batch of the MorphologyDataset dataset.

Returns

x – The embedding of the persistence image. The dimension of the tensor is (n_batch, 3 * (image_size // 4)**2).

Return type

torch.Tensor

training: bool
class morphoclass.models.CNNet(n_classes, image_size=100, bn=False)

Bases: torch.nn.modules.module.Module

Convolutional net for classifying persistence images.

The provided persistence images should be square greyscale images and can be obtained from persistence diagrams by applying Gaussian KDE.

Parameters
  • n_classes (int) – The number of classes.

  • image_size (int) – The width or height of the input images. The images are assumed to be square.

  • bn (bool) – If true then 1d batch normalization and a relu are applied to the flattened embeddings before the FC layer.

forward(data)

Do the forward pass.

Parameters

data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.

Returns

logits – The predicted logits of shape (n_images, n_classes)

Return type

torch.Tensor

loss_acc(data)

Get loss and accuracy.

Parameters

data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.

Returns

  • loss (float) – The loss value.

  • acc (float) – The accuracy value.

training: bool
class morphoclass.models.ConcateCNNet(n_node_features, n_classes, image_size, bn=False)

Bases: torch.nn.modules.module.Module

A neuron m-type classifier based on graph and image convolutions.

In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the CNN layers are applied to the persistence image representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.

Parameters
  • n_node_features (int) – The number of input node features for the GNN layers.

  • n_classes (int) – The number of output classes.

  • image_size (int) – The width (or height) of the input persistence images. It is assumed that the images are square so that the width and height are equal.

  • bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.

forward(data, images)

Compute the forward pass.

Parameters
  • data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.

  • images – A batch of input persistence images for the CNN layers.

Returns

The log softmax of the predictions.

Return type

log_softmax

training: bool
class morphoclass.models.ConcateNet(n_node_features, n_classes, n_features_perslay, bn=False)

Bases: torch.nn.modules.module.Module

A neuron m-type classifier based on graph convolutions and PersLay.

In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the PersLay layer is applied to the persistence diagram representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.

Parameters
  • n_node_features (int) – The number of input node features for the GNN layers.

  • n_classes (int) – The number of output classes.

  • n_features_perslay (int) – The number of features for the PersLay layer.

  • bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.

forward(data, diagrams, point_index)

Compute the forward pass.

Parameters
  • data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.

  • diagrams – A batch of input persistence diagrams for the PersLay layer.

  • point_index (torch.Tensor) – A one-dimensional integer tensor holding the segmentation map for samples in the batched data, e.g. tensor([0, 0, 1, 1, 1, 2, …]).

Returns

The log softmax of the predictions.

Return type

log_softmax

training: bool
class morphoclass.models.CorianderNet(n_classes=4, n_features=64, dropout=False)

Bases: torch.nn.modules.module.Module

A PersLay-based neural network for neuron m-type classification.

Parameters
  • n_classes (int) – The number of m-type classes to predict.

  • n_features (int) – The number of output feature maps for the PersLay layer.

  • dropout (bool, default False) – If true a dropout layer is inserted between the two fully-connected layers of the classifier part of the network.

forward(data)

Compute the forward pass.

Parameters

data (torch_geometric.data.Batch | torch.Tensor) – A batch of MorphologyDataset dataset.

Returns

The log softmax of the predictions.

Return type

log_softmax

loss_acc(data)

Get loss and accuracy.

Parameters

data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.

Returns

  • loss (float) – The loss value.

  • acc (float) – The accuracy value.

training: bool
class morphoclass.models.HBNet(num_features, class_segmentation_mask)

Bases: torch.nn.modules.module.Module

Hierarchical, Bidirectional Net.

Parameters
  • num_features – the number of input features

  • class_segmentation_mask – mask for hierarchical labels where each softmax block is marked by a different integer. The class mask can be generated by the method get_total_class_mask() of instances of the model_utils.HierarchicalLabels class.

compute_hierarchical_metric(metric_function, data, hl)

Compute metric for each layer in the label hierarchy.

Parameters
  • metric_function – A function computing the required metric. Should have the following signature metric_function(targets, predictions).

  • data – A batch of samples.

  • hl (model_utils.HierarchicalLabels) – Hierarchy structure.

Returns

List of metric evaluations for each layer in the label hierarchy.

Return type

results

forward(data)

Compute the forward pass.

Parameters

data – The input data.

Returns

The log softmax of the predictions.

Return type

log_softmax

gen_hierarchical_probabilities(probabilities, parent_mask, hl)

Descend the hierarchy layer-wise and compute probabilities.

Starting with the parent nodes specified in parent_mask compute the probabilities for their children using

P(child) = P(child|parent) * P(parent),

create a new mask containing all children, and recurse.

Parameters
  • probabilities – Probabilities for all nodes in the hierarchy tree.

  • parent_mask – Starting mask for the nodes from which to start descending the hierarchy. First call of the function should have parent_mask contain all root nodes.

  • hl (model_utils.HierarchicalLabels) – Hierarchy structure for labels.

Yields
  • mask – Mask for all nodes considered for the current layer in the hierarchy.

  • probabilities – Probabilities predicted for the nodes specified by by mask.

hierarchical_accuracies(data, hl)

Accuracies for each layer in the label hierarchy.

Parameters
Returns

List of accuracies for each layer in the label hierarchy.

Return type

accuracies

loss(data)

Compute the loss.

Parameters

data – The input data.

Returns

Return type

The loss.

precision_recall_f1(data, hl, *, average='micro')

Compute hierarchical precision, recall, F1 score.

Parameters
  • data – Abatch of samples.

  • hl (model_utils.HierarchicalLabels) – Hierarchy structure.

  • average – The type of multi-class average to take, accepted values are None, “micro”, and “macro”.

Returns

  • p_h – The hierarchical precision.

  • r_h – The hierarchical recall.

  • f1_h – The hierarchical F1 score

predict(data)

Make hierarchical predictions for given data.

Parameters

data – A batch of samples.

Returns

  • val_max – Probabilities for the predicted nodes in the hierarchy tree.

  • val_argmax – Indices for the predicted nodes in the hierarchy tree. For example, hierarchical one-hot prediction [1, 0, 0, 1, 0] would correspond to val_argmax = [0, 3].

predict_probabilities(data)

Compute the prediction probabilities.

Parameters

data – The input data.

Returns

Return type

The prediction probabilities.

training: bool
class morphoclass.models.ManEmbedder(n_features=1, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None)

Bases: torch.nn.modules.module.Module

The embedder for the ManNet network.

The embedder consists of two bidirectional ChebConv blocks followed by a global pooling layer.

Parameters
  • n_features (int) – The number of input features.

  • pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)

  • lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.

  • normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization

  • flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.

  • edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The computed graph embeddings of the input morphologies. The shape is (n_samples, 512).

Return type

torch.Tensor

training: bool
class morphoclass.models.ManNet(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)

Bases: morphoclass.models.man_net.ManNetR

The update version of the MultiAdjNet classifier.

Changes: - custom pooling - edge_weights - ChebConvs from pytorch-geometric - customizable normalization - customizable lambda_max

Parameters
  • n_features (int) – The number of input features.

  • n_global_features (int) – The number of global features.

  • n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.

  • pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)

  • lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.

  • normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization

  • flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.

  • edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.

  • bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.ManNetR(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)

Bases: torch.nn.modules.module.Module

The regression version of the ManNet classifier (no softmax).

Also it’s possible to train on a number of global features and have an optional batch normalization.

Parameters
  • n_features (int) – The number of input features.

  • n_global_features (int) – The number of global features.

  • n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.

  • pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)

  • lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.

  • normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization

  • flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.

  • edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.

  • bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

training: bool
class morphoclass.models.ManResNet1(n_features=1, n_classes=4)

Bases: torch.nn.modules.module.Module

MultiAdjNet with 1 residual block.

In comparison to the MultiAdjNet the second convolution layer is replaced by a BidirectionalResBlock with the same input and output dimension.

Because a bidirectional residual block contains two convolutional layers the net has a total of three convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 3 * 4 = 12 hops.

Additionally the AttentionGlobalPool is now the default pooling method.

Parameters
  • n_features (int (optional)) – The number of input features.

  • n_classes (int (optional)) – The number of output classes.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.ManResNet2(n_features=1, n_classes=4)

Bases: torch.nn.modules.module.Module

MultiAdjNet with 2 residual blocks.

In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and a BidirectionalResBlock, and the second convolutional layer is replaced by another Bidirectional block.

Because a bidirectional residual block contains two convolutional layers the net has a total of five convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 5 * 4 = 20 hops.

Additionally the AttentionGlobalPool is now the default pooling method.

Parameters
  • n_features (int (optional)) – The number of input features.

  • n_classes (int (optional)) – The number of output classes.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.ManResNet3(n_features=1, n_classes=4)

Bases: torch.nn.modules.module.Module

MultiAdjNet with 3 residual blocks.

In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and two BidirectionalResBlocks, and the second convolutional layer is replaced by another Bidirectional block.

Because a bidirectional residual block contains two convolutional layers the net has a total of seven convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 7 * 4 = 28 hops.

Additionally the AttentionGlobalPool is now the default pooling method.

Parameters
  • n_features (int (optional)) – The number of input features.

  • n_classes (int (optional)) – The number of output classes.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.MultiAdjNet(n_features=1, n_classes=4, attention=False, attention_per_feature=False, save_attention=False)

Bases: torch.nn.modules.module.Module

Model for classifying morphologies of pyramidal neurons.

This is the architecture that performed best in the TensorFlow implementation. It consists of two graph convolutions layers with each computing two convolutions: one on the directed adjacency matrix, and one with the adjacency matrix with the reversed direction. The results of both convolutions are concatenated and passed to the next layer. After the two parallel graph convolutions follows a global average pooling layer and a fully-connected layer. Finally, a softmax layer is used for prediction.

Parameters
  • n_features (int, default 1) – The number of input features.

  • n_classes (int, default 4) – The number of output classes.

  • attention (bool, default False) – If true, then an attention-based global pooling layer will be used, if false a global mean pooling layer.

  • attention_per_feature (bool, default False) – If true then the attention will be optimized for each feature separately. This will increase the number of trainable parameters. Only has effect if the parameter attention is true.

  • save_attention (bool, default False) – If true then the attention weights will be cached within the attention layer instance. See the AttentionGlobalPool class for more details.

accuracy(data)

Run the forward pass and compute the accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

acc – The accuracy on the current data batch.

Return type

float

forward(data)

Compute the forward pass.

Parameters

data (torch_geometric.data.data.Data) – A batch of input data.

Returns

The log softmax of the predictions.

Return type

log_softmax

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool