morphoclass.models.man_net module¶
The ManNet model and trainer classes.
-
class
morphoclass.models.man_net.ManEmbedder(n_features=1, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None)¶ Bases:
torch.nn.modules.module.ModuleThe embedder for the ManNet network.
The embedder consists of two bidirectional ChebConv blocks followed by a global pooling layer.
- Parameters
n_features (int) – The number of input features.
pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)
lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.
normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization
flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.
edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The computed graph embeddings of the input morphologies. The shape is (n_samples, 512).
- Return type
torch.Tensor
-
training: bool¶
-
class
morphoclass.models.man_net.ManNet(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)¶ Bases:
morphoclass.models.man_net.ManNetRThe update version of the MultiAdjNet classifier.
Changes: - custom pooling - edge_weights - ChebConvs from pytorch-geometric - customizable normalization - customizable lambda_max
- Parameters
n_features (int) – The number of input features.
n_global_features (int) – The number of global features.
n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.
pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)
lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.
normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization
flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.
edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.
bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
loss_acc(data)¶ Run the forward pass and compute the loss and accuracy.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
loss (float) – The loss on the given data batch.
acc (float) – The accuracy on the current data batch.
-
training: bool¶
-
class
morphoclass.models.man_net.ManNetR(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)¶ Bases:
torch.nn.modules.module.ModuleThe regression version of the ManNet classifier (no softmax).
Also it’s possible to train on a number of global features and have an optional batch normalization.
- Parameters
n_features (int) – The number of input features.
n_global_features (int) – The number of global features.
n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.
pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)
lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.
normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization
flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.
edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.
bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.
-
forward(data)¶ Run the forward pass.
- Parameters
data (torch_geometric.data.Data) – The input batch of data.
- Returns
The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).
- Return type
torch.Tensor
-
training: bool¶