morphoclass.models.man_net module

The ManNet model and trainer classes.

class morphoclass.models.man_net.ManEmbedder(n_features=1, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None)

Bases: torch.nn.modules.module.Module

The embedder for the ManNet network.

The embedder consists of two bidirectional ChebConv blocks followed by a global pooling layer.

Parameters
  • n_features (int) – The number of input features.

  • pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)

  • lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.

  • normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization

  • flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.

  • edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The computed graph embeddings of the input morphologies. The shape is (n_samples, 512).

Return type

torch.Tensor

training: bool
class morphoclass.models.man_net.ManNet(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)

Bases: morphoclass.models.man_net.ManNetR

The update version of the MultiAdjNet classifier.

Changes: - custom pooling - edge_weights - ChebConvs from pytorch-geometric - customizable normalization - customizable lambda_max

Parameters
  • n_features (int) – The number of input features.

  • n_global_features (int) – The number of global features.

  • n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.

  • pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)

  • lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.

  • normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization

  • flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.

  • edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.

  • bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

loss_acc(data)

Run the forward pass and compute the loss and accuracy.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

  • loss (float) – The loss on the given data batch.

  • acc (float) – The accuracy on the current data batch.

training: bool
class morphoclass.models.man_net.ManNetR(n_features=1, n_global_features=0, n_classes=4, pool_name='avg', lambda_max=3.0, normalization='sym', flow='target_to_source', edge_weight_idx=None, bn=False)

Bases: torch.nn.modules.module.Module

The regression version of the ManNet classifier (no softmax).

Also it’s possible to train on a number of global features and have an optional batch normalization.

Parameters
  • n_features (int) – The number of input features.

  • n_global_features (int) – The number of global features.

  • n_classes (int) – The number of classes. For each sample the output of the model will be an array of real values of length n_classes.

  • pool_name ({"avg", "sum", "att"}) – The type of pooling layer to use: - “avg”: global average pooling - “sum”: global sum pooling - “att”: global attention pooling (trainable)

  • lambda_max (float or list of float or None) – Originally the highest eigenvalue(s) of the adjacency matrix. In ChebConvs this value is usually computed from the adjacency matrix directly and used for normalization. This however doesn’t work for non-symmetric matrices and we fix a constant value instead of computing it. Experiments show that there is no impact on performance.

  • normalization ({None, "sym", "rw"}) – The normalization type of the graph Laplacian to use in the ChebConvs. Possible values: - None: no normalization - “sym”: symmetric normalization - “rw”: random walk normalization

  • flow ({"target_to_source", "source_to_target"}) – The message passing flow direction in ChebConvs for directed graphs.

  • edge_weight_idx (int or None) – The index of the edge feature tensor (data.edge_attr) to use as edge weights.

  • bn (bool) – Whether or not to apply batch normalization between the embedder and the classifier.

forward(data)

Run the forward pass.

Parameters

data (torch_geometric.data.Data) – The input batch of data.

Returns

The log of the prediction probabilities. The tensor has shape (n_samples, n_classes).

Return type

torch.Tensor

training: bool