morphoclass.layers.attention_global_pool module¶
Implementation of the AttentionGlobalPool layer.
-
class
morphoclass.layers.attention_global_pool.AttentionGlobalPool(n_features, attention_per_feature=False, save_attention=False)¶ Bases:
torch.nn.modules.module.ModuleA graph global pooling layer with attention.
- Parameters
n_features (int) – The number of input features.
attention_per_feature (bool, default False) – If true then separate attention weights are learned for each feature.
save_attention (bool, default False.) – If true then the attention values generated upon the forward pass will be cached in the layer instance. Might be useful for debugging and explain-AI applications.
-
forward(x, batch_segmentation)¶ Compute the forward pass.
- Parameters
x (torch.Tensor) – A batch of node features.
batch_segmentation (torch.Tensor) – A segmentation map for the node features. It’s a one-dimensional tensor with integer entries. Nodes with the same value in the segmentation map are considered to be from the same graph.
- Returns
The pooled node features.
- Return type
torch.Tensor
-
training: bool¶