morphoclass.xai.embedding_extractor module¶
Implementation of the EmbeddingExtractor class.
-
class
morphoclass.xai.embedding_extractor.EmbeddingExtractor(model, layer_name)¶ Bases:
objectExtract inputs/outputs from a given layer.
An instance of the embedding extractor is created by providing a model and a layer in that model for which the inputs and the outputs should be extracted.
To obtain the inputs and the outputs call the extractor’s __call__ function with a batch of data, which is internally forwarded to the model. the __call__ function then returns the inputs/outputs as a tuple.
- Parameters
model (torch.nn.Module) – A torch model from which to extract intermediate activations.
layer_name (str) – The layer in the given model for which to extract the inputs/outputs. The layer instance is obtained by getattr(model, layer_name) and should return an instance of torch.nn.Module.