morphoclass.models.concatecnnet module¶
Implementation of the ConcateCNNet classifier.
-
class
morphoclass.models.concatecnnet.ConcateCNNet(n_node_features, n_classes, image_size, bn=False)¶ Bases:
torch.nn.modules.module.ModuleA neuron m-type classifier based on graph and image convolutions.
In the feature extraction part of the network graph convolution layers are applied to the graph node features of the apical dendrites, while the CNN layers are applied to the persistence image representation of the same data. The resulting features are concatenated and passed through a fully-connected layer for classification.
- Parameters
n_node_features (int) – The number of input node features for the GNN layers.
n_classes (int) – The number of output classes.
image_size (int) – The width (or height) of the input persistence images. It is assumed that the images are square so that the width and height are equal.
bn (bool, default False) – Whether or not to include a batch normalization layer between the feature extractor and the fully-connected classification layer.
-
forward(data, images)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input graph data for the GNN layers.
images – A batch of input persistence images for the CNN layers.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
training: bool¶