morphoclass.models.concatecnnet module

Implementation of the ConcateCNNet classifier.

class morphoclass.models.concatecnnet.ConcateCNNet(n_node_features, n_classes, image_size, bn=False)

Bases: torch.nn.modules.module.Module

A neuron m-type classifier based on graph and image convolutions.

In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the CNN layers are applied to the persistence image representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.

Parameters
  • n_node_features (int) – The number of input node features for the GNN layers.

  • n_classes (int) – The number of output classes.

  • image_size (int) – The width (or height) of the input persistence images. It is assumed that the images are square so that the width and height are equal.

  • bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.

forward(data, images)

Compute the forward pass.

Parameters
  • data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.

  • images – A batch of input persistence images for the CNN layers.

Returns

The log softmax of the predictions.

Return type

log_softmax

training: bool