morphoclass.xai package

Module contents

XAI tools for morphology GNNs.

class morphoclass.xai.EmbeddingExtractor(model, layer_name)

Bases: object

Extract inputs/outputs from a given layer.

An instance of the embedding extractor is created by providing a model and a layer in that model for which the inputs and the outputs should be extracted.

To obtain the inputs and the outputs call the extractor’s __call__ function with a batch of data, which is internally forwarded to the model. the __call__ function then returns the inputs/outputs as a tuple.

Parameters
  • model (torch.nn.Module) – A torch model from which to extract intermediate activations.

  • layer_name (str) – The layer in the given model for which to extract the inputs/outputs. The layer instance is obtained by getattr(model, layer_name) and should return an instance of torch.nn.Module.

class morphoclass.xai.GradCAMExplainer(model, cam_layer)

Bases: object

Wrap a GNN model and extract GradCAM data upon forward pass.

Parameters
  • model (torch.nn.Module) – A GNN model.

  • cam_layer (torch.nn.Module) – A reference to a layer in model from which the GradCAM data will be extracted.

get_cam(sample, loader_cls, cls_idx=None, relu_weights=True, relu_cam=True)

Run the forward pass on sample and get the GradCAM data.

Parameters
  • sample – A morphology data sample.

  • loader_cls (type[morphoclass.data.MorphologyDataLoader]) – A data loader class.

  • cls_idx (int (optional)) – The numerical class of the sample. If not provided then the model prediction for the class will be used.

  • relu_weights (bool (optional)) – If true then a ReLU non-linearity will be applied to GradCAM weights. This effectively discards gradients with negative weights.

  • relu_cam (bool (optional)) – If true then a ReLU non-linearity will be applied to the GradCAM signal. This effectively sets negative GradCAM data to zero.

Returns

  • logits – The logits obtained ofter the forward pass on the given sample.

  • cam – The GradCAM data.

Raises

ValueError – When collected outputs still contain None values, and registered hook wasn’t able to collect data on model call.

morphoclass.xai.cnn_model_attributions(model, dataset, sample_id, interpretability_method_cls)

Explain CNN model.

Plot with feature maps after each feature extractor layer. Starting from original image to the last featrue extractor layer. Only one morphology sample is visualized.

Parameters
Returns

fig – Figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.cnn_model_attributions_population(model, dataset)

Generate the SHAP explanation for a population of neurons.

morphoclass.xai.get_outlier_detection_app(figures, dataset, description_text=None)

Visualize outlier detection in dash app for interactivity.

Parameters
  • figures (list) – List of plotly figures.

  • dataset (morphoclass.morphology_dataset.MorphologyDataset) – Dataset with neuronal morphologies.

  • description_text (str, optional) – Checkpoint information.

Returns

app – Application instance.

Return type

dash.Dash

morphoclass.xai.gnn_model_attributions(model, dataset, sample_id, interpretability_method_cls)

Explain GNN model.

Plot with two rows:

  • Original graph and graph with GradShap values within the nodes.

  • Heatmap of the original graph (zero-values) and heatmap of the GradShap values on the graph.

Only one morphology sample is visualized.

Parameters
Returns

fig – Figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.grad_cam_cnn_model(model, dataset, sample_id)

Explain CNN model.

Plot with feature maps after each feature extractor layer. Starting from original image to the last feature extractor layer. Only one morphology sample is visualized.

Parameters
Returns

fig – Figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.grad_cam_gnn_model(model, dataset, sample_id)

Explain GNN model.

Plot with two rows:

  • Original graph and graph with GradCam values within the nodes.

  • Heatmap of the original graph (zero-values) and heatmap of the GradCam values on the graph.

Only one morphology sample is visualized.

Parameters
Returns

fig – Figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.grad_cam_perslay_model(model, dataset, sample_id)

Explain PersLay model.

Plot with 3 rows:

  • Barcodes: The original barcode and GradCam weighted barcode (colored bar) after each feature extraction layer.

  • Persistence diagrams: The original PD and GradCam weighted PD (colored dot) after each feature extraction layer.

  • Graph: The original graph and GradCam weighted graph (colored edge) after each feature extraction layer.

Parameters
Returns

fig – Figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.perslay_model_attributions(model, dataset, sample_id, interpretability_method_cls)

Explain PersLay model.

Plot with 3 rows:

  • Barcodes: The original barcode and GradShap weighted barcode (colored bar) after each feature extraction layer.

  • Persistence diagrams: The original PD and GradShap weighted PD (colored dot) after each feature extraction layer.

  • Graph: The original graph and GradShap weighted graph (colored edge) after each feature extraction layer.

Parameters
Returns

fig – A figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.plot_node_saliency(tree, grads, ax=None, rot=0, scale=1.0, edge_color='orange', width=1, center_nodes=True, show_axes=False, show_legend=True, name='grad')

Plot the node saliency for a neuronal tree.

Parameters
  • tree (tmd.Tree.Tree.Tree) – The neuronal tree to plot.

  • grads (array) – The node saliency is the absolute value of the gradients. Therefore the length of the grads sequence should be equal to the number of nodes.

  • ax (matplotlib.axes.Axes (optional)) – Plotting axes.

  • rot (float (optional)) – Rotate the neuronal tree by the given angle around the y-axis. The angle is in degrees.

  • scale (float (optional)) – Change the thickness of the plotted graph edges.

  • edge_color (str (optional)) – The color of the edges. Is passed through to NetworkX

  • width (float (optional)) – Change the thickness of the plotted graph edges.

  • center_nodes (bool (optional)) – If true then the tree will be shifted so that the first node (usually the root of the tree) is at the coordinate origin.

  • show_axes (bool (optional)) – If true then show the coordinate axes.

  • show_legend (bool (optional)) – If true then show the plot legend.

  • name (str (optional)) – The name of the saliency that will appear in the plot legend.

morphoclass.xai.sklearn_model_attributions_shap(model, dataset, sample_id)

Explain sklearn model.

Plot with feature maps after each feature extractor layer. Starting from original image to the last feature extractor layer. Only one morphology sample is visualized.

Parameters
Returns

fig – A figure with explainable plots.

Return type

matplotlib.figure.Figure

morphoclass.xai.sklearn_model_attributions_tree(model, dataset)

Explain sklearn tree model.

Parameters
Returns

fig – Figure with explainable plots.

Return type

matplotlib.figure.Figure