morphoclass.xai.grad_cam_explainer module

Implementation of the GradCAMExplainer class.

class morphoclass.xai.grad_cam_explainer.GradCAMExplainer(model, cam_layer)

Bases: object

Wrap a GNN model and extract GradCAM data upon forward pass.

Parameters
  • model (torch.nn.Module) – A GNN model.

  • cam_layer (torch.nn.Module) – A reference to a layer in model from which the GradCAM data will be extracted.

get_cam(sample, loader_cls, cls_idx=None, relu_weights=True, relu_cam=True)

Run the forward pass on sample and get the GradCAM data.

Parameters
  • sample – A morphology data sample.

  • loader_cls (type[morphoclass.data.MorphologyDataLoader]) – A data loader class.

  • cls_idx (int (optional)) – The numerical class of the sample. If not provided then the model prediction for the class will be used.

  • relu_weights (bool (optional)) – If true then a ReLU non-linearity will be applied to GradCAM weights. This effectively discards gradients with negative weights.

  • relu_cam (bool (optional)) – If true then a ReLU non-linearity will be applied to the GradCAM signal. This effectively sets negative GradCAM data to zero.

Returns

  • logits – The logits obtained ofter the forward pass on the given sample.

  • cam – The GradCAM data.

Raises

ValueError – When collected outputs still contain None values, and registered hook wasn’t able to collect data on model call.