morphoclass.xai.embedding_extractor module

Implementation of the EmbeddingExtractor class.

class morphoclass.xai.embedding_extractor.EmbeddingExtractor(model, layer_name)

Bases: object

Extract inputs/outputs from a given layer.

An instance of the embedding extractor is created by providing a model and a layer in that model for which the inputs and the outputs should be extracted.

To obtain the inputs and the outputs call the extractor’s __call__ function with a batch of data, which is internally forwarded to the model. the __call__ function then returns the inputs/outputs as a tuple.

Parameters
  • model (torch.nn.Module) – A torch model from which to extract intermediate activations.

  • layer_name (str) – The layer in the given model for which to extract the inputs/outputs. The layer instance is obtained by getattr(model, layer_name) and should return an instance of torch.nn.Module.