morphoclass.models.hbnet module

Implementation of the hierarchical, bidirectional net (HBNet).

class morphoclass.models.hbnet.HBNet(num_features, class_segmentation_mask)

Bases: torch.nn.modules.module.Module

Hierarchical, Bidirectional Net.

Parameters
  • num_features – the number of input features

  • class_segmentation_mask – mask for hierarchical labels where each softmax block is marked by a different integer. The class mask can be generated by the method get_total_class_mask() of instances of the model_utils.HierarchicalLabels class.

compute_hierarchical_metric(metric_function, data, hl)

Compute metric for each layer in the label hierarchy.

Parameters
  • metric_function – A function computing the required metric. Should have the following signature metric_function(targets, predictions).

  • data – A batch of samples.

  • hl (model_utils.HierarchicalLabels) – Hierarchy structure.

Returns

List of metric evaluations for each layer in the label hierarchy.

Return type

results

forward(data)

Compute the forward pass.

Parameters

data – The input data.

Returns

The log softmax of the predictions.

Return type

log_softmax

gen_hierarchical_probabilities(probabilities, parent_mask, hl)

Descend the hierarchy layer-wise and compute probabilities.

Starting with the parent nodes specified in parent_mask compute the probabilities for their children using

P(child) = P(child|parent) * P(parent),

create a new mask containing all children, and recurse.

Parameters
  • probabilities – Probabilities for all nodes in the hierarchy tree.

  • parent_mask – Starting mask for the nodes from which to start descending the hierarchy. First call of the function should have parent_mask contain all root nodes.

  • hl (model_utils.HierarchicalLabels) – Hierarchy structure for labels.

Yields
  • mask – Mask for all nodes considered for the current layer in the hierarchy.

  • probabilities – Probabilities predicted for the nodes specified by by mask.

hierarchical_accuracies(data, hl)

Accuracies for each layer in the label hierarchy.

Parameters
Returns

List of accuracies for each layer in the label hierarchy.

Return type

accuracies

loss(data)

Compute the loss.

Parameters

data – The input data.

Returns

Return type

The loss.

precision_recall_f1(data, hl, *, average='micro')

Compute hierarchical precision, recall, F1 score.

Parameters
  • data – Abatch of samples.

  • hl (model_utils.HierarchicalLabels) – Hierarchy structure.

  • average – The type of multi-class average to take, accepted values are None, “micro”, and “macro”.

Returns

  • p_h – The hierarchical precision.

  • r_h – The hierarchical recall.

  • f1_h – The hierarchical F1 score

predict(data)

Make hierarchical predictions for given data.

Parameters

data – A batch of samples.

Returns

  • val_max – Probabilities for the predicted nodes in the hierarchy tree.

  • val_argmax – Indices for the predicted nodes in the hierarchy tree. For example, hierarchical one-hot prediction [1, 0, 0, 1, 0] would correspond to val_argmax = [0, 3].

predict_probabilities(data)

Compute the prediction probabilities.

Parameters

data – The input data.

Returns

Return type

The prediction probabilities.

training: bool