morphoclass.models.man_res_nets module

Implementation of the ManResNet1, ManResNet2, and ManResNet3 models.

class morphoclass.models.man_res_nets.ManResNet1(n_features=1, n_classes=4)

Bases: torch.nn.modules.module.Module

MultiAdjNet with 1 residual block.

In comparison to the MultiAdjNet the second convolution layer is replaced by a BidirectionalResBlock with the same input and output dimension.

Because a bidirectional residual block contains two convolutional layers the net has a total of three convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 3 * 4 = 12 hops.

Additionally the AttentionGlobalPool is now the default pooling method.

Parameters
  • n_features (int (optional)) – The number of input features.

  • n_classes (int (optional)) – The number of output classes.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.man_res_nets.ManResNet2(n_features=1, n_classes=4)

Bases: torch.nn.modules.module.Module

MultiAdjNet with 2 residual blocks.

In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and a BidirectionalResBlock, and the second convolutional layer is replaced by another Bidirectional block.

Because a bidirectional residual block contains two convolutional layers the net has a total of five convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 5 * 4 = 20 hops.

Additionally the AttentionGlobalPool is now the default pooling method.

Parameters
  • n_features (int (optional)) – The number of input features.

  • n_classes (int (optional)) – The number of output classes.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.man_res_nets.ManResNet3(n_features=1, n_classes=4)

Bases: torch.nn.modules.module.Module

MultiAdjNet with 3 residual blocks.

In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and two BidirectionalResBlocks, and the second convolutional layer is replaced by another Bidirectional block.

Because a bidirectional residual block contains two convolutional layers the net has a total of seven convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 7 * 4 = 28 hops.

Additionally the AttentionGlobalPool is now the default pooling method.

Parameters
  • n_features (int (optional)) – The number of input features.

  • n_classes (int (optional)) – The number of output classes.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool