morphoclass.models.coriander_net module¶
Implementation of the PersLay-based CorianderNet network.
-
class
morphoclass.models.coriander_net.CorianderNet(n_classes=4, n_features=64, dropout=False)¶ Bases:
torch.nn.modules.module.ModuleA PersLay-based neural network for neuron m-type classification.
- Parameters
n_classes (int) – The number of m-type classes to predict.
n_features (int) – The number of output feature maps for the PersLay layer.
dropout (bool, default False) – If true a dropout layer is inserted between the two fully-connected layers of the classifier part of the network.
-
forward(data)¶ Compute the forward pass.
- Parameters
data (torch_geometric.data.Batch | torch.Tensor) – A batch of MorphologyDataset dataset.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
loss_acc(data)¶ Get loss and accuracy.
- Parameters
data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.
- Returns
loss (float) – The loss value.
acc (float) – The accuracy value.
-
training: bool¶