morphoclass.models.man_res_nets module¶
Implementation of the ManResNet1, ManResNet2, and ManResNet3 models.
-
class
morphoclass.models.man_res_nets.ManResNet1(n_features=1, n_classes=4)¶ Bases:
torch.nn.modules.module.ModuleMultiAdjNet with 1 residual block.
In comparison to the MultiAdjNet the second convolution layer is replaced by a BidirectionalResBlock with the same input and output dimension.
Because a bidirectional residual block contains two convolutional layers the net has a total of three convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 3 * 4 = 12 hops.
Additionally the AttentionGlobalPool is now the default pooling method.
- Parameters
n_features (int (optional)) – The number of input features.
n_classes (int (optional)) – The number of output classes.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.man_res_nets.ManResNet2(n_features=1, n_classes=4)¶ Bases:
torch.nn.modules.module.ModuleMultiAdjNet with 2 residual blocks.
In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and a BidirectionalResBlock, and the second convolutional layer is replaced by another Bidirectional block.
Because a bidirectional residual block contains two convolutional layers the net has a total of five convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 5 * 4 = 20 hops.
Additionally the AttentionGlobalPool is now the default pooling method.
- Parameters
n_features (int (optional)) – The number of input features.
n_classes (int (optional)) – The number of output classes.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.man_res_nets.ManResNet3(n_features=1, n_classes=4)¶ Bases:
torch.nn.modules.module.ModuleMultiAdjNet with 3 residual blocks.
In comparison to the MultiAdjNet the first convolutional layer is split into a BidirectionalBlock and two BidirectionalResBlocks, and the second convolutional layer is replaced by another Bidirectional block.
Because a bidirectional residual block contains two convolutional layers the net has a total of seven convolutional layers with K=5 ChebConvs. Because a K=5 ChebConv takes into account up to the 4th power of the adjacency matrix this net has a total reach of 7 * 4 = 28 hops.
Additionally the AttentionGlobalPool is now the default pooling method.
- Parameters
n_features (int (optional)) – The number of input features.
n_classes (int (optional)) – The number of output classes.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶