morphoclass.layers.attention_global_pool module

Implementation of the AttentionGlobalPool layer.

class morphoclass.layers.attention_global_pool.AttentionGlobalPool(n_features, attention_per_feature=False, save_attention=False)

Bases: torch.nn.modules.module.Module

A graph global pooling layer with attention.

Parameters
  • n_features (int) – The number of input features.

  • attention_per_feature (bool, default False) – If true then separate attention weights are learned for each feature.

  • save_attention (bool, default False.) – If true then the attention values generated upon the forward pass will be cached in the layer instance. Might be useful for debugging and explain-AI applications.

forward(x, batch_segmentation)

Compute the forward pass.

Parameters
  • x (torch.Tensor) – A batch of node features.

  • batch_segmentation (torch.Tensor) – A segmentation map for the node features. It’s a one-dimensional tensor with integer entries. Nodes with the same value in the segmentation map are considered to be from the same graph.

Returns

The pooled node features.

Return type

torch.Tensor

training: bool