morphoclass.models.coriander_net module

Implementation of the PersLay-based CorianderNet network.

class morphoclass.models.coriander_net.CorianderNet(n_classes=4, n_features=64, dropout=False)

Bases: torch.nn.modules.module.Module

A PersLay-based neural network for neuron m-type classification.

Parameters
  • n_classes (int) – The number of m-type classes to predict.

  • n_features (int) – The number of output feature maps for the PersLay layer.

  • dropout (bool, default False) – If true a dropout layer is inserted between the two fully-connected layers of the classifier part of the network.

forward(data)

Compute the forward pass.

Parameters

data (torch_geometric.data.Batch | torch.Tensor) – A batch of MorphologyDataset dataset.

Returns

The log softmax of the predictions.

Return type

log_softmax

loss_acc(data)

Get loss and accuracy.

Parameters

data (torch_geometric.data.Batch) – A batch of MorphologyDataset dataset.

Returns

  • loss (float) – The loss value.

  • acc (float) – The accuracy value.

training: bool