morphoclass.models package¶
Submodules¶
- morphoclass.models.bidirectional_net module
- morphoclass.models.cnnet module
- morphoclass.models.concatecnnet module
- morphoclass.models.concatenet module
- morphoclass.models.coriander_net module
- morphoclass.models.hbnet module
- morphoclass.models.man_net module
- morphoclass.models.man_res_nets module
- morphoclass.models.multi_adj_net module
Module contents¶
Various models for morphology type classification of neurons.
-
class
morphoclass.models.BidirectionalNet(num_classes, num_nodes_features)¶ Bases:
torch.nn.modules.module.ModuleModel for classifying morphologies of pyramidal neurons.
This is the architecture that performed best in the TensorFlow implementation. It consists of two graph convolutions layers with each computing two convolutions: one on the directed adjacency matrix, and one with the adjacency matrix with the reversed direction. The results of both convolutions are concatenated and passed to the next layer. After the two parallel graph convolutions follows a global average pooling layer and a fully-connected layer. Finally, a softmax layer is used for prediction.
- Parameters
num_classes (int) – The number of output classes.
num_nodes_features (int) – The number of input node features.
-
accuracy(data)¶ Run the forward pass and compute the accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
acc – The accuracy on the current data batch.
- Return type
float
-
forward(data)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input data.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.CNNEmbedder¶ Bases:
torch.nn.modules.module.ModuleThe embedder part of the CNNet classifier.
-
forward(data)¶ Do the forward pass.
- Parameters
data (torch_geometric.data.Batch | torch.Tensor) – A batch of the MorphologyDataset dataset.
- Returns
x – The embedding of the persistence image. The dimension of the tensor is (n_batch, 3 * (image_size // 4)**2).
- Return type
torch.Tensor
-
training: bool¶
-
-
class
morphoclass.models.CNNet(n_classes, image_size=100, bn=False)¶ Bases:
torch.nn.modules.module.ModuleConvolutional net for classifying persistence images.
The provided persistence images should be square greyscale images and can be obtained from persistence diagrams by applying Gaussian KDE.
- Parameters
n_classes (int) – The number of classes.
image_size (int) – The width or height of the input images. The images are assumed to be square.
bn (bool) – If true then 1d batch normalization and a relu are applied to the flattened embeddings before the FC layer.
-
forward(data)¶ Do the forward pass.
- Parameters
data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.
- Returns
logits – The predicted logits of shape (n_images, n_classes)
- Return type
torch.Tensor
-
loss_acc(data)¶ Get loss and accuracy.
- Parameters
data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.
- Returns
loss (float) – The loss value.
acc (float) – The accuracy value.
-
training: bool¶
-
class
morphoclass.models.ConcateCNNet(n_node_features, n_classes, image_size, bn=False)¶ Bases:
torch.nn.modules.module.ModuleA neuron m-type classifier based on graph and image convolutions.
In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the CNN layers are applied to the persistence image representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.
- Parameters
n_node_features (int) – The number of input node features for the GNN layers.
n_classes (int) – The number of output classes.
image_size (int) – The width (or height) of the input persistence images. It is assumed that the images are square so that the width and height are equal.
bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.
-
forward(data, images)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.
images – A batch of input persistence images for the CNN layers.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
training: bool¶
-
class
morphoclass.models.ConcateNet(n_node_features, n_classes, n_features_perslay, bn=False)¶ Bases:
torch.nn.modules.module.ModuleA neuron m-type classifier based on graph convolutions and PersLay.
In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the PersLay layer is applied to the persistence diagram representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.
- Parameters
n_node_features (int) – The number of input node features for the GNN layers.
n_classes (int) – The number of output classes.
n_features_perslay (int) – The number of features for the PersLay layer.
bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.
-
forward(data, diagrams, point_index)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.
diagrams – A batch of input persistence diagrams for the PersLay layer.
point_index (torch.Tensor) – A one-dimensional integer tensor holding the segmentation map for samples in the batched data, e.g. tensor([0, 0, 1, 1, 1, 2, …]).
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
training: bool¶
-
class
morphoclass.models.CorianderNet(n_classes=4, n_features=64, dropout=False)¶ Bases:
torch.nn.modules.module.ModuleA PersLay-based neural network for neuron m-type classification.
- Parameters
n_classes (int) – The number of m-type classes to predict.
n_features (int) – The number of output feature maps for the PersLay layer.
dropout (bool, default False) – If true a dropout layer is inserted between the two fully-connected layers of the classifier part of the network.
-
forward(data)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.Batch | torch.Tensor) – A batch of MorphologyDataset dataset.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
loss_acc(data)¶ Get loss and accuracy.
- Parameters
data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.
- Returns
loss (float) – The loss value.
acc (float) – The accuracy value.
-
training: bool¶
-
class
morphoclass.models.HBNet(num_features, class_segmentation_mask)¶ Bases:
torch.nn.modules.module.ModuleHierarchical, Bidirectional Net.
- Parameters
num_features – the number of input features
class_segmentation_mask – mask for hierarchical labels where each softmax block is marked by a different integer. The class mask can be generated by the method get_total_class_mask() of instances of the model_utils.HierarchicalLabels class.
-
compute_hierarchical_metric(metric_function, data, hl)¶ Compute metric for each layer in the label hierarchy.
- Parameters
metric_function – A function computing the required metric. Should have the following signature metric_function(targets, predictions).
data – A batch of samples.
hl (model_utils.HierarchicalLabels) – Hierarchy structure.
- Returns
List of metric evaluations for each layer in the label hierarchy.
- Return type
results
-
forward(data)¶ Compute the forward pass.
- Parameters
data – The input data.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
gen_hierarchical_probabilities(probabilities, parent_mask, hl)¶ Descend the hierarchy layer-wise and compute probabilities.
Starting with the parent nodes specified in parent_mask compute the probabilities for their children using
P(child) = P(child|parent) * P(parent),
create a new mask containing all children, and recurse.
- Parameters
probabilities – Probabilities for all nodes in the hierarchy tree.
parent_mask – Starting mask for the nodes from which to start descending the hierarchy. First call of the function should have parent_mask contain all root nodes.
hl (model_utils.HierarchicalLabels) – Hierarchy structure for labels.
- Yields
mask – Mask for all nodes considered for the current layer in the hierarchy.
probabilities – Probabilities predicted for the nodes specified by by mask.
-
hierarchical_accuracies(data, hl)¶ Accuracies for each layer in the label hierarchy.
- Parameters
data – A batch of samples.
hl (model_utils.HierarchicalLabels) – Hierarchy structure.
- Returns
List of accuracies for each layer in the label hierarchy.
- Return type
accuracies
-
loss(data)¶ Compute the loss.
- Parameters
data – The input data.
- Returns
- Return type
The loss.
-
precision_recall_f1(data, hl, *, average='micro')¶ Compute hierarchical precision, recall, F1 score.
- Parameters
data – Abatch of samples.
hl (model_utils.HierarchicalLabels) – Hierarchy structure.
average – The type of multi-class average to take, accepted values are None, “micro”, and “macro”.
- Returns
p_h – The hierarchical precision.
r_h – The hierarchical recall.
f1_h – The hierarchical F1 score
-
predict(data)¶ Make hierarchical predictions for given data.
- Parameters
data – A batch of samples.
- Returns
val_max – Probabilities for the predicted nodes in the hierarchy tree.
val_argmax – Indices for the predicted nodes in the hierarchy tree. For example, hierarchical one-hot prediction [1, 0, 0, 1, 0] would correspond to val_argmax = [0, 3].
-
predict_probabilities(data)¶ Compute the prediction probabilities.
- Parameters
data – The input data.
- Returns
- Return type
The prediction probabilities.
-
training: bool¶
-
class
morphoclass.models.ManEmbedder(n_features=1, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None)¶ Bases:
torch.nn.modules.module.ModuleThe embedder for the ManNet network.
The embedder consists of two bidirectional ChebConv blocks followed by a global pooling layer.
- Parameters
n_features (int) – The number of input features.
pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)
lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.
normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization
flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.
edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The computed graph embeddings of the input morphologies. The shape is (n_samples, 512).
- Return type
torch.Tensor
-
training: bool¶
-
class
morphoclass.models.ManNet(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)¶ Bases:
morphoclass.models.man_net.ManNetRThe update version of the MultiAdjNet classifier.
Changes: - custom pooling - edge_weights - ChebConvs from pytorch-geometric - customizable normalization - customizable lambda_max
- Parameters
n_features (int) – The number of input features.
n_global_features (int) – The number of global features.
n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.
pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)
lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.
normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization
flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.
edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.
bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.ManNetR(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)¶ Bases:
torch.nn.modules.module.ModuleThe regression version of the ManNet classifier (no softmax).
Also it’s possible to train on a number of global features and have an optional batch normalization.
- Parameters
n_features (int) – The number of input features.
n_global_features (int) – The number of global features.
n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.
pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)
lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.
normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization
flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.
edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.
bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
training: bool¶
-
class
morphoclass.models.ManResNet1(n_features=1, n_classes=4)¶ Bases:
torch.nn.modules.module.ModuleMultiAdjNet with 1 residual block.
In comparison to the MultiAdjNet the second convolution layer is replaced by a BidirectionalResBlock with the same input and output dimension.
Because a bidirectional residual block contains two convolutional layers the net has a total of three convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 3 * 4 = 12 hops.
Additionally the AttentionGlobalPool is now the default pooling method.
- Parameters
n_features (int (optional)) – The number of input features.
n_classes (int (optional)) – The number of output classes.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.ManResNet2(n_features=1, n_classes=4)¶ Bases:
torch.nn.modules.module.ModuleMultiAdjNet with 2 residual blocks.
In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and a BidirectionalResBlock, and the second convolutional layer is replaced by another Bidirectional block.
Because a bidirectional residual block contains two convolutional layers the net has a total of five convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 5 * 4 = 20 hops.
Additionally the AttentionGlobalPool is now the default pooling method.
- Parameters
n_features (int (optional)) – The number of input features.
n_classes (int (optional)) – The number of output classes.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.ManResNet3(n_features=1, n_classes=4)¶ Bases:
torch.nn.modules.module.ModuleMultiAdjNet with 3 residual blocks.
In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and two BidirectionalResBlocks, and the second convolutional layer is replaced by another Bidirectional block.
Because a bidirectional residual block contains two convolutional layers the net has a total of seven convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 7 * 4 = 28 hops.
Additionally the AttentionGlobalPool is now the default pooling method.
- Parameters
n_features (int (optional)) – The number of input features.
n_classes (int (optional)) – The number of output classes.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.MultiAdjNet(n_features=1, n_classes=4, attention=False, attention_per_feature=False, save_attention=False)¶ Bases:
torch.nn.modules.module.ModuleModel for classifying morphologies of pyramidal neurons.
This is the architecture that performed best in the TensorFlow implementation. It consists of two graph convolutions layers with each computing two convolutions: one on the directed adjacency matrix, and one with the adjacency matrix with the reversed direction. The results of both convolutions are concatenated and passed to the next layer. After the two parallel graph convolutions follows a global average pooling layer and a fully-connected layer. Finally, a softmax layer is used for prediction.
- Parameters
n_features (int, default 1) – The number of input features.
n_classes (int, default 4) – The number of output classes.
attention (bool, default False) – If true, then an attention-based global pooling layer will be used, if false a global mean pooling layer.
attention_per_feature (bool, default False) – If true then the attention will be optimized for each feature separately. This will increase the number of trainable parameters. Only has effect if the parameter attention is true.
save_attention (bool, default False) – If true then the attention weights will be cached within the attention layer instance. See the AttentionGlobalPool class for more details.
-
accuracy(data)¶ Run the forward pass and compute the accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
acc – The accuracy on the current data batch.
- Return type
float
-
forward(data)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input data.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶