morphoclass.models.hbnet module¶
Implementation of the hierarchical, bidirectional net (HBNet).
-
class
morphoclass.models.hbnet.HBNet(num_features, class_segmentation_mask)¶ Bases:
torch.nn.modules.module.ModuleHierarchical, Bidirectional Net.
- Parameters
num_features – the number of input features
class_segmentation_mask – mask for hierarchical labels where each softmax block is marked by a different integer. The class mask can be generated by the method get_total_class_mask() of instances of the model_utils.HierarchicalLabels class.
-
compute_hierarchical_metric(metric_function, data, hl)¶ Compute metric for each layer in the label hierarchy.
- Parameters
metric_function – A function computing the required metric. Should have the following signature metric_function(targets, predictions).
data – A batch of samples.
hl (model_utils.HierarchicalLabels) – Hierarchy structure.
- Returns
List of metric evaluations for each layer in the label hierarchy.
- Return type
results
-
forward(data)¶ Compute the forward pass.
- Parameters
data – The input data.
- Returns
The log softmax of the predictions.
- Return type
log_softmax
-
gen_hierarchical_probabilities(probabilities, parent_mask, hl)¶ Descend the hierarchy layer-wise and compute probabilities.
Starting with the parent nodes specified in parent_mask compute the probabilities for their children using
P(child) = P(child|parent) * P(parent),
create a new mask containing all children, and recurse.
- Parameters
probabilities – Probabilities for all nodes in the hierarchy tree.
parent_mask – Starting mask for the nodes from which to start descending the hierarchy. First call of the function should have parent_mask contain all root nodes.
hl (model_utils.HierarchicalLabels) – Hierarchy structure for labels.
- Yields
mask – Mask for all nodes considered for the current layer in the hierarchy.
probabilities – Probabilities predicted for the nodes specified by by mask.
-
hierarchical_accuracies(data, hl)¶ Accuracies for each layer in the label hierarchy.
- Parameters
data – A batch of samples.
hl (model_utils.HierarchicalLabels) – Hierarchy structure.
- Returns
List of accuracies for each layer in the label hierarchy.
- Return type
accuracies
-
loss(data)¶ Compute the loss.
- Parameters
data – The input data.
- Returns
- Return type
The loss.
-
precision_recall_f1(data, hl, *, average='micro')¶ Compute hierarchical precision, recall, F1 score.
- Parameters
data – Abatch of samples.
hl (model_utils.HierarchicalLabels) – Hierarchy structure.
average – The type of multi-class average to take, accepted values are None, “micro”, and “macro”.
- Returns
p_h – The hierarchical precision.
r_h – The hierarchical recall.
f1_h – The hierarchical F1 score
-
predict(data)¶ Make hierarchical predictions for given data.
- Parameters
data – A batch of samples.
- Returns
val_max – Probabilities for the predicted nodes in the hierarchy tree.
val_argmax – Indices for the predicted nodes in the hierarchy tree. For example, hierarchical one-hot prediction [1, 0, 0, 1, 0] would correspond to val_argmax = [0, 3].
-
predict_probabilities(data)¶ Compute the prediction probabilities.
- Parameters
data – The input data.
- Returns
- Return type
The prediction probabilities.
-
training: bool¶