morphoclass.models.concatenet module¶
Implementation of the ConcateNet classifier.
-
class
morphoclass.models.concatenet.ConcateNet(n_node_features, n_classes, n_features_perslay, bn=False)¶ Bases:
torch.nn.modules.module.ModuleA neuron m-type classifier based on graph convolutions and PersLay.
In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the PersLay layer is applied to the persistence diagram representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.
- Parameters
n_node_features (int) – The number of input node features for the GNN layers.
n_classes (int) – The number of output classes.
n_features_perslay (int) – The number of features for the PersLay layer.
bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.
-
forward(data, diagrams, point_index)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.
diagrams – A batch of input persistence diagrams for the PersLay layer.
point_index (torch.Tensor) – A one-dimensional integer tensor holding the segmentation map for samples in the batched data, e.g. tensor([0, 0, 1, 1, 1, 2, …]).
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
training: bool¶