morphoclass.models.multi_adj_net module¶
Implementation of the MultiAdjNet classifier.
-
class
morphoclass.models.multi_adj_net.MultiAdjNet(n_features=1, n_classes=4, attention=False, attention_per_feature=False, save_attention=False)¶ Bases:
torch.nn.modules.module.ModuleModel for classifying morphologies of pyramidal neurons.
This is the architecture that performed best in the TensorFlow implementation. It consists of two graph convolutions layers with each computing two convolutions: one on the directed adjacency matrix, and one with the adjacency matrix with the reversed direction. The results of both convolutions are concatenated and passed to the next layer. After the two parallel graph convolutions follows a global average pooling layer and a fully-connected layer. Finally, a softmax layer is used for prediction.
- Parameters
n_features (int, default 1) – The number of input features.
n_classes (int, default 4) – The number of output classes.
attention (bool, default False) – If true, then an attention-based global pooling layer will be used, if false a global mean pooling layer.
attention_per_feature (bool, default False) – If true then the attention will be optimized for each feature separately. This will increase the number of trainable parameters. Only has effect if the parameter attention is true.
save_attention (bool, default False) – If true then the attention weights will be cached within the attention layer instance. See the AttentionGlobalPool class for more details.
-
accuracy(data)¶ Run the forward pass and compute the accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
acc – The accuracy on the current data batch.
- Return type
float
-
forward(data)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.data.Data) – A batch of input data.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶