morphoclass.xai package¶
Submodules¶
Module contents¶
XAI tools for morphology GNNs.
-
class
morphoclass.xai.EmbeddingExtractor(model, layer_name)¶ Bases:
objectExtract inputs/outputs from a given layer.
An instance of the embedding extractor is created by providing a model and a layer in that model for which the inputs and the outputs should be extracted.
To obtain the inputs and the outputs call the extractor’s __call__ function with a batch of data, which is internally forwarded to the model. the __call__ function then returns the inputs/outputs as a tuple.
- Parameters
model (torch.nn.Module) – A torch model from which to extract intermediate activations.
layer_name (str) – The layer in the given model for which to extract the inputs/outputs. The layer instance is obtained by getattr(model, layer_name) and should return an instance of torch.nn.Module.
-
class
morphoclass.xai.GradCAMExplainer(model, cam_layer)¶ Bases:
objectWrap a GNN model and extract GradCAM data upon forward pass.
- Parameters
model (torch.nn.Module) – A GNN model.
cam_layer (torch.nn.Module) – A reference to a layer in model from which the GradCAM data will be extracted.
-
get_cam(sample, loader_cls, cls_idx=None, relu_weights=True, relu_cam=True)¶ Run the forward pass on sample and get the GradCAM data.
- Parameters
sample – A morphology data sample.
loader_cls (type[morphoclass.data.MorphologyDataLoader]) – A data loader class.
cls_idx (int (optional)) – The numerical class of the sample. If not provided then the model prediction for the class will be used.
relu_weights (bool (optional)) – If true then a ReLU non-linearity will be applied to GradCAM weights. This effectively discards gradients with negative weights.
relu_cam (bool (optional)) – If true then a ReLU non-linearity will be applied to the GradCAM signal. This effectively sets negative GradCAM data to zero.
- Returns
logits – The logits obtained ofter the forward pass on the given sample.
cam – The GradCAM data.
- Raises
ValueError – When collected outputs still contain None values, and registered hook wasn’t able to collect data on model call.
-
morphoclass.xai.cnn_model_attributions(model, dataset, sample_id, interpretability_method_cls)¶ Explain CNN model.
Plot with feature maps after each feature extractor layer. Starting from original image to the last featrue extractor layer. Only one morphology sample is visualized.
- Parameters
model (morphoclass.models.cnnet.CNNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
interpretability_method_cls – An interpretability class from
captum.attr.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.cnn_model_attributions_population(model, dataset)¶ Generate the SHAP explanation for a population of neurons.
-
morphoclass.xai.get_outlier_detection_app(figures, dataset, description_text=None)¶ Visualize outlier detection in dash app for interactivity.
- Parameters
figures (list) – List of plotly figures.
dataset (morphoclass.morphology_dataset.MorphologyDataset) – Dataset with neuronal morphologies.
description_text (str, optional) – Checkpoint information.
- Returns
app – Application instance.
- Return type
dash.Dash
-
morphoclass.xai.gnn_model_attributions(model, dataset, sample_id, interpretability_method_cls)¶ Explain GNN model.
Plot with two rows:
Original graph and graph with GradShap values within the nodes.
Heatmap of the original graph (zero-values) and heatmap of the GradShap values on the graph.
Only one morphology sample is visualized.
- Parameters
model (morphoclass.models.man_net.ManNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing morphologies.
sample_id (int) – The id of morphology in the dataset.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.grad_cam_cnn_model(model, dataset, sample_id)¶ Explain CNN model.
Plot with feature maps after each feature extractor layer. Starting from original image to the last feature extractor layer. Only one morphology sample is visualized.
- Parameters
model (morphoclass.models.cnnet.CNNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.grad_cam_gnn_model(model, dataset, sample_id)¶ Explain GNN model.
Plot with two rows:
Original graph and graph with GradCam values within the nodes.
Heatmap of the original graph (zero-values) and heatmap of the GradCam values on the graph.
Only one morphology sample is visualized.
- Parameters
model (morphoclass.models.man_net.ManNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing morphologies.
sample_id (int) – The id of morphology in the dataset.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.grad_cam_perslay_model(model, dataset, sample_id)¶ Explain PersLay model.
Plot with 3 rows:
Barcodes: The original barcode and GradCam weighted barcode (colored bar) after each feature extraction layer.
Persistence diagrams: The original PD and GradCam weighted PD (colored dot) after each feature extraction layer.
Graph: The original graph and GradCam weighted graph (colored edge) after each feature extraction layer.
- Parameters
model (morphoclass.models.coriander_net.CorianderNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.perslay_model_attributions(model, dataset, sample_id, interpretability_method_cls)¶ Explain PersLay model.
Plot with 3 rows:
Barcodes: The original barcode and GradShap weighted barcode (colored bar) after each feature extraction layer.
Persistence diagrams: The original PD and GradShap weighted PD (colored dot) after each feature extraction layer.
Graph: The original graph and GradShap weighted graph (colored edge) after each feature extraction layer.
- Parameters
model (morphoclass.models.coriander_net.CorianderNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
interpretability_method_cls – An interpretability class from
captum.attr.
- Returns
fig – A figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.plot_node_saliency(tree, grads, ax=None, rot=0, scale=1.0, edge_color='orange', width=1, center_nodes=True, show_axes=False, show_legend=True, name='grad')¶ Plot the node saliency for a neuronal tree.
- Parameters
tree (tmd.Tree.Tree.Tree) – The neuronal tree to plot.
grads (array) – The node saliency is the absolute value of the gradients. Therefore the length of the grads sequence should be equal to the number of nodes.
ax (matplotlib.axes.Axes (optional)) – Plotting axes.
rot (float (optional)) – Rotate the neuronal tree by the given angle around the y-axis. The angle is in degrees.
scale (float (optional)) – Change the thickness of the plotted graph edges.
edge_color (str (optional)) – The color of the edges. Is passed through to NetworkX
width (float (optional)) – Change the thickness of the plotted graph edges.
center_nodes (bool (optional)) – If true then the tree will be shifted so that the first node (usually the root of the tree) is at the coordinate origin.
show_axes (bool (optional)) – If true then show the coordinate axes.
show_legend (bool (optional)) – If true then show the plot legend.
name (str (optional)) – The name of the saliency that will appear in the plot legend.
-
morphoclass.xai.sklearn_model_attributions_shap(model, dataset, sample_id)¶ Explain sklearn model.
Plot with feature maps after each feature extractor layer. Starting from original image to the last feature extractor layer. Only one morphology sample is visualized.
- Parameters
model – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
- Returns
fig – A figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.sklearn_model_attributions_tree(model, dataset)¶ Explain sklearn tree model.
- Parameters
model – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure