morphoclass.xai.model_attributions module¶
Explain model layers using GradShap.
-
morphoclass.xai.model_attributions.cnn_model_attributions(model, dataset, sample_id, interpretability_method_cls)¶ Explain CNN model.
Plot with feature maps after each feature extractor layer. Starting from original image to the last featrue extractor layer. Only one morphology sample is visualized.
- Parameters
model (morphoclass.models.cnnet.CNNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
interpretability_method_cls – An interpretability class from
captum.attr.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.model_attributions.cnn_model_attributions_population(model, dataset)¶ Generate the SHAP explanation for a population of neurons.
-
morphoclass.xai.model_attributions.get_edges_colors_based_on_barcode_colors(tree, colors)¶ Collect colors for edges based on barcode colors.
- Parameters
tree (tmd.Tree.Tree) – Morphology tree used to create barcode.
colors (list_like) – List of barcode colors.
- Returns
color_edges – List of edge colors.
- Return type
list_like
-
morphoclass.xai.model_attributions.gnn_model_attributions(model, dataset, sample_id, interpretability_method_cls)¶ Explain GNN model.
Plot with two rows:
Original graph and graph with GradShap values within the nodes.
Heatmap of the original graph (zero-values) and heatmap of the GradShap values on the graph.
Only one morphology sample is visualized.
- Parameters
model (morphoclass.models.man_net.ManNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing morphologies.
sample_id (int) – The id of morphology in the dataset.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.model_attributions.perslay_model_attributions(model, dataset, sample_id, interpretability_method_cls)¶ Explain PersLay model.
Plot with 3 rows:
Barcodes: The original barcode and GradShap weighted barcode (colored bar) after each feature extraction layer.
Persistence diagrams: The original PD and GradShap weighted PD (colored dot) after each feature extraction layer.
Graph: The original graph and GradShap weighted graph (colored edge) after each feature extraction layer.
- Parameters
model (morphoclass.models.coriander_net.CorianderNet) – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
interpretability_method_cls – An interpretability class from
captum.attr.
- Returns
fig – A figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.model_attributions.sklearn_model_attributions_shap(model, dataset, sample_id)¶ Explain sklearn model.
Plot with feature maps after each feature extractor layer. Starting from original image to the last feature extractor layer. Only one morphology sample is visualized.
- Parameters
model – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
sample_id (int) – The id of embedding in the dataset.
- Returns
fig – A figure with explainable plots.
- Return type
matplotlib.figure.Figure
-
morphoclass.xai.model_attributions.sklearn_model_attributions_tree(model, dataset)¶ Explain sklearn tree model.
- Parameters
model – Model that will be explained.
dataset (morphoclass.data.morphology_dataset.MorphologyDataset) – Dataset containing embeddings and morphologies.
- Returns
fig – Figure with explainable plots.
- Return type
matplotlib.figure.Figure