morphoclass.xai.grad_cam_explainer module¶
Implementation of the GradCAMExplainer class.
-
class
morphoclass.xai.grad_cam_explainer.GradCAMExplainer(model, cam_layer)¶ Bases:
objectWrap a GNN model and extract GradCAM data upon forward pass.
- Parameters
model (torch.nn.Module) – A GNN model.
cam_layer (torch.nn.Module) – A reference to a layer in model from which the GradCAM data will be extracted.
-
get_cam(sample, loader_cls, cls_idx=None, relu_weights=True, relu_cam=True)¶ Run the forward pass on sample and get the GradCAM data.
- Parameters
sample – A morphology data sample.
loader_cls (type[morphoclass.data.MorphologyDataLoader]) – A data loader class.
cls_idx (int (optional)) – The numerical class of the sample. If not provided then the model prediction for the class will be used.
relu_weights (bool (optional)) – If true then a ReLU non-linearity will be applied to GradCAM weights. This effectively discards gradients with negative weights.
relu_cam (bool (optional)) – If true then a ReLU non-linearity will be applied to the GradCAM signal. This effectively sets negative GradCAM data to zero.
- Returns
logits – The logits obtained ofter the forward pass on the given sample.
cam – The GradCAM data.
- Raises
ValueError – When collected outputs still contain None values, and registered hook wasn’t able to collect data on model call.